diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 1597249b0e23e..badde8f58db73 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -13,7 +13,12 @@ #include "triton/Conversion/TritonCPUToLLVM/Passes.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" +#define TRITON_CPU_USE_LLJIT + #include "llvm/Bitcode/BitcodeReader.h" +#ifdef TRITON_CPU_USE_LLJIT +#include "llvm/ExecutionEngine/Orc/LLJIT.h" +#else #include "llvm/ExecutionEngine/Orc/CompileUtils.h" #include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" @@ -22,6 +27,7 @@ #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" +#endif #include "llvm/IR/Constants.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" @@ -107,6 +113,101 @@ std::string llvmErrToString(const llvm::Error &err) { return res; }; +struct CompiledKernel; +std::vector> compiled_kernels; + +#ifdef TRITON_CPU_USE_LLJIT +struct CompiledKernel { + std::unique_ptr ll_jit; + + CompiledKernel() = default; + CompiledKernel(CompiledKernel &&) = default; + ~CompiledKernel() = default; +}; + +py::object cpu_load_binary(std::string name, std::string llvmBC, int shared, int devId) { + std::cout << "kernel_name: " << name << std::endl; + std::cout << "llvmbc size: " << llvmBC.length() << std::endl; + auto res = py::object(py::cast(nullptr)); + + llvm::LLVMContext context; + auto buf = llvm::MemoryBuffer::getMemBuffer( + llvm::StringRef(llvmBC.c_str(), llvmBC.length())); + auto mod = llvm::parseBitcodeFile(*buf, context); + if (!mod) { + std::cerr << "Failed to parse LLVM bitcode module" << std::endl; + return res; + } + if (getBoolEnv("MLIR_ENABLE_DUMP")) { + llvm::errs() << "********** Loaded Module (kernel_name=" << name + << ") **********\n" + << **mod << "\n"; + } + auto init_err = llvm::InitializeNativeTarget(); + if (init_err) { + std::cerr << "Failed to initialize native target." << std::endl; + return res; + } + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + llvm::orc::LLJITBuilder Builder; + + auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost(); + if (!detect_host_res) { + std::cerr << "Failed to initialize JITTargetMachineBuilder: " + << llvmErrToString(detect_host_res.takeError()); + return res; + } + llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res); + + auto data_layout_res = tmb.getDefaultDataLayoutForTarget(); + if (!data_layout_res) { + std::cerr << "Failed to initialize data layout: " + << llvmErrToString(data_layout_res.takeError()); + return res; + } + auto triple = tmb.getTargetTriple().getTriple(); + + std::cout << "cpu: " << tmb.getCPU() << std::endl; + std::cout << "target triple: " << triple << std::endl; + + Builder.setJITTargetMachineBuilder(std::move(tmb)); + auto build_res = Builder.create(); + if (!build_res) { + std::cerr << "Failed to create LLJIT instance" << std::endl; + return res; + } + + CompiledKernel kernel; + kernel.ll_jit = std::move(*build_res); + + (**mod).setDataLayout(std::move(*data_layout_res)); + (**mod).setTargetTriple(triple); + llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext()); + + auto err = kernel.ll_jit->addIRModule(std::move(tsm)); + if (err) { + std::cerr << "Cannot add LLVM module: " << llvmErrToString(err); + return res; + } + + auto lookup_res = kernel.ll_jit->lookup(name); + if (!lookup_res) { + std::cerr << "Failed to find function " << std::string(name) + << "\nError: " << llvmErrToString(lookup_res.takeError()); + return res; + } + + uint64_t fn_ptr = (*lookup_res).getValue(); + + compiled_kernels.push_back( + std::make_unique(std::move(kernel))); + auto *kernel_ptr = compiled_kernels.back().get(); + + return py::object(py::make_tuple(reinterpret_cast(kernel_ptr), + reinterpret_cast(fn_ptr), 0, 0)); +}; +#else struct CompiledKernel { std::unique_ptr execution_session; std::unique_ptr data_layout; @@ -124,112 +225,113 @@ struct CompiledKernel { } }; -std::vector> compiled_kernels; - -void init_triton_cpu_utils(py::module &&m) { - using namespace mlir::triton; - m.def("load_binary", [](std::string name, std::string llvmBC, int shared, int devId) { - auto res = py::object(py::cast(nullptr)); +py::object cpu_load_binary(std::string name, std::string llvmBC, int shared, int devId) { + auto res = py::object(py::cast(nullptr)); - llvm::LLVMContext context; - auto buf = llvm::MemoryBuffer::getMemBuffer( - llvm::StringRef(llvmBC.c_str(), llvmBC.length())); - - auto mod = llvm::parseBitcodeFile(*buf, context); - if (!mod) { - std::cerr << "Failed to parse LLVM bitcode module" << std::endl; - return res; - } + llvm::LLVMContext context; + auto buf = llvm::MemoryBuffer::getMemBuffer( + llvm::StringRef(llvmBC.c_str(), llvmBC.length())); + + auto mod = llvm::parseBitcodeFile(*buf, context); + if (!mod) { + std::cerr << "Failed to parse LLVM bitcode module" << std::endl; + return res; + } - if (getBoolEnv("MLIR_ENABLE_DUMP")) { - llvm::errs() << "********** Loaded Module (kernel_name=" << name - << ") **********\n" - << **mod << "\n"; - } + if (getBoolEnv("MLIR_ENABLE_DUMP")) { + llvm::errs() << "********** Loaded Module (kernel_name=" << name + << ") **********\n" + << **mod << "\n"; + } - auto init_err = llvm::InitializeNativeTarget(); - if (init_err) { - std::cerr << "Failed to initialize native target." << std::endl; - return res; - } + auto init_err = llvm::InitializeNativeTarget(); + if (init_err) { + std::cerr << "Failed to initialize native target." << std::endl; + return res; + } - llvm::InitializeNativeTargetAsmPrinter(); - llvm::InitializeNativeTargetAsmParser(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); - auto self_epc = - llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create()); + auto self_epc = + llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create()); - auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost(); - if (!detect_host_res) { - std::cerr << "Failed to initialize JITTargetMachineBuilder: " - << llvmErrToString(detect_host_res.takeError()); - return res; - } - llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res); + auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost(); + if (!detect_host_res) { + std::cerr << "Failed to initialize JITTargetMachineBuilder: " + << llvmErrToString(detect_host_res.takeError()); + return res; + } + llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res); - auto data_layout_res = tmb.getDefaultDataLayoutForTarget(); + auto data_layout_res = tmb.getDefaultDataLayoutForTarget(); - std::cout << "cpu: " << tmb.getCPU() << std::endl; - std::cout << "target triple: " << tmb.getTargetTriple().getTriple() << std::endl; - - if (!data_layout_res) { - std::cerr << "Failed to initialize data layout: " - << llvmErrToString(data_layout_res.takeError()); - return res; - } + std::cout << "cpu: " << tmb.getCPU() << std::endl; + std::cout << "target triple: " << tmb.getTargetTriple().getTriple() << std::endl; + + if (!data_layout_res) { + std::cerr << "Failed to initialize data layout: " + << llvmErrToString(data_layout_res.takeError()); + return res; + } - CompiledKernel kernel; - kernel.execution_session = - std::make_unique(std::move(self_epc)); - kernel.data_layout = - std::make_unique(std::move(*data_layout_res)); - kernel.mangle = std::make_unique( - *kernel.execution_session, *kernel.data_layout); - kernel.object_layer = std::make_unique( - *kernel.execution_session, - []() { return std::make_unique(); }); - kernel.compiler_layer = std::make_unique( - *kernel.execution_session, *kernel.object_layer, - std::make_unique(std::move(tmb))); + CompiledKernel kernel; + kernel.execution_session = + std::make_unique(std::move(self_epc)); + kernel.data_layout = + std::make_unique(std::move(*data_layout_res)); + kernel.mangle = std::make_unique( + *kernel.execution_session, *kernel.data_layout); + kernel.object_layer = std::make_unique( + *kernel.execution_session, + []() { return std::make_unique(); }); + kernel.compiler_layer = std::make_unique( + *kernel.execution_session, *kernel.object_layer, + std::make_unique(std::move(tmb))); - auto dylib_res = kernel.execution_session->createJITDylib("
"); - if (!dylib_res) { - std::cerr << "Failed to create initialize JITDylib: " - << llvmErrToString(dylib_res.takeError()); - return res; - } + auto dylib_res = kernel.execution_session->createJITDylib("
"); + if (!dylib_res) { + std::cerr << "Failed to create initialize JITDylib: " + << llvmErrToString(dylib_res.takeError()); + return res; + } - kernel.dylib = &(*dylib_res); - kernel.dylib->addGenerator(llvm::cantFail( - llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( - kernel.data_layout->getGlobalPrefix()))); + kernel.dylib = &(*dylib_res); + kernel.dylib->addGenerator(llvm::cantFail( + llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( + kernel.data_layout->getGlobalPrefix()))); - // Compile module. - (**mod).setDataLayout(*kernel.data_layout); - llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext()); - auto err = kernel.compiler_layer->add(*kernel.dylib, std::move(tsm)); - if (err) { - std::cerr << "Cannot add LLVM module: " << llvmErrToString(err); - return res; - } + // Compile module. + (**mod).setDataLayout(*kernel.data_layout); + llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext()); + auto err = kernel.compiler_layer->add(*kernel.dylib, std::move(tsm)); + if (err) { + std::cerr << "Cannot add LLVM module: " << llvmErrToString(err); + return res; + } - // Find kernel function pointer. - auto lookup_res = - kernel.execution_session->lookup({kernel.dylib}, (*kernel.mangle)(name)); - if (!lookup_res) { - std::cerr << "Failed to find function " << std::string(name) - << "\nError: " << llvmErrToString(lookup_res.takeError()); - return res; - } - uint64_t fn_ptr = lookup_res->getAddress().getValue(); - - compiled_kernels.push_back( - std::make_unique(std::move(kernel))); - auto *kernel_ptr = compiled_kernels.back().get(); + // Find kernel function pointer. + auto lookup_res = + kernel.execution_session->lookup({kernel.dylib}, (*kernel.mangle)(name)); + if (!lookup_res) { + std::cerr << "Failed to find function " << std::string(name) + << "\nError: " << llvmErrToString(lookup_res.takeError()); + return res; + } + uint64_t fn_ptr = lookup_res->getAddress().getValue(); + + compiled_kernels.push_back( + std::make_unique(std::move(kernel))); + auto *kernel_ptr = compiled_kernels.back().get(); - return py::object(py::make_tuple(reinterpret_cast(kernel_ptr), - reinterpret_cast(fn_ptr), 0, 0)); - }); + return py::object(py::make_tuple(reinterpret_cast(kernel_ptr), + reinterpret_cast(fn_ptr), 0, 0)); +} +#endif + +void init_triton_cpu_utils(py::module &&m) { + using namespace mlir::triton; + m.def("load_binary", cpu_load_binary); } void init_triton_cpu(py::module &&m) {