Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Improve performance for layer-norm in turtorial #5712

Open
ywq880611 opened this issue Jan 27, 2025 · 0 comments
Open

[RFC] Improve performance for layer-norm in turtorial #5712

ywq880611 opened this issue Jan 27, 2025 · 0 comments

Comments

@ywq880611
Copy link

ywq880611 commented Jan 27, 2025

Describe the issue

Background

I found there is a performance drop when the shape of input for layer_norm is some specific number, like the multiple of 16.

Here is an image shows the perf landscape for torch.float32 in the tutorial for layer-norm, it's torch.float16 in the default tutorial code:

Image

There is the micro reproduce code snippert.

its outpout looks like on my local (triton: 3.2.0+gitb1301d66 + RTX 3080):

 for n as 8170 of softmax: 487.761209 gb/s
 for n as 8171 of softmax: 468.252137 gb/s
 for n as 8172 of softmax: 495.668768 gb/s
 for n as 8173 of softmax: 418.591544 gb/s
 for n as 8174 of softmax: 434.931814 gb/s
 for n as 8175 of softmax: 495.759715 gb/s
 for n as 8176 of softmax: 379.837408 gb/s
 for n as 8177 of softmax: 488.543707 gb/s
 for n as 8178 of softmax: 494.512492 gb/s
 for n as 8179 of softmax: 496.355028 gb/s
 for n as 8180 of softmax: 450.843957 gb/s
 for n as 8181 of softmax: 496.100066 gb/s
 for n as 8182 of softmax: 496.281183 gb/s
 for n as 8183 of softmax: 410.368286 gb/s
 for n as 8184 of softmax: 450.676321 gb/s
 for n as 8185 of softmax: 495.122895 gb/s
 for n as 8186 of softmax: 445.927929 gb/s
 for n as 8187 of softmax: 403.300486 gb/s
 for n as 8188 of softmax: 452.297592 gb/s
 for n as 8189 of softmax: 497.622479 gb/s
 for n as 8190 of softmax: 466.168626 gb/s
 for n as 8191 of softmax: 402.629787 gb/s
 for n as 8192 of softmax: 331.701871 gb/s

We could see when N=8176 and N=8192, the performance is dramatic worse (~330 vs ~450) than other size, the differences between these size is that both N=8176 and N=8192 is a multiple of 16, so it will be compiled to a vectorized load/write kernel.

Investigation and summary

I suspect the root cause is cache thrashing, when N=8191, the triton compile code uses lost of registers, so its occupancy is lower, then there is less blocks run on a SM, so the SM's L1 cache could cover all data which it would like to use.

For N=8192, its compiled code use less registers, so there is many blocks run on a SM, so there are frenquently L1 cache eviction, then the cache thrashing happens, it hurts performance.

I draft a doc which contains more details about this issue, if you're interesting, please feel free to leave your comments in this doc.

Next steps

If you think the hypothesis makes sense, we could try to introduce some heuristic which use shared memory for this case where lots of data was load many times, WDYT? Any insights would be very appreciated!

Environment details

Triton: 3.2.0+gitb1301d66
GPU: RTX 3080

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant