Skip to content

Commit

Permalink
Add conversions for mixed precision matmuls. (#32)
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich authored and minjang committed Oct 23, 2024
1 parent 5da71b5 commit feec8ad
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 17 deletions.
13 changes: 13 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3152,6 +3152,7 @@ def convert_fp8_to_fp32(x, device, dtype_str):
assert "Unsupported float8 dtype"


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize(
"M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack",
Expand Down Expand Up @@ -3182,6 +3183,18 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
if is_interpreter():
if in_dtype == 'bfloat16':
pytest.skip("bfloat16 is not supported in the interpreter")
elif is_cpu():
if input_precision != "ieee":
pytest.skip(f"{input_precision} not supported on CPU")
if in_dtype == 'float8e4nv' or in_dtype == 'float8e5':
pytest.skip("float8e4nv and float8e5 not supported on CPU")
# This test kernel runs in a single thread and can take a long time
# for bigger sizes with the current codegen on CPU. Limit input sizes
# by default to get more reasonable tests execution time.
if os.environ.get('TRITON_CPU_TEST_DOT_FULL_SIZE', '0') != '1':
M = min(M, 64)
N = min(N, 64)
K = min(K, 32)
else:
if is_cuda():
capability = torch.cuda.get_device_capability()
Expand Down
7 changes: 5 additions & 2 deletions third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ def make_tttcir(self, mod, metadata, opt):
# TTCIR -> Target TTCIR
pm = ir.pass_manager(mod.context)
pm.enable_debug()
if self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features:
cpu.passes.ttcpuir.add_convert_unsupported_ops(pm)
promote_bf16_to_fp32 = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features
# We don't have any lowering for mixed precision matmuls, so always use casts for now
convert_mixed_precision_matmul = True
cpu.passes.ttcpuir.add_convert_unsupported_ops(pm, promote_bf16_to_fp32, convert_mixed_precision_matmul)
if promote_bf16_to_fp32:
cpu.passes.ttcpuir.add_decompose_fp_conversions(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
Expand Down
3 changes: 3 additions & 0 deletions third_party/cpu/include/TritonCPUTransforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ namespace cpu {
#include "cpu/include/TritonCPUTransforms/Passes.h.inc"

std::unique_ptr<OperationPass<ModuleOp>> createConvertUnsupportedOps();
std::unique_ptr<OperationPass<ModuleOp>>
createConvertUnsupportedOps(bool promoteBf16ToFp32,
bool convertMixedPrecisionMatmul);
std::unique_ptr<OperationPass<ModuleOp>> createDecomposeFpConversions();

#define GEN_PASS_REGISTRATION
Expand Down
11 changes: 10 additions & 1 deletion third_party/cpu/include/TritonCPUTransforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@ def ConvertUnsupportedOps : Pass<"triton-cpu-add-casts-for-unsupported-ops", "ml
by the target natively. Operations are converted to a supported data type
with casts added for inputs and the result.
}];
// TODO: add options to specify which operations to convert.

let options = [
Option<"promoteBf16ToFp32", "promote-bf16-to-fp32",
"bool", /*default*/"false",
"Convert BF16 operations to FP32.">,
Option<"convertMixedPrecisionMatmul", "convert-mixed-precision-matmul",
"bool", /*default*/"false",
"Convert inputs of a mixed-precision matmul to a destination type.">,
];

let constructor = "mlir::triton::cpu::createConvertUnsupportedOps()";

let dependentDialects = ["mlir::arith::ArithDialect",
Expand Down
103 changes: 92 additions & 11 deletions third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

namespace mlir {
namespace triton {
namespace cpu {
#define GEN_PASS_DEF_CONVERTUNSUPPORTEDOPS
#include "cpu/include/TritonCPUTransforms/Passes.h.inc"
} // namespace cpu
} // namespace triton
} // namespace mlir

Expand Down Expand Up @@ -165,24 +167,96 @@ struct ConvertBf16Abs : public OpRewritePattern<math::AbsFOp> {
}
};

struct ConvertMixedPrecisionMatmul
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Value acc = op.getAcc();
auto lhsTy = cast<VectorType>(lhs.getType());
auto rhsTy = cast<VectorType>(rhs.getType());
auto accTy = cast<VectorType>(acc.getType());
auto resTy = cast<VectorType>(op.getType());

if (lhsTy.getElementType() == resTy.getElementType() &&
rhsTy.getElementType() == resTy.getElementType() &&
accTy.getElementType() == resTy.getElementType())
return failure();

Type commonElemTy = resTy.getElementType();
if (lhsTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth())
commonElemTy = lhsTy;
if (rhsTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth())
commonElemTy = rhsTy;
if (accTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth())
commonElemTy = accTy;

lhs = castElemTy(loc, lhs, commonElemTy, rewriter);
rhs = castElemTy(loc, rhs, commonElemTy, rewriter);
acc = castElemTy(loc, acc, commonElemTy, rewriter);

Value newRes = rewriter.create<vector::ContractionOp>(
loc, lhs, rhs, acc, op.getIndexingMaps(), op.getIteratorTypes());
newRes = castElemTy(loc, newRes, resTy.getElementType(), rewriter);

rewriter.replaceOp(op, newRes);
return success();
}

Value castElemTy(Location loc, Value val, Type elemTy,
PatternRewriter &rewriter) const {
auto valTy = cast<VectorType>(val.getType());
if (valTy.getElementType() == elemTy)
return val;

auto resTy = toTyOrVectorOf(valTy, elemTy);
if (valTy.getElementType().isInteger()) {
if (valTy.getElementTypeBitWidth() > elemTy.getIntOrFloatBitWidth())
return rewriter.create<arith::TruncIOp>(loc, resTy, val);
else
return rewriter.create<arith::ExtSIOp>(loc, resTy, val);
} else {
if (valTy.getElementTypeBitWidth() > elemTy.getIntOrFloatBitWidth())
return rewriter.create<arith::TruncFOp>(loc, resTy, val);
else
return rewriter.create<arith::ExtFOp>(loc, resTy, val);
}
}
};

struct ConvertUnsupportedOps
: public triton::impl::ConvertUnsupportedOpsBase<ConvertUnsupportedOps> {
using ConvertUnsupportedOpsBase::ConvertUnsupportedOpsBase;
: public triton::cpu::impl::ConvertUnsupportedOpsBase<
ConvertUnsupportedOps> {
ConvertUnsupportedOps() = default;

ConvertUnsupportedOps(bool promoteBf16ToFp32,
bool convertMixedPrecisionMatmul) {
this->promoteBf16ToFp32 = promoteBf16ToFp32;
this->convertMixedPrecisionMatmul = convertMixedPrecisionMatmul;
}

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

RewritePatternSet patterns(context);
patterns.add<ConvertBf16ToFp32<arith::AddFOp>>(context);
patterns.add<ConvertBf16ToFp32<arith::SubFOp>>(context);
patterns.add<ConvertBf16ToFp32<arith::MulFOp>>(context);
patterns.add<ConvertIToBf16ToFp32<arith::SIToFPOp>>(context);
patterns.add<ConvertIToBf16ToFp32<arith::UIToFPOp>>(context);
patterns.add<ConvertBf16MaskedLoadOp>(context);
patterns.add<ConvertBf16MaskedStoreOp>(context);

patterns.add<ConvertBf16Abs>(context);
if (promoteBf16ToFp32) {
patterns.add<ConvertBf16ToFp32<arith::AddFOp>>(context);
patterns.add<ConvertBf16ToFp32<arith::SubFOp>>(context);
patterns.add<ConvertBf16ToFp32<arith::MulFOp>>(context);
patterns.add<ConvertIToBf16ToFp32<arith::SIToFPOp>>(context);
patterns.add<ConvertIToBf16ToFp32<arith::UIToFPOp>>(context);
patterns.add<ConvertBf16MaskedLoadOp>(context);
patterns.add<ConvertBf16MaskedStoreOp>(context);
patterns.add<ConvertBf16Abs>(context);
}
if (convertMixedPrecisionMatmul) {
patterns.add<ConvertMixedPrecisionMatmul>(context);
}

if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns))))
return signalPassFailure();
Expand All @@ -199,6 +273,13 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertUnsupportedOps() {
return std::make_unique<ConvertUnsupportedOps>();
}

std::unique_ptr<OperationPass<ModuleOp>>
createConvertUnsupportedOps(bool promoteBf16ToFp32,
bool convertMixedPrecisionMatmul) {
return std::make_unique<ConvertUnsupportedOps>(promoteBf16ToFp32,
convertMixedPrecisionMatmul);
}

} // namespace cpu
} // namespace triton
} // namespace mlir
9 changes: 6 additions & 3 deletions third_party/cpu/triton_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) {
m.def("add_triton_cpu_to_llvmir_pipeline", [](mlir::PassManager &pm) {
mlir::triton::cpu::tritonCPUToLLVMPipelineBuilder(pm);
});
m.def("add_convert_unsupported_ops", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::cpu::createConvertUnsupportedOps());
});
m.def("add_convert_unsupported_ops",
[](mlir::PassManager &pm, bool promote_bf16_to_fp32,
bool convert_mixed_precision_matmul) {
pm.addPass(mlir::triton::cpu::createConvertUnsupportedOps(
promote_bf16_to_fp32, convert_mixed_precision_matmul));
});
m.def("add_decompose_fp_conversions", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::cpu::createDecomposeFpConversions());
});
Expand Down

0 comments on commit feec8ad

Please sign in to comment.