Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton 3.1.0 failed with a simple tl.dot and then tl.store example #5557

Open
annyan09023 opened this issue Jan 9, 2025 · 5 comments
Open
Labels

Comments

@annyan09023
Copy link

annyan09023 commented Jan 9, 2025

Describe the bug

Code

import torch
import triton
import triton.language as tl
print(triton.__version__)

K = 16

@triton.jit
def matmul_kernel(p1, p2, rp, K: tl.constexpr):
    region = tl.arange(0, K)[:, None] * K + tl.arange(0, K)[None, :]
    m1 = tl.load(p1 + region)
    m2 = tl.load(p2+ region)
    r = tl.dot(m1, m2)
    tl.store(rp + region, r)


m1 = torch.rand([K, K], dtype=torch.float32).cuda()
m2 = torch.rand([K, K], dtype=torch.float32).cuda()
torch_result = m1 @ m2
print(f"Torch: {torch_result}")

triton_result = torch.empty([K, K], dtype=torch.float32).cuda()
matmul_kernel[(1,)](m1, m2, triton_result, K)
print(f"Triton: {triton_result}")
print(torch.allclose(torch_result, triton_result, atol=1e-2))

version is: 3.1.0
Error output:

File .../all_notebook_lib_jupyter_exedir/triton/runtime/jit.py:662, in JITFunction.run(self, grid, warmup, *args, **kwargs)
    660     # compile the kernel
    661     src = self.ASTSource(self, signature, constants, configs[0])
--> 662     kernel = self.compile(
    663         src,
    664         target=target,
    665         options=options.__dict__,
    666     )
    667     self.cache[device][key] = kernel
    669 # Check that used global values have not changed.

File .../all_notebook_lib_jupyter_exedir/triton/compiler/compiler.py:282, in compile(src, target, options)
    280 use_ttgir_loc = os.environ.get("USE_TTGIR_LOC", "0") == "1"
    281 for ext, compile_ir in list(stages.items())[first_stage:]:
--> 282     next_module = compile_ir(module, metadata)
    283     ir_filename = f"{src.name}.{ext}"
    284     metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)

File .../all_notebook_lib_jupyter_exedir/triton/backends/nvidia/compiler.py:318, in CUDABackend.add_stages.<locals>.<lambda>(src, metadata)
    316 stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
    317 stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
--> 318 stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
    319 stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability)
    320 stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability)

File .../notebooks/all_notebook_lib_jupyter_exedir/triton/backends/nvidia/compiler.py:216, in CUDABackend.make_llir(src, metadata, options, capability)
    214 if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
    215     passes.llvmir.add_di_scope(pm)
--> 216 pm.run(mod)
    217 # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
    218 llvm.init_targets()

IndexError: map::at

Environment details

Triton: 3.1.0
GPU: Nvidia-T4

The code was got from a previous opened issue: #4230 according to author, it should work before

@annyan09023 annyan09023 added the bug label Jan 9, 2025
@annyan09023 annyan09023 changed the title Triton 3.1.0 failed with a simple tl.dot and then tl.load example Triton 3.1.0 failed with a simple tl.dot and then tl.store example Jan 9, 2025
@Jokeren
Copy link
Contributor

Jokeren commented Jan 9, 2025

Cannot reproduce using triton main

@annyan09023
Copy link
Author

Cannot reproduce using triton main

If I change float32 to float16, it would work for T4. Do you happen to know why?

@Jokeren
Copy link
Contributor

Jokeren commented Jan 9, 2025

  1. T4 is Turing, so in theory we deprecated the support already
  2. Can you confirm if there's a problem using triton/main?

@annyan09023
Copy link
Author

  1. T4 is Turing, so in theory we deprecated the support already
  2. Can you confirm if there's a problem using triton/main?

Yes, can reproduce with main and nightly on T4. I tried yesterday.

@Jokeren
Copy link
Contributor

Jokeren commented Jan 9, 2025

OK, then I'll keep the "bug" label but low priority as Turing is not our focus

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants