diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index f9f87e6ba1f7..8c9bcca7cf28 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -70,4 +70,4 @@ jobs: - name: Run python unit tests run: | - python -m pytest -n 32 --device cpu python/test/unit/language/test_core.py -m cpu + python -m pytest -s -n 32 --device cpu python/test/unit/language/test_core.py -m cpu diff --git a/python/src/ir.cc b/python/src/ir.cc index b07f0a8d1744..ea62b48ec281 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1679,6 +1679,8 @@ void init_triton_ir(py::module &&m) { }); ::llvm::DebugFlag = true; + // For release build setCurrentDebugTypes is a macro, so avoid + // namespace prefix using namespace llvm; setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d1787bd21043..1efcb9eeb585 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1642,6 +1642,15 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"): pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') + if is_cpu() and (dtype_x in torch_float8_dtypes or dtype_z in torch_float8_dtypes): + pytest.skip(f'test_cast{(dtype_x, dtype_z)} is not supported on CPU.') + + # fptrunc fp32->fp16 is broken in LLVM for large vectors: + # https://github.com/llvm/llvm-project/issues/95274 + # TODO: remove the change after the bug is fixed. + if is_cpu() and dtype_x == "float32" and dtype_z == "float16": + size = 512 + # bf16 vector cast is broken in LLVM for large vectors: # https://github.com/llvm/llvm-project/issues/92471 # TODO: Remove the change after the bug is fixed. @@ -2201,6 +2210,12 @@ def kernel(X, Z, BLOCK: tl.constexpr): def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + # fpext fp16->fp32 is broken in LLVM for large vectors: + # https://github.com/llvm/llvm-project/issues/95278 + # TODO: remove the change after the bug is fixed. + if is_cpu() and dtype_str == "float16": + shape = (min(shape[0], 512), min(shape[1], 512)) + @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr, USE_I1: tl.constexpr): diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index c20a36aab10e..e96d04614661 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -154,13 +154,13 @@ import triton import triton.language as tl - BLOCK_SIZE_M = 32 BLOCK_SIZE_N = 32 BLOCK_SIZE_K = 32 GROUP_SIZE_M = 8 USE_GPU = True + @triton.jit def matmul_kernel( # Pointers to matrices @@ -227,7 +227,7 @@ def matmul_kernel( # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - + # Convert the accumulator to the output matrix C's type if needed. c = accumulator @@ -236,14 +236,13 @@ def matmul_kernel( offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - + #TODO: Currently masked load is not supported yet. #c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) #tl.store(c_ptrs, c, mask=c_mask) tl.store(c_ptrs, c) - # %% # We can now create a convenience wrapper function that only takes two input tensors, # and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. @@ -256,9 +255,10 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): M, K = a.shape K, N = b.shape #TODO: Currently masked load is not supported yet. - assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" + assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( + K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" if c is None: - # Allocates output. + # Allocates output. c = torch.empty((M, N), device=a.device, dtype=a.dtype) else: assert c.shape == (M, N), "Incompatible dimensions" @@ -270,9 +270,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, # + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # GROUP_SIZE_M=GROUP_SIZE_M, # ) return c @@ -298,7 +296,8 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print("✅ TritonCPU and TorchCPU match") else: - print("❌ TritonCPU and TorchCPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}') + print("❌ TritonCPU and TorchCPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') # %% # Benchmark @@ -326,13 +325,13 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print("✅ TritonGPU and TorchGPU match") else: - print("❌ TritonGPU and TorchGPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}') + print("❌ TritonGPU and TorchGPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') LINE_VALS += ['triton-gpu', 'torch-gpu'] LINE_NAMES += ['TritonGPU', 'TorchGPU'] LINE_STYLES += [('yellow', '-'), ('red', '-')] - # %% # Seems like we're good to go! @@ -359,7 +358,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})', args={}, # Values for function arguments not in `x_names` and `y_name`. )) - def benchmark(M, N, K, provider): import os @@ -383,7 +381,8 @@ def benchmark(M, N, K, provider): ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles) elif provider == 'torch-cpu': c = torch.empty((M, N), device=a.device, dtype=a.dtype) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, + is_cpu=True) elif provider == 'triton-cpu-single': c = torch.empty((M, N), device=a.device, dtype=a.dtype) ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h b/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h index b2edc5e98b36..ba2d64d8f5f0 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h +++ b/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h @@ -225,10 +225,12 @@ struct ReduceScanOpConversionBase : public OpConversionPattern { createShuffleDummies(Location loc, ValueRange inputs, ConversionPatternRewriter &rewriter) const { if (shuffleDummies.empty()) { + SmallVector dummyShape({1}); for (auto val : inputs) { auto ty = cast(val.getType()); shuffleDummies.push_back(rewriter.create( - loc, rewriter.getZeroAttr(ty.cloneWith(1, ty.getElementType())))); + loc, rewriter.getZeroAttr( + ty.cloneWith(dummyShape, ty.getElementType())))); } } return shuffleDummies;