diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 07d1b188707e..754fd88f6bbe 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -25,7 +25,7 @@ GPU_BLOCK_SIZE = 1024 CPU_BLOCK_SIZE = 4096 -USE_GPU = True +USE_GPU = False @triton.jit @@ -89,7 +89,7 @@ def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, is_cpu): triton.runtime.driver.set_active_to_cpu() x = torch.rand(size, device='cpu') y = torch.rand(size, device='cpu') -output_torch_cpu = x + y +output_torch_cpu = torch.add(x, y) output_triton_cpu = add(x, y, None, is_cpu=True) print(output_torch_cpu) print(output_triton_cpu) @@ -98,7 +98,7 @@ def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, is_cpu): LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu'] LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU'] -LINE_STYLES = [('blue', '-'), ('green', '-'), ('cyan', '-')] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '-')] if USE_GPU and triton.runtime.driver.get_active_gpus(): triton.runtime.driver.set_active_to_gpu() @@ -150,31 +150,35 @@ def benchmark(size, provider): y = torch.rand(size, device=device, dtype=torch.float32) if device == 'cpu': + is_cpu = True triton.runtime.driver.set_active_to_cpu() if 'single' in provider: os.environ['TRITON_CPU_SINGLE_CORE'] = '1' else: os.unsetenv('TRITON_CPU_SINGLE_CORE') else: + is_cpu = False triton.runtime.driver.set_active_to_gpu() quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=is_cpu) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles, is_cpu=is_cpu) elif provider == 'torch-cpu': # Note that we preallocate the output buffer here to only measure the kernel performance # without a large chunk of memory allocation. output = torch.empty_like(x) ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles, - is_cpu=True) + is_cpu=is_cpu) elif provider == 'triton-cpu-single': output = torch.empty_like(x) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, + is_cpu=is_cpu) elif provider == 'triton-cpu': output = torch.empty_like(x) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, + is_cpu=is_cpu) gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/02-fused-softmax-cpu.py b/python/tutorials/02-fused-softmax-cpu.py new file mode 100644 index 000000000000..1fce4c78345f --- /dev/null +++ b/python/tutorials/02-fused-softmax-cpu.py @@ -0,0 +1,244 @@ +""" +Fused Softmax +============= + +In this tutorial, you will write a fused softmax operation that is significantly faster +than PyTorch's native op for a particular class of matrices: those whose rows can fit in +the GPU's SRAM. + +In doing so, you will learn about: + +* The benefits of kernel fusion for bandwidth-bound operations. + +* Reduction operators in Triton. + +""" + +# %% +# Motivations +# ----------- +# +# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. +# Let us consider instead the case of a simple (numerically stabilized) softmax operation: + +import torch + +import triton +import triton.language as tl + +USE_GPU = False + + +@torch.jit.script +def naive_softmax(x): + """Compute row-wise softmax of X using native pytorch + + We subtract the maximum element in order to avoid overflows. Softmax is invariant to + this shift. + """ + # read MN elements ; write M elements + x_max = x.max(dim=1)[0] + # read MN + M elements ; write MN elements + z = x - x_max[:, None] + # read MN elements ; write MN elements + numerator = torch.exp(z) + # read MN elements ; write M elements + denominator = numerator.sum(dim=1) + # read MN + M elements ; write MN elements + ret = numerator / denominator[:, None] + # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements + return ret + + +# %% +# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` +# requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. +# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads +# X once and does all the necessary computations on-chip. +# Doing so would require reading and writing back only :math:`MN` bytes, so we could +# expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). +# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically +# but, as we will see later, it is still far from ideal. + +# %% +# Compute Kernel +# -------------- +# +# Our softmax kernel works as follows: each program loads a row of the input matrix X, +# normalizes it and writes back the result to the output Y. +# +# Note that one important limitation of Triton is that each block must have a +# power-of-two number of elements, so we need to internally "pad" each row and guard the +# memory operations properly if we want to handle any possible input shapes: + + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): + # The rows of the softmax are independent, so we parallelize across those + row_idx = tl.program_id(0) + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) + + +# %% +# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. + + +def softmax(x, y=None): + n_rows, n_cols = x.shape + # The block size is the smallest power of two greater than the number of columns in `x` + BLOCK_SIZE = triton.next_power_of_2(n_cols) + # Another trick we can use is to ask the compiler to use more threads per row by + # increasing the number of warps (`num_warps`) over which each row is distributed. + # You will see in the next tutorial how to auto-tune this value in a more natural + # way so you don't have to come up with manual heuristics yourself. + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + # Allocate output + if y is None: + y = torch.empty_like(x) + # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row of + # the input matrix + softmax_kernel[(n_rows, )]( + y, + x, + x.stride(0), + y.stride(0), + n_cols, + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return y + + +# %% +# Unit Test +# --------- + +# %% +# We make sure that we test our kernel on a matrix with an irregular number of rows and columns. +# This will allow us to verify that our padding mechanism works. + +triton.runtime.driver.set_active_to_cpu() + +torch.manual_seed(0) +x = torch.randn(1823, 781, device='cpu') +y_triton_cpu = softmax(x) +y_torch_cpu = torch.softmax(x, axis=1) +assert torch.allclose(y_triton_cpu, y_torch_cpu), (y_triton_cpu, y_torch_cpu) + +LINE_VALS = [ + 'triton-cpu-single', + 'triton-cpu', + 'torch-cpu-compile', + 'torch-cpu-jit', + 'torch-cpu-native', +] +LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (compile)', 'TorchCPU (jit)', 'TorchCPU (native)'] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '-'), ('green', '--'), ('green', '-.')] + +if USE_GPU and triton.runtime.driver.get_active_gpus(): + triton.runtime.driver.set_active_to_gpu() + x = x.to('cuda') + y_triton_gpu = softmax(x) + y_torch_gpu = torch.softmax(x, axis=1) + assert torch.allclose(y_triton_gpu, y_torch_gpu), (y_triton_gpu, y_torch_gpu) + LINE_VALS += ['triton-gpu', 'torch-gpu-native', 'torch-gpu-jit'] + LINE_NAMES += ['TritonGPU', 'TorchGPU (native)', 'TorchGPU (jit)'] + LINE_STYLES += [('yellow', '-'), ('red', '-'), ('red', '--')] + +# %% +# As expected, the results are identical. + +# %% +# Benchmark +# --------- +# +# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. +# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['N'], # argument names to use as an x-axis for the plot + x_vals=[128 * i for i in range(2, 52, 2)], # different possible values for `x_name` + line_arg='provider', # argument name whose value corresponds to a different line in the plot + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. + ylabel="GB/s", # label name for the y-axis + plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. + args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` + )) +def benchmark(M, N, provider): + import os + + # Currently compilation time is very long. Let's show the progress. + print(f"Running {provider} with {M} x {N}...") + + device = 'cpu' if 'cpu' in provider else 'cuda' + x = torch.randn(M, N, device=device, dtype=torch.float32) + + if device == 'cpu': + is_cpu = True + y = torch.empty_like(x) + triton.runtime.driver.set_active_to_cpu() + if 'single' in provider: + os.environ['TRITON_CPU_SINGLE_CORE'] = '1' + else: + os.unsetenv('TRITON_CPU_SINGLE_CORE') + else: + is_cpu = False + y = None + triton.runtime.driver.set_active_to_gpu() + + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch-cpu-native': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, + is_cpu=is_cpu) + if provider == 'torch-cpu-jit': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, is_cpu=is_cpu) + if provider == 'torch-cpu-compile': + compiled = torch.compile(naive_softmax) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x), quantiles=quantiles, is_cpu=is_cpu) + if provider == 'triton-cpu-single': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, is_cpu=is_cpu) + if provider == 'triton-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, is_cpu=is_cpu) + if provider == 'triton-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, is_cpu=is_cpu) + if provider == 'torch-gpu-native': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, + is_cpu=is_cpu) + if provider == 'torch-gpu-jit': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, is_cpu=is_cpu) + gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +benchmark.run(show_plots=True, print_data=True) + +# %% +# In the above plot, we can see that: +# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. +# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. +# Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape. diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index e96d04614661..937cbc652ba7 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -158,7 +158,7 @@ BLOCK_SIZE_N = 32 BLOCK_SIZE_K = 32 GROUP_SIZE_M = 8 -USE_GPU = True +USE_GPU = False @triton.jit @@ -217,9 +217,9 @@ def matmul_kernel( # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. - #TODO: Currently masked load is not supported yet. - #a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - #b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # TODO: Currently masked load is not supported yet. + # a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) a = tl.load(a_ptrs) b = tl.load(b_ptrs) # We accumulate along the K dimension. @@ -237,9 +237,9 @@ def matmul_kernel( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - #TODO: Currently masked load is not supported yet. - #c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - #tl.store(c_ptrs, c, mask=c_mask) + # TODO: Currently masked load is not supported yet. + # c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + # tl.store(c_ptrs, c, mask=c_mask) tl.store(c_ptrs, c) @@ -309,9 +309,9 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): # We can now compare the performance of our kernel against that of Pytorch. Here we focus on square matrices, # but feel free to arrange this script as you wish to benchmark any other matrix shape. -LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu'] -LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU'] -LINE_STYLES = [('blue', '-'), ('green', '-'), ('cyan', '-')] +LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu-native', 'torch-cpu-compile'] +LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (native)', 'TorchCPU (compile)'] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '--'), ('green', '-')] if USE_GPU and triton.runtime.driver.get_active_gpus(): triton.runtime.driver.set_active_to_gpu() @@ -366,12 +366,14 @@ def benchmark(M, N, K, provider): b = torch.randn((K, N), device=device, dtype=torch.float32) if device == 'cpu': + c = torch.empty((M, N), device=a.device, dtype=a.dtype) triton.runtime.driver.set_active_to_cpu() if 'single' in provider: os.environ['TRITON_CPU_SINGLE_CORE'] = '1' else: os.unsetenv('TRITON_CPU_SINGLE_CORE') else: + c = None triton.runtime.driver.set_active_to_gpu() quantiles = [0.5, 0.2, 0.8] @@ -379,15 +381,15 @@ def benchmark(M, N, K, provider): ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) elif provider == 'triton-gpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles) - elif provider == 'torch-cpu': - c = torch.empty((M, N), device=a.device, dtype=a.dtype) + elif provider == 'torch-cpu-native': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, is_cpu=True) + elif provider == 'torch-cpu-compile': + compiled = torch.compile(torch.matmul) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu-single': - c = torch.empty((M, N), device=a.device, dtype=a.dtype) ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu': - c = torch.empty((M, N), device=a.device, dtype=a.dtype) ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True) perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms)