Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for x86 CI workflow #26

Merged
merged 8 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ jobs:

- name: Run python unit tests
run: |
python -m pytest -n 32 --device cpu python/test/unit/language/test_core.py -m cpu
python -m pytest -s -n 32 --device cpu python/test/unit/language/test_core.py -m cpu
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3a8316216807d64a586b971f51695e23883331f7
765206e050453018e861637a08a4520f29238074
5 changes: 4 additions & 1 deletion python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1615,7 +1615,10 @@ void init_triton_ir(py::module &&m) {
});

::llvm::DebugFlag = true;
::llvm::setCurrentDebugTypes(debugTypes.data(), debugTypes.size());
// For release build setCurrentDebugTypes is a macro, so avoid
// namespace prefix
using namespace llvm;
setCurrentDebugTypes(debugTypes.data(), debugTypes.size());
}

if (failed(self.run(mod.getOperation())))
Expand Down
15 changes: 15 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,15 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device):
if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"):
pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.')

if is_cpu() and (dtype_x in torch_float8_dtypes or dtype_z in torch_float8_dtypes):
pytest.skip(f'test_cast{(dtype_x, dtype_z)} is not supported on CPU.')

# fptrunc fp32->fp16 is broken in LLVM for large vectors:
# https://github.com/llvm/llvm-project/issues/95274
# TODO: remove the change after the bug is fixed.
if is_cpu() and dtype_x == "float32" and dtype_z == "float16":
size = 512

# bf16 vector cast is broken in LLVM for large vectors:
# https://github.com/llvm/llvm-project/issues/92471
# TODO: Remove the change after the bug is fixed.
Expand Down Expand Up @@ -2138,6 +2147,12 @@ def kernel(X, Z, BLOCK: tl.constexpr):
def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device):
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested

# fpext fp16->fp32 is broken in LLVM for large vectors:
# https://github.com/llvm/llvm-project/issues/95278
# TODO: remove the change after the bug is fixed.
if is_cpu() and dtype_str == "float16":
shape = (min(shape[0], 512), min(shape[1], 512))

@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr,
AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr):
Expand Down
27 changes: 13 additions & 14 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@
import triton
import triton.language as tl


BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 32
BLOCK_SIZE_K = 32
GROUP_SIZE_M = 8
USE_GPU = True


@triton.jit
def matmul_kernel(
# Pointers to matrices
Expand Down Expand Up @@ -227,7 +227,7 @@ def matmul_kernel(
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

# Convert the accumulator to the output matrix C's type if needed.
c = accumulator

Expand All @@ -236,14 +236,13 @@ def matmul_kernel(
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]

#TODO: Currently masked load is not supported yet.
#c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
#tl.store(c_ptrs, c, mask=c_mask)
tl.store(c_ptrs, c)



# %%
# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
Expand All @@ -256,9 +255,10 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
M, K = a.shape
K, N = b.shape
#TODO: Currently masked load is not supported yet.
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size"
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (
K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size"
if c is None:
# Allocates output.
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
else:
assert c.shape == (M, N), "Incompatible dimensions"
Expand All @@ -270,9 +270,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K, #
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, #
GROUP_SIZE_M=GROUP_SIZE_M, #
)
return c
Expand All @@ -298,7 +296,8 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ TritonCPU and TorchCPU match")
else:
print("❌ TritonCPU and TorchCPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}')
print("❌ TritonCPU and TorchCPU differ, the maximum difference is "
f'{torch.max(torch.abs(triton_output - torch_output))}')

# %%
# Benchmark
Expand Down Expand Up @@ -326,13 +325,13 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ TritonGPU and TorchGPU match")
else:
print("❌ TritonGPU and TorchGPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}')
print("❌ TritonGPU and TorchGPU differ, the maximum difference is "
f'{torch.max(torch.abs(triton_output - torch_output))}')

LINE_VALS += ['triton-gpu', 'torch-gpu']
LINE_NAMES += ['TritonGPU', 'TorchGPU']
LINE_STYLES += [('yellow', '-'), ('red', '-')]


# %%
# Seems like we're good to go!

Expand All @@ -359,7 +358,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})',
args={}, # Values for function arguments not in `x_names` and `y_name`.
))

def benchmark(M, N, K, provider):
import os

Expand All @@ -383,7 +381,8 @@ def benchmark(M, N, K, provider):
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles)
elif provider == 'torch-cpu':
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, is_cpu=True)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles,
is_cpu=True)
elif provider == 'triton-cpu-single':
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True)
Expand Down
4 changes: 3 additions & 1 deletion third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,12 @@ struct ReduceScanOpConversionBase : public OpConversionPattern<OpT> {
createShuffleDummies(Location loc, ValueRange inputs,
ConversionPatternRewriter &rewriter) const {
if (shuffleDummies.empty()) {
SmallVector<int64_t, 1> dummyShape({1});
for (auto val : inputs) {
auto ty = cast<VectorType>(val.getType());
shuffleDummies.push_back(rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(ty.cloneWith(1, ty.getElementType()))));
loc, rewriter.getZeroAttr(
ty.cloneWith(dummyShape, ty.getElementType()))));
}
}
return shuffleDummies;
Expand Down
Loading