Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: Split-K atomic add memory semantics #17

Open
michael-diggin opened this issue Jan 21, 2025 · 0 comments
Open

Question: Split-K atomic add memory semantics #17

michael-diggin opened this issue Jan 21, 2025 · 0 comments

Comments

@michael-diggin
Copy link

Hi!
I've been writing some fused kernels that contain matmuls on sizes (M, K)@(K,N) where K >> M, N, and I've found using a split-k algorithm like the one in https://github.com/triton-lang/kernels/blob/main/kernels/matmul.py to be really advantageous.
I've found that using tl.atomic_add(... sem="relaxed") gets the best results. My understanding is that for this type of kernel, this is still safe since the atomic operation isn't used to communicate between different threads/blocks.
Is that correct? It's not used in the implementation in this repository so checking in case I've missed something.
I'd be happy to add this in a PR if it is correct/useful for the kernel I've linked above.
Thanks!

I ran some benchmarks of the different semantics (using torch 2.5.1, triton 3.1 and an A100 40gb GPU) for a fairly contrived example of M, N=256 and this is the plot I generated:

Image

Benchmarking code
import torch
import triton
import triton.language as tl

@triton.jit
def _splitk_kernel(A, B, C, M, N, K,
            stride_am, stride_ak,
            stride_bk, stride_bn,
            stride_cm, stride_cn,
            BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
            GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, SEM: tl.constexpr,
            ):
    
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    num_pid_m, num_pid_n = tl.num_programs(1), tl.num_programs(2)
    pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_M)
    pid_k = tl.program_id(2)

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    rm = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    rn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
    rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_K), BLOCK_K)
    A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in tl.range(0, tl.cdiv(K, BLOCK_K*SPLIT_K)):
        k_remaining = K - k * (BLOCK_K*SPLIT_K)
        a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
        b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
        acc += tl.dot(a, b)
        A += BLOCK_K*SPLIT_K * stride_ak
        B += BLOCK_K*SPLIT_K * stride_bk

    acc = acc.to(C.dtype.element_ty)
    C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    tl.atomic_add(C, acc, sem=SEM)

def _zero_output(*args, **kwargs):
    # sets the C block to zero
    if kwargs["SPLIT_K"] != 1:
      args[2].zero_()

_splitk_kernel.add_pre_run_hook(_zero_output)

def splitk_kernel(x, y, sem):
  m, k = x.shape
  _, n = y.shape
  z = torch.empty((m, n), dtype=x.dtype, device=x.device)
  grid = lambda meta: (triton.cdiv(m, meta['BLOCK_M']), triton.cdiv(n, meta['BLOCK_N']), meta["SPLIT_K"])
  _splitk_kernel[grid](x, y, z,
                       m, n, k,
                       x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), z.stride(1),
                       # On an A100 40GB I've found this to get the best results
                       # as it launches 96 thread blocks which is reasonably close to the 108 SMs the A100 has
                       BLOCK_M=64, BLOCK_N=64, BLOCK_K=128, SPLIT_K=6, SEM=sem,
                       num_warps=4, num_stages=4,
                       GROUP_M=4)
  return z

semantics = ["relaxed", "release", "acquire", "acq_rel"]

configs = [triton.testing.Benchmark(
            x_names=["K"],
            x_vals=[128 * i for i in range(2, 200, 10)],
            line_arg="semantic",
            line_vals=semantics,
            line_names=semantics,
            styles=[("green", "-"), ("blue", "-"), ("red", "-"), ("orange", "-")],
            ylabel="TFLOPS",
            plot_name="bench",
            args={},
)]

@triton.testing.perf_report(configs)
def benchmark(K, semantic):
    N = 256
    M = 256
    X = torch.rand((M, K), device="cuda", dtype=torch.float16).contiguous()
    Y = torch.rand((K, N), device="cuda", dtype=torch.float16).contiguous()
    quantiles = [0.5, 0.2, 0.8]
    
    ms, min_ms, max_ms = triton.testing.do_bench(lambda: splitk_kernel(X, Y, semantic), quantiles=quantiles)

    perf = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)


benchmark.run(show_plots=True, print_data=False)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant