From 613bdda6a700082f2554b868ca7f2ba439481b9c Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 20 Jun 2024 14:09:54 -0500 Subject: [PATCH] Move byte manipulation ops from elwise ops conversion. (#28) Signed-off-by: Ilya Enkovich --- .../cpu/include/TritonToTritonCPU/Passes.h | 1 + .../cpu/include/TritonToTritonCPU/Passes.td | 14 ++ .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 1 + .../TritonToTritonCPU/ConvertElemManipOps.cpp | 208 ++++++++++++++++++ .../ConvertElementwiseOps.cpp | 129 ----------- .../cpu/lib/TritonToTritonCPU/Pipeline.cpp | 1 + 6 files changed, 225 insertions(+), 129 deletions(-) create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index b5107e5e78a3..14df893f0bac 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -19,6 +19,7 @@ namespace cpu { #include "cpu/include/TritonToTritonCPU/Passes.h.inc" std::unique_ptr> createConvertElementwiseOps(); +std::unique_ptr> createConvertElemManipOps(); std::unique_ptr> createConvertMemoryOps(); std::unique_ptr> createConvertPtrOps(); std::unique_ptr> createConvertDotOp(); diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index 5dd3bf903440..dfac926a9f5b 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -31,6 +31,20 @@ def ConvertElementwiseOps : Pass<"triton-cpu-convert-elementwise-ops", "mlir::Mo "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertElemManipOps : Pass<"triton-cpu-convert-elem-manip-ops", "mlir::ModuleOp"> { + let summary = "Convert elements manipulation ops (transpose, shuffle, etc.)."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertElemManipOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + def ConvertPtrOps : Pass<"triton-cpu-convert-ptr-ops", "mlir::ModuleOp"> { let summary = "Convert Triton ops related to pointer arithmetics."; let description = [{ diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index 636ea039e718..dc34c5bd0199 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -3,6 +3,7 @@ add_triton_library(TritonToTritonCPU ConvertControlFlowOps.cpp ConvertDotOp.cpp ConvertElementwiseOps.cpp + ConvertElemManipOps.cpp ConvertHistogramOp.cpp ConvertMemoryOps.cpp ConvertPtrOps.cpp diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp new file mode 100644 index 000000000000..99211ea90e41 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp @@ -0,0 +1,208 @@ +#include "OpTypeConversion.h" +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTELEMMANIPOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ElemManipOpConversionTarget : public ConversionTarget { +public: + explicit ElemManipOpConversionTarget(MLIRContext &ctx, + TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + } +}; + +struct ReshapeOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(isa(op.getType())); + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcShape = dyn_cast(src.getType()).getShape(); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto dstShape = resTy.getShape(); + auto elemTy = resTy.getElementType(); + + // There are restrictions on how shape can be modified by ShapeCastOp + // when rank is changed. For now, we simply detect it and handle through + // a cast to 1D vector. Better solution may be required later. + if (canCastShape(srcShape, dstShape)) { + rewriter.replaceOpWithNewOp( + op, VectorType::get(dstShape, elemTy), src); + } else { + SmallVector tmpShape({resTy.getNumElements()}); + auto tmp = rewriter.create( + loc, VectorType::get(tmpShape, elemTy), src); + rewriter.replaceOpWithNewOp( + op, VectorType::get(dstShape, elemTy), tmp); + } + return success(); + } + +private: + bool canCastShape(ArrayRef src, ArrayRef dst) const { + if (src.size() == dst.size()) + return true; + if (src.size() > dst.size()) + return canCastShape(dst, src); + + size_t srcIdx = 0; + size_t dstIdx = 0; + while (srcIdx < src.size() && dstIdx < dst.size()) { + if (src[srcIdx] == 1) { + ++srcIdx; + } else { + // Source dim size should be a product of continuous dest dim sizes. + int64_t srcSize = src[srcIdx++]; + int64_t dstSize = dst[dstIdx++]; + while (dstSize < srcSize && dstIdx < dst.size()) + dstSize *= dst[dstIdx++]; + if (dstSize != srcSize) + return false; + } + } + + // Skip trailing 1s. + while (srcIdx < src.size() && src[srcIdx] == 1) + ++srcIdx; + while (dstIdx < dst.size() && dst[dstIdx] == 1) + ++dstIdx; + + return srcIdx == src.size() && dstIdx == dst.size(); + } +}; + +struct TransOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto val = rewriter.getRemappedValue(op.getSrc()); + auto order = op.getOrder(); + SmallVector permutation(order.begin(), order.end()); + rewriter.replaceOpWithNewOp(op, val, permutation); + return success(); + } +}; + +struct JoinOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = rewriter.getRemappedValue(op.getLhs()); + auto rhs = rewriter.getRemappedValue(op.getRhs()); + auto interleave = rewriter.create(loc, lhs, rhs); + // JoinOp creates a new dimension, but InterleaveOp doubles the final one. + // Use ShapeCastOp to get the required shape. + auto resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, interleave); + return success(); + } +}; + +struct CatOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = rewriter.getRemappedValue(op.getLhs()); + auto rhs = rewriter.getRemappedValue(op.getRhs()); + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + SmallVector indices(lhsTy.getShape()[0] + rhsTy.getShape()[0]); + std::iota(indices.begin(), indices.end(), 0); + rewriter.replaceOpWithNewOp(op, lhs, rhs, indices); + return success(); + } +}; + +struct ConvertElemManipOps + : public triton::impl::ConvertElemManipOpsBase { + using ConvertElemManipOpsBase::ConvertElemManipOpsBase; + + ConvertElemManipOps() : ConvertElemManipOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ElemManipOpConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertElemManipOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index cadec818910b..7edf15f2e921 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -51,16 +51,10 @@ class ElementwiseOpConversionTarget : public ConversionTarget { addDynamicallyLegalOp( [](triton::BitcastOp op) { return isa(op.getType()); }); - addIllegalOp(); - addIllegalOp(); addIllegalOp(); addIllegalOp(); - addIllegalOp(); addIllegalOp(); addIllegalOp(); - addIllegalOp(); - addIllegalOp(); - addIllegalOp(); } }; @@ -84,70 +78,6 @@ struct ConstantOpConversion : public OpConversionPattern { } }; -struct ReshapeOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(isa(op.getType())); - auto loc = op.getLoc(); - auto src = rewriter.getRemappedValue(op.getSrc()); - auto srcShape = dyn_cast(src.getType()).getShape(); - auto resTy = - dyn_cast(getTypeConverter()->convertType(op.getType())); - auto dstShape = resTy.getShape(); - auto elemTy = resTy.getElementType(); - - // There are restrictions on how shape can be modified by ShapeCastOp - // when rank is changed. For now, we simply detect it and handle through - // a cast to 1D vector. Better solution may be required later. - if (canCastShape(srcShape, dstShape)) { - rewriter.replaceOpWithNewOp( - op, VectorType::get(dstShape, elemTy), src); - } else { - SmallVector tmpShape({resTy.getNumElements()}); - auto tmp = rewriter.create( - loc, VectorType::get(tmpShape, elemTy), src); - rewriter.replaceOpWithNewOp( - op, VectorType::get(dstShape, elemTy), tmp); - } - return success(); - } - -private: - bool canCastShape(ArrayRef src, ArrayRef dst) const { - if (src.size() == dst.size()) - return true; - if (src.size() > dst.size()) - return canCastShape(dst, src); - - size_t srcIdx = 0; - size_t dstIdx = 0; - while (srcIdx < src.size() && dstIdx < dst.size()) { - if (src[srcIdx] == 1) { - ++srcIdx; - } else { - // Source dim size should be a product of continuous dest dim sizes. - int64_t srcSize = src[srcIdx++]; - int64_t dstSize = dst[dstIdx++]; - while (dstSize < srcSize && dstIdx < dst.size()) - dstSize *= dst[dstIdx++]; - if (dstSize != srcSize) - return false; - } - } - - // Skip trailing 1s. - while (srcIdx < src.size() && src[srcIdx] == 1) - ++srcIdx; - while (dstIdx < dst.size() && dst[dstIdx] == 1) - ++dstIdx; - - return srcIdx == src.size() && dstIdx == dst.size(); - } -}; - struct MulhiUIOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -204,57 +134,6 @@ struct ClampFOpConversion : public OpConversionPattern { } }; -struct TransOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto val = rewriter.getRemappedValue(op.getSrc()); - auto order = op.getOrder(); - SmallVector permutation(order.begin(), order.end()); - rewriter.replaceOpWithNewOp(op, val, permutation); - return success(); - } -}; - -struct JoinOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto lhs = rewriter.getRemappedValue(op.getLhs()); - auto rhs = rewriter.getRemappedValue(op.getRhs()); - auto interleave = rewriter.create(loc, lhs, rhs); - // JoinOp creates a new dimension, but InterleaveOp doubles the final one. - // Use ShapeCastOp to get the required shape. - auto resTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resTy, interleave); - return success(); - } -}; - -struct CatOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto lhs = rewriter.getRemappedValue(op.getLhs()); - auto rhs = rewriter.getRemappedValue(op.getRhs()); - auto lhsTy = dyn_cast(lhs.getType()); - auto rhsTy = dyn_cast(rhs.getType()); - SmallVector indices(lhsTy.getShape()[0] + rhsTy.getShape()[0]); - std::iota(indices.begin(), indices.end(), 0); - rewriter.replaceOpWithNewOp(op, lhs, rhs, indices); - return success(); - } -}; - struct ConvertElementwiseOps : public triton::impl::ConvertElementwiseOpsBase { using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; @@ -326,20 +205,12 @@ struct ConvertElementwiseOps patterns.add>( typeConverter, context); - patterns.add>( - typeConverter, context); - patterns.add>( - typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); - patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp index ec7c62f72f52..c7e7de72eecf 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp @@ -11,6 +11,7 @@ void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertMemoryOps()); pm.addPass(mlir::triton::cpu::createConvertPtrOps()); pm.addPass(mlir::triton::cpu::createConvertElementwiseOps()); + pm.addPass(mlir::triton::cpu::createConvertElemManipOps()); pm.addPass(mlir::triton::cpu::createConvertDotOp()); pm.addPass(mlir::triton::cpu::createConvertHistogramOp()); pm.addPass(mlir::triton::cpu::createConvertReductionOp());