Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Storing FP32 values compilation error #5497

Open
iamdanialkamali opened this issue Dec 26, 2024 · 1 comment
Open

Storing FP32 values compilation error #5497

iamdanialkamali opened this issue Dec 26, 2024 · 1 comment
Labels

Comments

@iamdanialkamali
Copy link

Describe the bug

I am trying to make a batch matrix multiplication code using Triton inspired by triton_transformer.

import torch
import triton
import triton.language as tl
from torch import nn
import os
from torch.autograd import Function

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
DEVICE = torch.device("cuda")

# Define activation functions
@triton.jit
def leaky_relu(x):
    return tl.where(x >= 0, x, 0.01 * x)

@triton.jit
def relu(x):
    return tl.where(x >= 0, x, 0)

@triton.jit
def squared_relu(x):
    return tl.where(x > 0, x * x, 0.)

@triton.jit
def sigmoid(x):
    return 1. / (1. + tl.exp(-x))

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64,  'BLOCK_SIZE_N': 64,  'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 32,  'BLOCK_SIZE_N': 32,  'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def bmm_kernel(
    x_ptr, y_ptr, o_ptr,
    M, N, K,
    stride_al, stride_am, stride_ak,
    stride_bl, stride_bk, stride_bn,
    stride_ol, stride_om, stride_on,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    activation: tl.constexpr
):
    pid_batch = tl.program_id(0)
    pid = tl.program_id(1)

    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offs_k = tl.arange(0, BLOCK_SIZE_K)

    x_ptrs = x_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + pid_batch * stride_al)
    y_ptrs = y_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + pid_batch * stride_bl)

    o = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    for k in range(0, K, BLOCK_SIZE_K):
        x = tl.load(x_ptrs, mask=offs_k[None, :] < min(BLOCK_SIZE_K, K - k), other=0.0)
        y = tl.load(y_ptrs, mask=offs_k[:, None] < min(BLOCK_SIZE_K, K - k), other=0.0)
        o += tl.dot(x, y)

        x_ptrs += BLOCK_SIZE_K * stride_ak
        y_ptrs += BLOCK_SIZE_K * stride_bk

    if activation == "relu":  # ReLU
        o = relu(o)
    elif activation == "leaky_relu":  # Leaky ReLU
        o = leaky_relu(o)
    elif activation == "squared_relu":  # Squared ReLU
        o = squared_relu(o)
    elif activation == "sigmoid":  # Sigmoid
        o = sigmoid(o)

    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)

    o_ptrs = o_ptr + stride_om * offs_m[:, None] + stride_on * offs_n[None, :] + stride_ol * pid_batch
    tl.store(o_ptrs, o, mask=mask)


def bmm(x, y, activation=None):
    unsqueeze = False
    if x.ndim == 2:
        x = x.unsqueeze(0)
        unsqueeze = True
    B, M, K = x.shape

    if y.ndim == 2:
        y = y.unsqueeze(0).expand(B, -1, -1)

    _, K_y, N = y.shape
    assert K == K_y, "Inner dimensions of matrices must match for multiplication."

    o = torch.empty((B, M, N), device=x.device, dtype=x.dtype)

    grid = lambda META: (B, triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']))

    bmm_kernel[grid](
        x, y, o,
        M, N, K,
        x.stride(0), x.stride(1), x.stride(2),
        y.stride(0), y.stride(1), y.stride(2),
        o.stride(0), o.stride(1), o.stride(2),
        activation=activation,
    )

    torch.cuda.synchronize()
    if unsqueeze:
        return o.squeeze(0)
    else:
        return o

When I run the code using float16 it works as shown below:

    
B, M, K, N = 4, 8, 16, 8  # Batch size, matrix dimensions

x = torch.randn((B, M, K), requires_grad=True, device=DEVICE, dtype=torch.float16)
y = torch.randn((B, K, N), requires_grad=True, device=DEVICE, dtype=torch.float16)    
bmm(x,y)

However, when I run the code withfloat32 it fails

B, M, K, N = 4, 8, 16, 8  # Batch size, matrix dimensions

x = torch.randn((B, M, K), requires_grad=True, device=DEVICE, dtype=torch.float32)
y = torch.randn((B, K, N), requires_grad=True, device=DEVICE, dtype=torch.float32)    
bmm(x,y)
python3.9/site-packages/triton/backends/nvidia/compiler.py:216, in CUDABackend.make_llir(src, metadata, options, capability)
    [214](<REDACTED>/python3.9/site-packages/triton/backends/nvidia/compiler.py:214) if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
    [215](<REDACTED>/python3.9/site-packages/triton/backends/nvidia/compiler.py:215)     passes.llvmir.add_di_scope(pm)
--> [216](<REDACTED>/python3.9/site-packages/triton/backends/nvidia/compiler.py:216) pm.run(mod)
    [217](<REDACTED>/python3.9/site-packages/triton/backends/nvidia/compiler.py:217) # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
    [218](<REDACTED>/python3.9/site-packages/triton/backends/nvidia/compiler.py:218) llvm.init_targets()

IndexError: map::at

What's the source of this issue?

Environment details

Triton: 3.0.0
CUDA: 12.0
Python: 3.9.21
Torch: 2.4.1+cu121

@peterbell10
Copy link
Contributor

Doesn't reproduce on main, could you try building triton from source?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants