Skip to content

Commit

Permalink
Add torch.compile cases
Browse files Browse the repository at this point in the history
  • Loading branch information
minjang committed Jun 19, 2024
1 parent 61c1829 commit 2e92d57
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 34 deletions.
20 changes: 12 additions & 8 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

GPU_BLOCK_SIZE = 1024
CPU_BLOCK_SIZE = 4096
USE_GPU = True
USE_GPU = False


@triton.jit
Expand Down Expand Up @@ -89,7 +89,7 @@ def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, is_cpu):
triton.runtime.driver.set_active_to_cpu()
x = torch.rand(size, device='cpu')
y = torch.rand(size, device='cpu')
output_torch_cpu = x + y
output_torch_cpu = torch.add(x, y)
output_triton_cpu = add(x, y, None, is_cpu=True)
print(output_torch_cpu)
print(output_triton_cpu)
Expand All @@ -98,7 +98,7 @@ def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, is_cpu):

LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu']
LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU']
LINE_STYLES = [('blue', '-'), ('green', '-'), ('cyan', '-')]
LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '-')]

if USE_GPU and triton.runtime.driver.get_active_gpus():
triton.runtime.driver.set_active_to_gpu()
Expand Down Expand Up @@ -150,31 +150,35 @@ def benchmark(size, provider):
y = torch.rand(size, device=device, dtype=torch.float32)

if device == 'cpu':
is_cpu = True
triton.runtime.driver.set_active_to_cpu()
if 'single' in provider:
os.environ['TRITON_CPU_SINGLE_CORE'] = '1'
else:
os.unsetenv('TRITON_CPU_SINGLE_CORE')
else:
is_cpu = False
triton.runtime.driver.set_active_to_gpu()

quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=is_cpu)
elif provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles, is_cpu=is_cpu)
elif provider == 'torch-cpu':
# Note that we preallocate the output buffer here to only measure the kernel performance
# without a large chunk of memory allocation.
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles,
is_cpu=True)
is_cpu=is_cpu)
elif provider == 'triton-cpu-single':
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles,
is_cpu=is_cpu)
elif provider == 'triton-cpu':
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles,
is_cpu=is_cpu)
gbps = lambda ms: 12 * size / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down
40 changes: 28 additions & 12 deletions python/tutorials/02-fused-softmax-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import triton
import triton.language as tl

USE_GPU = True
USE_GPU = False


@torch.jit.script
Expand Down Expand Up @@ -145,9 +145,15 @@ def softmax(x):
y_torch_cpu = torch.softmax(x, axis=1)
assert torch.allclose(y_triton_cpu, y_torch_cpu), (y_triton_cpu, y_torch_cpu)

LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu-native', 'torch-cpu-jit']
LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (native)', 'TorchCPU (jit)']
LINE_STYLES = [('blue', '-'), ('blue', '--'), ('green', '-'), ('green', '--')]
LINE_VALS = [
'triton-cpu-single',
'triton-cpu',
'torch-cpu-compile',
'torch-cpu-jit',
'torch-cpu-native',
]
LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (compile)', 'TorchCPU (jit)', 'TorchCPU (native)']
LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '-'), ('green', '--'), ('green', '-.')]

if USE_GPU and triton.runtime.driver.get_active_gpus():
triton.runtime.driver.set_active_to_gpu()
Expand All @@ -173,7 +179,7 @@ def softmax(x):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 52)], # different possible values for `x_name`
x_vals=[128 * i for i in range(2, 52, 2)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=LINE_VALS, # Possible values for `line_arg`.
line_names=LINE_NAMES, # Label name for the lines.
Expand All @@ -185,33 +191,43 @@ def softmax(x):
def benchmark(M, N, provider):
import os

# Currently compilation time is very long. Let's show the progress.
print(f"Running {provider} with {M} x {N}...")

device = 'cpu' if 'cpu' in provider else 'cuda'
x = torch.randn(M, N, device=device, dtype=torch.float32)

if device == 'cpu':
is_cpu = True
triton.runtime.driver.set_active_to_cpu()
if 'single' in provider:
os.environ['TRITON_CPU_SINGLE_CORE'] = '1'
else:
os.unsetenv('TRITON_CPU_SINGLE_CORE')
else:
is_cpu = False
triton.runtime.driver.set_active_to_gpu()

quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-cpu-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles,
is_cpu=is_cpu)
if provider == 'torch-cpu-jit':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, is_cpu=is_cpu)
if provider == 'torch-cpu-compile':
compiled = torch.compile(naive_softmax)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x), quantiles=quantiles, is_cpu=is_cpu)
if provider == 'torch-gpu-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles,
is_cpu=is_cpu)
if provider == 'torch-gpu-jit':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, is_cpu=is_cpu)
if provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, is_cpu=is_cpu)
if provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, is_cpu=is_cpu)
if provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, is_cpu=is_cpu)
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down
30 changes: 16 additions & 14 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
BLOCK_SIZE_N = 32
BLOCK_SIZE_K = 32
GROUP_SIZE_M = 8
USE_GPU = True
USE_GPU = False


@triton.jit
Expand Down Expand Up @@ -217,9 +217,9 @@ def matmul_kernel(
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.

#TODO: Currently masked load is not supported yet.
#a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
#b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# TODO: Currently masked load is not supported yet.
# a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
# We accumulate along the K dimension.
Expand All @@ -237,9 +237,9 @@ def matmul_kernel(
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]

#TODO: Currently masked load is not supported yet.
#c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
#tl.store(c_ptrs, c, mask=c_mask)
# TODO: Currently masked load is not supported yet.
# c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
# tl.store(c_ptrs, c, mask=c_mask)
tl.store(c_ptrs, c)


Expand Down Expand Up @@ -309,9 +309,9 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
# We can now compare the performance of our kernel against that of Pytorch. Here we focus on square matrices,
# but feel free to arrange this script as you wish to benchmark any other matrix shape.

LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu']
LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU']
LINE_STYLES = [('blue', '-'), ('green', '-'), ('cyan', '-')]
LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu-native', 'torch-cpu-compile']
LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (native)', 'TorchCPU (compile)']
LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '--'), ('green', '-')]

if USE_GPU and triton.runtime.driver.get_active_gpus():
triton.runtime.driver.set_active_to_gpu()
Expand Down Expand Up @@ -366,28 +366,30 @@ def benchmark(M, N, K, provider):
b = torch.randn((K, N), device=device, dtype=torch.float32)

if device == 'cpu':
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
triton.runtime.driver.set_active_to_cpu()
if 'single' in provider:
os.environ['TRITON_CPU_SINGLE_CORE'] = '1'
else:
os.unsetenv('TRITON_CPU_SINGLE_CORE')
else:
c = None
triton.runtime.driver.set_active_to_gpu()

quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
elif provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles)
elif provider == 'torch-cpu':
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
elif provider == 'torch-cpu-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles,
is_cpu=True)
elif provider == 'torch-cpu-compile':
compiled = torch.compile(torch.matmul)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu-single':
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu':
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True)
perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
Expand Down

0 comments on commit 2e92d57

Please sign in to comment.