From 7f807d9cd8c1126a2ed0d45843d5c60d61c51e1f Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 7 Jun 2024 14:32:19 -0700 Subject: [PATCH] Support atomic ops for CPU. Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 17 +- .../cpu/include/TritonCPUToLLVM/Passes.h | 1 + .../cpu/include/TritonCPUToLLVM/Passes.td | 11 + .../cpu/include/TritonToTritonCPU/Passes.h | 1 + .../cpu/include/TritonToTritonCPU/Passes.td | 14 ++ .../lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp | 154 +++++++++++++ .../cpu/lib/TritonCPUToLLVM/CMakeLists.txt | 1 + .../cpu/lib/TritonCPUToLLVM/Pipeline.cpp | 1 + .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 1 + .../TritonToTritonCPU/ConvertAtomicOps.cpp | 218 ++++++++++++++++++ .../cpu/lib/TritonToTritonCPU/Pipeline.cpp | 1 + 11 files changed, 414 insertions(+), 6 deletions(-) create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 8c56d3458766..50b5248e36e9 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1348,6 +1348,7 @@ def kernel(X, Y, Z): # --------------- # test atomics # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "op, dtype_x_str, mode, sem", @@ -1378,13 +1379,12 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device): if is_interpreter(): if dtype_x_str == 'float16': pytest.skip("Only test atomic float16 ops on GPU") - else: - check_cuda_only(device) - capability = torch.cuda.get_device_capability() - if capability[0] < 7: - if dtype_x_str == 'float16': - pytest.skip("Only test atomic float16 ops on devices with sm >= 70") + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability[0] < 7: + if dtype_x_str == 'float16': + pytest.skip("Only test atomic float16 ops on devices with sm >= 70") n_programs = 5 # triton kernel @@ -1434,6 +1434,7 @@ def kernel(X, Z): assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_atomic_rmw_predicate(num_ctas, device): @@ -1449,6 +1450,7 @@ def kernel(X): assert x.item() == 63 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("shape, axis, num_ctas", [(shape, axis, num_ctas) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] @@ -1481,6 +1483,7 @@ def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_tensor_atomic_rmw_block(num_ctas, device): @@ -1500,6 +1503,7 @@ def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): assert torch.min(x).item() == 0.0 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -1537,6 +1541,7 @@ def serialized_add(data, Lock, SEM: tl.constexpr): assert f"atom.global.{sem_str}" in h.asm["ptx"] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) @pytest.mark.parametrize("num_ctas", num_ctas_list) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h index a1fbce2e4892..7d739f1c32fe 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -24,6 +24,7 @@ std::unique_ptr> createFuncOpToLLVMPass(); std::unique_ptr> createMemoryOpToLLVMPass(); std::unique_ptr> createGetProgramIdOpToLLVMPass(); std::unique_ptr> createLowerMultiReductionPass(); +std::unique_ptr> createAtomicOpsToLLVMPass(); void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm); void registerTritonCPUToLLVMPipeline(); diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td index 2abe88338dcf..0759ddbf7925 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.td +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -54,4 +54,15 @@ def LowerMultiReduction : Pass<"triton-cpu-lower-multi-reduction", "mlir::triton "mlir::triton::TritonDialect"]; } +def AtomicOpsToLLVM : Pass<"triton-cpu-atomic-ops-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Triton atomic operations to LLVM."; + let description = [{ + }]; + let constructor = "mlir::triton::cpu::createAtomicOpsToLLVMPass()"; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::triton::cpu::TritonCPUDialect", + "mlir::triton::TritonDialect"]; +} + #endif diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index c7d072ab9175..b5107e5e78a3 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -26,6 +26,7 @@ std::unique_ptr> createConvertControlFlowOps(); std::unique_ptr> createConvertHistogramOp(); std::unique_ptr> createConvertReductionOp(); std::unique_ptr> createConvertScanOp(); +std::unique_ptr> createConvertAtomicOps(); void tritonToTritonCPUPipelineBuilder(OpPassManager &pm); void registerTritonToTritonCPUPipeline(); diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index 28ad258c38c0..5dd3bf903440 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -114,4 +114,18 @@ def ConvertScanOp : Pass<"triton-cpu-convert-scan", "mlir::ModuleOp"> { "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertAtomicOps : Pass<"triton-cpu-convert-atomic-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton atomic operations."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertAtomicOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + #endif diff --git a/third_party/cpu/lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp new file mode 100644 index 000000000000..9a2c183e1c4c --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp @@ -0,0 +1,154 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_ATOMICOPSTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +LLVM::AtomicOrdering getOrdering(MemSemantic sem) { + switch (sem) { + case MemSemantic::RELAXED: + return LLVM::AtomicOrdering::monotonic; + case MemSemantic::ACQUIRE: + return LLVM::AtomicOrdering::acquire; + case MemSemantic::RELEASE: + return LLVM::AtomicOrdering::release; + case MemSemantic::ACQUIRE_RELEASE: + return LLVM::AtomicOrdering::acq_rel; + default: + llvm_unreachable("Unexpected atomic mem semantic"); + } +} + +// TODO: use enums to access struct fields. +struct AtomicRMWOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto opKind = getAtomicBinOp(op.getAtomicRmwOp(), op.getType()); + auto ptr = rewriter.getRemappedValue(op.getPtr()); + auto val = rewriter.getRemappedValue(op.getVal()); + auto ordering = getOrdering(op.getSem()); + rewriter.replaceOpWithNewOp(op, opKind, ptr, val, + ordering); + return success(); + } + + LLVM::AtomicBinOp getAtomicBinOp(RMWOp op, Type type) const { + switch (op) { + case RMWOp::AND: + return LLVM::AtomicBinOp::_and; + case RMWOp::OR: + return LLVM::AtomicBinOp::_or; + case RMWOp::XOR: + return LLVM::AtomicBinOp::_xor; + case RMWOp::ADD: + return LLVM::AtomicBinOp::add; + case RMWOp::FADD: + return LLVM::AtomicBinOp::fadd; + case RMWOp::MAX: + return type.isIntOrIndex() ? LLVM::AtomicBinOp::max + : LLVM::AtomicBinOp::fmax; + case RMWOp::MIN: + return type.isIntOrIndex() ? LLVM::AtomicBinOp::min + : LLVM::AtomicBinOp::fmin; + case RMWOp::UMAX: + return LLVM::AtomicBinOp::umax; + case RMWOp::UMIN: + return LLVM::AtomicBinOp::umin; + case RMWOp::XCHG: + return LLVM::AtomicBinOp::xchg; + default: + llvm_unreachable("Unexpected atomic op"); + } + } +}; + +struct AtomicCASOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ptr = rewriter.getRemappedValue(op.getPtr()); + auto cmp = rewriter.getRemappedValue(op.getCmp()); + auto val = rewriter.getRemappedValue(op.getVal()); + auto ordering = getOrdering(op.getSem()); + auto failureOrdering = ordering != LLVM::AtomicOrdering::monotonic + ? LLVM::AtomicOrdering::acquire + : ordering; + Value cmpXchg = rewriter.create( + loc, ptr, cmp, val, ordering, failureOrdering); + Value oldVal = rewriter.create(loc, cmpXchg, 0); + rewriter.replaceOp(op, oldVal); + return success(); + } +}; + +struct AtomicOpsToLLVM + : public triton::impl::AtomicOpsToLLVMBase { + using AtomicOpsToLLVMBase::AtomicOpsToLLVMBase; + + AtomicOpsToLLVM() : AtomicOpsToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createAtomicOpsToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt index 0cf83bc03b06..9e5f71f8d4e5 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonCPUToLLVM + AtomicOpsToLLVM.cpp FuncOpToLLVM.cpp GetProgramIdOpToLLVM.cpp LowerMultiReduction.cpp diff --git a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp index 914f56e668f8..0263a1e65214 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp @@ -11,6 +11,7 @@ void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm) { pm.addPass(mlir::triton::cpu::createFuncOpToLLVMPass()); pm.addPass(mlir::triton::cpu::createGetProgramIdOpToLLVMPass()); pm.addPass(mlir::triton::cpu::createMemoryOpToLLVMPass()); + pm.addPass(mlir::triton::cpu::createAtomicOpsToLLVMPass()); // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); } diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index fc22e12b867d..636ea039e718 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonToTritonCPU + ConvertAtomicOps.cpp ConvertControlFlowOps.cpp ConvertDotOp.cpp ConvertElementwiseOps.cpp diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp new file mode 100644 index 000000000000..61d3ac65e2fc --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp @@ -0,0 +1,218 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTATOMICOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class AtomicConversionTarget : public ConversionTarget { +public: + explicit AtomicConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addDynamicallyLegalOp( + [&](triton::AtomicRMWOp op) -> std::optional { + return converter.isLegal(op) && !op.getMask(); + }); + addDynamicallyLegalOp( + [&](triton::AtomicCASOp op) -> std::optional { + return converter.isLegal(op); + }); + } +}; + +struct AtomicRMWOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto mask = + op.getMask() ? rewriter.getRemappedValue(op.getMask()) : nullptr; + arith::ConstantOp maskCst = mask ? getConstMaskDef(mask) : nullptr; + auto rmwOp = op.getAtomicRmwOp(); + auto ptrs = rewriter.getRemappedValue(op.getPtr()); + auto vals = rewriter.getRemappedValue(op.getVal()); + auto sem = op.getSem(); + auto scope = op.getScope(); + + if (mask && !isa(mask.getType())) { + auto res = lowerScalarMaskToCF(loc, rmwOp, ptrs, vals, mask, sem, scope, + rewriter); + rewriter.replaceOp(op, res); + return success(); + } + + auto ptrTy = cast(op.getPtr().getType()).getElementType(); + auto vecTy = cast(vals.getType()); + auto strides = computeStrides(vecTy.getShape()); + auto res = + rewriter.create(loc, rewriter.getZeroAttr(vecTy)); + int64_t numElems = vecTy.getNumElements(); + for (int64_t idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = rewriter.create(loc, vals, indices); + Value resElem; + + if (mask && !maskCst) { + // Non-const mask values are lowered to CF. + Value maskVal = rewriter.create(loc, mask, indices); + resElem = lowerScalarMaskToCF(loc, rmwOp, ptr, val, maskVal, sem, scope, + rewriter); + } else if (!mask || + (maskCst && cast(maskCst.getValue()) + .getValues()[idx])) { + // Const true mask case. + resElem = rewriter.create( + loc, val.getType(), rmwOp, ptr, val, nullptr, sem, scope); + } + + // Elements with const false mask are skipped. + if (resElem) { + rewriter.create(loc, resElem, res, indices); + } + } + + rewriter.replaceOp(op, res); + return success(); + } + + Value lowerScalarMaskToCF(Location loc, RMWOp rmwOp, Value ptr, Value val, + Value mask, MemSemantic sem, MemSyncScope scope, + ConversionPatternRewriter &rewriter) const { + // Check for constant mask. + if (auto maskDef = mask.getDefiningOp()) { + auto maskVal = cast(maskDef.getValue()); + if (maskVal.getValue().isZero()) { + return rewriter.create( + loc, rewriter.getZeroAttr(val.getType())); + } else { + return rewriter.create( + loc, val.getType(), rmwOp, ptr, val, nullptr, sem, scope); + } + } + + Block *headerBlock = rewriter.getBlock(); + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(val.getType())); + Block *condBlock = + rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToStart(condBlock); + Value resVal = rewriter.create( + loc, val.getType(), rmwOp, ptr, val, nullptr, sem, scope); + Block *footerBlock = + rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); + Value res = footerBlock->addArgument(resVal.getType(), resVal.getLoc()); + rewriter.setInsertionPointToEnd(headerBlock); + rewriter.create(loc, mask, condBlock, footerBlock, zero); + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create(loc, footerBlock, resVal); + rewriter.setInsertionPointToStart(footerBlock); + + return res; + } + + arith::ConstantOp getConstMaskDef(Value mask) const { + while (auto cast = mask.getDefiningOp()) + mask = cast.getOperand(0); + return mask.getDefiningOp(); + } +}; + +struct AtomicCASOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ptrs = rewriter.getRemappedValue(op.getPtr()); + auto cmpVals = rewriter.getRemappedValue(op.getCmp()); + auto vals = rewriter.getRemappedValue(op.getVal()); + auto sem = op.getSem(); + auto scope = op.getScope(); + auto ptrTy = cast(op.getPtr().getType()).getElementType(); + auto vecTy = cast(vals.getType()); + auto strides = computeStrides(vecTy.getShape()); + auto res = + rewriter.create(loc, rewriter.getZeroAttr(vecTy)); + int64_t numElems = vecTy.getNumElements(); + for (int64_t idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = rewriter.create(loc, vals, indices); + Value cmpVal = rewriter.create(loc, cmpVals, indices); + Value resElem = rewriter.create( + loc, val.getType(), ptr, cmpVal, val, sem, scope); + rewriter.create(loc, resElem, res, indices); + } + + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertAtomicOps + : public triton::impl::ConvertAtomicOpsBase { + using ConvertAtomicOpsBase::ConvertAtomicOpsBase; + + ConvertAtomicOps() : ConvertAtomicOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + AtomicConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertAtomicOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp index 2b26cec34248..ec7c62f72f52 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp @@ -16,6 +16,7 @@ void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertReductionOp()); pm.addPass(mlir::triton::cpu::createConvertScanOp()); pm.addPass(mlir::triton::cpu::createConvertControlFlowOps()); + pm.addPass(mlir::triton::cpu::createConvertAtomicOps()); // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); }