Skip to content

Commit

Permalink
Add an option to choose between default reduction lowering and our ow…
Browse files Browse the repository at this point in the history
…n. (triton-lang#98)

Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich authored Sep 20, 2024
1 parent 3b540b1 commit fbfaff7
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 14 deletions.
2 changes: 1 addition & 1 deletion third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def make_ttcir(mod, metadata, opt):
cpu.passes.ttcpuir.add_convert_elem_manip_ops(pm)
cpu.passes.ttcpuir.add_convert_dot_op(pm)
cpu.passes.ttcpuir.add_convert_histogram_op(pm)
cpu.passes.ttcpuir.add_convert_reduction_op(pm)
cpu.passes.ttcpuir.add_convert_reduction_op(pm, False)
cpu.passes.ttcpuir.add_convert_scan_op(pm)
cpu.passes.ttcpuir.add_convert_cf_ops(pm)
cpu.passes.ttcpuir.add_convert_atomic_ops(pm)
Expand Down
2 changes: 2 additions & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertDotOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertControlFlowOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertHistogramOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertReductionOp();
std::unique_ptr<OperationPass<ModuleOp>>
createConvertReductionOp(bool useMultiDimReductionOp);
std::unique_ptr<OperationPass<ModuleOp>> createConvertScanOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertAtomicOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertDebugOps();
Expand Down
6 changes: 6 additions & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def ConvertReductionOp : Pass<"triton-cpu-convert-reduction", "mlir::ModuleOp">
}];
let constructor = "mlir::triton::cpu::createConvertReductionOp()";

let options = [
Option<"useMultiDimReductionOp", "use-multidim-reduction-op",
"bool", /*default*/"false",
"Use vector::MultiDimReductionOp and its default lowering when possible.">,
];

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::vector::VectorDialect",
"mlir::scf::SCFDialect",
Expand Down
37 changes: 27 additions & 10 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

namespace mlir {
namespace triton {
namespace cpu {
#define GEN_PASS_DEF_CONVERTREDUCTIONOP
#include "cpu/include/TritonToTritonCPU/Passes.h.inc"
} // namespace cpu
} // namespace triton
} // namespace mlir

Expand All @@ -42,16 +44,20 @@ class ReductionConversionTarget : public ConversionTarget {
struct ReduceOpConversion
: public ReduceScanOpConversionBase<triton::ReduceOp,
triton::ReduceReturnOp> {
using ReduceScanOpConversionBase::ReduceScanOpConversionBase;
ReduceOpConversion(bool useMultiDimReductionOp,
const TypeConverter &typeConverter, MLIRContext *context)
: ReduceScanOpConversionBase(typeConverter, context) {
this->useMultiDimReductionOp = useMultiDimReductionOp;
}

LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// More simple cases with a single input and a single combine
// operation can utilize target-specific reduction operations like
// horizaontal vector operations. We detect such cases here and map
// them to the vector::MultiDimReductionOp.
if (succeeded(mapToMultiDimReductionOp(op, rewriter)))
// More simple cases with a single input and a single combine operation
// can be mapped to a vector::MultiDimReductionOp. The resulting code
// depends on a quality of LLVM backend and is not always perfect though.
if (useMultiDimReductionOp &&
succeeded(mapToMultiDimReductionOp(op, rewriter)))
return success();

return ReduceScanOpConversionBase::matchAndRewrite(op, adaptor, rewriter);
Expand Down Expand Up @@ -249,13 +255,18 @@ struct ReduceOpConversion

return rewriter.create<arith::ConstantOp>(loc, resTy, initVal);
}

private:
bool useMultiDimReductionOp;
};

struct ConvertReductionOp
: public triton::impl::ConvertReductionOpBase<ConvertReductionOp> {
using ConvertReductionOpBase::ConvertReductionOpBase;
: public triton::cpu::impl::ConvertReductionOpBase<ConvertReductionOp> {
ConvertReductionOp() = default;

ConvertReductionOp() : ConvertReductionOpBase() {}
ConvertReductionOp(bool useMultiDimReductionOp) {
this->useMultiDimReductionOp = useMultiDimReductionOp;
}

void runOnOperation() override {
MLIRContext *context = &getContext();
Expand All @@ -264,7 +275,8 @@ struct ConvertReductionOp
TritonToTritonCPUTypeConverter typeConverter;
ReductionConversionTarget convTarget(*context, typeConverter);
RewritePatternSet patterns(context);
patterns.add<ReduceOpConversion>(typeConverter, context);
patterns.add<ReduceOpConversion>(useMultiDimReductionOp, typeConverter,
context);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
Expand All @@ -281,6 +293,11 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertReductionOp() {
return std::make_unique<ConvertReductionOp>();
}

std::unique_ptr<OperationPass<ModuleOp>>
createConvertReductionOp(bool useMultiDimReductionOp) {
return std::make_unique<ConvertReductionOp>(useMultiDimReductionOp);
}

} // namespace cpu
} // namespace triton
} // namespace mlir
8 changes: 5 additions & 3 deletions third_party/cpu/triton_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) {
m.def("add_convert_histogram_op", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::cpu::createConvertHistogramOp());
});
m.def("add_convert_reduction_op", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::cpu::createConvertReductionOp());
});
m.def("add_convert_reduction_op",
[](mlir::PassManager &pm, bool use_multidim_reduction_op) {
pm.addPass(mlir::triton::cpu::createConvertReductionOp(
use_multidim_reduction_op));
});
m.def("add_convert_scan_op", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::cpu::createConvertScanOp());
});
Expand Down

0 comments on commit fbfaff7

Please sign in to comment.