From dc8dfb6d28c4bc7cc83a7eb5defd1279ff093d4c Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 13 Jan 2025 16:23:47 -0600 Subject: [PATCH] Support VNNI pre-encoded input in AMX lowering. (#210) Signed-off-by: Ilya Enkovich --- ...d-matmul-fp32.py => cpu-blocked-matmul.py} | 137 +++++++---- test/TritonCPU/dot-to-amx.mlir | 223 ++++++++++++++++++ third_party/cpu/language/cpu/__init__.py | 3 + third_party/cpu/language/cpu/utils.py | 22 ++ .../ConvertDotOp/ConvertDotCommon.cpp | 73 +++++- .../ConvertDotOp/ConvertDotCommon.h | 20 +- .../ConvertDotOp/ConvertDotToAMX.cpp | 11 +- .../TritonToTritonCPU/ConvertElemManipOps.cpp | 21 +- 8 files changed, 435 insertions(+), 75 deletions(-) rename python/tutorials/{cpu-blocked-matmul-fp32.py => cpu-blocked-matmul.py} (73%) create mode 100644 third_party/cpu/language/cpu/utils.py diff --git a/python/tutorials/cpu-blocked-matmul-fp32.py b/python/tutorials/cpu-blocked-matmul.py similarity index 73% rename from python/tutorials/cpu-blocked-matmul-fp32.py rename to python/tutorials/cpu-blocked-matmul.py index 8f0f0ebce41a..e8f274d6c552 100644 --- a/python/tutorials/cpu-blocked-matmul-fp32.py +++ b/python/tutorials/cpu-blocked-matmul.py @@ -18,11 +18,13 @@ import os DTYPE = os.getenv("DTYPE", "float32") +in_dtype = getattr(torch, DTYPE) +out_dtype = torch.float32 if in_dtype.is_floating_point else torch.int32 # Choose block size depending on dtype. We have more register # capacity for bfloat16/float16 compared to float32. BLOCK_SIZE_M = 8 if DTYPE == "float32" else 32 BLOCK_SIZE_N = 32 -BLOCK_SIZE_K = 8 if DTYPE == "float32" else 32 +BLOCK_SIZE_K = 8 if DTYPE == "float32" else 64 // in_dtype.itemsize GROUP_SIZE_M = 8 @@ -38,6 +40,9 @@ # tensor are transposed. It provides contiguos placement for a column # of blocks. # +# If PACKED_B is set to True then B is VNNI encoded. Only works when +# BLOCKED_B is True. +# # If TRANSPOSED_BLOCK_A is set to True then tail dimensions of the LHS # tensor are transposed. Transposed LHS block better matches FMA lowering # used by Triton CPU backend which processes RHS block row-by-row and LHS @@ -46,7 +51,7 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, BLOCKED_B: tl.constexpr, - TRANSPOSED_B: tl.constexpr): + TRANSPOSED_B: tl.constexpr, PACKED_B: tl.constexpr): tl.static_assert(BLOCKED_A or not TRANSPOSED_BLOCK_A) tl.static_assert(BLOCKED_B or not TRANSPOSED_B) pid = tl.program_id(axis=0) @@ -85,9 +90,11 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M, N, K, BLOCK_SIZ tl.store(a_out_ptr, val) if BLOCKED_B: + B_PACKED_NUM: tl.constexpr = 32 // in_b.type.element_ty.primitive_bitwidth if PACKED_B else 1 + PACKED_BLOCK_SIZE_K: tl.constexpr = BLOCK_SIZE_K // B_PACKED_NUM if PACKED_B else BLOCK_SIZE_K + PACKED_BLOCK_SIZE_N: tl.constexpr = BLOCK_SIZE_N * B_PACKED_NUM if PACKED_B else BLOCK_SIZE_N B_OUT_BLOCKS_K = N // BLOCK_SIZE_N if TRANSPOSED_B else K // BLOCK_SIZE_K B_OUT_BLOCKS_N = K // BLOCK_SIZE_K if TRANSPOSED_B else N // BLOCK_SIZE_N - B_OUT_STRIDE_K: tl.constexpr = BLOCK_SIZE_N B_OUT_STRIDE_BLOCK_K = (K * BLOCK_SIZE_N if TRANSPOSED_B else BLOCK_SIZE_K * N) B_OUT_STRIDE_BLOCK_N: tl.constexpr = BLOCK_SIZE_K * BLOCK_SIZE_N for in_block_k in tl.range(in_block_m, K // BLOCK_SIZE_K, M // BLOCK_SIZE_M): @@ -95,15 +102,28 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M, N, K, BLOCK_SIZ b_out_block_n = in_block_k if TRANSPOSED_B else in_block_n b_in_ptr = tl.make_block_ptr(base=in_b, shape=(K, N), strides=(N, 1), offsets=(in_block_k * BLOCK_SIZE_K, in_block_n * BLOCK_SIZE_N), - block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(1, 0)) - b_out_ptr = tl.make_block_ptr(base=out_b, - shape=(B_OUT_BLOCKS_K, B_OUT_BLOCKS_N, BLOCK_SIZE_K, BLOCK_SIZE_N), - strides=(B_OUT_STRIDE_BLOCK_K, B_OUT_STRIDE_BLOCK_N, B_OUT_STRIDE_K, 1), - offsets=(b_out_block_k, b_out_block_n, 0, 0), - block_shape=(1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N), order=(3, 2, 1, 0)) - val = tl.load(b_in_ptr) - val = tl.reshape(val, (1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N)) - tl.store(b_out_ptr, val) + block_shape=(1, BLOCK_SIZE_N), order=(1, 0)) + b_out_ptr = tl.make_block_ptr( + base=out_b, shape=(B_OUT_BLOCKS_K, B_OUT_BLOCKS_N, PACKED_BLOCK_SIZE_K, PACKED_BLOCK_SIZE_N), + strides=(B_OUT_STRIDE_BLOCK_K, B_OUT_STRIDE_BLOCK_N, PACKED_BLOCK_SIZE_N, 1), + offsets=(b_out_block_k, b_out_block_n, 0, 0), block_shape=(1, 1, 1, PACKED_BLOCK_SIZE_N), + order=(3, 2, 1, 0)) + for i in tl.range(0, BLOCK_SIZE_K // B_PACKED_NUM): + row1 = tl.load(b_in_ptr).reshape((BLOCK_SIZE_N, )) + if B_PACKED_NUM > 1: + b_in_ptr = tl.advance(b_in_ptr, (1, 0)) + row2 = tl.load(b_in_ptr).reshape((BLOCK_SIZE_N, )) + if B_PACKED_NUM > 2: + b_in_ptr = tl.advance(b_in_ptr, (1, 0)) + row3 = tl.load(b_in_ptr).reshape((BLOCK_SIZE_N, )) + b_in_ptr = tl.advance(b_in_ptr, (1, 0)) + row4 = tl.load(b_in_ptr).reshape((BLOCK_SIZE_N, )) + row1 = tl.ravel(tl.join(row1, row3)) + row2 = tl.ravel(tl.join(row2, row4)) + row1 = tl.ravel(tl.join(row1, row2)) + tl.store(b_out_ptr, row1.reshape((1, 1, 1, PACKED_BLOCK_SIZE_N))) + b_in_ptr = tl.advance(b_in_ptr, (1, 0)) + b_out_ptr = tl.advance(b_out_ptr, (0, 0, 1, 0)) # Matmul kernel that computes a single output block [BLOCK_SIZE_M, BLOCK_SIZE_N]. LHS can be in the @@ -125,7 +145,7 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOC BLOCK_SIZE_K: tl.constexpr, # number of blocks in a group GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, - BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr): + BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr, PACKED_B: tl.constexpr, OUT_DTYPE: tl.constexpr): # TRANSPOSED_BLOCK_A means that each block in A is transposed. # It is allowed only for blocked input. assert (BLOCKED_A or not TRANSPOSED_BLOCK_A) @@ -151,37 +171,43 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOC a_stride_block_k = A_BLOCK_SIZE_M * A_BLOCK_SIZE_K if BLOCKED_A else A_BLOCK_SIZE_K a_stride_block_m = BLOCK_SIZE_M * K + B_PACKED_NUM: tl.constexpr = 32 // b_ptr.type.element_ty.primitive_bitwidth if PACKED_B else 1 + PACKED_BLOCK_SIZE_K: tl.constexpr = BLOCK_SIZE_K // B_PACKED_NUM if PACKED_B else BLOCK_SIZE_K + PACKED_BLOCK_SIZE_N: tl.constexpr = BLOCK_SIZE_N * B_PACKED_NUM if PACKED_B else BLOCK_SIZE_N + assert BLOCKED_B or not TRANSPOSED_B b_stride_n = 1 - b_stride_k = BLOCK_SIZE_N if BLOCKED_B else N + b_stride_k = PACKED_BLOCK_SIZE_N if BLOCKED_B else N * B_PACKED_NUM if TRANSPOSED_B: b_stride_block_n = BLOCK_SIZE_N * K b_stride_block_k = BLOCK_SIZE_K * BLOCK_SIZE_N else: - b_stride_block_n = BLOCK_SIZE_K * BLOCK_SIZE_N if BLOCKED_B else BLOCK_SIZE_N + b_stride_block_n = BLOCK_SIZE_K * BLOCK_SIZE_N if BLOCKED_B else PACKED_BLOCK_SIZE_N b_stride_block_k = BLOCK_SIZE_K * N a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(A_BLOCKS_M, A_BLOCKS_K, A_BLOCK_SIZE_M, A_BLOCK_SIZE_K), strides=(a_stride_block_m, a_stride_block_k, a_stride_m, a_stride_k), offsets=(block_m, 0, 0, 0), block_shape=(1, 1, A_BLOCK_SIZE_M, A_BLOCK_SIZE_K), order=(3, 2, 1, 0)) - b_block_ptr = tl.make_block_ptr(base=b_ptr, - shape=(K // BLOCK_SIZE_K, N // BLOCK_SIZE_N, BLOCK_SIZE_K, BLOCK_SIZE_N), - strides=(b_stride_block_k, b_stride_block_n, b_stride_k, b_stride_n), - offsets=(0, block_n, 0, 0), block_shape=(1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N), - order=(3, 2, 1, 0)) + b_block_ptr = tl.make_block_ptr( + base=b_ptr, shape=(K // BLOCK_SIZE_K, N // BLOCK_SIZE_N, PACKED_BLOCK_SIZE_K, PACKED_BLOCK_SIZE_N), + strides=(b_stride_block_k, b_stride_block_n, b_stride_k, b_stride_n), offsets=(0, block_n, 0, 0), + block_shape=(1, 1, PACKED_BLOCK_SIZE_K, PACKED_BLOCK_SIZE_N), order=(3, 2, 1, 0)) c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(N, 1), offsets=(block_m * BLOCK_SIZE_M, block_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) - c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=OUT_DTYPE) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_block_ptr).reshape((A_BLOCK_SIZE_M, A_BLOCK_SIZE_K)) - b = tl.load(b_block_ptr).reshape((BLOCK_SIZE_K, BLOCK_SIZE_N)) + b = tl.load(b_block_ptr).reshape((PACKED_BLOCK_SIZE_K, PACKED_BLOCK_SIZE_N)) if TRANSPOSED_BLOCK_A: a = a.T - c += tl.dot(a, b, out_dtype=tl.float32) + if PACKED_B: + b = tl.extra.cpu.vnni_decode(b) + + c += tl.dot(a, b, out_dtype=OUT_DTYPE) a_block_ptr = tl.advance(a_block_ptr, (0, 1, 0, 0)) b_block_ptr = tl.advance(b_block_ptr, (1, 0, 0, 0)) @@ -190,7 +216,7 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOC def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, bb: torch.Tensor, M, N, K, PREPACKED, - BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, num_threads=0): + BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, PACKED_B, num_threads=0): #TODO: Currently masked load is not supported yet. assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" @@ -203,7 +229,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # GROUP_SIZE_M=GROUP_SIZE_M, # BLOCKED_A=BLOCKED_A, TRANSPOSED_BLOCK_A=TRANSPOSED_BLOCK_A, # - BLOCKED_B=BLOCKED_B, TRANSPOSED_B=TRANSPOSED_B) + BLOCKED_B=BLOCKED_B, TRANSPOSED_B=TRANSPOSED_B, PACKED_B=PACKED_B) if BLOCKED_A: a = ab if BLOCKED_B: @@ -214,7 +240,8 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # GROUP_SIZE_M=GROUP_SIZE_M, # BLOCKED_A=BLOCKED_A, TRANSPOSED_BLOCK_A=TRANSPOSED_BLOCK_A, # - BLOCKED_B=BLOCKED_B, TRANSPOSED_B=TRANSPOSED_B, num_threads=num_threads) + BLOCKED_B=BLOCKED_B, TRANSPOSED_B=TRANSPOSED_B, PACKED_B=PACKED_B, # + OUT_DTYPE=tl.float32 if a.dtype.is_floating_point else tl.int32, num_threads=num_threads) return c @@ -227,13 +254,17 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, triton.runtime.driver.set_active_to_cpu() -a = torch.randn((512, 512), device='cpu', dtype=torch.float32) -b = torch.randn((512, 512), device='cpu', dtype=torch.float32) -c = torch.empty((512, 512), device='cpu', dtype=torch.float32) -torch_output = torch.matmul(a, b) +if in_dtype.is_floating_point: + a = torch.randn((512, 512), device='cpu', dtype=in_dtype) + b = torch.randn((512, 512), device='cpu', dtype=in_dtype) +else: + a = torch.randint(0, 5, (512, 512), device='cpu', dtype=in_dtype) + b = torch.randint(0, 5, (512, 512), device='cpu', dtype=in_dtype) +c = torch.empty((512, 512), device='cpu', dtype=out_dtype) +torch_output = torch.matmul(a.to(out_dtype), b.to(out_dtype)) rtol = 0 -a_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_M) * (512 // BLOCK_SIZE_K) * 64), device='cpu', dtype=torch.float32) -b_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_K) * (512 // BLOCK_SIZE_N) * 64), device='cpu', dtype=torch.float32) +a_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_M) * (512 // BLOCK_SIZE_K) * 64), device='cpu', dtype=in_dtype) +b_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_K) * (512 // BLOCK_SIZE_N) * 64), device='cpu', dtype=in_dtype) triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, True, False, False, False, False, False) if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print("✅ TritonCPU and TorchCPU match") @@ -241,7 +272,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, print("❌ TritonCPU and TorchCPU differ, the maximum difference is " f'{torch.max(torch.abs(triton_output - torch_output))}') assert False -triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, False, True, True, True, True, True) +triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, False, True, True, True, True, DTYPE != "float32") if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print("✅ TritonCPU pre-packed and TorchCPU match") else: @@ -260,13 +291,13 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, # but feel free to arrange this script as you wish to benchmark any other matrix shape. -def encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype): - assert dtype == 'float32' or dtype == 'bfloat16' or dtype == 'float16' - return f"triton-cpu{'-ba' if blocked_a else ''}{'-ta' if transposed_a else ''}{'-bb' if blocked_b else ''}{'-tb' if transposed_b else ''}{'-prepack' if prepack else ''}{'-st' if single_thread else ''}-{dtype}" +def encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, packed_b, prepack, single_thread, dtype): + assert dtype == 'float32' or dtype == 'bfloat16' or dtype == 'float16' or dtype == 'int8' + return f"triton-cpu{'-ba' if blocked_a else ''}{'-ta' if transposed_a else ''}{'-bb' if blocked_b else ''}{'-tb' if transposed_b else ''}{'-pb' if packed_b else ''}{'-prepack' if prepack else ''}{'-st' if single_thread else ''}-{dtype}" def encode_torch_provider(single_thread, dtype): - assert dtype == 'float32' or dtype == 'bfloat16' or dtype == 'float16' + assert dtype == 'float32' or dtype == 'bfloat16' or dtype == 'float16' or dtype == 'int8' return f"torch-cpu-native{'-st' if single_thread else ''}-{dtype}" @@ -277,28 +308,30 @@ def decode_provider(provider): dtype = torch.float16 elif '-float32' in provider: dtype = torch.float32 + elif '-int8' in provider: + dtype = torch.int8 if 'triton-cpu' in provider: backend = 'triton-cpu' elif 'torch-cpu-native' in provider: backend = 'torch-cpu-native' elif 'torch-cpu-compile' in provider: backend = 'torch-cpu-compile' - return backend, '-ba' in provider, '-ta' in provider, '-bb' in provider, '-tb' in provider, '-prepack' in provider, '-st' in provider, dtype + return backend, '-ba' in provider, '-ta' in provider, '-bb' in provider, '-tb' in provider, '-pb' in provider, '-prepack' in provider, '-st' in provider, dtype BLOCK_TRANSPOSE_A_OPTS = [(False, False)] -BLOCK_TRANSPOSE_B_OPTS = [(True, True), (False, False)] +BLOCK_TRANSPOSE_PACK_B_OPTS = [(True, True, True), (True, True, False), (False, False, False)] PREPACK_OPTS = [False, True] SINGLE_THREAD_OPTS = [False] DTYPE_OPTS = [DTYPE] LINE_VALS = [ - encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype) + encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, packed_b, prepack, single_thread, dtype) for single_thread in SINGLE_THREAD_OPTS for blocked_a, transposed_a in BLOCK_TRANSPOSE_A_OPTS - for blocked_b, transposed_b in BLOCK_TRANSPOSE_B_OPTS + for blocked_b, transposed_b, packed_b in BLOCK_TRANSPOSE_PACK_B_OPTS for prepack in PREPACK_OPTS for dtype in DTYPE_OPTS - if blocked_a or blocked_b or not prepack + if (blocked_a or blocked_b or not prepack) and (not packed_b or dtype != "float32") ] + [encode_torch_provider(single_thread, dtype) for dtype in DTYPE_OPTS for single_thread in SINGLE_THREAD_OPTS] LINE_NAMES = LINE_VALS LINE_STYLES = None @@ -323,9 +356,14 @@ def decode_provider(provider): def benchmark(M, N, K, provider): device = 'cpu' if 'cpu' in provider else 'cuda' - backend, blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype = decode_provider(provider) - a = torch.randn((M, K), device=device, dtype=dtype) - b = torch.randn((K, N), device=device, dtype=dtype) + backend, blocked_a, transposed_a, blocked_b, transposed_b, packed_b, prepack, single_thread, dtype = decode_provider( + provider) + if dtype.is_floating_point: + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((K, N), device=device, dtype=dtype) + else: + a = torch.randint(0, 5, (M, K), device=device, dtype=dtype) + b = torch.randint(0, 5, (K, N), device=device, dtype=dtype) if single_thread: torch.set_num_threads(1) @@ -333,10 +371,10 @@ def benchmark(M, N, K, provider): torch.set_num_threads(default_num_threads) if backend == 'triton-cpu': - c = torch.zeros((M, N), device=a.device, dtype=torch.float32) + c = torch.zeros((M, N), device=a.device, dtype=out_dtype) a_tmp = torch.zeros((M * K + (M // BLOCK_SIZE_M) * (K // BLOCK_SIZE_K) * 64), device=device, dtype=dtype) b_tmp = torch.zeros((K * N + (K // BLOCK_SIZE_K) * (N // BLOCK_SIZE_N) * 64), device=device, dtype=dtype) - c = torch.zeros((M, N), device=a.device, dtype=torch.float32) + c = torch.zeros((M, N), device=a.device, dtype=out_dtype) if prepack and (blocked_a or blocked_b): grid = ((M // BLOCK_SIZE_M) * (N // BLOCK_SIZE_N), ) block_transpose_combined_kernel[grid]( @@ -345,7 +383,7 @@ def benchmark(M, N, K, provider): BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # GROUP_SIZE_M=GROUP_SIZE_M, # BLOCKED_A=blocked_a, TRANSPOSED_BLOCK_A=transposed_a, # - BLOCKED_B=blocked_b, TRANSPOSED_B=transposed_b) + BLOCKED_B=blocked_b, TRANSPOSED_B=transposed_b, PACKED_B=packed_b) if blocked_a: a = a_tmp if blocked_b: @@ -362,7 +400,8 @@ def benchmark(M, N, K, provider): elif backend == 'triton-cpu': ms, min_ms, max_ms = triton.testing.do_bench( lambda: matmul(a, b, c, a_tmp, b_tmp, M, N, K, prepack, blocked_a, transposed_a, blocked_b, transposed_b, - num_threads=int(single_thread)), quantiles=quantiles, measure_time_with_hooks=True, rep=1000) + packed_b, num_threads=int(single_thread)), quantiles=quantiles, measure_time_with_hooks=True, + rep=1000) perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) diff --git a/test/TritonCPU/dot-to-amx.mlir b/test/TritonCPU/dot-to-amx.mlir index da501849f723..5b6a020306e6 100644 --- a/test/TritonCPU/dot-to-amx.mlir +++ b/test/TritonCPU/dot-to-amx.mlir @@ -236,3 +236,226 @@ module { tt.return loc(#loc) } loc(#loc) } loc(#loc) + +// ----- + +// A case with VNNI pre-encoded RHS that can be directly accessed from the input memory. +// We expect both LHS and RHS tiles to be directly loaded from the input mmemory. + +// CHECK-LABEL: @test_loop_pre_encoded_direct +// CHECK: %[[LHS_MEMREF:.+]] = triton_cpu.extract_memref +// CHECK-NEXT: %[[LHS_INDICES:.+]]:4 = triton_cpu.extract_indices +// CHECK: %[[RHS_MEMREF:.+]] = triton_cpu.extract_memref +// CHECK-NEXT: %[[RHS_INDICES:.+]]:4 = triton_cpu.extract_indices +// CHECK: amx.tile_load %[[LHS_MEMREF]][%[[LHS_INDICES]]#0, %[[LHS_INDICES]]#1, %[[LHS_INDICES]]#2, %[[LHS_INDICES]]#3] +// CHECK: amx.tile_load %[[RHS_MEMREF]][%[[RHS_INDICES]]#0, %[[RHS_INDICES]]#1, %[[RHS_INDICES]]#2, %[[RHS_INDICES]]#3] +#loc = loc(unknown) +module { + tt.func public @test_loop_pre_encoded_direct(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) + %c31_i32 = arith.constant 31 : i32 loc(#loc) + %c1024_i64 = arith.constant 1024 : i64 loc(#loc) + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x32xf32> loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c64_i64 = arith.constant 64 : i64 loc(#loc) + %c16_i64 = arith.constant 16 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c32_i32 = arith.constant 32 : i32 loc(#loc) + %c1_i32 = arith.constant 1 : i32 loc(#loc) + %0 = tt.get_program_id x : i32 loc(#loc) + %1 = arith.divsi %arg4, %c32_i32 : i32 loc(#loc) + %2 = arith.divsi %0, %1 : i32 loc(#loc) + %3 = arith.remsi %0, %1 : i32 loc(#loc) + %4 = arith.muli %arg5, %c32_i32 : i32 loc(#loc) + %5 = arith.divsi %arg3, %c32_i32 : i32 loc(#loc) + %6 = arith.divsi %arg5, %c32_i32 : i32 loc(#loc) + %7 = arith.extsi %5 : i32 to i64 loc(#loc) + %8 = arith.extsi %6 : i32 to i64 loc(#loc) + %9 = arith.extsi %4 : i32 to i64 loc(#loc) + %10 = arith.extsi %arg5 : i32 to i64 loc(#loc) + %11 = tt.make_tensor_ptr %arg0, [%7, %8, %c32_i64, %c32_i64], [%9, %c32_i64, %10, %c1_i64], [%2, %c0_i32, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %12 = arith.extsi %1 : i32 to i64 loc(#loc) + %13 = tt.make_tensor_ptr %arg1, [%8, %12, %c16_i64, %c64_i64], [%c1024_i64, %9, %c64_i64, %c1_i64], [%c0_i32, %3, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %14 = arith.muli %2, %c32_i32 : i32 loc(#loc) + %15 = arith.muli %3, %c32_i32 : i32 loc(#loc) + %16 = arith.extsi %arg3 : i32 to i64 loc(#loc) + %17 = arith.extsi %arg4 : i32 to i64 loc(#loc) + %18 = tt.make_tensor_ptr %arg2, [%16, %17], [%17, %c1_i64], [%14, %15] {order = array} : > loc(#loc) + %19 = arith.addi %arg5, %c31_i32 : i32 loc(#loc) + %20 = arith.divsi %19, %c32_i32 : i32 loc(#loc) + %21:3 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %cst_0, %arg8 = %11, %arg9 = %13) -> (vector<32x32xf32>, !tt.ptr>, !tt.ptr>) : i32 { + %24 = triton_cpu.extract_memref %arg8 : > -> memref> loc(#loc) + %25:4 = triton_cpu.extract_indices %arg8 : > -> index, index, index, index loc(#loc) + %26 = vector.transfer_read %24[%25#0, %25#1, %25#2, %25#3], %cst {in_bounds = [true, true]} : memref>, vector<32x32xbf16> loc(#loc) + %27 = triton_cpu.extract_memref %arg9 : > -> memref> loc(#loc) + %28:4 = triton_cpu.extract_indices %arg9 : > -> index, index, index, index loc(#loc) + %29 = vector.transfer_read %27[%28#0, %28#1, %28#2, %28#3], %cst {in_bounds = [true, true]} : memref>, vector<16x64xbf16> loc(#loc) + %res1, %res2 = vector.deinterleave %29 : vector<16x64xbf16> -> vector<16x32xbf16> loc(#loc) + %30 = vector.transpose %res1, [1, 0] : vector<16x32xbf16> to vector<32x16xbf16> loc(#loc) + %31 = vector.transpose %res2, [1, 0] : vector<16x32xbf16> to vector<32x16xbf16> loc(#loc) + %32 = vector.interleave %30, %31 : vector<32x16xbf16> -> vector<32x32xbf16> loc(#loc) + %33 = vector.transpose %32, [1, 0] : vector<32x32xbf16> to vector<32x32xbf16> loc(#loc) + %34 = triton_cpu.dot %26, %33, %arg7, inputPrecision = tf32 : vector<32x32xbf16> * vector<32x32xbf16> -> vector<32x32xf32> loc(#loc) + %35 = tt.advance %arg8, [%c0_i32, %c1_i32, %c0_i32, %c0_i32] : > loc(#loc) + %36 = tt.advance %arg9, [%c1_i32, %c0_i32, %c0_i32, %c0_i32] : > loc(#loc) + scf.yield %34, %35, %36 : vector<32x32xf32>, !tt.ptr>, !tt.ptr> loc(#loc) + } loc(#loc) + %22 = triton_cpu.extract_memref %18 : > -> memref> loc(#loc) + %23:2 = triton_cpu.extract_indices %18 : > -> index, index loc(#loc) + vector.transfer_write %21#0, %22[%23#0, %23#1] {in_bounds = [true, true]} : vector<32x32xf32>, memref> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) + +// ----- + +// A case with VNNI pre-encoded RHS that cannot be directly accessed from the input memory. +// We expect LHS to be directly loaded from the input mmemory and RHS to be loaded through +// a temporary buffer without additional encoding. + + +// CHECK-LABEL: @test_loop_pre_encoded_indirect +// CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<16x64xbf16> +// CHECK: %[[LHS_MEMREF:.+]] = triton_cpu.extract_memref +// CHECK-NEXT: %[[LHS_INDICES:.+]]:4 = triton_cpu.extract_indices +// CHECK: %[[RHS_MEMREF:.+]] = triton_cpu.extract_memref +// CHECK-NEXT: %[[RHS_INDICES:.+]]:4 = triton_cpu.extract_indices +// CHECK-NEXT: %[[RHS:.+]] = vector.transfer_read %[[RHS_MEMREF]][%[[RHS_INDICES]]#0, %[[RHS_INDICES]]#1, %[[RHS_INDICES]]#2, %[[RHS_INDICES]]#3] +// CHECK: vector.transfer_write %[[RHS]], %[[RHS_BUF]][%c0, %c0] {in_bounds = [true, true]} +// CHECK: amx.tile_load %[[LHS_MEMREF]][%[[LHS_INDICES]]#0, %[[LHS_INDICES]]#1, %[[LHS_INDICES]]#2, %[[LHS_INDICES]]#3] +// CHECK: amx.tile_load %[[RHS_BUF]][%c0, %c0] +#loc = loc(unknown) +module { + tt.func public @test_loop_pre_encoded_indirect(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) + %c31_i32 = arith.constant 31 : i32 loc(#loc) + %c1024_i64 = arith.constant 1024 : i64 loc(#loc) + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x32xf32> loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c64_i64 = arith.constant 64 : i64 loc(#loc) + %c16_i64 = arith.constant 16 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c32_i32 = arith.constant 32 : i32 loc(#loc) + %c1_i32 = arith.constant 1 : i32 loc(#loc) + %0 = tt.get_program_id x : i32 loc(#loc) + %1 = arith.divsi %arg4, %c32_i32 : i32 loc(#loc) + %2 = arith.divsi %0, %1 : i32 loc(#loc) + %3 = arith.remsi %0, %1 : i32 loc(#loc) + %4 = arith.muli %arg5, %c32_i32 : i32 loc(#loc) + %5 = arith.divsi %arg3, %c32_i32 : i32 loc(#loc) + %6 = arith.divsi %arg5, %c32_i32 : i32 loc(#loc) + %7 = arith.extsi %5 : i32 to i64 loc(#loc) + %8 = arith.extsi %6 : i32 to i64 loc(#loc) + %9 = arith.extsi %4 : i32 to i64 loc(#loc) + %10 = arith.extsi %arg5 : i32 to i64 loc(#loc) + %11 = tt.make_tensor_ptr %arg0, [%7, %8, %c32_i64, %c32_i64], [%9, %c32_i64, %10, %c1_i64], [%2, %c0_i32, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %12 = arith.extsi %1 : i32 to i64 loc(#loc) + %13 = tt.make_tensor_ptr %arg1, [%8, %12, %c16_i64, %c64_i64], [%c1024_i64, %9, %c64_i64, %c1_i64], [%c0_i32, %3, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %14 = arith.muli %2, %c32_i32 : i32 loc(#loc) + %15 = arith.muli %3, %c32_i32 : i32 loc(#loc) + %16 = arith.extsi %arg3 : i32 to i64 loc(#loc) + %17 = arith.extsi %arg4 : i32 to i64 loc(#loc) + %18 = tt.make_tensor_ptr %arg2, [%16, %17], [%17, %c1_i64], [%14, %15] {order = array} : > loc(#loc) + %19 = arith.addi %arg5, %c31_i32 : i32 loc(#loc) + %20 = arith.divsi %19, %c32_i32 : i32 loc(#loc) + %21:3 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %cst_0, %arg8 = %11, %arg9 = %13) -> (vector<32x32xf32>, !tt.ptr>, !tt.ptr>) : i32 { + %24 = triton_cpu.extract_memref %arg8 : > -> memref> loc(#loc) + %25:4 = triton_cpu.extract_indices %arg8 : > -> index, index, index, index loc(#loc) + %26 = vector.transfer_read %24[%25#0, %25#1, %25#2, %25#3], %cst {in_bounds = [true, true]} : memref>, vector<32x32xbf16> loc(#loc) + %27 = triton_cpu.extract_memref %arg9 : > -> memref> loc(#loc) + %28:4 = triton_cpu.extract_indices %arg9 : > -> index, index, index, index loc(#loc) + %29 = vector.transfer_read %27[%28#0, %28#1, %28#2, %28#3], %cst {in_bounds = [false, false]} : memref>, vector<16x64xbf16> loc(#loc) + %res1, %res2 = vector.deinterleave %29 : vector<16x64xbf16> -> vector<16x32xbf16> loc(#loc) + %30 = vector.transpose %res1, [1, 0] : vector<16x32xbf16> to vector<32x16xbf16> loc(#loc) + %31 = vector.transpose %res2, [1, 0] : vector<16x32xbf16> to vector<32x16xbf16> loc(#loc) + %32 = vector.interleave %30, %31 : vector<32x16xbf16> -> vector<32x32xbf16> loc(#loc) + %33 = vector.transpose %32, [1, 0] : vector<32x32xbf16> to vector<32x32xbf16> loc(#loc) + %34 = triton_cpu.dot %26, %33, %arg7, inputPrecision = tf32 : vector<32x32xbf16> * vector<32x32xbf16> -> vector<32x32xf32> loc(#loc) + %35 = tt.advance %arg8, [%c0_i32, %c1_i32, %c0_i32, %c0_i32] : > loc(#loc) + %36 = tt.advance %arg9, [%c1_i32, %c0_i32, %c0_i32, %c0_i32] : > loc(#loc) + scf.yield %34, %35, %36 : vector<32x32xf32>, !tt.ptr>, !tt.ptr> loc(#loc) + } loc(#loc) + %22 = triton_cpu.extract_memref %18 : > -> memref> loc(#loc) + %23:2 = triton_cpu.extract_indices %18 : > -> index, index loc(#loc) + vector.transfer_write %21#0, %22[%23#0, %23#1] {in_bounds = [true, true]} : vector<32x32xf32>, memref> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) + +// ----- + +// A case with int8 VNNI pre-encoded RHS that can be directly accessed from the input memory. +// We expect both LHS and RHS tiles to be directly loaded from the input mmemory. + +// CHECK-LABEL: @test_loop_int8_pre_encoded +// CHECK: %[[LHS_MEMREF:.+]] = triton_cpu.extract_memref +// CHECK-NEXT: %[[LHS_INDICES:.+]]:4 = triton_cpu.extract_indices +// CHECK: %[[RHS_MEMREF:.+]] = triton_cpu.extract_memref +// CHECK-NEXT: %[[RHS_INDICES:.+]]:4 = triton_cpu.extract_indices +// CHECK: amx.tile_load %[[LHS_MEMREF]][%[[LHS_INDICES]]#0, %[[LHS_INDICES]]#1, %[[LHS_INDICES]]#2, %[[LHS_INDICES]]#3] +// CHECK: amx.tile_load %[[RHS_MEMREF]][%[[RHS_INDICES]]#0, %[[RHS_INDICES]]#1, %[[RHS_INDICES]]#2, %[[RHS_INDICES]]#3] +#loc = loc(unknown) +module { + tt.func public @test_loop_int8_pre_encoded(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i8 = arith.constant 0 : i8 loc(#loc) + %c31_i32 = arith.constant 31 : i32 loc(#loc) + %c1024_i64 = arith.constant 1024 : i64 loc(#loc) + %cst = arith.constant dense<0> : vector<32x32xi32> loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c128_i64 = arith.constant 128 : i64 loc(#loc) + %c8_i64 = arith.constant 8 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c32_i32 = arith.constant 32 : i32 loc(#loc) + %c1_i32 = arith.constant 1 : i32 loc(#loc) + %0 = tt.get_program_id x : i32 loc(#loc) + %1 = arith.divsi %arg4, %c32_i32 : i32 loc(#loc) + %2 = arith.divsi %0, %1 : i32 loc(#loc) + %3 = arith.remsi %0, %1 : i32 loc(#loc) + %4 = arith.muli %arg5, %c32_i32 : i32 loc(#loc) + %5 = arith.divsi %arg3, %c32_i32 : i32 loc(#loc) + %6 = arith.divsi %arg5, %c32_i32 : i32 loc(#loc) + %7 = arith.extsi %5 : i32 to i64 loc(#loc) + %8 = arith.extsi %6 : i32 to i64 loc(#loc) + %9 = arith.extsi %4 : i32 to i64 loc(#loc) + %10 = arith.extsi %arg5 : i32 to i64 loc(#loc) + %11 = tt.make_tensor_ptr %arg0, [%7, %8, %c32_i64, %c32_i64], [%9, %c32_i64, %10, %c1_i64], [%2, %c0_i32, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %12 = arith.extsi %1 : i32 to i64 loc(#loc) + %13 = tt.make_tensor_ptr %arg1, [%8, %12, %c8_i64, %c128_i64], [%c1024_i64, %9, %c128_i64, %c1_i64], [%c0_i32, %3, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %14 = arith.muli %2, %c32_i32 : i32 loc(#loc) + %15 = arith.muli %3, %c32_i32 : i32 loc(#loc) + %16 = arith.extsi %arg3 : i32 to i64 loc(#loc) + %17 = arith.extsi %arg4 : i32 to i64 loc(#loc) + %18 = tt.make_tensor_ptr %arg2, [%16, %17], [%17, %c1_i64], [%14, %15] {order = array} : > loc(#loc) + %19 = arith.addi %arg5, %c31_i32 : i32 loc(#loc) + %20 = arith.divsi %19, %c32_i32 : i32 loc(#loc) + %21:3 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %cst, %arg8 = %11, %arg9 = %13) -> (vector<32x32xi32>, !tt.ptr>, !tt.ptr>) : i32 { + %24 = triton_cpu.extract_memref %arg8 : > -> memref> loc(#loc) + %25:4 = triton_cpu.extract_indices %arg8 : > -> index, index, index, index loc(#loc) + %26 = vector.transfer_read %24[%25#0, %25#1, %25#2, %25#3], %c0_i8 {in_bounds = [true, true]} : memref>, vector<32x32xi8> loc(#loc) + %27 = triton_cpu.extract_memref %arg9 : > -> memref> loc(#loc) + %28:4 = triton_cpu.extract_indices %arg9 : > -> index, index, index, index loc(#loc) + %30 = vector.transfer_read %27[%28#0, %28#1, %28#2, %28#3], %c0_i8 {in_bounds = [true, true]} : memref>, vector<8x128xi8> loc(#loc) + %res1, %res2 = vector.deinterleave %30 : vector<8x128xi8> -> vector<8x64xi8> loc(#loc) + %31 = vector.transpose %res1, [1, 0] : vector<8x64xi8> to vector<64x8xi8> loc(#loc) + %32 = vector.transpose %res2, [1, 0] : vector<8x64xi8> to vector<64x8xi8> loc(#loc) + %33 = vector.interleave %31, %32 : vector<64x8xi8> -> vector<64x16xi8> loc(#loc) + %34 = vector.transpose %33, [1, 0] : vector<64x16xi8> to vector<16x64xi8> loc(#loc) + %res1_0, %res2_1 = vector.deinterleave %34 : vector<16x64xi8> -> vector<16x32xi8> loc(#loc) + %35 = vector.transpose %res1_0, [1, 0] : vector<16x32xi8> to vector<32x16xi8> loc(#loc) + %36 = vector.transpose %res2_1, [1, 0] : vector<16x32xi8> to vector<32x16xi8> loc(#loc) + %37 = vector.interleave %35, %36 : vector<32x16xi8> -> vector<32x32xi8> loc(#loc) + %38 = vector.transpose %37, [1, 0] : vector<32x32xi8> to vector<32x32xi8> loc(#loc) + %39 = triton_cpu.dot %26, %38, %arg7, inputPrecision = tf32 : vector<32x32xi8> * vector<32x32xi8> -> vector<32x32xi32> loc(#loc) + %40 = tt.advance %arg8, [%c0_i32, %c1_i32, %c0_i32, %c0_i32] : > loc(#loc) + %41 = tt.advance %arg9, [%c1_i32, %c0_i32, %c0_i32, %c0_i32] : > loc(#loc) + scf.yield %39, %40, %41 : vector<32x32xi32>, !tt.ptr>, !tt.ptr> loc(#loc) + } loc(#loc) + %22 = triton_cpu.extract_memref %18 : > -> memref> loc(#loc) + %23:2 = triton_cpu.extract_indices %18 : > -> index, index loc(#loc) + vector.transfer_write %21#0, %22[%23#0, %23#1] {in_bounds = [true, true]} : vector<32x32xi32>, memref> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) diff --git a/third_party/cpu/language/cpu/__init__.py b/third_party/cpu/language/cpu/__init__.py index e69de29bb2d1..d0618d5cd3e9 100644 --- a/third_party/cpu/language/cpu/__init__.py +++ b/third_party/cpu/language/cpu/__init__.py @@ -0,0 +1,3 @@ +from .utils import vnni_decode + +__all__ = ["vnni_decode"] diff --git a/third_party/cpu/language/cpu/utils.py b/third_party/cpu/language/cpu/utils.py new file mode 100644 index 000000000000..82538971a971 --- /dev/null +++ b/third_party/cpu/language/cpu/utils.py @@ -0,0 +1,22 @@ +from triton import jit +import triton.language as tl +from triton.language.core import builtin + + +@jit +def _vnni_decode(arg0): + tl.static_assert(len(arg0.shape) == 2) + tmp = arg0.reshape((arg0.shape[0], arg0.shape[1] // 2, 2)) + tmp1, tmp2 = tl.split(tmp) + return tl.join(tmp1.T, tmp2.T).reshape((arg0.shape[1] // 2, arg0.shape[0] * 2)).T + + +@builtin +def vnni_decode(arg0, _builder=None, _generator=None): + bitwidth = arg0.dtype.primitive_bitwidth + if bitwidth > 16: + raise ValueError("Expected 8-bit or 16-bit values for vnni_decode") + decoded = _generator.call_JitFunction(_vnni_decode, (arg0, ), kwargs={}) + if bitwidth == 8: + decoded = _generator.call_JitFunction(_vnni_decode, (decoded, ), kwargs={}) + return decoded diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp index 4ad5de863fb4..8fc432b9734e 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp @@ -56,15 +56,80 @@ Value getInitAccValue(Value val) { return forOp.getInitArgs()[initValIdx]; } -MemBuffer findInputBuffer(Value val, bool allowTransposed) { +namespace { + +// Check if val is a result of transpose operation. If it is, then return +// a source of that transpose operation. Otherwise, return nullptr. +Value getTransposedSrc(Value val) { + auto transposeOp = val.getDefiningOp(); + if (transposeOp) + return transposeOp.getVector(); + return nullptr; +} + +// We are looking for the following sequence: +// %tmp1, %tmp2 = vector.deinterleave %src +// %tmp3 = vector.transpose %tmp1, [1, 0] +// %tmp4 = vector.transpose %tmp2, [1, 0] +// %tmp5 = vector.interleave %tmp3, %tmp4 +// %val = vector.transpose %tmp5, [1, 0] +// and return %src if pattern matching succeeds. +Value getVnniSrcImpl(Value val) { + auto transposedVal = getTransposedSrc(val); + if (!transposedVal) + return nullptr; + + auto interleave = transposedVal.getDefiningOp(); + if (!interleave) + return nullptr; + + auto tmp1 = getTransposedSrc(interleave.getLhs()); + auto tmp2 = getTransposedSrc(interleave.getRhs()); + if (!tmp1 || !tmp2) + return nullptr; + + auto deinterleave1 = tmp1.getDefiningOp(); + auto deinterleave2 = tmp2.getDefiningOp(); + if (!deinterleave1 || deinterleave1 != deinterleave2 || + deinterleave1.getResult(0) != tmp1 || deinterleave2.getResult(1) != tmp2) + return nullptr; + + return deinterleave1.getSource(); +} + +} // namespace + +Value getVnniSrc(Value val) { + Type elemTy = getElementTypeOrSelf(val.getType()); + + // VNNI encoding is used for 8-bit and 16-bit values only. + if (elemTy.getIntOrFloatBitWidth() > 16) + return nullptr; + + // For 16-bit values VNNI encoding is a single interleave of + // subsequenct rows. For 8-bit values, it's applied twice. + Value encoded = getVnniSrcImpl(val); + if (encoded && elemTy.getIntOrFloatBitWidth() == 8) + encoded = getVnniSrcImpl(encoded); + + return encoded; +} + +MemBuffer findInputBuffer(Value val, bool allowTransposed, bool allowVnni) { MemBuffer buf; if (allowTransposed) { - auto transposeOp = val.getDefiningOp(); - if (transposeOp) { - val = transposeOp.getVector(); + auto transposed = getTransposedSrc(val); + if (transposed) { + val = transposed; buf.transposed = true; } + } else if (allowVnni) { + auto vnniVal = getVnniSrc(val); + if (vnniVal) { + val = vnniVal; + buf.vnni = true; + } } auto valLoad = val.getDefiningOp(); diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h index e26529d91882..2760ebd14fbb 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h @@ -24,6 +24,9 @@ struct MemBuffer { SmallVector step; // True if buffer holds transposed value. bool transposed = false; + // Ttue if buffer holds value in VNNI (interleaved to groups of 32bit) + // encoding. + bool vnni = false; bool empty() const { return !memRef; } }; @@ -48,10 +51,17 @@ template bool hasMaskOrBoundsCheck(T op) { return hasBoundsCheck || mask; } -// Search for a buffer holding required value. If allowTransposed is true, -// then buffer is allowed to hold both transposed and not transposed value. +// Search for a buffer holding required value. +// +// If allowTransposed is true, then buffer is allowed to hold both transposed +// and not transposed value. +// +// If allowVnni then buffer is allowed to hold value in both original and +// VNNI-encoded form. This flag is ignored if allowTransposed is true. +// // Return empty buffer if no memory holding value was found. -MemBuffer findInputBuffer(Value val, bool allowTransposed = false); +MemBuffer findInputBuffer(Value val, bool allowTransposed = false, + bool allowVnni = false); // Cast vector to a specified element type using ext or trunc // operations. Return the original value if it already matches @@ -67,6 +77,10 @@ MemBuffer allocateTmpBuffer(Location loc, VectorType vecTy, Value shiftIndex(Location loc, Value index, int64_t offs, PatternRewriter &rewriter); +// Check if val is a result of a sequence that performs VNNI decoding. +// If it is, then return the original encoded value. Otherwise, return nullptr. +Value getVnniSrc(Value val); + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp index 1b6dd9269ac1..11ce852e7570 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp @@ -402,9 +402,9 @@ MemBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, LDBG("Preparing buffer (interleave=" << interleave << ") for a vector: " << val); auto vecTy = cast(val.getType()); - MemBuffer inputBuf = findInputBuffer(val); + MemBuffer inputBuf = findInputBuffer(val, false, interleave); if (!inputBuf.empty()) { - if (interleave) { + if (interleave && !inputBuf.vnni) { LDBG(" Copying from the original memref with interleave: " << inputBuf.memRef); auto tmpBuf = allocateTmpBuffer(loc, getSwizzledRhsTileType(vecTy), @@ -426,7 +426,12 @@ MemBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, MemBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); if (interleave) { - interleaveAndStore(loc, val, buf.memRef, rewriter); + auto interleavedVal = getVnniSrc(val); + if (interleavedVal) { + LDBG(" Using pre-encoding value: " << interleavedVal); + op_write(interleavedVal, buf.memRef, buf.indices); + } else + interleaveAndStore(loc, val, buf.memRef, rewriter); } else { op_write(val, buf.memRef, buf.indices); } diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp index a39a93e42446..cc8ccfeb5374 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp @@ -177,31 +177,20 @@ struct SplitOpConversion : public OpConversionPattern { auto src = rewriter.getRemappedValue(op.getSrc()); auto srcTy = cast(src.getType()); auto resTy = getTypeConverter()->convertType(op.getType(0)); + assert(srcTy.getShape().back() == 2); SmallVector results; if (srcTy.getRank() == 1) { results.push_back(rewriter.create(loc, src, 0)); results.push_back(rewriter.create(loc, src, 1)); + rewriter.replaceOp(op, results); } else { - SmallVector tmpShape({srcTy.getNumElements()}); + SmallVector tmpShape(srcTy.getShape().drop_back()); + tmpShape.back() *= 2; auto tmp = rewriter.create( loc, VectorType::get(tmpShape, srcTy.getElementType()), src); - - SmallVector evenIndices; - SmallVector oddIndices; - for (int64_t i = 0; i < srcTy.getNumElements(); i += 2) { - evenIndices.push_back(i); - oddIndices.push_back(i + 1); - } - - Value res1 = - rewriter.create(loc, tmp, tmp, evenIndices); - Value res2 = - rewriter.create(loc, tmp, tmp, oddIndices); - results.push_back(rewriter.create(loc, resTy, res1)); - results.push_back(rewriter.create(loc, resTy, res2)); + rewriter.replaceOpWithNewOp(op, tmp); } - rewriter.replaceOp(op, results); return success(); } };