Skip to content

Commit

Permalink
Update matrix-multiplication-cpu tutorial, use preallocated output bu…
Browse files Browse the repository at this point in the history
…ffer for CPU. (triton-lang#24)
  • Loading branch information
Kuigesi authored and Devjiu committed Aug 13, 2024
1 parent ce270d8 commit 8058f35
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,19 @@ def matmul_kernel(
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.


def matmul(a, b):
def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
#TODO: Currently masked load is not supported yet.
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size"
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
if c is None:
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
else:
assert c.shape == (M, N), "Incompatible dimensions"
# 1D launch kernel where each block gets its own program.
grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), )
matmul_kernel[grid](
Expand All @@ -284,10 +288,9 @@ def matmul(a, b):

triton.runtime.driver.set_active_to_cpu()


a = torch.randn((512, 512), device='cpu', dtype=torch.float32)
b = torch.randn((512, 512), device='cpu', dtype=torch.float32)
triton_output = matmul(a, b)
triton_output = matmul(a, b, None)
torch_output = torch.matmul(a, b)
print(f"triton_cpu_output_with_{a.dtype}_inputs={triton_output}")
print(f"torch_cpu_output_with_{a.dtype}_inputs={torch_output}")
Expand Down Expand Up @@ -315,7 +318,7 @@ def matmul(a, b):
triton.runtime.driver.set_active_to_gpu()
a = a.to('cuda')
b = b.to('cuda')
triton_output = matmul(a, b)
triton_output = matmul(a, b, None)
torch_output = torch.matmul(a, b)
print(f"triton_gpu_output_with_{a.dtype}_inputs={triton_output}")
print(f"torch_gpu_output_with_{a.dtype}_inputs={torch_output}")
Expand Down Expand Up @@ -377,13 +380,16 @@ def benchmark(M, N, K, provider):
if provider == 'torch-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
elif provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles)
elif provider == 'torch-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True)
perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)

Expand Down

0 comments on commit 8058f35

Please sign in to comment.