Skip to content

Commit

Permalink
Preallocate output buffer for softmax tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
minjang committed Jun 20, 2024
1 parent 2e92d57 commit 01543a8
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions python/tutorials/02-fused-softmax-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.


def softmax(x):
def softmax(x, y=None):
n_rows, n_cols = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
Expand All @@ -114,7 +114,8 @@ def softmax(x):
if BLOCK_SIZE >= 4096:
num_warps = 16
# Allocate output
y = torch.empty_like(x)
if y is None:
y = torch.empty_like(x)
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row of
# the input matrix
softmax_kernel[(n_rows, )](
Expand Down Expand Up @@ -199,13 +200,15 @@ def benchmark(M, N, provider):

if device == 'cpu':
is_cpu = True
y = torch.empty_like(x)
triton.runtime.driver.set_active_to_cpu()
if 'single' in provider:
os.environ['TRITON_CPU_SINGLE_CORE'] = '1'
else:
os.unsetenv('TRITON_CPU_SINGLE_CORE')
else:
is_cpu = False
y = None
triton.runtime.driver.set_active_to_gpu()

quantiles = [0.5, 0.2, 0.8]
Expand All @@ -217,17 +220,17 @@ def benchmark(M, N, provider):
if provider == 'torch-cpu-compile':
compiled = torch.compile(naive_softmax)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x), quantiles=quantiles, is_cpu=is_cpu)
if provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, is_cpu=is_cpu)
if provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, is_cpu=is_cpu)
if provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, is_cpu=is_cpu)
if provider == 'torch-gpu-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles,
is_cpu=is_cpu)
if provider == 'torch-gpu-jit':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, is_cpu=is_cpu)
if provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, is_cpu=is_cpu)
if provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, is_cpu=is_cpu)
if provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, is_cpu=is_cpu)
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down

0 comments on commit 01543a8

Please sign in to comment.