Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support atomic ops for CPU. #20

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,7 @@ def kernel(X, Y, Z):
# ---------------
# test atomics
# ---------------
@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize(
"op, dtype_x_str, mode, sem",
Expand Down Expand Up @@ -1378,13 +1379,12 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
if is_interpreter():
if dtype_x_str == 'float16':
pytest.skip("Only test atomic float16 ops on GPU")
else:
check_cuda_only(device)

capability = torch.cuda.get_device_capability()
if capability[0] < 7:
if dtype_x_str == 'float16':
pytest.skip("Only test atomic float16 ops on devices with sm >= 70")
if is_cuda():
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
if dtype_x_str == 'float16':
pytest.skip("Only test atomic float16 ops on devices with sm >= 70")
n_programs = 5

# triton kernel
Expand Down Expand Up @@ -1434,6 +1434,7 @@ def kernel(X, Z):
assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"]


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_atomic_rmw_predicate(num_ctas, device):
Expand All @@ -1449,6 +1450,7 @@ def kernel(X):
assert x.item() == 63


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("shape, axis, num_ctas", [(shape, axis, num_ctas)
for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)]
Expand Down Expand Up @@ -1481,6 +1483,7 @@ def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_tensor_atomic_rmw_block(num_ctas, device):
Expand All @@ -1500,6 +1503,7 @@ def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
assert torch.min(x).item() == 0.0


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
Expand Down Expand Up @@ -1537,6 +1541,7 @@ def serialized_add(data, Lock, SEM: tl.constexpr):
assert f"atom.global.{sem_str}" in h.asm["ptx"]


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/include/TritonCPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createFuncOpToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>> createMemoryOpToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>> createGetProgramIdOpToLLVMPass();
std::unique_ptr<OperationPass<triton::FuncOp>> createLowerMultiReductionPass();
std::unique_ptr<OperationPass<ModuleOp>> createAtomicOpsToLLVMPass();

void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm);
void registerTritonCPUToLLVMPipeline();
Expand Down
11 changes: 11 additions & 0 deletions third_party/cpu/include/TritonCPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,15 @@ def LowerMultiReduction : Pass<"triton-cpu-lower-multi-reduction", "mlir::triton
"mlir::triton::TritonDialect"];
}

def AtomicOpsToLLVM : Pass<"triton-cpu-atomic-ops-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert Triton atomic operations to LLVM.";
let description = [{
}];
let constructor = "mlir::triton::cpu::createAtomicOpsToLLVMPass()";

let dependentDialects = ["mlir::vector::VectorDialect",
"mlir::triton::cpu::TritonCPUDialect",
"mlir::triton::TritonDialect"];
}

#endif
1 change: 1 addition & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertControlFlowOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertHistogramOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertReductionOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertScanOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertAtomicOps();

void tritonToTritonCPUPipelineBuilder(OpPassManager &pm);
void registerTritonToTritonCPUPipeline();
Expand Down
14 changes: 14 additions & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,18 @@ def ConvertScanOp : Pass<"triton-cpu-convert-scan", "mlir::ModuleOp"> {
"mlir::triton::cpu::TritonCPUDialect"];
}

def ConvertAtomicOps : Pass<"triton-cpu-convert-atomic-ops", "mlir::ModuleOp"> {
let summary = "Convert Triton atomic operations.";
let description = [{

}];
let constructor = "mlir::triton::cpu::createConvertAtomicOps()";

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::vector::VectorDialect",
"mlir::scf::SCFDialect",
"mlir::triton::TritonDialect",
"mlir::triton::cpu::TritonCPUDialect"];
}

#endif
154 changes: 154 additions & 0 deletions third_party/cpu/lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#include "TypeConverter.h"

#include "cpu/include/TritonCPUToLLVM/Passes.h"

#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"

#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"

namespace mlir {
namespace triton {
#define GEN_PASS_DEF_ATOMICOPSTOLLVM
#include "cpu/include/TritonCPUToLLVM/Passes.h.inc"
} // namespace triton
} // namespace mlir

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::cpu;

namespace {

class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};

LLVM::AtomicOrdering getOrdering(MemSemantic sem) {
switch (sem) {
case MemSemantic::RELAXED:
return LLVM::AtomicOrdering::monotonic;
case MemSemantic::ACQUIRE:
return LLVM::AtomicOrdering::acquire;
case MemSemantic::RELEASE:
return LLVM::AtomicOrdering::release;
case MemSemantic::ACQUIRE_RELEASE:
return LLVM::AtomicOrdering::acq_rel;
default:
llvm_unreachable("Unexpected atomic mem semantic");
}
}

// TODO: use enums to access struct fields.
struct AtomicRMWOpConversion : public OpConversionPattern<AtomicRMWOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtomicRMWOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto opKind = getAtomicBinOp(op.getAtomicRmwOp(), op.getType());
auto ptr = rewriter.getRemappedValue(op.getPtr());
auto val = rewriter.getRemappedValue(op.getVal());
auto ordering = getOrdering(op.getSem());
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(op, opKind, ptr, val,
ordering);
return success();
}

LLVM::AtomicBinOp getAtomicBinOp(RMWOp op, Type type) const {
switch (op) {
case RMWOp::AND:
return LLVM::AtomicBinOp::_and;
case RMWOp::OR:
return LLVM::AtomicBinOp::_or;
case RMWOp::XOR:
return LLVM::AtomicBinOp::_xor;
case RMWOp::ADD:
return LLVM::AtomicBinOp::add;
case RMWOp::FADD:
return LLVM::AtomicBinOp::fadd;
case RMWOp::MAX:
return type.isIntOrIndex() ? LLVM::AtomicBinOp::max
: LLVM::AtomicBinOp::fmax;
case RMWOp::MIN:
return type.isIntOrIndex() ? LLVM::AtomicBinOp::min
: LLVM::AtomicBinOp::fmin;
case RMWOp::UMAX:
return LLVM::AtomicBinOp::umax;
case RMWOp::UMIN:
return LLVM::AtomicBinOp::umin;
case RMWOp::XCHG:
return LLVM::AtomicBinOp::xchg;
default:
llvm_unreachable("Unexpected atomic op");
}
}
};

struct AtomicCASOpConversion : public OpConversionPattern<AtomicCASOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtomicCASOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto ptr = rewriter.getRemappedValue(op.getPtr());
auto cmp = rewriter.getRemappedValue(op.getCmp());
auto val = rewriter.getRemappedValue(op.getVal());
auto ordering = getOrdering(op.getSem());
auto failureOrdering = ordering != LLVM::AtomicOrdering::monotonic
? LLVM::AtomicOrdering::acquire
: ordering;
Value cmpXchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, ptr, cmp, val, ordering, failureOrdering);
Value oldVal = rewriter.create<LLVM::ExtractValueOp>(loc, cmpXchg, 0);
rewriter.replaceOp(op, oldVal);
return success();
}
};

struct AtomicOpsToLLVM
: public triton::impl::AtomicOpsToLLVMBase<AtomicOpsToLLVM> {
using AtomicOpsToLLVMBase::AtomicOpsToLLVMBase;

AtomicOpsToLLVM() : AtomicOpsToLLVMBase() {}

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

mlir::LowerToLLVMOptions option(context);
TritonCPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMConversionTarget convTarget(*context);

RewritePatternSet patterns(context);
patterns.add<AtomicRMWOpConversion>(typeConverter, context);
patterns.add<AtomicCASOpConversion>(typeConverter, context);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};

} // anonymous namespace

namespace mlir {
namespace triton {
namespace cpu {

std::unique_ptr<OperationPass<ModuleOp>> createAtomicOpsToLLVMPass() {
return std::make_unique<AtomicOpsToLLVM>();
}

} // namespace cpu
} // namespace triton
} // namespace mlir
1 change: 1 addition & 0 deletions third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_triton_library(TritonCPUToLLVM
AtomicOpsToLLVM.cpp
FuncOpToLLVM.cpp
GetProgramIdOpToLLVM.cpp
LowerMultiReduction.cpp
Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm) {
pm.addPass(mlir::triton::cpu::createFuncOpToLLVMPass());
pm.addPass(mlir::triton::cpu::createGetProgramIdOpToLLVMPass());
pm.addPass(mlir::triton::cpu::createMemoryOpToLLVMPass());
pm.addPass(mlir::triton::cpu::createAtomicOpsToLLVMPass());
// pm.addPass(mlir::createReconcileUnrealizedCastsPass());
}

Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_triton_library(TritonToTritonCPU
ConvertAtomicOps.cpp
ConvertControlFlowOps.cpp
ConvertDotOp.cpp
ConvertElementwiseOps.cpp
Expand Down
Loading
Loading