Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

max(tensor, axis = 1) doesn't work without TRITON_INTERPRET=1 #5522

Open
wa008 opened this issue Jan 3, 2025 · 3 comments
Open

max(tensor, axis = 1) doesn't work without TRITON_INTERPRET=1 #5522

wa008 opened this issue Jan 3, 2025 · 3 comments

Comments

@wa008
Copy link

wa008 commented Jan 3, 2025

Describe the bug

tensor.max(axis = 1) doesn't work and tl.max(tensor, axis = 1) works well without TRITON_INTERPRET=1

Actually, It's not a big deal...

key code:

    new_max_val = input_val.max(axis = 1) # it doesn't work
    # new_max_val = tl.max(input_val, axis = 1) # it works well

error information:

RuntimeError: Cannot call @triton.jit'd outside of the scope of a kernel

The above exception was the direct cause of the following exception:

CompilationError                          Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py](https://localhost:8080/#) in make_ir(self, options, codegen_fns, context)
    111 
    112     def make_ir(self, options, codegen_fns, context):
--> 113         return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
    114 
    115     def parse_options(self):

CompilationError: at 23:26:
    target = tl.load(target_ptr + offsets_rows, mask = offsets_rows < M, other = 0.0)
    max_val = tl.full(target.shape, -float("inf"), dtype = tl.float32)
    sumexp = tl.zeros_like(max_val)

    allcurx = tl.zeros_like(max_val)
    for index in tl.range(0, N, BLOCK_SIZE_N):
        offsets_input = offsets_rows[:, None] * N + (offsets_cols + index)[None, :]
        mask_input = (offsets_rows[:, None] < M) & ((offsets_cols + index)[None, :] < N)
        input_val = tl.load(input_ptr + offsets_input, mask=mask_input, other = -float("inf"))

        if index == 0:
            new_max_val = input_val.max(axis = 1) # it doesn't work

full code: https://drive.google.com/file/d/1UBXqpBhELn8amEOU8h5zj1i6WvihZKJX/view?usp=drive_link

Environment details

Triton: 3.1.0
Environment: colab

@wa008 wa008 added the bug label Jan 3, 2025
@wa008 wa008 changed the title max(tensor, axis = 0) doesn't work when TRITON_INTERPRET=1 max(tensor, axis = 0) doesn't work without TRITON_INTERPRET=1 Jan 3, 2025
@Jokeren
Copy link
Contributor

Jokeren commented Jan 3, 2025

Please attach a reproducer

@wa008 wa008 changed the title max(tensor, axis = 0) doesn't work without TRITON_INTERPRET=1 max(tensor, axis = 1) doesn't work without TRITON_INTERPRET=1 Jan 3, 2025
@wa008
Copy link
Author

wa008 commented Jan 3, 2025

Please attach a reproducer

I updated key code and full code on description, Thanks for your attention!

@Jokeren
Copy link
Contributor

Jokeren commented Jan 3, 2025

I cannot reproduce it on my end. Very possibly you're not using the latest version. Please consider build triton from source and retry.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants