From feec8adca007a47d8796a389d3e203a36f3b03ae Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 2 Jul 2024 12:28:34 -0500 Subject: [PATCH] Add conversions for mixed precision matmuls. (#32) Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 13 +++ third_party/cpu/backend/compiler.py | 7 +- .../cpu/include/TritonCPUTransforms/Passes.h | 3 + .../cpu/include/TritonCPUTransforms/Passes.td | 11 +- .../ConvertUnsupportedOps.cpp | 103 ++++++++++++++++-- third_party/cpu/triton_cpu.cc | 9 +- 6 files changed, 129 insertions(+), 17 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3290f1253aaf..75dbd1b6bf5a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3152,6 +3152,7 @@ def convert_fp8_to_fp32(x, device, dtype_str): assert "Unsupported float8 dtype" +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack", @@ -3182,6 +3183,18 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty if is_interpreter(): if in_dtype == 'bfloat16': pytest.skip("bfloat16 is not supported in the interpreter") + elif is_cpu(): + if input_precision != "ieee": + pytest.skip(f"{input_precision} not supported on CPU") + if in_dtype == 'float8e4nv' or in_dtype == 'float8e5': + pytest.skip("float8e4nv and float8e5 not supported on CPU") + # This test kernel runs in a single thread and can take a long time + # for bigger sizes with the current codegen on CPU. Limit input sizes + # by default to get more reasonable tests execution time. + if os.environ.get('TRITON_CPU_TEST_DOT_FULL_SIZE', '0') != '1': + M = min(M, 64) + N = min(N, 64) + K = min(K, 32) else: if is_cuda(): capability = torch.cuda.get_device_capability() diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 6a4e3d08535c..6ed2a6f6e111 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -93,8 +93,11 @@ def make_tttcir(self, mod, metadata, opt): # TTCIR -> Target TTCIR pm = ir.pass_manager(mod.context) pm.enable_debug() - if self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features: - cpu.passes.ttcpuir.add_convert_unsupported_ops(pm) + promote_bf16_to_fp32 = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features + # We don't have any lowering for mixed precision matmuls, so always use casts for now + convert_mixed_precision_matmul = True + cpu.passes.ttcpuir.add_convert_unsupported_ops(pm, promote_bf16_to_fp32, convert_mixed_precision_matmul) + if promote_bf16_to_fp32: cpu.passes.ttcpuir.add_decompose_fp_conversions(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h index 035122aa98fb..213161fecc8a 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.h +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -19,6 +19,9 @@ namespace cpu { #include "cpu/include/TritonCPUTransforms/Passes.h.inc" std::unique_ptr> createConvertUnsupportedOps(); +std::unique_ptr> +createConvertUnsupportedOps(bool promoteBf16ToFp32, + bool convertMixedPrecisionMatmul); std::unique_ptr> createDecomposeFpConversions(); #define GEN_PASS_REGISTRATION diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td index 0eb4910394fd..2e92bc42c6c5 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.td +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -10,7 +10,16 @@ def ConvertUnsupportedOps : Pass<"triton-cpu-add-casts-for-unsupported-ops", "ml by the target natively. Operations are converted to a supported data type with casts added for inputs and the result. }]; - // TODO: add options to specify which operations to convert. + + let options = [ + Option<"promoteBf16ToFp32", "promote-bf16-to-fp32", + "bool", /*default*/"false", + "Convert BF16 operations to FP32.">, + Option<"convertMixedPrecisionMatmul", "convert-mixed-precision-matmul", + "bool", /*default*/"false", + "Convert inputs of a mixed-precision matmul to a destination type.">, + ]; + let constructor = "mlir::triton::cpu::createConvertUnsupportedOps()"; let dependentDialects = ["mlir::arith::ArithDialect", diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp index dec4970d1ccd..5d991b376902 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp @@ -11,8 +11,10 @@ namespace mlir { namespace triton { +namespace cpu { #define GEN_PASS_DEF_CONVERTUNSUPPORTEDOPS #include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu } // namespace triton } // namespace mlir @@ -165,24 +167,96 @@ struct ConvertBf16Abs : public OpRewritePattern { } }; +struct ConvertMixedPrecisionMatmul + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Value acc = op.getAcc(); + auto lhsTy = cast(lhs.getType()); + auto rhsTy = cast(rhs.getType()); + auto accTy = cast(acc.getType()); + auto resTy = cast(op.getType()); + + if (lhsTy.getElementType() == resTy.getElementType() && + rhsTy.getElementType() == resTy.getElementType() && + accTy.getElementType() == resTy.getElementType()) + return failure(); + + Type commonElemTy = resTy.getElementType(); + if (lhsTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth()) + commonElemTy = lhsTy; + if (rhsTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth()) + commonElemTy = rhsTy; + if (accTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth()) + commonElemTy = accTy; + + lhs = castElemTy(loc, lhs, commonElemTy, rewriter); + rhs = castElemTy(loc, rhs, commonElemTy, rewriter); + acc = castElemTy(loc, acc, commonElemTy, rewriter); + + Value newRes = rewriter.create( + loc, lhs, rhs, acc, op.getIndexingMaps(), op.getIteratorTypes()); + newRes = castElemTy(loc, newRes, resTy.getElementType(), rewriter); + + rewriter.replaceOp(op, newRes); + return success(); + } + + Value castElemTy(Location loc, Value val, Type elemTy, + PatternRewriter &rewriter) const { + auto valTy = cast(val.getType()); + if (valTy.getElementType() == elemTy) + return val; + + auto resTy = toTyOrVectorOf(valTy, elemTy); + if (valTy.getElementType().isInteger()) { + if (valTy.getElementTypeBitWidth() > elemTy.getIntOrFloatBitWidth()) + return rewriter.create(loc, resTy, val); + else + return rewriter.create(loc, resTy, val); + } else { + if (valTy.getElementTypeBitWidth() > elemTy.getIntOrFloatBitWidth()) + return rewriter.create(loc, resTy, val); + else + return rewriter.create(loc, resTy, val); + } + } +}; + struct ConvertUnsupportedOps - : public triton::impl::ConvertUnsupportedOpsBase { - using ConvertUnsupportedOpsBase::ConvertUnsupportedOpsBase; + : public triton::cpu::impl::ConvertUnsupportedOpsBase< + ConvertUnsupportedOps> { + ConvertUnsupportedOps() = default; + + ConvertUnsupportedOps(bool promoteBf16ToFp32, + bool convertMixedPrecisionMatmul) { + this->promoteBf16ToFp32 = promoteBf16ToFp32; + this->convertMixedPrecisionMatmul = convertMixedPrecisionMatmul; + } void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); RewritePatternSet patterns(context); - patterns.add>(context); - patterns.add>(context); - patterns.add>(context); - patterns.add>(context); - patterns.add>(context); - patterns.add(context); - patterns.add(context); - - patterns.add(context); + if (promoteBf16ToFp32) { + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + } + if (convertMixedPrecisionMatmul) { + patterns.add(context); + } if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) return signalPassFailure(); @@ -199,6 +273,13 @@ std::unique_ptr> createConvertUnsupportedOps() { return std::make_unique(); } +std::unique_ptr> +createConvertUnsupportedOps(bool promoteBf16ToFp32, + bool convertMixedPrecisionMatmul) { + return std::make_unique(promoteBf16ToFp32, + convertMixedPrecisionMatmul); +} + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 06bfef0b299a..748b72fe549d 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -35,9 +35,12 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_triton_cpu_to_llvmir_pipeline", [](mlir::PassManager &pm) { mlir::triton::cpu::tritonCPUToLLVMPipelineBuilder(pm); }); - m.def("add_convert_unsupported_ops", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::cpu::createConvertUnsupportedOps()); - }); + m.def("add_convert_unsupported_ops", + [](mlir::PassManager &pm, bool promote_bf16_to_fp32, + bool convert_mixed_precision_matmul) { + pm.addPass(mlir::triton::cpu::createConvertUnsupportedOps( + promote_bf16_to_fp32, convert_mixed_precision_matmul)); + }); m.def("add_decompose_fp_conversions", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createDecomposeFpConversions()); });