You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi!
I've been writing some fused kernels that contain matmuls on sizes (M, K)@(K,N) where K >> M, N, and I've found using a split-k algorithm like the one in https://github.com/triton-lang/kernels/blob/main/kernels/matmul.py to be really advantageous.
I've found that using tl.atomic_add(... sem="relaxed") gets the best results. My understanding is that for this type of kernel, this is still safe since the atomic operation isn't used to communicate between different threads/blocks.
Is that correct? It's not used in the implementation in this repository so checking in case I've missed something.
I'd be happy to add this in a PR if it is correct/useful for the kernel I've linked above.
Thanks!
I ran some benchmarks of the different semantics (using torch 2.5.1, triton 3.1 and an A100 40gb GPU) for a fairly contrived example of M, N=256 and this is the plot I generated:
Benchmarking code
import torch
import triton
import triton.language as tl
@triton.jit
def _splitk_kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, SEM: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_m, num_pid_n = tl.num_programs(1), tl.num_programs(2)
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_M)
pid_k = tl.program_id(2)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rm = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_K), BLOCK_K)
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in tl.range(0, tl.cdiv(K, BLOCK_K*SPLIT_K)):
k_remaining = K - k * (BLOCK_K*SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
acc += tl.dot(a, b)
A += BLOCK_K*SPLIT_K * stride_ak
B += BLOCK_K*SPLIT_K * stride_bk
acc = acc.to(C.dtype.element_ty)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
tl.atomic_add(C, acc, sem=SEM)
def _zero_output(*args, **kwargs):
# sets the C block to zero
if kwargs["SPLIT_K"] != 1:
args[2].zero_()
_splitk_kernel.add_pre_run_hook(_zero_output)
def splitk_kernel(x, y, sem):
m, k = x.shape
_, n = y.shape
z = torch.empty((m, n), dtype=x.dtype, device=x.device)
grid = lambda meta: (triton.cdiv(m, meta['BLOCK_M']), triton.cdiv(n, meta['BLOCK_N']), meta["SPLIT_K"])
_splitk_kernel[grid](x, y, z,
m, n, k,
x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), z.stride(1),
# On an A100 40GB I've found this to get the best results
# as it launches 96 thread blocks which is reasonably close to the 108 SMs the A100 has
BLOCK_M=64, BLOCK_N=64, BLOCK_K=128, SPLIT_K=6, SEM=sem,
num_warps=4, num_stages=4,
GROUP_M=4)
return z
semantics = ["relaxed", "release", "acquire", "acq_rel"]
configs = [triton.testing.Benchmark(
x_names=["K"],
x_vals=[128 * i for i in range(2, 200, 10)],
line_arg="semantic",
line_vals=semantics,
line_names=semantics,
styles=[("green", "-"), ("blue", "-"), ("red", "-"), ("orange", "-")],
ylabel="TFLOPS",
plot_name="bench",
args={},
)]
@triton.testing.perf_report(configs)
def benchmark(K, semantic):
N = 256
M = 256
X = torch.rand((M, K), device="cuda", dtype=torch.float16).contiguous()
Y = torch.rand((K, N), device="cuda", dtype=torch.float16).contiguous()
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(lambda: splitk_kernel(X, Y, semantic), quantiles=quantiles)
perf = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True, print_data=False)
The text was updated successfully, but these errors were encountered:
Hi!
I've been writing some fused kernels that contain matmuls on sizes
(M, K)@(K,N)
whereK >> M, N
, and I've found using a split-k algorithm like the one in https://github.com/triton-lang/kernels/blob/main/kernels/matmul.py to be really advantageous.I've found that using
tl.atomic_add(... sem="relaxed")
gets the best results. My understanding is that for this type of kernel, this is still safe since the atomic operation isn't used to communicate between different threads/blocks.Is that correct? It's not used in the implementation in this repository so checking in case I've missed something.
I'd be happy to add this in a PR if it is correct/useful for the kernel I've linked above.
Thanks!
I ran some benchmarks of the different semantics (using torch 2.5.1, triton 3.1 and an A100 40gb GPU) for a fairly contrived example of
M, N=256
and this is the plot I generated:Benchmarking code
The text was updated successfully, but these errors were encountered: