Skip to content

Commit

Permalink
Matmul tutorial - external preprocessing (triton-lang#15)
Browse files Browse the repository at this point in the history
Adds an optional flag to move matmul input preprocessing
outside of the benchmarked kernel.
This option allows to exclude preprocessing overhead from
performance measurements.
  • Loading branch information
adam-smnk authored and Devjiu committed Jan 20, 2025
1 parent 2930ef7 commit ad183bb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,14 @@
PAD_B_ONLY = True
USE_BLOCK_POINTERS = os.getenv("USE_BLOCK_POINTERS", "1") != "0"
GROUP_SIZE_M = 8

USE_GPU = False
USE_BLOCK_POINTERS = False
DATA_TYPE = torch.float32
K_DIM_PADDING = False
DYNAMIC_K_BLOCK = False
CACHE_PADDING = False
PREPROCESS_EXTERNAL = False

@triton.jit
def pad_kernel(in_ptr, out_ptr, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, PADDING: tl.constexpr):
Expand Down Expand Up @@ -290,15 +292,13 @@ def matmul_kernel(
c_tile_ptr = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(c_tile_ptr, c)


# %%
# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.

a_scratch = torch.empty((), dtype=DTYPE)
b_scratch = torch.empty((), dtype=DTYPE)


def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
Expand Down

0 comments on commit ad183bb

Please sign in to comment.