We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I am trying to make a batch matrix multiplication code using Triton inspired by triton_transformer.
import torch import triton import triton.language as tl from torch import nn import os from torch.autograd import Function os.environ["CUDA_VISIBLE_DEVICES"] = "0" DEVICE = torch.device("cuda") # Define activation functions @triton.jit def leaky_relu(x): return tl.where(x >= 0, x, 0.01 * x) @triton.jit def relu(x): return tl.where(x >= 0, x, 0) @triton.jit def squared_relu(x): return tl.where(x > 0, x * x, 0.) @triton.jit def sigmoid(x): return 1. / (1. + tl.exp(-x)) @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), ], key=['M', 'N', 'K'], ) @triton.jit def bmm_kernel( x_ptr, y_ptr, o_ptr, M, N, K, stride_al, stride_am, stride_ak, stride_bl, stride_bk, stride_bn, stride_ol, stride_om, stride_on, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, activation: tl.constexpr ): pid_batch = tl.program_id(0) pid = tl.program_id(1) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) x_ptrs = x_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + pid_batch * stride_al) y_ptrs = y_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + pid_batch * stride_bl) o = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, K, BLOCK_SIZE_K): x = tl.load(x_ptrs, mask=offs_k[None, :] < min(BLOCK_SIZE_K, K - k), other=0.0) y = tl.load(y_ptrs, mask=offs_k[:, None] < min(BLOCK_SIZE_K, K - k), other=0.0) o += tl.dot(x, y) x_ptrs += BLOCK_SIZE_K * stride_ak y_ptrs += BLOCK_SIZE_K * stride_bk if activation == "relu": # ReLU o = relu(o) elif activation == "leaky_relu": # Leaky ReLU o = leaky_relu(o) elif activation == "squared_relu": # Squared ReLU o = squared_relu(o) elif activation == "sigmoid": # Sigmoid o = sigmoid(o) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) o_ptrs = o_ptr + stride_om * offs_m[:, None] + stride_on * offs_n[None, :] + stride_ol * pid_batch tl.store(o_ptrs, o, mask=mask) def bmm(x, y, activation=None): unsqueeze = False if x.ndim == 2: x = x.unsqueeze(0) unsqueeze = True B, M, K = x.shape if y.ndim == 2: y = y.unsqueeze(0).expand(B, -1, -1) _, K_y, N = y.shape assert K == K_y, "Inner dimensions of matrices must match for multiplication." o = torch.empty((B, M, N), device=x.device, dtype=x.dtype) grid = lambda META: (B, triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N'])) bmm_kernel[grid]( x, y, o, M, N, K, x.stride(0), x.stride(1), x.stride(2), y.stride(0), y.stride(1), y.stride(2), o.stride(0), o.stride(1), o.stride(2), activation=activation, ) torch.cuda.synchronize() if unsqueeze: return o.squeeze(0) else: return o
When I run the code using float16 it works as shown below:
float16
B, M, K, N = 4, 8, 16, 8 # Batch size, matrix dimensions x = torch.randn((B, M, K), requires_grad=True, device=DEVICE, dtype=torch.float16) y = torch.randn((B, K, N), requires_grad=True, device=DEVICE, dtype=torch.float16) bmm(x,y)
However, when I run the code withfloat32 it fails
float32
B, M, K, N = 4, 8, 16, 8 # Batch size, matrix dimensions x = torch.randn((B, M, K), requires_grad=True, device=DEVICE, dtype=torch.float32) y = torch.randn((B, K, N), requires_grad=True, device=DEVICE, dtype=torch.float32) bmm(x,y)
python3.9/site-packages/triton/backends/nvidia/compiler.py:216, in CUDABackend.make_llir(src, metadata, options, capability) [214](<REDACTED>/python3.9/site-packages/triton/backends/nvidia/compiler.py:214) if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": [215](<REDACTED>/python3.9/site-packages/triton/backends/nvidia/compiler.py:215) passes.llvmir.add_di_scope(pm) --> [216](<REDACTED>/python3.9/site-packages/triton/backends/nvidia/compiler.py:216) pm.run(mod) [217](<REDACTED>/python3.9/site-packages/triton/backends/nvidia/compiler.py:217) # LLVM-IR (MLIR) -> LLVM-IR (LLVM) [218](<REDACTED>/python3.9/site-packages/triton/backends/nvidia/compiler.py:218) llvm.init_targets() IndexError: map::at
What's the source of this issue?
Triton: 3.0.0 CUDA: 12.0 Python: 3.9.21 Torch: 2.4.1+cu121
The text was updated successfully, but these errors were encountered:
Doesn't reproduce on main, could you try building triton from source?
Sorry, something went wrong.
No branches or pull requests
Describe the bug
I am trying to make a batch matrix multiplication code using Triton inspired by triton_transformer.
When I run the code using
float16
it works as shown below:However, when I run the code with
float32
it failsWhat's the source of this issue?
Environment details
Triton: 3.0.0
CUDA: 12.0
Python: 3.9.21
Torch: 2.4.1+cu121
The text was updated successfully, but these errors were encountered: