Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TUTORIAL] Add the non-persistent softmax and make it for CPU #22

Merged
merged 3 commits into from
Jun 20, 2024

Conversation

minjang
Copy link
Collaborator

@minjang minjang commented Jun 13, 2024

The updated fused-softmax has only persistent thread approach, which isn't straightforward for CPU. So, at least for now, bring back the old version and make it for CPU.

> % python3 python/tutorials/02-fused-softmax-cpu.py
softmax-performance:
         N  TritonCPU 1  TritonCPU  TorchCPU (native)  TorchCPU (jit)    TritonGPU  TorchGPU (native)  TorchGPU (jit)
0    256.0     2.045810   0.542863           0.550007        0.105642   923.042225         907.072700      289.023160
1    384.0     1.427173   0.807165           0.810944        0.177129  1110.779696        1095.309199      366.122895
2    512.0     1.784980   1.068528           1.088198        0.211705  1297.742602        1263.344616      429.744271
3    640.0     1.041241   1.301883           1.346970        0.262986  1489.454534        1418.528150      488.345762
4    768.0     1.230687   1.609295           1.618619        0.317283  1536.000004        1542.023579      517.389457
5    896.0     1.416729   1.849389           1.897273        0.333948  1632.569476        1644.272402      552.713269
6   1024.0     1.576671   2.081733           2.149679        0.462035  1730.323387        1716.163678      568.333887
7   1152.0     0.907632   2.351888           2.485568        0.472176  1673.259543        1351.257714      578.542444
8   1280.0     0.979819   2.566572           2.737079        0.530910  1724.631523        1471.066260      575.634579
9   1408.0     1.074241   2.874134           2.970841        0.581537  1754.004924        1373.135287      577.409710
10  1536.0     1.234496   3.098663           3.206416        0.632289  1793.459554        1446.976986      564.357367
11  1664.0     1.248459   3.253711           3.514285        0.667555  1836.137885        1380.823376      553.046423
12  1792.0     1.308964   3.570367           3.644222        0.778013  1857.295629        1454.047605      552.546838
13  1920.0     1.397980   3.754436           4.049148        0.858990  1892.281048        1413.429190      537.033607
14  2048.0     1.098481   3.738032           3.799553        0.852281  1917.834370        1800.130436      520.126988
15  2176.0     0.292744   3.409505           4.282864        0.829781  1894.748314        1455.404295      513.297370
16  2304.0     0.299672   3.639198           4.355468        0.811003  1921.250756        1501.779785      506.395380
17  2432.0     0.334191   3.866809           4.669806        0.894382  1933.515485        1510.229245      499.572316
18  2560.0     0.336801   3.938183           4.194150        0.918556  1940.370035        1565.973662      497.049667
19  2688.0     0.402710   4.385851           5.425224        1.087770  1953.521621        1432.853753      494.522457
20  2816.0     0.415767   4.416612           5.027677        1.039629  1971.007567        1478.003110      495.801922
21  2944.0     0.457665   4.677714           5.401216        0.985688  1974.234465        1474.159433      492.510389
22  3072.0     0.413956   5.058929           6.468632        1.114623  1977.201751        1534.501456      492.520434
23  3200.0     0.554844   5.242629           6.017523        1.247334  1983.535101        1442.888555      490.575643
24  3328.0     0.572705   5.504096           6.268642        1.155721  1990.579442        1483.618615      489.918352
25  3456.0     0.641787   5.673712           6.742990        1.198830  2005.067363        1475.174591      488.669415
26  3584.0     0.656487   5.869218           6.482186        1.476252  2022.047420        1519.675365      489.172411
27  3712.0     0.679076   5.982188           6.830568        1.240226  2032.667340        1475.005008      487.194065
28  3840.0     0.786011   6.294764           7.042874        1.376185  2040.560452        1510.046061      488.861834
29  3968.0     0.853159   5.944800           6.563339        1.309741  2045.937575        1506.574686      484.901465
30  4096.0     0.923387   6.258306           6.739743        1.435479  2039.039315        1975.649459      482.103901
31  4224.0     0.406333   4.971108           7.311503        1.452058  2049.941231        1854.792404      482.204690
32  4352.0     0.217094   4.467071           6.976144        1.488152  2055.557217        1873.244256      485.504749
33  4480.0     0.232029   4.908665           7.119896        1.392533  2058.106796        1852.794835      486.326707
34  4608.0     0.227503   5.463585           7.254451        1.494221  2063.223480        1870.972161      485.901754
35  4736.0     0.251055   4.627794           7.686124        1.617630  2068.087057        1798.503138      485.597659
36  4864.0     0.252149   5.424953           7.345514        1.719156  2073.578774        1815.137066      485.120886
37  4992.0     0.275326   4.835612           6.765786        1.642713  2077.126276        1797.400842      486.421917
38  5120.0     0.235966   5.317037           9.137089        1.788218  2082.987623        1822.976348      485.541776
39  5248.0     0.297392   6.633583           9.538905        1.727937  2086.161444        1747.626724      485.627328
40  5376.0     0.299619   7.118532           8.234546        1.659288  2090.780078        1774.097261      486.481461
41  5504.0     0.320007   6.008573           8.128889        1.602451  2092.480490        1759.630378      487.130160
42  5632.0     0.314133   6.351171           8.050021        1.792623  2092.586334        1778.892039      486.106541
43  5760.0     0.337048   5.757307           8.092474        1.688131  2087.133685        1739.893821      487.256497
44  5888.0     0.339773   6.242073           7.904060        1.838005  2086.989189        1761.411634      487.473169
45  6016.0     0.362206   6.302467           9.525986        1.739385  2069.326109        1745.151224      487.410710
46  6144.0     0.327055   6.859109           8.240472        1.959596  2068.537081        2015.846193      485.302067
47  6272.0     0.387260   6.386997           7.970573        1.781531  2065.121588        1976.770663      488.071118
48  6400.0     0.390301   6.781471           8.607681        1.922527  2068.686929        1979.939542      488.928692
49  6528.0     0.419166   7.119655           8.357271        1.831618  2072.124067        1958.017621      489.396867

@minjang minjang marked this pull request as ready for review June 13, 2024 05:00
@minjang minjang requested a review from ptillet as a code owner June 13, 2024 05:00
@minjang minjang requested review from gshimansky and ienkovich June 13, 2024 05:01
@ienkovich
Copy link
Collaborator

Thanks, Minjang! Interesting numbers. I'm surprised that TritonCPU is quite close to the native torch without any optimizations added for the CPU so far. Can we also have a torch-inductor column in addition to torch-native and torch-jit on CPU? For GPU it's probably not relevant because it's also Triton based but for CPU it's something we want to be compared to.

@minjang minjang force-pushed the add-fused-softmax-cpu branch from 6ceba1b to 2e92d57 Compare June 19, 2024 19:27
@minjang
Copy link
Collaborator Author

minjang commented Jun 19, 2024

(Sorry for the late update. I was busy for handling internal oncall issues.)

Added torch.compile columns if possible, which uses TorchInductor.

We still see some big drops, but this looks a good starting baseline.

> % taskset -c 0-31 python3 python/tutorials/02-fused-softmax-cpu.py
softmax-performance:
         N  TritonCPU 1  TritonCPU  TorchCPU (compile)  TorchCPU (jit)  TorchCPU (native)
0    256.0     1.893370  26.797909           43.473083       27.090951          90.345849
1    512.0     1.743850  37.646113           64.511089        3.581964         119.810337
2    768.0     1.252370  30.318409           73.473425        4.943790         137.891151
3   1024.0     1.644789  41.430566           88.233102       63.480742         135.692999
4   1280.0     1.023737  27.411674           90.937467       58.451413         139.056957
5   1536.0     1.220970  32.536733          143.768512        3.848108         226.020733
6   1792.0     1.375308  35.754880          148.295864       74.563703         219.283022
7   2048.0     1.516434  34.618771           25.372825        4.833615          15.196132
8   2304.0     0.770304  21.778825           12.489594        4.816089          14.817564
9   2560.0     0.840771  23.021519           11.375798        4.815758          14.603736
10  2816.0     0.905981  26.346897           13.585207        4.697075          14.613452
11  3072.0     0.988469  28.634150           12.233268        4.816486          13.798405
12  3328.0     1.063231  31.076541           14.483119       12.773678          13.475917
13  3584.0     1.126588  33.615553           13.887661       13.583502          16.522656
14  3840.0     1.212261  35.386718           14.756030        5.264437          13.833688
15  4096.0     1.290812  38.116423           14.933730        4.507049          16.085682
16  4352.0     0.429899  12.688686           13.330427        5.018789          14.184997
17  4608.0     0.454297  13.381622           15.674941        5.094559          15.002737
18  4864.0     0.476519  13.812249           13.959362        4.533724          15.026777
19  5120.0     0.511426  14.565795           13.435608        4.856423          16.197014
20  5376.0     0.524268  15.698727           14.819116        4.767064          13.714059
21  5632.0     0.547353  16.653203           16.036805        5.524448          15.910593
22  5888.0     0.573365  16.963346           14.366816        4.244346          12.888016
23  6144.0     0.589853  17.922215           13.714745        4.615137          14.722438
24  6400.0     0.609925  19.176807           13.836896        4.767489          14.738774

softmax

matmul-performance-fp32 (BLOCK_SIZE_M=32, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32, GROUP_SIZE_M=8).png
matmul-fp32

Copy link
Collaborator

@ienkovich ienkovich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding Torch Inductor!

@minjang minjang merged commit 9538724 into triton-lang:main Jun 20, 2024
2 of 4 checks passed
minjang added a commit to minjang/triton-cpu that referenced this pull request Jun 22, 2024
…-lang#22)

* [TUTORIAL] Add 02-fused-softmax with the previous non-persistent implementation

* Add torch.compile cases

* Preallocate output buffer for softmax tutorial
@minjang minjang deleted the add-fused-softmax-cpu branch June 23, 2024 07:13
minjang pushed a commit that referenced this pull request Jun 24, 2024
When running
[convert_blocked1d_to_slice0](https://github.com/triton-lang/triton/blob/0ba5f0c3cd029d5c3d1f01b9bf29dac32c27345e/test/Conversion/tritongpu_to_llvm.mlir#L924)
Triton ends up computing a rank of a matrix with 0 columns during linear
layout lowering, which trips up f2reduce, and causes undefined behavior,
detectable through
[UBSAN](https://clang.llvm.org/docs/UndefinedBehaviorSanitizer.html).

Fix this by returning the rank (0) early in these cases, without calling
f2reduce.

<details><summary>Stack trace</summary>
<p>

```
third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30: runtime error: shift exponent 18446744073709551615 is too large for 64-bit type 'unsigned long long'
    #0 0x556ee2fea3be in inplace_rref_small third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30
    #1 0x556ee2fea3be in f2reduce::inplace_rref_strided(unsigned long*, unsigned long, unsigned long, unsigned long) third_party/triton/third_party/f2reduce/f2reduce.cpp:470:9
    #2 0x556ee2ea70da in getMatrixRank third_party/triton/lib/Tools/LinearLayout.cpp:125:3
    #3 0x556ee2ea70da in mlir::triton::LinearLayout::checkInvariants(bool) third_party/triton/lib/Tools/LinearLayout.cpp:299:7
    #4 0x556ee2ea656d in mlir::triton::LinearLayout::tryCreate(llvm::MapVector<mlir::StringAttr, std::__u::vector<std::__u::vector<int, std::__u::allocator<int>>, std::__u::allocator<std::__u::vector<int, std::__u::allocator<int>>>>, llvm::DenseMap<mlir::StringAttr, unsigned int, llvm::DenseMapInfo<mlir::StringAttr, void>, llvm::detail::DenseMapPair<mlir::StringAttr, unsigned int>>, llvm::SmallVector<std::__u::pair<mlir::StringAttr, std::__u::vector<std::__u::vector<int, std::__u::allocator<int>>, std::__u::allocator<std::__u::vector<int, std::__u::allocator<int>>>>>, 0u>>, llvm::ArrayRef<std::__u::pair<mlir::StringAttr, int>>, bool) third_party/triton/lib/Tools/LinearLayout.cpp:190:41
    #5 0x556ee2eb2150 in mlir::triton::LinearLayout::divideRight(mlir::triton::LinearLayout const&) third_party/triton/lib/Tools/LinearLayout.cpp:654:51
    #6 0x556ee2ee1c39 in mlir::cvtNeedsSharedMemory(mlir::RankedTensorType, mlir::RankedTensorType) third_party/triton/lib/Analysis/Utility.cpp:652:14
    #7 0x556ee2cf38fd in mlir::triton::getRepShapeForCvtLayout(mlir::triton::gpu::ConvertLayoutOp) third_party/triton/lib/Analysis/Allocation.cpp:66:8
    #8 0x556ee2cf3efa in mlir::triton::getScratchConfigForCvtLayout(mlir::triton::gpu::ConvertLayoutOp, unsigned int&, unsigned int&) third_party/triton/lib/Analysis/Allocation.cpp:95:19
    #9 0x556ee2cf6057 in mlir::triton::AllocationAnalysis::getScratchValueSize(mlir::Operation*) third_party/triton/lib/Analysis/Allocation.cpp:272:24
    #10 0x556ee2cf5499 in operator() third_party/triton/lib/Analysis/Allocation.cpp:343:7
    #11 0x556ee2cf5499 in void llvm::function_ref<void (mlir::Operation*)>::callback_fn<mlir::triton::AllocationAnalysis::getValuesAndSizes()::'lambda'(mlir::Operation*)>(long, mlir::Operation*) third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
    #12 0x556edeeee7a9 in operator() third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
    #13 0x556edeeee7a9 in void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:174:5
    #14 0x556edeeee87c in void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:182:9
    #15 0x556ee2cf49e7 in walk<(mlir::WalkOrder)0, mlir::ForwardIterator, (lambda at third_party/triton/lib/Analysis/Allocation.cpp:341:42), mlir::Operation *, void> third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:313:10
    #16 0x556ee2cf49e7 in walk<(mlir::WalkOrder)0, mlir::ForwardIterator, (lambda at third_party/triton/lib/Analysis/Allocation.cpp:341:42), void> third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h:794:12
    #17 0x556ee2cf49e7 in mlir::triton::AllocationAnalysis::getValuesAndSizes() third_party/triton/lib/Analysis/Allocation.cpp:341:16
    #18 0x556ee2cf4852 in run third_party/triton/lib/Analysis/Allocation.cpp:182:5
    #19 0x556ee2cf4852 in AllocationAnalysis third_party/triton/lib/Analysis/Allocation.cpp:169:5
    #20 0x556ee2cf4852 in mlir::Allocation::run(llvm::DenseMap<mlir::FunctionOpInterface, mlir::Allocation, llvm::DenseMapInfo<mlir::FunctionOpInterface, void>, llvm::detail::DenseMapPair<mlir::FunctionOpInterface, mlir::Allocation>>&) third_party/triton/lib/Analysis/Allocation.cpp:627:3
    #21 0x556ee1677402 in operator() third_party/triton/include/triton/Analysis/Allocation.h:227:26
    #22 0x556ee1677402 in void mlir::CallGraph<mlir::Allocation>::doWalk<(mlir::WalkOrder)0, (mlir::WalkOrder)1, mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::CallOpInterface, mlir::FunctionOpInterface), mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::FunctionOpInterface)>(mlir::FunctionOpInterface, llvm::DenseSet<mlir::FunctionOpInterface, llvm::DenseMapInfo<mlir::FunctionOpInterface, void>>&, mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::CallOpInterface, mlir::FunctionOpInterface), mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::FunctionOpInterface)) third_party/triton/include/triton/Analysis/Utility.h:350:7
    #23 0x556ee16756b3 in walk<(mlir::WalkOrder)0, (mlir::WalkOrder)1, (lambda at third_party/triton/include/triton/Analysis/Allocation.h:222:9), (lambda at third_party/triton/include/triton/Analysis/Allocation.h:224:9)> third_party/triton/include/triton/Analysis/Utility.h:242:7
    #24 0x556ee16756b3 in mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp) third_party/triton/include/triton/Analysis/Allocation.h:220:5
    #25 0x556ee2c2bf18 in (anonymous namespace)::AllocateSharedMemory::runOnOperation() third_party/triton/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp:26:22
...
UndefinedBehaviorSanitizer: invalid-shift-exponent third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30 
```
</p>
</details>
minjang added a commit that referenced this pull request Jun 24, 2024
* [TUTORIAL] Add 02-fused-softmax with the previous non-persistent implementation

* Add torch.compile cases

* Preallocate output buffer for softmax tutorial
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Aug 13, 2024
…-lang#22)

* [TUTORIAL] Add 02-fused-softmax with the previous non-persistent implementation

* Add torch.compile cases

* Preallocate output buffer for softmax tutorial
int3 pushed a commit that referenced this pull request Aug 29, 2024
* [TUTORIAL] Add 02-fused-softmax with the previous non-persistent implementation

* Add torch.compile cases

* Preallocate output buffer for softmax tutorial
minjang added a commit that referenced this pull request Sep 22, 2024
* [TUTORIAL] Add 02-fused-softmax with the previous non-persistent implementation

* Add torch.compile cases

* Preallocate output buffer for softmax tutorial
minjang added a commit that referenced this pull request Oct 22, 2024
* [TUTORIAL] Add 02-fused-softmax with the previous non-persistent implementation

* Add torch.compile cases

* Preallocate output buffer for softmax tutorial
minjang added a commit that referenced this pull request Oct 24, 2024
* [TUTORIAL] Add 02-fused-softmax with the previous non-persistent implementation

* Add torch.compile cases

* Preallocate output buffer for softmax tutorial
int3 pushed a commit that referenced this pull request Dec 6, 2024
* [TUTORIAL] Add 02-fused-softmax with the previous non-persistent implementation

* Add torch.compile cases

* Preallocate output buffer for softmax tutorial
ienkovich pushed a commit that referenced this pull request Dec 6, 2024
* [TUTORIAL] Add 02-fused-softmax with the previous non-persistent implementation

* Add torch.compile cases

* Preallocate output buffer for softmax tutorial
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants