You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
its outpout looks like on my local (triton: 3.2.0+gitb1301d66 + RTX 3080):
for n as 8170 of softmax: 487.761209 gb/s
for n as 8171 of softmax: 468.252137 gb/s
for n as 8172 of softmax: 495.668768 gb/s
for n as 8173 of softmax: 418.591544 gb/s
for n as 8174 of softmax: 434.931814 gb/s
for n as 8175 of softmax: 495.759715 gb/s
for n as 8176 of softmax: 379.837408 gb/s
for n as 8177 of softmax: 488.543707 gb/s
for n as 8178 of softmax: 494.512492 gb/s
for n as 8179 of softmax: 496.355028 gb/s
for n as 8180 of softmax: 450.843957 gb/s
for n as 8181 of softmax: 496.100066 gb/s
for n as 8182 of softmax: 496.281183 gb/s
for n as 8183 of softmax: 410.368286 gb/s
for n as 8184 of softmax: 450.676321 gb/s
for n as 8185 of softmax: 495.122895 gb/s
for n as 8186 of softmax: 445.927929 gb/s
for n as 8187 of softmax: 403.300486 gb/s
for n as 8188 of softmax: 452.297592 gb/s
for n as 8189 of softmax: 497.622479 gb/s
for n as 8190 of softmax: 466.168626 gb/s
for n as 8191 of softmax: 402.629787 gb/s
for n as 8192 of softmax: 331.701871 gb/s
We could see when N=8176 and N=8192, the performance is dramatic worse (~330 vs ~450) than other size, the differences between these size is that both N=8176 and N=8192 is a multiple of 16, so it will be compiled to a vectorized load/write kernel.
Investigation and summary
I suspect the root cause is cache thrashing, when N=8191, the triton compile code uses lost of registers, so its occupancy is lower, then there is less blocks run on a SM, so the SM's L1 cache could cover all data which it would like to use.
For N=8192, its compiled code use less registers, so there is many blocks run on a SM, so there are frenquently L1 cache eviction, then the cache thrashing happens, it hurts performance.
I draft a doc which contains more details about this issue, if you're interesting, please feel free to leave your comments in this doc.
Next steps
If you think the hypothesis makes sense, we could try to introduce some heuristic which use shared memory for this case where lots of data was load many times, WDYT? Any insights would be very appreciated!
Environment details
Triton: 3.2.0+gitb1301d66
GPU: RTX 3080
The text was updated successfully, but these errors were encountered:
Describe the issue
Background
I found there is a performance drop when the shape of input for
layer_norm
is some specific number, like the multiple of 16.Here is an image shows the perf landscape for
torch.float32
in the tutorial for layer-norm, it'storch.float16
in the default tutorial code:There is the micro reproduce code snippert.
its outpout looks like on my local (triton: 3.2.0+gitb1301d66 + RTX 3080):
We could see when N=8176 and N=8192, the performance is dramatic worse (
~330 vs ~450
) than other size, the differences between these size is that both N=8176 and N=8192 is a multiple of 16, so it will be compiled to a vectorized load/write kernel.Investigation and summary
I suspect the root cause is cache thrashing, when N=8191, the triton compile code uses lost of registers, so its occupancy is lower, then there is less blocks run on a SM, so the SM's L1 cache could cover all data which it would like to use.
For N=8192, its compiled code use less registers, so there is many blocks run on a SM, so there are frenquently L1 cache eviction, then the cache thrashing happens, it hurts performance.
I draft a doc which contains more details about this issue, if you're interesting, please feel free to leave your comments in this doc.
Next steps
If you think the hypothesis makes sense, we could try to introduce some heuristic which use shared memory for this case where lots of data was load many times, WDYT? Any insights would be very appreciated!
Environment details
Triton: 3.2.0+gitb1301d66
GPU: RTX 3080
The text was updated successfully, but these errors were encountered: