Skip to content

Commit

Permalink
Add unsupported op conversions for BF16 type.
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich committed Jun 18, 2024
1 parent ede8a8e commit 825ca0f
Show file tree
Hide file tree
Showing 15 changed files with 458 additions and 2 deletions.
2 changes: 2 additions & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"

#include "cpu/include/TritonCPUOpt/Passes.h"
#include "cpu/include/TritonCPUToLLVM/Passes.h"
#include "cpu/include/TritonToTritonCPU/Passes.h"
#include "nvidia/include/NVGPUToLLVM/Passes.h"
Expand Down Expand Up @@ -65,6 +66,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
// CPU passes
mlir::triton::cpu::registerTritonToTritonCPUPasses();
mlir::triton::cpu::registerTritonToTritonCPUPipeline();
mlir::triton::cpu::registerTritonCPUOptPasses();
mlir::triton::cpu::registerTritonCPUToLLVMPasses();
mlir::triton::cpu::registerTritonCPUToLLVMPipeline();

Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def get_install_requires():
"autopep8",
"flake8",
"isort",
"numpy",
"numpy<2.0.0",
"pytest",
"scipy>=1.7.1",
],
Expand Down
16 changes: 16 additions & 0 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,4 +411,20 @@ void init_triton_llvm(py::module &&m) {
}
}
});

m.def("get_cpu_tripple", []() { return llvm::sys::getProcessTriple(); });

m.def("get_cpu_name", []() { return llvm::sys::getHostCPUName().str(); });

m.def("get_cpu_features", []() {
llvm::StringMap<bool> features;
llvm::sys::getHostCPUFeatures(features);

std::set<std::string> res;
for (auto &f : features) {
if (f.second)
res.insert(f.first().str());
}
return res;
});
}
2 changes: 1 addition & 1 deletion third_party/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
add_subdirectory(include)
add_subdirectory(lib)
if(TRITON_BUILD_PYTHON_MODULE)
add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM)
add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUOpt)
target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm)
endif()
17 changes: 17 additions & 0 deletions third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def supports_target(target: GPUTarget):
def __init__(self, target: tuple) -> None:
super().__init__(target)
self.binary_ext = "bc"
self.cpu_arch = llvm.get_cpu_tripple().split("-")[0]
self.cpu_name = llvm.get_cpu_name()
self.cpu_features = llvm.get_cpu_features()

def parse_options(self, opts) -> Any:
args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts}
Expand Down Expand Up @@ -86,6 +89,19 @@ def make_ttcir(mod, metadata, opt):
metadata["cluster_dims"] = (opt.cluster_dims[0], opt.cluster_dims[1], opt.cluster_dims[2])
return mod

def make_ottcir(self, mod, metadata, opt):
# TTCIR -> Optimized TTCIR
pm = ir.pass_manager(mod.context)
pm.enable_debug()
if self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features:
cpu.passes.ttcpuir.add_convert_unsupported_ops(pm)
cpu.passes.ttcpuir.add_vectorize_fp_to_fp(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
passes.common.add_canonicalizer(pm)
pm.run(mod)
return mod

@staticmethod
def make_llir(src, metadata, options):
# warp-specialization mutates num_warps
Expand Down Expand Up @@ -152,6 +168,7 @@ def make_bc(src, metadata, options):
def add_stages(self, stages, options):
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options)
stages["ottcir"] = lambda src, metadata: self.make_ottcir(src, metadata, options)
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
stages["bc"] = lambda src, metadata: self.make_bc(src, metadata, options)

Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/include/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(TritonCPUOpt)
add_subdirectory(TritonCPUToLLVM)
add_subdirectory(TritonToTritonCPU)
3 changes: 3 additions & 0 deletions third_party/cpu/include/TritonCPUOpt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUOpt)
add_public_tablegen_target(TritonCPUOptPassIncGen)
32 changes: 32 additions & 0 deletions third_party/cpu/include/TritonCPUOpt/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef TRITONCPUOPT_CONVERSION_PASSES_H
#define TRITONCPUOPT_CONVERSION_PASSES_H

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include <memory>

namespace mlir {

class ModuleOp;
template <typename T> class OperationPass;

namespace triton {
namespace cpu {

#define GEN_PASS_DECL
#include "cpu/include/TritonCPUOpt/Passes.h.inc"

std::unique_ptr<OperationPass<ModuleOp>> createConvertUnsupportedOps();
std::unique_ptr<OperationPass<ModuleOp>> createVectorizeFpToFp();

#define GEN_PASS_REGISTRATION
#include "cpu/include/TritonCPUOpt/Passes.h.inc"

} // namespace cpu
} // namespace triton

} // namespace mlir

#endif
39 changes: 39 additions & 0 deletions third_party/cpu/include/TritonCPUOpt/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef TRITONCPUOPT_CONVERSION_PASSES
#define TRITONCPUOPT_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def ConvertUnsupportedOps : Pass<"triton-cpu-add-casts-for-unsupported-ops", "mlir::ModuleOp"> {
let summary = "Convert operations on unsupported types.";
let description = [{
This pass converts various operations on data types that are not supported
by the target natively. Operations are converted to a supported data type
with casts added for inputs and the result.
}];
// TODO: add options to specify which operations to convert.
let constructor = "mlir::triton::cpu::createConvertUnsupportedOps()";

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::vector::VectorDialect",
"mlir::triton::TritonDialect",
"mlir::triton::cpu::TritonCPUDialect"];
}

def VectorizeFpToFp : Pass<"triton-cpu-vectorize-fp-to-fp", "mlir::ModuleOp"> {
let summary = "Decompose fp conversion ops.";
let description = [{
This pass is used for targets lacking native instructions to convert FP
vectors. By default, LLVM would decompose them using scalar FP conversion
intrinsics. This pass transforms such conversions into vector code
instead.
}];
// TODO: add options to specify which FP conversions to decompose.
let constructor = "mlir::triton::cpu::createVectorizeFpToFp()";

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::vector::VectorDialect",
"mlir::triton::TritonDialect",
"mlir::triton::cpu::TritonCPUDialect"];
}

#endif
1 change: 1 addition & 0 deletions third_party/cpu/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(Analysis)
add_subdirectory(TritonCPUOpt)
add_subdirectory(TritonCPUToLLVM)
add_subdirectory(TritonToTritonCPU)
7 changes: 7 additions & 0 deletions third_party/cpu/lib/TritonCPUOpt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
add_triton_library(TritonCPUOpt
ConvertUnsupportedOps.cpp
VectorizeFpToFp.cpp

DEPENDS
TritonCPUOptPassIncGen
)
204 changes: 204 additions & 0 deletions third_party/cpu/lib/TritonCPUOpt/ConvertUnsupportedOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
#include "OptCommon.h"

#include "cpu/include/TritonCPUOpt/Passes.h"

#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"

namespace mlir {
namespace triton {
#define GEN_PASS_DEF_CONVERTUNSUPPORTEDOPS
#include "cpu/include/TritonCPUOpt/Passes.h.inc"
} // namespace triton
} // namespace mlir

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::cpu;

namespace {

template <typename OpT>
struct ConvertBf16ToFp32 : public OpRewritePattern<OpT> {
using OpRewritePattern<OpT>::OpRewritePattern;

LogicalResult matchAndRewrite(OpT op,
PatternRewriter &rewriter) const override {
// TODO: support mixed-type ops?
if (!isAllBf16(op->getOperandTypes()) || !isAllBf16(op->getResultTypes()))
return failure();

Location loc = op.getLoc();
OperationState newState(loc, OpT::getOperationName());
// Convert operands to fp32 and generate fp32 op.
for (auto operand : op->getOperands()) {
Value newOperand = rewriter.create<arith::ExtFOp>(
loc, toFp32(operand.getType()), operand);
newState.operands.push_back(newOperand);
}
newState.types = toFp32(op->getResultTypes());
newState.attributes = op->getAttrs();
auto newOp = rewriter.create(newState);

// Convert op results back to Bf16
SmallVector<Value> results;
for (auto res : llvm::enumerate(newOp->getResults()))
results.push_back(rewriter.create<arith::TruncFOp>(
loc, op->getResult(res.index()).getType(), res.value()));
rewriter.replaceOp(op, results);

return success();
}

bool isAllBf16(TypeRange types) const {
return std::all_of(types.begin(), types.end(),
[this](auto ty) { return isBf16(ty); });
}

SmallVector<Type> toFp32(TypeRange types) const {
SmallVector<Type> res;
for (auto ty : types)
res.push_back(::toFp32(ty));
return res;
}
};

template <typename OpT>
struct ConvertIToBf16ToFp32 : public OpRewritePattern<OpT> {
using OpRewritePattern<OpT>::OpRewritePattern;

LogicalResult matchAndRewrite(OpT op,
PatternRewriter &rewriter) const override {
if (!isBf16(op.getType()))
return failure();

Location loc = op.getLoc();
Value fp32Val =
rewriter.create<OpT>(loc, toFp32(op.getType()), op.getOperand());
Value res = rewriter.create<arith::TruncFOp>(loc, op.getType(), fp32Val);
rewriter.replaceOp(op, res);
return success();
}
};

Value convertMemRefToI16(Value memRef, PatternRewriter &rewriter) {
// Memory references for masked operations are always built
// with PtrToMemRefOp.
auto def = memRef.getDefiningOp<PtrToMemRefOp>();
assert(def);
auto insPoint = rewriter.saveInsertionPoint();
rewriter.setInsertionPointAfter(def);
MemRefType memRefTy = cast<MemRefType>(memRef.getType());
Type newMemRefTy =
MemRefType::get(memRefTy.getShape(), rewriter.getI16Type(),
memRefTy.getLayout(), memRefTy.getMemorySpace());
Value res = rewriter.create<PtrToMemRefOp>(memRef.getLoc(), newMemRefTy,
def.getSrc());
rewriter.restoreInsertionPoint(insPoint);
return res;
}

struct ConvertBf16MaskedLoadOp : public OpRewritePattern<vector::MaskedLoadOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::MaskedLoadOp op,
PatternRewriter &rewriter) const override {
if (!isBf16(op.getType()))
return failure();

Location loc = op.getLoc();
Value newBase = convertMemRefToI16(op.getBase(), rewriter);
Value newPassThru = rewriter.create<arith::BitcastOp>(
loc, toInt16(op.getPassThru().getType()), op.getPassThru());
Value intVal = rewriter.create<vector::MaskedLoadOp>(
loc, toInt16(op.getType()), newBase, op.getIndices(), op.getMask(),
newPassThru);
Value res = rewriter.create<arith::BitcastOp>(loc, op.getType(), intVal);
rewriter.replaceOp(op, res);
return success();
}
};

struct ConvertBf16MaskedStoreOp
: public OpRewritePattern<vector::MaskedStoreOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::MaskedStoreOp op,
PatternRewriter &rewriter) const override {
if (!isBf16(op.getValueToStore().getType()))
return failure();

Location loc = op.getLoc();
Value newBase = convertMemRefToI16(op.getBase(), rewriter);
Value intVal = rewriter.create<arith::BitcastOp>(
loc, toInt16(op.getValueToStore().getType()), op.getValueToStore());
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
op, newBase, op.getIndices(), op.getMask(), intVal);
return success();
}
};

struct ConvertBf16Abs : public OpRewritePattern<math::AbsFOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(math::AbsFOp op,
PatternRewriter &rewriter) const override {
if (!isBf16(op.getType()) || !isBf16(op.getOperand().getType()))
return failure();

Location loc = op.getLoc();
Value src = op.getOperand();
Value intSrc =
rewriter.create<arith::BitcastOp>(loc, toInt16(op.getType()), src);
TypedAttr maskAttr = rewriter.getI16IntegerAttr(0x7fff);
if (auto vecTy = dyn_cast<VectorType>(intSrc.getType()))
maskAttr = SplatElementsAttr::get(vecTy, maskAttr);
Value mask = rewriter.create<arith::ConstantOp>(loc, maskAttr);
Value res = rewriter.create<arith::AndIOp>(loc, intSrc, mask);
res = rewriter.create<arith::BitcastOp>(loc, op.getType(), res);
rewriter.replaceOp(op, res);
return success();
}
};

struct ConvertUnsupportedOps
: public triton::impl::ConvertUnsupportedOpsBase<ConvertUnsupportedOps> {
using ConvertUnsupportedOpsBase::ConvertUnsupportedOpsBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

RewritePatternSet patterns(context);
patterns.add<ConvertBf16ToFp32<arith::AddFOp>>(context);
patterns.add<ConvertBf16ToFp32<arith::SubFOp>>(context);
patterns.add<ConvertBf16ToFp32<arith::MulFOp>>(context);
patterns.add<ConvertIToBf16ToFp32<arith::SIToFPOp>>(context);
patterns.add<ConvertIToBf16ToFp32<arith::UIToFPOp>>(context);
patterns.add<ConvertBf16MaskedLoadOp>(context);
patterns.add<ConvertBf16MaskedStoreOp>(context);

patterns.add<ConvertBf16Abs>(context);

if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns))))
return signalPassFailure();
}
};

} // namespace

namespace mlir {
namespace triton {
namespace cpu {

std::unique_ptr<OperationPass<ModuleOp>> createConvertUnsupportedOps() {
return std::make_unique<ConvertUnsupportedOps>();
}

} // namespace cpu
} // namespace triton
} // namespace mlir
Loading

0 comments on commit 825ca0f

Please sign in to comment.