You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to implement GEMM for FP8 using triton on RTX 4090. It performs well for FP16, but performs worse for FP8 than cutlass(torch._scale_mm), and adjust config doesn't improve performance.
Here is the reproduction:
importtorchimporttritonimporttriton.languageastlDEVICE="cuda"defis_cuda():
returntriton.runtime.driver.active.get_current_target().backend=="cuda"defis_hip_mi200():
target=triton.runtime.driver.active.get_current_target()
returntarget.backend=='hip'andtarget.arch=='gfx90a'defget_cuda_autotune_config():
return [
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
# Good config for fp8 inputs.triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4)
]
defget_hip_autotune_config():
return [
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=4, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
num_warps=8, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=8, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3},
num_warps=4, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
num_warps=4, num_stages=2),
]
defget_autotune_config():
ifis_cuda():
returnget_cuda_autotune_config()
else:
returnget_hip_autotune_config()
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:# - A list of `triton.Config` objects that define different configurations of# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try# - An auto-tuning *key* whose change in values will trigger evaluation of all the# provided configs@triton.autotune(configs=get_autotune_config(),key=['M', 'N', 'K'],)@triton.jitdefmatmul_kernel(
# Pointers to matricesa_ptr, b_ptr, c_ptr,
# Matrix dimensionsM, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`# by to get the element one row down (A has M rows).stride_am, stride_ak, #stride_bk, stride_bn, #stride_cm, stride_cn,
# Meta-parametersBLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #GROUP_SIZE_M: tl.constexpr, #ACTIVATION: tl.constexpr#
):
"""Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """# -----------------------------------------------------------# Map program ids `pid` to the block of C it should compute.# This is done in a grouped ordering to promote L2 data reuse.# See above `L2 Cache Optimizations` section for details.pid=tl.program_id(axis=0)
num_pid_m=tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n=tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group=GROUP_SIZE_M*num_pid_ngroup_id=pid//num_pid_in_groupfirst_pid_m=group_id*GROUP_SIZE_Mgroup_size_m=min(num_pid_m-first_pid_m, GROUP_SIZE_M)
pid_m=first_pid_m+ ((pid%num_pid_in_group) %group_size_m)
pid_n= (pid%num_pid_in_group) //group_size_m# ----------------------------------------------------------# Create pointers for the first blocks of A and B.# We will advance this pointer as we move in the K direction# and accumulate# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers# See above `Pointer Arithmetic` section for detailsoffs_am= (pid_m*BLOCK_SIZE_M+tl.arange(0, BLOCK_SIZE_M)) %Moffs_bn= (pid_n*BLOCK_SIZE_N+tl.arange(0, BLOCK_SIZE_N)) %Noffs_k=tl.arange(0, BLOCK_SIZE_K)
a_ptrs=a_ptr+ (offs_am[:, None] *stride_am+offs_k[None, :] *stride_ak)
b_ptrs=b_ptr+ (offs_k[:, None] *stride_bk+offs_bn[None, :] *stride_bn)
# -----------------------------------------------------------# Iterate to compute a block of the C matrix.# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block# of fp32 values for higher accuracy.# `accumulator` will be converted back to fp16 after the loop.accumulator=tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
forkinrange(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.# If it is out of bounds, set it to 0.a=tl.load(a_ptrs, mask=offs_k[None, :] <K-k*BLOCK_SIZE_K, other=0.0)
b=tl.load(b_ptrs, mask=offs_k[:, None] <K-k*BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.accumulator=tl.dot(a, b, accumulator)
# Advance the ptrs to the next K block.a_ptrs+=BLOCK_SIZE_K*stride_akb_ptrs+=BLOCK_SIZE_K*stride_bk# You can fuse arbitrary activation functions here# while the accumulator is still in FP32!ifACTIVATION=="leaky_relu":
accumulator=leaky_relu(accumulator)
c=accumulator.to(tl.float16)
# -----------------------------------------------------------# Write back the block of the output matrix C with masks.offs_cm=pid_m*BLOCK_SIZE_M+tl.arange(0, BLOCK_SIZE_M)
offs_cn=pid_n*BLOCK_SIZE_N+tl.arange(0, BLOCK_SIZE_N)
c_ptrs=c_ptr+stride_cm*offs_cm[:, None] +stride_cn*offs_cn[None, :]
c_mask= (offs_cm[:, None] <M) & (offs_cn[None, :] <N)
tl.store(c_ptrs, c, mask=c_mask)
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.@triton.jitdefleaky_relu(x):
returntl.where(x>=0, x, 0.01*x)
# %%# We can now create a convenience wrapper function that only takes two input tensors,# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.defmatmul(a, b, activation=""):
# Check constraints.asserta.shape[1] ==b.shape[0], "Incompatible dimensions"asserta.is_contiguous(), "Matrix A must be contiguous"M, K=a.shapeK, N=b.shape# Allocates output.c=torch.empty((M, N), device=a.device, dtype=torch.float16)
# 1D launch kernel where each block gets its own program.grid=lambdaMETA: (triton.cdiv(M, META['BLOCK_SIZE_M']) *triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a, b, c, #M, N, K, #a.stride(0), a.stride(1), #b.stride(0), b.stride(1), #c.stride(0), c.stride(1), #ACTIVATION=activation#
)
returnc# %%# Unit Test# ---------## We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS).torch.manual_seed(0)
a=torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b=torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
triton_output=matmul(a, b)
torch_output=torch.matmul(a, b)
print(f"triton_output_with_fp16_inputs={triton_output}")
print(f"torch_output_with_fp16_inputs={torch_output}")
# Bigger tolerance for AMD MI200 devices.# MI200 devices use reduced precision fp16 and bf16 and flush input and# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devicesrtol=1e-2ifis_hip_mi200() else0iftorch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
TORCH_HAS_FP8=hasattr(torch, "float8_e4m3fn")
ifTORCH_HAS_FP8andis_cuda():
torch.manual_seed(0)
a=torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b=torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
a=a.to(torch.float8_e4m3fn)
# pre-transpose b for efficiency.b=b.Tb=b.to(torch.float8_e4m3fn)
triton_output=matmul(a, b)
scale_a=torch.tensor(1., device="cuda", dtype=torch.float32)
scale_b=torch.tensor(1., device="cuda", dtype=torch.float32)
torch_output=torch._scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.float16)
print(f"triton_output_with_fp8_inputs={triton_output}")
print(f"torch_output_with_fp8_inputs={torch_output}")
iftorch.allclose(triton_output, torch_output, atol=0.125, rtol=0):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
# %%# Benchmark# ---------## Square Matrix Performance# ~~~~~~~~~~~~~~~~~~~~~~~~~~## We can now compare the performance of our kernel against that of cuBLAS or rocBLAS. Here we focus on square matrices,# but feel free to arrange this script as you wish to benchmark any other matrix shape.ref_lib='cuBLAS'ifis_cuda() else'rocBLAS'configs= []
forfp8_inputsin [False, True]:
iffp8_inputsand (notTORCH_HAS_FP8ornotis_cuda()):
continueconfigs.append(
triton.testing.Benchmark(
x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plotx_vals=[128*iforiinrange(16, 33)], # Different possible values for `x_name`line_arg="provider", # Argument name whose value corresponds to a different line in the plot# Possible values for `line_arg`line_vals=[ref_lib.lower(), "triton"], # Label name for the linesline_names=[ref_lib, "Triton"], # Line stylesstyles=[("green", "-"), ("blue", "-")],
ylabel="TFLOPS", # Label name for the y-axisplot_name="matmul-performance-"+
("fp16"ifnotfp8_inputselse"fp8"), # Name for the plot, used also as a file name for saving the plot.args={"fp8_inputs": fp8_inputs},
))
@triton.testing.perf_report(configs)defbenchmark(M, N, K, provider, fp8_inputs):
a=torch.randn((M, K), device=DEVICE, dtype=torch.float16)
b=torch.randn((K, N), device=DEVICE, dtype=torch.float16)
ifTORCH_HAS_FP8andfp8_inputs:
a=a.to(torch.float8_e4m3fn)
b=b.Tb=b.to(torch.float8_e4m3fn)
quantiles= [0.5, 0.2, 0.8]
ifprovider==ref_lib.lower():
ifTORCH_HAS_FP8andfp8_inputs:
scale_a=torch.tensor(1., device="cuda", dtype=torch.float32)
scale_b=torch.tensor(1., device="cuda", dtype=torch.float32)
ms, min_ms, max_ms=triton.testing.do_bench(lambda: torch._scaled_mm(a, b, scale_a, scale_b), quantiles=quantiles)
else:
ms, min_ms, max_ms=triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
ifprovider=='triton':
ms, min_ms, max_ms=triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
perf=lambdams: 2*M*N*K*1e-12/ (ms*1e-3)
returnperf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True, print_data=True)
Describe the issue
I'm trying to implement GEMM for FP8 using triton on RTX 4090. It performs well for FP16, but performs worse for FP8 than cutlass(torch._scale_mm), and adjust config doesn't improve performance.
Here is the reproduction:
The results:
Environment details
triton: 3.1.0
CUDA Version: 12.4
GPU: RTX 4090
The text was updated successfully, but these errors were encountered: