Skip to content

Commit

Permalink
code clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Kuigesi committed Jun 14, 2024
1 parent 4866464 commit f47a21b
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions third_party/cpu/triton_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define TRITON_CPU_USE_LLJIT

#include "llvm/Bitcode/BitcodeReader.h"

#ifdef TRITON_CPU_USE_LLJIT
#include "llvm/ExecutionEngine/Orc/LLJIT.h"
#else
Expand All @@ -28,6 +29,7 @@
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#endif

#include "llvm/IR/Constants.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -126,28 +128,30 @@ struct CompiledKernel {
};

py::object cpu_load_binary(std::string name, std::string llvmBC, int shared, int devId) {
std::cout << "kernel_name: " << name << std::endl;
std::cout << "llvmbc size: " << llvmBC.length() << std::endl;
auto res = py::object(py::cast(nullptr));

llvm::LLVMContext context;
auto buf = llvm::MemoryBuffer::getMemBuffer(
llvm::StringRef(llvmBC.c_str(), llvmBC.length()));

auto mod = llvm::parseBitcodeFile(*buf, context);
if (!mod) {
std::cerr << "Failed to parse LLVM bitcode module" << std::endl;
return res;
}

if (getBoolEnv("MLIR_ENABLE_DUMP")) {
llvm::errs() << "********** Loaded Module (kernel_name=" << name
<< ") **********\n"
<< **mod << "\n";
}

auto init_err = llvm::InitializeNativeTarget();
if (init_err) {
std::cerr << "Failed to initialize native target." << std::endl;
return res;
}

llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
llvm::orc::LLJITBuilder Builder;
Expand All @@ -158,6 +162,7 @@ py::object cpu_load_binary(std::string name, std::string llvmBC, int shared, int
<< llvmErrToString(detect_host_res.takeError());
return res;
}

llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res);

auto data_layout_res = tmb.getDefaultDataLayoutForTarget();
Expand All @@ -166,12 +171,14 @@ py::object cpu_load_binary(std::string name, std::string llvmBC, int shared, int
<< llvmErrToString(data_layout_res.takeError());
return res;
}
auto triple = tmb.getTargetTriple().getTriple();

std::cout << "cpu: " << tmb.getCPU() << std::endl;
std::cout << "target triple: " << triple << std::endl;

(**mod).setDataLayout(*data_layout_res);
//std::string triple = tmb.getTargetTriple().getTriple();
//(**mod).setTargetTriple(triple);
llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext());

Builder.setJITTargetMachineBuilder(std::move(tmb));

auto build_res = Builder.create();
if (!build_res) {
std::cerr << "Failed to create LLJIT instance" << std::endl;
Expand All @@ -181,10 +188,6 @@ py::object cpu_load_binary(std::string name, std::string llvmBC, int shared, int
CompiledKernel kernel;
kernel.ll_jit = std::move(*build_res);

(**mod).setDataLayout(std::move(*data_layout_res));
(**mod).setTargetTriple(triple);
llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext());

auto err = kernel.ll_jit->addIRModule(std::move(tsm));
if (err) {
std::cerr << "Cannot add LLVM module: " << llvmErrToString(err);
Expand Down Expand Up @@ -265,9 +268,6 @@ py::object cpu_load_binary(std::string name, std::string llvmBC, int shared, int
llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res);

auto data_layout_res = tmb.getDefaultDataLayoutForTarget();

std::cout << "cpu: " << tmb.getCPU() << std::endl;
std::cout << "target triple: " << tmb.getTargetTriple().getTriple() << std::endl;

if (!data_layout_res) {
std::cerr << "Failed to initialize data layout: "
Expand Down Expand Up @@ -356,3 +356,5 @@ void init_triton_cpu(py::module &&m) {
return res;
});
}

#undef TRITON_CPU_USE_LLJIT

0 comments on commit f47a21b

Please sign in to comment.