Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Move all llvm dependent code into triton_cpu.cc, add LLJIT, support macos and arm #25

Closed
wants to merge 32 commits into from

Conversation

Kuigesi
Copy link
Collaborator

@Kuigesi Kuigesi commented Jun 14, 2024

  1. Move all llvm dependent code into triton_cpu.cc to avoid double registration of llvm passes. Previous we also link llvm library into cpu_utils.so, which may lead to some static variables be initialized twice, causing the double registration issue.
  2. Use off-the -shelf LLJIT to load bitcode and just-in-time compile it. The old version has bugs in the codegen pass with certain arm cpu (neoverse-v2).
  3. Support mac m1
  4. Support arm cpu (neoverse-v2) for linux.

bertmaher and others added 27 commits May 14, 2024 18:25
* [BACKEND][CPU] Implement the empty cpu backend

* Run clang-format

* Fix yadf error

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
A quick addition on how to use it.
Summary: This is stll a kind of the boilerplate and basic lowering for the first milestone (compiling vector addition). This PR firstly lowers `tt.func` and `tt.return`.


Test Plan: It can safely compile an empty kernel.

```
@triton.jit
def add_kernel(x_ptr,  y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    return
```

> TRITON_ENABLE_LLVM_DEBUG=1 TRITON_CPU_BACKEND=1 python3 empty_kerne.py

```
//===-------------------------------------------===//
Legalizing operation : 'tt.func'(0x73be2a0) {
  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'tt.func -> ()' {
Trying to match "(anonymous namespace)::FuncOpConversion"
    ** Insert  : 'llvm.func'(0x6c04c70)
    ** Insert Block into : 'llvm.func'(0x6c04c70)
    ** Insert Block into : 'llvm.func'(0x6c04c70)
    ** Erase   : 'tt.func'(0x73be2a0)
"(anonymous namespace)::FuncOpConversion" result 1

    //===-------------------------------------------===//
    Legalizing operation : 'llvm.func'(0x6c04c70) {
    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

...

//===-------------------------------------------===//
Legalizing operation : 'tt.return'(0x73efeb0) {
  "tt.return"() : () -> ()

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'tt.return -> ()' {
Trying to match "(anonymous namespace)::ReturnOpConversion"
    ** Insert  : 'llvm.return'(0x73c0f00)
    ** Replace : 'tt.return'(0x73efeb0)
"(anonymous namespace)::ReturnOpConversion" result 1

    //===-------------------------------------------===//
    Legalizing operation : 'llvm.return'(0x73c0f00) {
      "llvm.return"() : () -> ()

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//
  } -> SUCCESS : pattern applied successfully
```
…riton-lang#1)

Summary: As title, `tl.program_id` needs to be supported first. As of now, we think pid will be provided as additional function arguments to the kernel. So, getting program_id is mapped to reading one of the last three arguments.

I also quickly implemented `tl.device_print` or `print`, only for scalar types for a quick "Hello, World!" testing.

Test Plan: Tested with a simple example:

```
@triton.jit
def add_kernel(...):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    foo = pid + 42
    tl.device_print("Hello, World!", foo, pid)
```

The resulting .llir is valid:
```
@printfFormat_1 = internal constant [31 x i8] c"pid (%u, %u, %u) test: %u, %u\0A\00"

declare !dbg !3 i32 @printf(ptr, ...)

define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6) !dbg !7 {
  %8 = add i32 %4, 42, !dbg !8
  %9 = call i32 (ptr, ...) @printf(ptr @printfFormat_0, i32 %4, i32 %5, i32 %6, i32 %8, i32 %4)
  ret void, !dbg !9
}
```

Tried to compile with a fake main function:
```
> % cat main.c
extern void add_kernel(float*, float*, float*, int, int, int, int);

int main() {
    add_kernel(0, 0, 0, 4, 5, 6, 7);
}

> % llc -filetype=obj add_kernel.llir && clang -o a.out add_kernel.llir.o main.c
> % ./a.out
pid (5, 6, 7) Hello, World!: 47, 5
```
…owering

Support basic lowering through vector dialect in CPU backend.
…ion flows (triton-lang#6)

* Support basic lowering through vector dialect in CPU backend.

Signed-off-by: Ilya Enkovich <[email protected]>

* Use axis info in memory op lowering.

Signed-off-by: Ilya Enkovich <[email protected]>

* Mark test_ptx_cast as enabled for CPU.

Signed-off-by: Ilya Enkovich <[email protected]>

* Support umulhi operation.

Signed-off-by: Ilya Enkovich <[email protected]>

* Support tl.clamp, tl.minimum, tl.maximum.

Signed-off-by: Ilya Enkovich <[email protected]>

* Add enable_fp_fusion opt for CPU (only affects ASM dump now).

Signed-off-by: Ilya Enkovich <[email protected]>

* Fix kernel args passing for propagated constants.

Signed-off-by: Ilya Enkovich <[email protected]>

* Add permutations support.

Signed-off-by: Ilya Enkovich <[email protected]>

* Support 2-D transfer_read/transfer_write lowering.

Signed-off-by: Ilya Enkovich <[email protected]>

* Introduce shape info analysis and use it for loads/stores by block pointers.

Delay scalar pointers lowering.

Signed-off-by: Ilya Enkovich <[email protected]>

* Support 'other' arg for loads.

Signed-off-by: Ilya Enkovich <[email protected]>

* Support tl.join.

Signed-off-by: Ilya Enkovich <[email protected]>

* Minor renaming.

Signed-off-by: Ilya Enkovich <[email protected]>

---------

Signed-off-by: Ilya Enkovich <[email protected]>
…ent (triton-lang#8)

* [BACKEND][CPU] Make it buildable and runnable in a different environment

* Revert seemingly inconsistent python code formatting
Signed-off-by: Ilya Enkovich <[email protected]>
Co-authored-by: Minjang Kim <[email protected]>
…iton-lang#11)

* [CPU] Support flexible active driver + update vector-add tutorial

* Update vector-add to run CPU always + optional GPU

* Update do_bench for CPU
…ng#17)

* Fixed yaml syntax

Signed-off-by: Gregory Shimansky <[email protected]>

* Removed cpu label from run-on

Signed-off-by: Gregory Shimansky <[email protected]>

* Added missing zlib-dev

Signed-off-by: Gregory Shimansky <[email protected]>

* Added missing apt-get update

Signed-off-by: Gregory Shimansky <[email protected]>

* Remove pip cache because on self-hosted runner it slows things down

Signed-off-by: Gregory Shimansky <[email protected]>

* Corrected path to tests

Signed-off-by: Gregory Shimansky <[email protected]>

* Added installation of torch==2.1.2

Signed-off-by: Gregory Shimansky <[email protected]>

---------

Signed-off-by: Gregory Shimansky <[email protected]>
* [CPU] Add OpenMP launcher

* Address the comments

* Fix induction variable type

* Always use preallocated output buffer for CPU with torch.add
* [CPU] Dump human-readable asm code in TRITON_CACHE_DIR

* Don't touch the main compiler.py
…-lang#23)

* add un-masked tiled matrix-multiplication for triton-cpu

* clean and add comment

* move test under tutorials
@Kuigesi Kuigesi requested a review from ptillet as a code owner June 14, 2024 23:56
@Kuigesi Kuigesi changed the title Move all llvm dependent code into triton_cpu.cc and add LLJIT Move all llvm dependent code into triton_cpu.cc, add LLJIT, and support macos and arm Jun 14, 2024
@Kuigesi Kuigesi changed the title Move all llvm dependent code into triton_cpu.cc, add LLJIT, and support macos and arm Move all llvm dependent code into triton_cpu.cc, add LLJIT, support macos and arm Jun 14, 2024
@Kuigesi Kuigesi changed the title Move all llvm dependent code into triton_cpu.cc, add LLJIT, support macos and arm [CPU] Move all llvm dependent code into triton_cpu.cc, add LLJIT, support macos and arm Jun 15, 2024
@Kuigesi
Copy link
Collaborator Author

Kuigesi commented Jun 15, 2024

Perf result of running un-masked matrix-multiplication example on arm neoverse-v2:

(base) [[email protected] ~/ruiqi/triton-cpu (add-lljit)]$ python3 python/tutorials/03-matrix-multiplication-cpu.py 
triton_cpu_output_with_torch.float32_inputs=tensor([[ 34.1671,  -0.6265,   2.1590,  ..., -24.7347,  45.6670,  12.1024],
        [-12.4534, -18.2695,   5.6328,  ...,   3.7349, -24.7166,  18.7389],
        [ 10.1346,   7.3148, -15.2706,  ...,  33.5417, -18.3999, -61.9606],
        ...,
        [-25.3489,  16.8383,  25.1225,  ...,  22.5323, -34.2459,  22.0001],
        [  9.3619,  47.0127,  12.2615,  ..., -32.7080, -19.9729,   7.6856],
        [-12.7559, -20.7132,  14.4755,  ...,  23.1968, -14.5273,  19.9374]])
torch_cpu_output_with_torch.float32_inputs=tensor([[ 34.1671,  -0.6265,   2.1590,  ..., -24.7347,  45.6670,  12.1024],
        [-12.4534, -18.2695,   5.6328,  ...,   3.7349, -24.7166,  18.7390],
        [ 10.1346,   7.3148, -15.2706,  ...,  33.5417, -18.3999, -61.9606],
        ...,
        [-25.3489,  16.8383,  25.1225,  ...,  22.5323, -34.2459,  22.0001],
        [  9.3619,  47.0127,  12.2615,  ..., -32.7080, -19.9729,   7.6856],
        [-12.7559, -20.7132,  14.4755,  ...,  23.1968, -14.5273,  19.9374]])
✅ TritonCPU and TorchCPU match
matmul-performance-fp32 (BLOCK_SIZE_M=32, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32, GROUP_SIZE_M=8):
         M       N       K  TritonCPU 1   TritonCPU     TorchCPU
0    256.0   256.0   256.0     3.498078   49.695028   119.304407
1    384.0   384.0   384.0     3.566519   87.075222   129.384039
2    512.0   512.0   512.0     3.533186   95.162003   328.395512
3    640.0   640.0   640.0     3.518112  157.181118  1167.856332
4    768.0   768.0   768.0     3.478274  122.914360   874.587610
5    896.0   896.0   896.0     3.436000  203.076546  1171.739504
6   1024.0  1024.0  1024.0     3.439377  224.565503  1906.239565
7   1152.0  1152.0  1152.0     3.471396  199.340181  2920.473969
8   1280.0  1280.0  1280.0     3.384347  144.589168    61.136822
9   1408.0  1408.0  1408.0     3.329353  159.479037  2655.169267
10  1536.0  1536.0  1536.0     3.421107  150.505174  2323.931550
11  1664.0  1664.0  1664.0     3.367326  237.912040  2942.447208
12  1792.0  1792.0  1792.0     3.376185  162.921996    79.129514
13  1920.0  1920.0  1920.0     3.308813  176.119590   103.867430
14  2048.0  2048.0  2048.0     3.366153  161.267611   110.504712
15  2176.0  2176.0  2176.0     3.382672  239.158536   570.182579
16  2304.0  2304.0  2304.0     3.378930  241.075954  4500.999939
17  2432.0  2432.0  2432.0     3.255841  208.675597  3314.278067
18  2560.0  2560.0  2560.0     3.331287  244.037487   558.830061

@minjang minjang requested review from minjang and ienkovich June 15, 2024 01:33
@ienkovich
Copy link
Collaborator

The patch looks good! But I don't see a reason to have driver.cpp now. Can we completely remove it?

@Kuigesi
Copy link
Collaborator Author

Kuigesi commented Jun 20, 2024

The patch looks good! But I don't see a reason to have driver.cpp now. Can we completely remove it?

Thx ilya for the review. Sorry for the late reply, I was busy with writing review doc the past days. Yes, I think the driver.cpp can be completely removed. I also see that there is a new static compilation PR opened. I will test them on my arm servers and mac.

minjang pushed a commit that referenced this pull request Jun 24, 2024
When running
[convert_blocked1d_to_slice0](https://github.com/triton-lang/triton/blob/0ba5f0c3cd029d5c3d1f01b9bf29dac32c27345e/test/Conversion/tritongpu_to_llvm.mlir#L924)
Triton ends up computing a rank of a matrix with 0 columns during linear
layout lowering, which trips up f2reduce, and causes undefined behavior,
detectable through
[UBSAN](https://clang.llvm.org/docs/UndefinedBehaviorSanitizer.html).

Fix this by returning the rank (0) early in these cases, without calling
f2reduce.

<details><summary>Stack trace</summary>
<p>

```
third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30: runtime error: shift exponent 18446744073709551615 is too large for 64-bit type 'unsigned long long'
    #0 0x556ee2fea3be in inplace_rref_small third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30
    #1 0x556ee2fea3be in f2reduce::inplace_rref_strided(unsigned long*, unsigned long, unsigned long, unsigned long) third_party/triton/third_party/f2reduce/f2reduce.cpp:470:9
    #2 0x556ee2ea70da in getMatrixRank third_party/triton/lib/Tools/LinearLayout.cpp:125:3
    #3 0x556ee2ea70da in mlir::triton::LinearLayout::checkInvariants(bool) third_party/triton/lib/Tools/LinearLayout.cpp:299:7
    #4 0x556ee2ea656d in mlir::triton::LinearLayout::tryCreate(llvm::MapVector<mlir::StringAttr, std::__u::vector<std::__u::vector<int, std::__u::allocator<int>>, std::__u::allocator<std::__u::vector<int, std::__u::allocator<int>>>>, llvm::DenseMap<mlir::StringAttr, unsigned int, llvm::DenseMapInfo<mlir::StringAttr, void>, llvm::detail::DenseMapPair<mlir::StringAttr, unsigned int>>, llvm::SmallVector<std::__u::pair<mlir::StringAttr, std::__u::vector<std::__u::vector<int, std::__u::allocator<int>>, std::__u::allocator<std::__u::vector<int, std::__u::allocator<int>>>>>, 0u>>, llvm::ArrayRef<std::__u::pair<mlir::StringAttr, int>>, bool) third_party/triton/lib/Tools/LinearLayout.cpp:190:41
    #5 0x556ee2eb2150 in mlir::triton::LinearLayout::divideRight(mlir::triton::LinearLayout const&) third_party/triton/lib/Tools/LinearLayout.cpp:654:51
    #6 0x556ee2ee1c39 in mlir::cvtNeedsSharedMemory(mlir::RankedTensorType, mlir::RankedTensorType) third_party/triton/lib/Analysis/Utility.cpp:652:14
    #7 0x556ee2cf38fd in mlir::triton::getRepShapeForCvtLayout(mlir::triton::gpu::ConvertLayoutOp) third_party/triton/lib/Analysis/Allocation.cpp:66:8
    #8 0x556ee2cf3efa in mlir::triton::getScratchConfigForCvtLayout(mlir::triton::gpu::ConvertLayoutOp, unsigned int&, unsigned int&) third_party/triton/lib/Analysis/Allocation.cpp:95:19
    #9 0x556ee2cf6057 in mlir::triton::AllocationAnalysis::getScratchValueSize(mlir::Operation*) third_party/triton/lib/Analysis/Allocation.cpp:272:24
    #10 0x556ee2cf5499 in operator() third_party/triton/lib/Analysis/Allocation.cpp:343:7
    #11 0x556ee2cf5499 in void llvm::function_ref<void (mlir::Operation*)>::callback_fn<mlir::triton::AllocationAnalysis::getValuesAndSizes()::'lambda'(mlir::Operation*)>(long, mlir::Operation*) third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
    #12 0x556edeeee7a9 in operator() third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
    #13 0x556edeeee7a9 in void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:174:5
    #14 0x556edeeee87c in void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:182:9
    #15 0x556ee2cf49e7 in walk<(mlir::WalkOrder)0, mlir::ForwardIterator, (lambda at third_party/triton/lib/Analysis/Allocation.cpp:341:42), mlir::Operation *, void> third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:313:10
    #16 0x556ee2cf49e7 in walk<(mlir::WalkOrder)0, mlir::ForwardIterator, (lambda at third_party/triton/lib/Analysis/Allocation.cpp:341:42), void> third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h:794:12
    #17 0x556ee2cf49e7 in mlir::triton::AllocationAnalysis::getValuesAndSizes() third_party/triton/lib/Analysis/Allocation.cpp:341:16
    #18 0x556ee2cf4852 in run third_party/triton/lib/Analysis/Allocation.cpp:182:5
    #19 0x556ee2cf4852 in AllocationAnalysis third_party/triton/lib/Analysis/Allocation.cpp:169:5
    #20 0x556ee2cf4852 in mlir::Allocation::run(llvm::DenseMap<mlir::FunctionOpInterface, mlir::Allocation, llvm::DenseMapInfo<mlir::FunctionOpInterface, void>, llvm::detail::DenseMapPair<mlir::FunctionOpInterface, mlir::Allocation>>&) third_party/triton/lib/Analysis/Allocation.cpp:627:3
    #21 0x556ee1677402 in operator() third_party/triton/include/triton/Analysis/Allocation.h:227:26
    #22 0x556ee1677402 in void mlir::CallGraph<mlir::Allocation>::doWalk<(mlir::WalkOrder)0, (mlir::WalkOrder)1, mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::CallOpInterface, mlir::FunctionOpInterface), mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::FunctionOpInterface)>(mlir::FunctionOpInterface, llvm::DenseSet<mlir::FunctionOpInterface, llvm::DenseMapInfo<mlir::FunctionOpInterface, void>>&, mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::CallOpInterface, mlir::FunctionOpInterface), mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::FunctionOpInterface)) third_party/triton/include/triton/Analysis/Utility.h:350:7
    #23 0x556ee16756b3 in walk<(mlir::WalkOrder)0, (mlir::WalkOrder)1, (lambda at third_party/triton/include/triton/Analysis/Allocation.h:222:9), (lambda at third_party/triton/include/triton/Analysis/Allocation.h:224:9)> third_party/triton/include/triton/Analysis/Utility.h:242:7
    #24 0x556ee16756b3 in mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp) third_party/triton/include/triton/Analysis/Allocation.h:220:5
    #25 0x556ee2c2bf18 in (anonymous namespace)::AllocateSharedMemory::runOnOperation() third_party/triton/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp:26:22
...
UndefinedBehaviorSanitizer: invalid-shift-exponent third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30 
```
</p>
</details>
@Kuigesi
Copy link
Collaborator Author

Kuigesi commented Jun 25, 2024

PR Closed after static compilation is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants