Skip to content

Commit

Permalink
add LLJIT
Browse files Browse the repository at this point in the history
  • Loading branch information
Kuigesi committed Jun 14, 2024
1 parent 031be2f commit 4866464
Showing 1 changed file with 194 additions and 92 deletions.
286 changes: 194 additions & 92 deletions third_party/cpu/triton_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
#include "triton/Conversion/TritonCPUToLLVM/Passes.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"

#define TRITON_CPU_USE_LLJIT

#include "llvm/Bitcode/BitcodeReader.h"
#ifdef TRITON_CPU_USE_LLJIT
#include "llvm/ExecutionEngine/Orc/LLJIT.h"
#else
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
Expand All @@ -22,6 +27,7 @@
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#endif
#include "llvm/IR/Constants.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -107,6 +113,101 @@ std::string llvmErrToString(const llvm::Error &err) {
return res;
};

struct CompiledKernel;
std::vector<std::unique_ptr<CompiledKernel>> compiled_kernels;

#ifdef TRITON_CPU_USE_LLJIT
struct CompiledKernel {
std::unique_ptr<llvm::orc::LLJIT> ll_jit;

CompiledKernel() = default;
CompiledKernel(CompiledKernel &&) = default;
~CompiledKernel() = default;
};

py::object cpu_load_binary(std::string name, std::string llvmBC, int shared, int devId) {
std::cout << "kernel_name: " << name << std::endl;
std::cout << "llvmbc size: " << llvmBC.length() << std::endl;
auto res = py::object(py::cast(nullptr));

llvm::LLVMContext context;
auto buf = llvm::MemoryBuffer::getMemBuffer(
llvm::StringRef(llvmBC.c_str(), llvmBC.length()));
auto mod = llvm::parseBitcodeFile(*buf, context);
if (!mod) {
std::cerr << "Failed to parse LLVM bitcode module" << std::endl;
return res;
}
if (getBoolEnv("MLIR_ENABLE_DUMP")) {
llvm::errs() << "********** Loaded Module (kernel_name=" << name
<< ") **********\n"
<< **mod << "\n";
}
auto init_err = llvm::InitializeNativeTarget();
if (init_err) {
std::cerr << "Failed to initialize native target." << std::endl;
return res;
}
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
llvm::orc::LLJITBuilder Builder;

auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost();
if (!detect_host_res) {
std::cerr << "Failed to initialize JITTargetMachineBuilder: "
<< llvmErrToString(detect_host_res.takeError());
return res;
}
llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res);

auto data_layout_res = tmb.getDefaultDataLayoutForTarget();
if (!data_layout_res) {
std::cerr << "Failed to initialize data layout: "
<< llvmErrToString(data_layout_res.takeError());
return res;
}
auto triple = tmb.getTargetTriple().getTriple();

std::cout << "cpu: " << tmb.getCPU() << std::endl;
std::cout << "target triple: " << triple << std::endl;

Builder.setJITTargetMachineBuilder(std::move(tmb));
auto build_res = Builder.create();
if (!build_res) {
std::cerr << "Failed to create LLJIT instance" << std::endl;
return res;
}

CompiledKernel kernel;
kernel.ll_jit = std::move(*build_res);

(**mod).setDataLayout(std::move(*data_layout_res));
(**mod).setTargetTriple(triple);
llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext());

auto err = kernel.ll_jit->addIRModule(std::move(tsm));
if (err) {
std::cerr << "Cannot add LLVM module: " << llvmErrToString(err);
return res;
}

auto lookup_res = kernel.ll_jit->lookup(name);
if (!lookup_res) {
std::cerr << "Failed to find function " << std::string(name)
<< "\nError: " << llvmErrToString(lookup_res.takeError());
return res;
}

uint64_t fn_ptr = (*lookup_res).getValue();

compiled_kernels.push_back(
std::make_unique<CompiledKernel>(std::move(kernel)));
auto *kernel_ptr = compiled_kernels.back().get();

return py::object(py::make_tuple(reinterpret_cast<uint64_t>(kernel_ptr),
reinterpret_cast<uint64_t>(fn_ptr), 0, 0));
};
#else
struct CompiledKernel {
std::unique_ptr<llvm::orc::ExecutionSession> execution_session;
std::unique_ptr<llvm::DataLayout> data_layout;
Expand All @@ -124,112 +225,113 @@ struct CompiledKernel {
}
};

std::vector<std::unique_ptr<CompiledKernel>> compiled_kernels;

void init_triton_cpu_utils(py::module &&m) {
using namespace mlir::triton;
m.def("load_binary", [](std::string name, std::string llvmBC, int shared, int devId) {
auto res = py::object(py::cast(nullptr));
py::object cpu_load_binary(std::string name, std::string llvmBC, int shared, int devId) {
auto res = py::object(py::cast(nullptr));

llvm::LLVMContext context;
auto buf = llvm::MemoryBuffer::getMemBuffer(
llvm::StringRef(llvmBC.c_str(), llvmBC.length()));

auto mod = llvm::parseBitcodeFile(*buf, context);
if (!mod) {
std::cerr << "Failed to parse LLVM bitcode module" << std::endl;
return res;
}
llvm::LLVMContext context;
auto buf = llvm::MemoryBuffer::getMemBuffer(
llvm::StringRef(llvmBC.c_str(), llvmBC.length()));

auto mod = llvm::parseBitcodeFile(*buf, context);
if (!mod) {
std::cerr << "Failed to parse LLVM bitcode module" << std::endl;
return res;
}

if (getBoolEnv("MLIR_ENABLE_DUMP")) {
llvm::errs() << "********** Loaded Module (kernel_name=" << name
<< ") **********\n"
<< **mod << "\n";
}
if (getBoolEnv("MLIR_ENABLE_DUMP")) {
llvm::errs() << "********** Loaded Module (kernel_name=" << name
<< ") **********\n"
<< **mod << "\n";
}

auto init_err = llvm::InitializeNativeTarget();
if (init_err) {
std::cerr << "Failed to initialize native target." << std::endl;
return res;
}
auto init_err = llvm::InitializeNativeTarget();
if (init_err) {
std::cerr << "Failed to initialize native target." << std::endl;
return res;
}

llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();

auto self_epc =
llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create());
auto self_epc =
llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create());

auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost();
if (!detect_host_res) {
std::cerr << "Failed to initialize JITTargetMachineBuilder: "
<< llvmErrToString(detect_host_res.takeError());
return res;
}
llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res);
auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost();
if (!detect_host_res) {
std::cerr << "Failed to initialize JITTargetMachineBuilder: "
<< llvmErrToString(detect_host_res.takeError());
return res;
}
llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res);

auto data_layout_res = tmb.getDefaultDataLayoutForTarget();
auto data_layout_res = tmb.getDefaultDataLayoutForTarget();

std::cout << "cpu: " << tmb.getCPU() << std::endl;
std::cout << "target triple: " << tmb.getTargetTriple().getTriple() << std::endl;

if (!data_layout_res) {
std::cerr << "Failed to initialize data layout: "
<< llvmErrToString(data_layout_res.takeError());
return res;
}
std::cout << "cpu: " << tmb.getCPU() << std::endl;
std::cout << "target triple: " << tmb.getTargetTriple().getTriple() << std::endl;

if (!data_layout_res) {
std::cerr << "Failed to initialize data layout: "
<< llvmErrToString(data_layout_res.takeError());
return res;
}

CompiledKernel kernel;
kernel.execution_session =
std::make_unique<llvm::orc::ExecutionSession>(std::move(self_epc));
kernel.data_layout =
std::make_unique<llvm::DataLayout>(std::move(*data_layout_res));
kernel.mangle = std::make_unique<llvm::orc::MangleAndInterner>(
*kernel.execution_session, *kernel.data_layout);
kernel.object_layer = std::make_unique<llvm::orc::RTDyldObjectLinkingLayer>(
*kernel.execution_session,
[]() { return std::make_unique<llvm::SectionMemoryManager>(); });
kernel.compiler_layer = std::make_unique<llvm::orc::IRCompileLayer>(
*kernel.execution_session, *kernel.object_layer,
std::make_unique<llvm::orc::ConcurrentIRCompiler>(std::move(tmb)));
CompiledKernel kernel;
kernel.execution_session =
std::make_unique<llvm::orc::ExecutionSession>(std::move(self_epc));
kernel.data_layout =
std::make_unique<llvm::DataLayout>(std::move(*data_layout_res));
kernel.mangle = std::make_unique<llvm::orc::MangleAndInterner>(
*kernel.execution_session, *kernel.data_layout);
kernel.object_layer = std::make_unique<llvm::orc::RTDyldObjectLinkingLayer>(
*kernel.execution_session,
[]() { return std::make_unique<llvm::SectionMemoryManager>(); });
kernel.compiler_layer = std::make_unique<llvm::orc::IRCompileLayer>(
*kernel.execution_session, *kernel.object_layer,
std::make_unique<llvm::orc::ConcurrentIRCompiler>(std::move(tmb)));

auto dylib_res = kernel.execution_session->createJITDylib("<main>");
if (!dylib_res) {
std::cerr << "Failed to create initialize JITDylib: "
<< llvmErrToString(dylib_res.takeError());
return res;
}
auto dylib_res = kernel.execution_session->createJITDylib("<main>");
if (!dylib_res) {
std::cerr << "Failed to create initialize JITDylib: "
<< llvmErrToString(dylib_res.takeError());
return res;
}

kernel.dylib = &(*dylib_res);
kernel.dylib->addGenerator(llvm::cantFail(
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
kernel.data_layout->getGlobalPrefix())));
kernel.dylib = &(*dylib_res);
kernel.dylib->addGenerator(llvm::cantFail(
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
kernel.data_layout->getGlobalPrefix())));

// Compile module.
(**mod).setDataLayout(*kernel.data_layout);
llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext());
auto err = kernel.compiler_layer->add(*kernel.dylib, std::move(tsm));
if (err) {
std::cerr << "Cannot add LLVM module: " << llvmErrToString(err);
return res;
}
// Compile module.
(**mod).setDataLayout(*kernel.data_layout);
llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext());
auto err = kernel.compiler_layer->add(*kernel.dylib, std::move(tsm));
if (err) {
std::cerr << "Cannot add LLVM module: " << llvmErrToString(err);
return res;
}

// Find kernel function pointer.
auto lookup_res =
kernel.execution_session->lookup({kernel.dylib}, (*kernel.mangle)(name));
if (!lookup_res) {
std::cerr << "Failed to find function " << std::string(name)
<< "\nError: " << llvmErrToString(lookup_res.takeError());
return res;
}
uint64_t fn_ptr = lookup_res->getAddress().getValue();

compiled_kernels.push_back(
std::make_unique<CompiledKernel>(std::move(kernel)));
auto *kernel_ptr = compiled_kernels.back().get();
// Find kernel function pointer.
auto lookup_res =
kernel.execution_session->lookup({kernel.dylib}, (*kernel.mangle)(name));
if (!lookup_res) {
std::cerr << "Failed to find function " << std::string(name)
<< "\nError: " << llvmErrToString(lookup_res.takeError());
return res;
}
uint64_t fn_ptr = lookup_res->getAddress().getValue();

compiled_kernels.push_back(
std::make_unique<CompiledKernel>(std::move(kernel)));
auto *kernel_ptr = compiled_kernels.back().get();

return py::object(py::make_tuple(reinterpret_cast<uint64_t>(kernel_ptr),
reinterpret_cast<uint64_t>(fn_ptr), 0, 0));
});
return py::object(py::make_tuple(reinterpret_cast<uint64_t>(kernel_ptr),
reinterpret_cast<uint64_t>(fn_ptr), 0, 0));
}
#endif

void init_triton_cpu_utils(py::module &&m) {
using namespace mlir::triton;
m.def("load_binary", cpu_load_binary);
}

void init_triton_cpu(py::module &&m) {
Expand Down

0 comments on commit 4866464

Please sign in to comment.