From 9279fb9759d50f370d7ab08234f67267b33b2ff6 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 25 Jun 2024 01:05:17 -0500 Subject: [PATCH] [CPU] Add conversion for unsupported BF16 ops via target-specific stage (#27) * Remove unused code. Signed-off-by: Ilya Enkovich * Fma is always allowed on CPU. Signed-off-by: Ilya Enkovich * Add unsupported op conversions for BF16 type. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich --- bin/RegisterTritonDialects.h | 2 + .../triton/Dialect/TritonCPU/CMakeLists.txt | 1 - .../TritonCPU/Transforms/CMakeLists.txt | 3 - .../Dialect/TritonCPU/Transforms/Passes.h | 16 -- .../Dialect/TritonCPU/Transforms/Passes.td | 6 - .../Transforms/TritonCPUConversion.h | 31 --- lib/Conversion/CMakeLists.txt | 3 - lib/Conversion/TritonCPUToLLVM/CMakeLists.txt | 20 -- .../TritonCPUToLLVM/CPUTargetInfo.cpp | 49 ----- .../TritonCPUToLLVM/ControlFlowOpToLLVM.cpp | 37 ---- .../TritonCPUToLLVM/FuncOpToLLVM.cpp | 54 ----- .../TritonCPUToLLVM/PrintOpToLLVM.cpp | 131 ----------- .../TritonCPUToLLVM/SPMDOpToLLVM.cpp | 39 ---- .../TritonCPUToLLVM/TritonCPUToLLVM.cpp | 117 ---------- .../TritonCPUToLLVM/TypeConverter.cpp | 31 --- .../TritonToTritonCPU/CMakeLists.txt | 15 -- .../TritonToTritonCPU/TritonCPUConversion.cpp | 108 ---------- .../TritonToTritonCPUPass.cpp | 41 ---- lib/Dialect/TritonCPU/CMakeLists.txt | 1 - .../TritonCPU/Transforms/CMakeLists.txt | 13 -- python/setup.py | 2 +- python/src/llvm.cc | 15 ++ python/test/unit/language/test_core.py | 3 - third_party/cpu/CMakeLists.txt | 2 +- third_party/cpu/backend/compiler.py | 17 ++ third_party/cpu/include/CMakeLists.txt | 1 + .../TritonCPUTransforms/CMakeLists.txt | 3 + .../cpu/include/TritonCPUTransforms/Passes.h | 32 +++ .../cpu/include/TritonCPUTransforms/Passes.td | 39 ++++ third_party/cpu/lib/CMakeLists.txt | 1 + .../lib/TritonCPUTransforms/CMakeLists.txt | 7 + .../ConvertUnsupportedOps.cpp | 204 ++++++++++++++++++ .../DecomposeFpConversions.cpp | 81 +++++++ .../cpu/lib/TritonCPUTransforms/OptCommon.h | 46 ++++ third_party/cpu/triton_cpu.cc | 7 + 35 files changed, 457 insertions(+), 721 deletions(-) delete mode 100644 include/triton/Dialect/TritonCPU/Transforms/CMakeLists.txt delete mode 100644 include/triton/Dialect/TritonCPU/Transforms/Passes.h delete mode 100644 include/triton/Dialect/TritonCPU/Transforms/Passes.td delete mode 100644 include/triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h delete mode 100644 lib/Conversion/TritonCPUToLLVM/CMakeLists.txt delete mode 100644 lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp delete mode 100644 lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp delete mode 100644 lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp delete mode 100644 lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp delete mode 100644 lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp delete mode 100644 lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp delete mode 100644 lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp delete mode 100644 lib/Conversion/TritonToTritonCPU/CMakeLists.txt delete mode 100644 lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp delete mode 100644 lib/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.cpp delete mode 100644 lib/Dialect/TritonCPU/Transforms/CMakeLists.txt create mode 100644 third_party/cpu/include/TritonCPUTransforms/CMakeLists.txt create mode 100644 third_party/cpu/include/TritonCPUTransforms/Passes.h create mode 100644 third_party/cpu/include/TritonCPUTransforms/Passes.td create mode 100644 third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt create mode 100644 third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp create mode 100644 third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp create mode 100644 third_party/cpu/lib/TritonCPUTransforms/OptCommon.h diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 9ad8d758e67d..9758ced3a5a8 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -17,6 +17,7 @@ #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" #include "cpu/include/TritonCPUToLLVM/Passes.h" +#include "cpu/include/TritonCPUTransforms/Passes.h" #include "cpu/include/TritonToTritonCPU/Passes.h" #include "nvidia/include/NVGPUToLLVM/Passes.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" @@ -70,6 +71,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { // CPU passes mlir::triton::cpu::registerTritonToTritonCPUPasses(); mlir::triton::cpu::registerTritonToTritonCPUPipeline(); + mlir::triton::cpu::registerTritonCPUTransformsPasses(); mlir::triton::cpu::registerTritonCPUToLLVMPasses(); mlir::triton::cpu::registerTritonCPUToLLVMPipeline(); diff --git a/include/triton/Dialect/TritonCPU/CMakeLists.txt b/include/triton/Dialect/TritonCPU/CMakeLists.txt index 9f57627c321f..f33061b2d87c 100644 --- a/include/triton/Dialect/TritonCPU/CMakeLists.txt +++ b/include/triton/Dialect/TritonCPU/CMakeLists.txt @@ -1,2 +1 @@ add_subdirectory(IR) -add_subdirectory(Transforms) diff --git a/include/triton/Dialect/TritonCPU/Transforms/CMakeLists.txt b/include/triton/Dialect/TritonCPU/Transforms/CMakeLists.txt deleted file mode 100644 index 6aa946f64932..000000000000 --- a/include/triton/Dialect/TritonCPU/Transforms/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonCPU) -add_public_tablegen_target(TritonCPUTransformsIncGen) diff --git a/include/triton/Dialect/TritonCPU/Transforms/Passes.h b/include/triton/Dialect/TritonCPU/Transforms/Passes.h deleted file mode 100644 index f31e47317080..000000000000 --- a/include/triton/Dialect/TritonCPU/Transforms/Passes.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef TRITON_DIALECT_TRITONCPU_TRANSFORMS_PASSES_H_ -#define TRITON_DIALECT_TRITONCPU_TRANSFORMS_PASSES_H_ - -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace triton { -namespace cpu {} // namespace cpu -} // namespace triton - -/// Generate the code for registering passes. -#define GEN_PASS_REGISTRATION -#include "triton/Dialect/TritonCPU/Transforms/Passes.h.inc" - -} // namespace mlir -#endif diff --git a/include/triton/Dialect/TritonCPU/Transforms/Passes.td b/include/triton/Dialect/TritonCPU/Transforms/Passes.td deleted file mode 100644 index a1d5271ee6e7..000000000000 --- a/include/triton/Dialect/TritonCPU/Transforms/Passes.td +++ /dev/null @@ -1,6 +0,0 @@ -#ifndef TRITONCPU_PASSES -#define TRITONCPU_PASSES - -include "mlir/Pass/PassBase.td" - -#endif diff --git a/include/triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h b/include/triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h deleted file mode 100644 index 01c24e19c60e..000000000000 --- a/include/triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h +++ /dev/null @@ -1,31 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Defines utilities to use while converting to the TritonCPU dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_DIALECT_TRITONCPU_TRANSFORMS_TRITONCPUCONVERSION_H_ -#define TRITON_DIALECT_TRITONCPU_TRANSFORMS_TRITONCPUCONVERSION_H_ - -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { - -class TritonCPUTypeConverter : public TypeConverter { -public: - TritonCPUTypeConverter(MLIRContext *context); - -private: - MLIRContext *context; -}; - -class TritonCPUConversionTarget : public ConversionTarget { - -public: - explicit TritonCPUConversionTarget(MLIRContext &ctx, - TritonCPUTypeConverter &typeConverter); -}; - -} // namespace mlir - -#endif // TRITON_DIALECT_TRITONCPU_TRANSFORMS_TRITONCPUCONVERSION_H_ diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 426b22a42ef6..143a4375a811 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,5 +1,2 @@ -# TODO(minjang): I will remove these scratches soon. -# add_subdirectory(TritonToTritonCPU) add_subdirectory(TritonToTritonGPU) -# add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonGPUToLLVM) diff --git a/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt deleted file mode 100644 index db507557fb22..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -add_triton_library(TritonCPUToLLVM - ControlFlowOpToLLVM.cpp - CPUTargetInfo.cpp - FuncOpToLLVM.cpp - PrintOpToLLVM.cpp - SPMDOpToLLVM.cpp - TypeConverter.cpp - TritonCPUToLLVM.cpp - - DEPENDS - TritonCPUConversionPassIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - TritonAnalysis - TritonIR - TritonCPUIR - TritonCPUTransforms -) diff --git a/lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp b/lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp deleted file mode 100644 index 8dd050b80bbf..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h" -#include "triton/Conversion/TritonCPUToLLVM/Utility.h" - -namespace { -LLVM::LLVMFuncOp getPrintfDeclaration(ConversionPatternRewriter &rewriter) { - auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); - StringRef funcName("printf"); - Operation *funcOp = moduleOp.lookupSymbol(funcName); - if (funcOp) - return cast(*funcOp); - - auto *context = rewriter.getContext(); - - // int printf(char* format, ...) - SmallVector argsType{ptr_ty(context)}; - auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType, true); - - ConversionPatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - - return rewriter.create(UnknownLoc::get(context), funcName, - funcType); -} -} // namespace - -namespace mlir::triton::cpu { - -Value CPUTargetInfo::programId(ConversionPatternRewriter &rewriter, - Location loc, LLVM::LLVMFuncOp funcOp, - int axis) const { - assert(axis >= 0 && axis < 3); - - // program_id for CPU is provided as function arguments. The last three - // arguments are __grid0 to __grid2 of i32. - assert(funcOp && funcOp.getArguments().size() >= 3); - return funcOp.getArgument(funcOp.getArguments().size() - 3 + axis); -} - -void CPUTargetInfo::printf(ConversionPatternRewriter &rewriter, - Value formatStrStart, int /*formatStrByteCount*/, - ValueRange args) const { - auto loc = UnknownLoc::get(rewriter.getContext()); - SmallVector formatStrAndArgs{formatStrStart}; - for (auto arg : args) { - formatStrAndArgs.push_back(arg); - } - call(getPrintfDeclaration(rewriter), formatStrAndArgs); -} -} // namespace mlir::triton::cpu diff --git a/lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp deleted file mode 100644 index a270c0d60845..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" -#include "triton/Conversion/TritonCPUToLLVM/Utility.h" -#include "llvm/Support/ErrorHandling.h" - -namespace { - -using namespace mlir; -using namespace mlir::triton; - -struct ReturnOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto funcOp = op->getParentOfType(); - if (funcOp->hasAttr("cpu.kernel")) { - if (op.getNumOperands() > 0) { - return rewriter.notifyMatchFailure( - op, "Kernel functions do not support return with operands"); - } - rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), - op->getAttrs()); - } else { - llvm_unreachable("Not implemented"); - } - return success(); - } -}; - -} // namespace - -void mlir::triton::cpu::populateControlFlowOpToLLVMPattern( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { - patterns.add(typeConverter, benefit); -} diff --git a/lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp deleted file mode 100644 index 9ecd470345ad..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp +++ /dev/null @@ -1,54 +0,0 @@ -#include "mlir/Support/LogicalResult.h" -#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" -#include "triton/Conversion/TritonCPUToLLVM/Utility.h" - -namespace mlir { -FailureOr -convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, - ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &converter); -} - -namespace { - -using namespace mlir; -using namespace mlir::triton; - -struct FuncOpConversion : public ConvertOpToLLVMPattern { - FuncOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit) {} - - LogicalResult - matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!LLVM::isKernel(funcOp)) { - llvm_unreachable("Not implemented"); - } - - LLVM::LLVMFuncOp newFuncOp = - *mlir::convertFuncOpToLLVMFuncOp(funcOp, rewriter, *getTypeConverter()); - if (!newFuncOp) { - return failure(); - } - - auto ctx = funcOp->getContext(); - if (LLVM::isKernel(funcOp)) { - // Set an attribute to indicate this function is a kernel entry. - newFuncOp->setAttr("cpu.kernel", - rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); - } else { - llvm_unreachable("Not implemented"); - } - - rewriter.eraseOp(funcOp); - return success(); - } -}; - -} // namespace - -void mlir::triton::cpu::populateFuncOpConversionPattern( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { - patterns.add(typeConverter, benefit); -} diff --git a/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp deleted file mode 100644 index b424cf8e37b7..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp +++ /dev/null @@ -1,131 +0,0 @@ -#include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/PatternMatch.h" -#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h" -#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" -#include "triton/Conversion/TritonCPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" - -namespace { - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -struct PrintOpConversion : public ConvertOpToLLVMPattern { - explicit PrintOpConversion(LLVMTypeConverter &typeConverter, - const CPUTargetInfo &targetInfo, - PatternBenefit benefit) - : mlir::ConvertOpToLLVMPattern(typeConverter, benefit), - targetInfo(targetInfo) {} - - LogicalResult - matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - - auto getPid = [&](int axis) { - return targetInfo.programId( - rewriter, loc, op->getParentOfType(), axis); - }; - SmallVector values = {getPid(0), getPid(1), getPid(2)}; - - std::string formatStr; - llvm::raw_string_ostream os(formatStr); - os << "pid (" << getFormatSubstr(values[0]) << ", " - << getFormatSubstr(values[1]) << ", " << getFormatSubstr(values[2]) - << ")" << op.getPrefix(); - - for (size_t i = 0; i < op.getNumOperands(); i++) { - auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); - if (dyn_cast(op.getOperand(i).getType())) { - llvm_unreachable("Not implemented for tensor types"); - } - - // Only support scalars for now. - assert(elems.size() == 1); - if (i != 0) { - os << ", "; - } - os << getFormatSubstr(elems[0]); - values.push_back(elems[0]); - } - - llPrintf(formatStr, values, rewriter); - rewriter.eraseOp(op); - return success(); - } - - // TODO: This code is the same as the GPU-backend code. Consider refactoring. - std::string getFormatSubstr(Value value, bool hex = false, - std::optional width = std::nullopt) const { - Type type = value.getType(); - if (isa(type)) { - return "%p"; - } - // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the - // type (so 4 for fp16, 8 for int32, 16 for int64). - if (hex) { - // Ignore `width` for `hex` values, pad to typeWidth. - std::string ret = - "0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); - if (type.getIntOrFloatBitWidth() > 32) { - ret += "ll"; - } - ret += "x"; - return ret; - } - - std::string prefix = "%"; - if (width.has_value()) { - prefix += std::to_string(*width); - } else if (hex) { - prefix += "0"; - prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4); - } - - if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { - return prefix + "f"; - } else if (type.isSignedInteger()) { - if (type.getIntOrFloatBitWidth() == 64) - return prefix + "lli"; - else - return prefix + "i"; - } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { - if (type.getIntOrFloatBitWidth() == 64) - return prefix + "llu"; - else - return prefix + "u"; - } - assert(false && "not supported type"); - return ""; - } - - Value llPrintf(StringRef msg, ValueRange args, - ConversionPatternRewriter &rewriter, - int *formatStrByteCount = nullptr) const { - assert(!msg.empty() && "printf with empty string not supported"); - llvm::SmallString<64> msgNewline(msg); - msgNewline.push_back('\n'); - msgNewline.push_back('\0'); - Value msgValue = - LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), - rewriter, "printfFormat_", msgNewline); - targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); - if (formatStrByteCount) - *formatStrByteCount = msgNewline.size_in_bytes(); - return msgValue; - } - -protected: - const CPUTargetInfo &targetInfo; -}; - -} // namespace - -void mlir::triton::cpu::populatePrintOpToLLVMPattern( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - const CPUTargetInfo &targetInfo, PatternBenefit benefit) { - patterns.add(typeConverter, targetInfo, benefit); -} diff --git a/lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp deleted file mode 100644 index 65fef7a7d0d5..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp +++ /dev/null @@ -1,39 +0,0 @@ -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" -#include "triton/Conversion/TritonCPUToLLVM/Utility.h" - -namespace { - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -struct GetProgramIdOpConversion - : public ConvertOpToLLVMPattern { - explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter, - const CPUTargetInfo &targetInfo, - PatternBenefit benefit = 1) - : ConvertOpToLLVMPattern(typeConverter, benefit), - targetInfo(targetInfo) {} - - LogicalResult - matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value programId = targetInfo.programId( - rewriter, op->getLoc(), op->getParentOfType(), - op.getAxisAsInt()); - rewriter.replaceOp(op, programId); - return success(); - } - -private: - const CPUTargetInfo &targetInfo; -}; - -} // namespace - -void mlir::triton::cpu::populateSPMDOpToLLVMPattern( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - const CPUTargetInfo &targetInfo, PatternBenefit benefit) { - patterns.add(typeConverter, targetInfo, benefit); -} diff --git a/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp deleted file mode 100644 index cb15f87ee206..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp +++ /dev/null @@ -1,117 +0,0 @@ -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/LLVMCommon/VectorPattern.h" -#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonCPUToLLVM/Passes.h" -#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" -#include "triton/Conversion/TritonCPUToLLVM/TypeConverter.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_CONVERTTRITONCPUTOLLVM -#include "triton/Conversion/TritonCPUToLLVM/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; - -namespace { - -class TritonLLVMFunctionConversionTarget : public ConversionTarget { -public: - explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalOp(); - } -}; - -class TritonLLVMConversionTarget : public ConversionTarget { -public: - explicit TritonLLVMConversionTarget(MLIRContext &ctx) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addIllegalDialect(); - addIllegalDialect(); - addLegalOp(); - } -}; - -struct ConvertTritonCPUToLLVM - : public triton::impl::ConvertTritonCPUToLLVMBase { - using ConvertTritonCPUToLLVMBase< - ConvertTritonCPUToLLVM>::ConvertTritonCPUToLLVMBase; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - ConvertTritonCPUToLLVM() : ConvertTritonCPUToLLVMBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - mlir::LowerToLLVMOptions option(context); - option.overrideIndexBitwidth(32); - TritonCPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMConversionTarget convTarget(*context); - - // Lower functions - { - mlir::LowerToLLVMOptions option(context); - TritonCPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMFunctionConversionTarget funcTarget(*context); - RewritePatternSet funcPatterns(context); - mlir::triton::cpu::populateFuncOpConversionPattern( - typeConverter, funcPatterns, - mlir::triton::cpu::patternBenefitDefault); - mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, - funcPatterns); - if (failed( - applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) - return signalPassFailure(); - } - - RewritePatternSet patterns(context); - mlir::triton::cpu::CPUTargetInfo targetInfo; - int benefit = - mlir::triton::cpu::patternBenefitPrioritizeOverLLVMConversions; - mlir::triton::cpu::populateControlFlowOpToLLVMPattern(typeConverter, - patterns, benefit); - mlir::triton::cpu::populatePrintOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); - mlir::triton::cpu::populateSPMDOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // anonymous namespace - -namespace mlir { -namespace triton { - -std::unique_ptr> createConvertTritonCPUToLLVMPass() { - return std::make_unique(); -} - -} // namespace triton -} // namespace mlir diff --git a/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp deleted file mode 100644 index 72ef796fdabb..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp +++ /dev/null @@ -1,31 +0,0 @@ -#include "triton/Conversion/TritonCPUToLLVM/TypeConverter.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "triton/Conversion/MLIRTypes.h" -#include "llvm/Support/ErrorHandling.h" - -using namespace mlir; -using namespace mlir::triton; - -TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( - MLIRContext *ctx, LowerToLLVMOptions &option, - const DataLayoutAnalysis *analysis) - : LLVMTypeConverter(ctx, option, analysis) { - addConversion([&](triton::PointerType type) -> std::optional { - return convertTritonPointerType(type); - }); - - // Internally store bfloat16 as int16 - addConversion([&](BFloat16Type type) -> std::optional { - return IntegerType::get(type.getContext(), 16); - }); -} - -Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( - triton::PointerType type) { - auto ctx = type.getContext(); - auto pointeeType = type.getPointeeType(); - if (isa(pointeeType)) { - llvm_unreachable("Not implemented"); - } - return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); -} diff --git a/lib/Conversion/TritonToTritonCPU/CMakeLists.txt b/lib/Conversion/TritonToTritonCPU/CMakeLists.txt deleted file mode 100644 index f1b612b9c291..000000000000 --- a/lib/Conversion/TritonToTritonCPU/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -add_triton_library(TritonToTritonCPU - TritonCPUConversion.cpp - TritonToTritonCPUPass.cpp - - DEPENDS - TritonConversionToCPUPassIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MLIRTransforms - TritonIR - TritonCPUIR - TritonCPUTransforms -) diff --git a/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp b/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp deleted file mode 100644 index 97948404bdbf..000000000000 --- a/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include "triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h" - -#include "mlir/IR/IRMapping.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" -#include -#include - -using namespace mlir; -using namespace mlir::triton::cpu; - -// -// TypeConverter -// -TritonCPUTypeConverter::TritonCPUTypeConverter(MLIRContext *context) - : context(context) { - addConversion([](Type type) { return type; }); - - // Add encoding for tensor - addConversion([this](RankedTensorType tensorType) -> RankedTensorType { - // TODO: - return tensorType; - }); - - // Add encoding for tensor pointer - addConversion([this](triton::PointerType ptrType) -> triton::PointerType { - // Check whether tensor pointer `tt.ptr>` - auto pointeeTensorType = - dyn_cast(ptrType.getPointeeType()); - if (pointeeTensorType == nullptr) - return ptrType; - - // Add layout into the tensor - auto convertedTensorType = convertType(pointeeTensorType); - return triton::PointerType::get(convertedTensorType, - ptrType.getAddressSpace()); - }); - - // - // Materializations - // - // This will be called when (newArgType != origArgType) - // This will create newArg, and map(origArg, newArg) - addArgumentMaterialization([&](OpBuilder &builder, - RankedTensorType tensorType, ValueRange inputs, - Location loc) -> std::optional { - llvm_unreachable("Argument rematerialization should not happen in Triton " - "-> TritonCPU conversion"); - return std::nullopt; - }); - - // If the origValue still has live user(s), use this to - // convert origValue to newValue - addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, - ValueRange inputs, - Location loc) -> std::optional { - llvm_unreachable("Source rematerialization should not happen in Triton -> " - "TritonCPU Conversion"); - return std::nullopt; - }); - - // This will be called when (desiredType != newOperandType) - // where, desiredType = typeConverter->convertType(origType) - // NOTE: only for remapped values. - addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, - ValueRange inputs, Location loc) { - llvm_unreachable("Source rematerialization should not happen in Triton -> " - "TritonCPU Conversion"); - return std::nullopt; - }); -} - -// -// TritonCPUConversion -// -TritonCPUConversionTarget::TritonCPUConversionTarget( - MLIRContext &context, TritonCPUTypeConverter &typeConverter) - : ConversionTarget(context) { - // TODO: we should also verify ops of TritonCPUDialect - addLegalDialect(); - - // Some ops from SCF are illegal - addIllegalOp(); - - addDynamicallyLegalDialect([&](Operation *op) { - bool hasLegalRegions = true; - for (auto ®ion : op->getRegions()) { - hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); - } - if (hasLegalRegions && typeConverter.isLegal(op)) { - return true; - } - return false; - }); - - // We have requirements for the data layouts - addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { - Attribute aEncoding = - cast(dotOp.getA().getType()).getEncoding(); - Attribute bEncoding = - cast(dotOp.getB().getType()).getEncoding(); - // TODO: - return false; - }); -} diff --git a/lib/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.cpp b/lib/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.cpp deleted file mode 100644 index 44c41636a3f3..000000000000 --- a/lib/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.cpp +++ /dev/null @@ -1,41 +0,0 @@ -#include "triton/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "triton/Analysis/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h" -#include "llvm/ADT/APSInt.h" -#include - -#define GEN_PASS_CLASSES -#include "triton/Conversion/TritonToTritonCPU/Passes.h.inc" - -namespace { - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -class ConvertTritonToTritonCPU - : public ConvertTritonToTritonCPUBase { -public: - ConvertTritonToTritonCPU() = default; - - void runOnOperation() override { - // TODO: - } -}; - -} // namespace - -std::unique_ptr> -mlir::triton::createConvertTritonToTritonCPUPass() { - return std::make_unique<::ConvertTritonToTritonCPU>(); -} diff --git a/lib/Dialect/TritonCPU/CMakeLists.txt b/lib/Dialect/TritonCPU/CMakeLists.txt index 9f57627c321f..f33061b2d87c 100644 --- a/lib/Dialect/TritonCPU/CMakeLists.txt +++ b/lib/Dialect/TritonCPU/CMakeLists.txt @@ -1,2 +1 @@ add_subdirectory(IR) -add_subdirectory(Transforms) diff --git a/lib/Dialect/TritonCPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonCPU/Transforms/CMakeLists.txt deleted file mode 100644 index 1714215b9434..000000000000 --- a/lib/Dialect/TritonCPU/Transforms/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_triton_library(TritonCPUTransforms - - DEPENDS - TritonCPUTransformsIncGen - - LINK_LIBS PUBLIC - MLIRTransforms - MLIRTransformUtils - TritonAnalysis - TritonIR - TritonCPUIR - MLIRTransformUtils -) diff --git a/python/setup.py b/python/setup.py index 0dcdf9a91517..21f97553dd64 100644 --- a/python/setup.py +++ b/python/setup.py @@ -683,7 +683,7 @@ def get_entry_points(): "autopep8", "flake8", "isort", - "numpy", + "numpy<2.0.0", "pytest", "scipy>=1.7.1", "llnl-hatchet", diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 24f0310d0960..6f9ba2bffe0b 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -538,6 +538,21 @@ void init_triton_llvm(py::module &&m) { } } }); + + m.def("get_cpu_tripple", []() { return llvm::sys::getProcessTriple(); }); + + m.def("get_cpu_name", []() { return llvm::sys::getHostCPUName().str(); }); + + m.def("get_cpu_features", []() { + auto features = llvm::sys::getHostCPUFeatures(); + + std::set res; + for (auto &f : features) { + if (f.second) + res.insert(f.first().str()); + } + return res; + }); } void triton_stacktrace_signal_handler(void *) { diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 246f61d8cd38..9dc637a643da 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5440,9 +5440,6 @@ def mul_add(data): if is_cuda(): found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None assert found_fma == enable_fp_fusion - elif is_cpu(): - found_fma = re.search(r'vfma', h.asm["asm"].decode('utf-8')) is not None - assert found_fma == enable_fp_fusion # ----------------------- diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index 1b08addbc9b7..b107e2434e1e 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -3,6 +3,6 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) - add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM) + add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms) target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm) endif() diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 0a98532eceba..6a4e3d08535c 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -43,6 +43,9 @@ def supports_target(target: GPUTarget): def __init__(self, target: tuple) -> None: super().__init__(target) self.binary_ext = "asm" + self.cpu_arch = llvm.get_cpu_tripple().split("-")[0] + self.cpu_name = llvm.get_cpu_name() + self.cpu_features = llvm.get_cpu_features() def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} @@ -86,6 +89,19 @@ def make_ttcir(mod, metadata, opt): metadata["cluster_dims"] = (opt.cluster_dims[0], opt.cluster_dims[1], opt.cluster_dims[2]) return mod + def make_tttcir(self, mod, metadata, opt): + # TTCIR -> Target TTCIR + pm = ir.pass_manager(mod.context) + pm.enable_debug() + if self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features: + cpu.passes.ttcpuir.add_convert_unsupported_ops(pm) + cpu.passes.ttcpuir.add_decompose_fp_conversions(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + passes.common.add_canonicalizer(pm) + pm.run(mod) + return mod + @staticmethod def make_llir(src, metadata, options): # warp-specialization mutates num_warps @@ -144,6 +160,7 @@ def make_asm(src, metadata, options): def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options) + stages["tttcir"] = lambda src, metadata: self.make_tttcir(src, metadata, options) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) stages["asm"] = lambda src, metadata: self.make_asm(src, metadata, options) diff --git a/third_party/cpu/include/CMakeLists.txt b/third_party/cpu/include/CMakeLists.txt index fc9a19e52b0d..b4c91e794072 100644 --- a/third_party/cpu/include/CMakeLists.txt +++ b/third_party/cpu/include/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonCPUTransforms) add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/include/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/include/TritonCPUTransforms/CMakeLists.txt new file mode 100644 index 000000000000..cb2cb234172d --- /dev/null +++ b/third_party/cpu/include/TritonCPUTransforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUTransforms) +add_public_tablegen_target(TritonCPUTransformsPassIncGen) diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h new file mode 100644 index 000000000000..035122aa98fb --- /dev/null +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -0,0 +1,32 @@ +#ifndef TritonCPUTransforms_CONVERSION_PASSES_H +#define TritonCPUTransforms_CONVERSION_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { +namespace cpu { + +#define GEN_PASS_DECL +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" + +std::unique_ptr> createConvertUnsupportedOps(); +std::unique_ptr> createDecomposeFpConversions(); + +#define GEN_PASS_REGISTRATION +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" + +} // namespace cpu +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td new file mode 100644 index 000000000000..0eb4910394fd --- /dev/null +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -0,0 +1,39 @@ +#ifndef TRITONCPUOPT_CONVERSION_PASSES +#define TRITONCPUOPT_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertUnsupportedOps : Pass<"triton-cpu-add-casts-for-unsupported-ops", "mlir::ModuleOp"> { + let summary = "Convert operations on unsupported types."; + let description = [{ + This pass converts various operations on data types that are not supported + by the target natively. Operations are converted to a supported data type + with casts added for inputs and the result. + }]; + // TODO: add options to specify which operations to convert. + let constructor = "mlir::triton::cpu::createConvertUnsupportedOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def DecomposeFpConversions : Pass<"triton-cpu-decompose-fp-conversions", "mlir::ModuleOp"> { + let summary = "Decompose fp conversion ops."; + let description = [{ + This pass is used for targets lacking native instructions to convert FP + vectors. By default, LLVM would decompose them using scalar FP conversion + intrinsics. This pass transforms such conversions into vector code + instead. + }]; + // TODO: add options to specify which FP conversions to decompose. + let constructor = "mlir::triton::cpu::createDecomposeFpConversions()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +#endif diff --git a/third_party/cpu/lib/CMakeLists.txt b/third_party/cpu/lib/CMakeLists.txt index 1db64c58ec20..fad51ab86ea9 100644 --- a/third_party/cpu/lib/CMakeLists.txt +++ b/third_party/cpu/lib/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Analysis) add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonCPUTransforms) add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt new file mode 100644 index 000000000000..5a52aa7e86b6 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt @@ -0,0 +1,7 @@ +add_triton_library(TritonCPUTransforms + ConvertUnsupportedOps.cpp + DecomposeFpConversions.cpp + + DEPENDS + TritonCPUTransformsPassIncGen +) diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp new file mode 100644 index 000000000000..dec4970d1ccd --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp @@ -0,0 +1,204 @@ +#include "OptCommon.h" + +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTUNSUPPORTEDOPS +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +template +struct ConvertBf16ToFp32 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpT op, + PatternRewriter &rewriter) const override { + // TODO: support mixed-type ops? + if (!isAllBf16(op->getOperandTypes()) || !isAllBf16(op->getResultTypes())) + return failure(); + + Location loc = op.getLoc(); + OperationState newState(loc, OpT::getOperationName()); + // Convert operands to fp32 and generate fp32 op. + for (auto operand : op->getOperands()) { + Value newOperand = rewriter.create( + loc, toFp32(operand.getType()), operand); + newState.operands.push_back(newOperand); + } + newState.types = toFp32(op->getResultTypes()); + newState.attributes = op->getAttrs(); + auto newOp = rewriter.create(newState); + + // Convert op results back to Bf16 + SmallVector results; + for (auto res : llvm::enumerate(newOp->getResults())) + results.push_back(rewriter.create( + loc, op->getResult(res.index()).getType(), res.value())); + rewriter.replaceOp(op, results); + + return success(); + } + + bool isAllBf16(TypeRange types) const { + return std::all_of(types.begin(), types.end(), + [this](auto ty) { return isBf16(ty); }); + } + + SmallVector toFp32(TypeRange types) const { + SmallVector res; + for (auto ty : types) + res.push_back(::toFp32(ty)); + return res; + } +}; + +template +struct ConvertIToBf16ToFp32 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpT op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getType())) + return failure(); + + Location loc = op.getLoc(); + Value fp32Val = + rewriter.create(loc, toFp32(op.getType()), op.getOperand()); + Value res = rewriter.create(loc, op.getType(), fp32Val); + rewriter.replaceOp(op, res); + return success(); + } +}; + +Value convertMemRefToI16(Value memRef, PatternRewriter &rewriter) { + // Memory references for masked operations are always built + // with PtrToMemRefOp. + auto def = memRef.getDefiningOp(); + assert(def); + auto insPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(def); + MemRefType memRefTy = cast(memRef.getType()); + Type newMemRefTy = + MemRefType::get(memRefTy.getShape(), rewriter.getI16Type(), + memRefTy.getLayout(), memRefTy.getMemorySpace()); + Value res = rewriter.create(memRef.getLoc(), newMemRefTy, + def.getSrc()); + rewriter.restoreInsertionPoint(insPoint); + return res; +} + +struct ConvertBf16MaskedLoadOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MaskedLoadOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getType())) + return failure(); + + Location loc = op.getLoc(); + Value newBase = convertMemRefToI16(op.getBase(), rewriter); + Value newPassThru = rewriter.create( + loc, toInt16(op.getPassThru().getType()), op.getPassThru()); + Value intVal = rewriter.create( + loc, toInt16(op.getType()), newBase, op.getIndices(), op.getMask(), + newPassThru); + Value res = rewriter.create(loc, op.getType(), intVal); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertBf16MaskedStoreOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MaskedStoreOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getValueToStore().getType())) + return failure(); + + Location loc = op.getLoc(); + Value newBase = convertMemRefToI16(op.getBase(), rewriter); + Value intVal = rewriter.create( + loc, toInt16(op.getValueToStore().getType()), op.getValueToStore()); + rewriter.replaceOpWithNewOp( + op, newBase, op.getIndices(), op.getMask(), intVal); + return success(); + } +}; + +struct ConvertBf16Abs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::AbsFOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getType()) || !isBf16(op.getOperand().getType())) + return failure(); + + Location loc = op.getLoc(); + Value src = op.getOperand(); + Value intSrc = + rewriter.create(loc, toInt16(op.getType()), src); + TypedAttr maskAttr = rewriter.getI16IntegerAttr(0x7fff); + if (auto vecTy = dyn_cast(intSrc.getType())) + maskAttr = SplatElementsAttr::get(vecTy, maskAttr); + Value mask = rewriter.create(loc, maskAttr); + Value res = rewriter.create(loc, intSrc, mask); + res = rewriter.create(loc, op.getType(), res); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertUnsupportedOps + : public triton::impl::ConvertUnsupportedOpsBase { + using ConvertUnsupportedOpsBase::ConvertUnsupportedOpsBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + RewritePatternSet patterns(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add(context); + patterns.add(context); + + patterns.add(context); + + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertUnsupportedOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp new file mode 100644 index 000000000000..a82958b3ae2b --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp @@ -0,0 +1,81 @@ +#include "OptCommon.h" + +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_DECOMPOSEFPCONVERSIONS +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +struct Fp32ToBf16Conversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const override { + Value src = op.getIn(); + if (!isBf16(op.getType()) || !isFp32(src.getType())) + return failure(); + + Location loc = op.getLoc(); + Value i32Src = + rewriter.create(loc, toInt32(src.getType()), src); + TypedAttr shiftValAttr = rewriter.getI32IntegerAttr(16); + if (auto vecTy = dyn_cast(i32Src.getType())) + shiftValAttr = SplatElementsAttr::get(vecTy, shiftValAttr); + Value shiftedSrc = rewriter.create( + loc, i32Src, rewriter.create(loc, shiftValAttr)); + Value i16Res = rewriter.create(loc, toInt16(src.getType()), + shiftedSrc); + Value res = rewriter.create(loc, op.getType(), i16Res); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct DecomposeFpConversions + : public triton::impl::DecomposeFpConversionsBase { + using DecomposeFpConversionsBase::DecomposeFpConversionsBase; + + DecomposeFpConversions() : DecomposeFpConversionsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + RewritePatternSet patterns(context); + patterns.add(context); + + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createDecomposeFpConversions() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/OptCommon.h b/third_party/cpu/lib/TritonCPUTransforms/OptCommon.h new file mode 100644 index 000000000000..a9fe054b8ede --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/OptCommon.h @@ -0,0 +1,46 @@ +#ifndef TRITONCPU_CONVERSION_TRITONCPUOPT_OPTCOMMON_H +#define TRITONCPU_CONVERSION_TRITONCPUOPT_OPTCOMMON_H + +#include "mlir/IR/BuiltinTypes.h" + +namespace mlir { +namespace triton { +namespace cpu { + +inline bool isTyOrVectorOf(mlir::Type ty, mlir::Type elemTy) { + if (auto vecTy = dyn_cast(ty)) + return vecTy.getElementType() == elemTy; + return ty == elemTy; +} + +inline bool isBf16(mlir::Type ty) { + return isTyOrVectorOf(ty, mlir::BFloat16Type::get(ty.getContext())); +} + +inline bool isFp32(mlir::Type ty) { + return isTyOrVectorOf(ty, mlir::Float32Type::get(ty.getContext())); +} + +inline mlir::Type toTyOrVectorOf(mlir::Type ty, mlir::Type elemTy) { + if (auto vecTy = dyn_cast(ty)) + return vecTy.cloneWith(std::nullopt, elemTy); + return elemTy; +} + +inline mlir::Type toInt16(mlir::Type ty) { + return toTyOrVectorOf(ty, mlir::IntegerType::get(ty.getContext(), 16)); +} + +inline mlir::Type toInt32(mlir::Type ty) { + return toTyOrVectorOf(ty, mlir::IntegerType::get(ty.getContext(), 32)); +} + +inline mlir::Type toFp32(mlir::Type ty) { + return toTyOrVectorOf(ty, mlir::Float32Type::get(ty.getContext())); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index fa4eb818dce5..06bfef0b299a 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -1,4 +1,5 @@ #include "TritonCPUToLLVM/Passes.h" +#include "TritonCPUTransforms/Passes.h" #include "TritonToTritonCPU/Passes.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" @@ -34,6 +35,12 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_triton_cpu_to_llvmir_pipeline", [](mlir::PassManager &pm) { mlir::triton::cpu::tritonCPUToLLVMPipelineBuilder(pm); }); + m.def("add_convert_unsupported_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertUnsupportedOps()); + }); + m.def("add_decompose_fp_conversions", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createDecomposeFpConversions()); + }); m.def("add_vector_to_scf", [](mlir::PassManager &pm, bool full_unroll, unsigned target_rank, bool lower_tensors) { mlir::VectorTransferToSCFOptions opts;