Skip to content

Commit

Permalink
move all llvm dependent code for cpu into triton_cpu.cc to avoid doub…
Browse files Browse the repository at this point in the history
…le registration of llvm passes, and add flags for mac
  • Loading branch information
Kuigesi committed Jun 14, 2024
1 parent 7bf0591 commit 031be2f
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 233 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
MLIRGPUToROCDLTransforms

# LLVM
LLVMPasses
LLVMOrcJIT
LLVMNVPTXCodeGen
# LLVMNVPTXAsmPrinter
LLVMAMDGPUCodeGen
Expand Down
5 changes: 5 additions & 0 deletions python/triton/runtime/build.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import sys
import platform
import io
import sysconfig
import os
Expand All @@ -20,6 +21,7 @@ def quiet():

def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
suffix = sysconfig.get_config_var('EXT_SUFFIX')
system = platform.system()
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
cc = os.environ.get("CC")
Expand All @@ -42,6 +44,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
include_dirs = include_dirs + [srcdir, py_include_dir]
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so]
# on mac, we use dynamic lookup to load python
if system == "Darwin":
cc_cmd += ["-undefined" ,"dynamic_lookup", "-flto"]
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += [f"-I{dir}" for dir in include_dirs]
Expand Down
170 changes: 0 additions & 170 deletions third_party/cpu/backend/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,6 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/TargetSelect.h"

#include <cstddef>
#include <iostream>
#include <string>
Expand All @@ -39,160 +25,6 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
return Py_BuildValue("{s:i}", "max_shared_mem", 0);
}

bool getBoolEnv(const std::string &env) {
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) { return std::tolower(c); });
return (str == "on" || str == "true" || str == "1");
}

llvm::orc::ThreadSafeContext &getThreadSafeContext() {
static llvm::orc::ThreadSafeContext tsc;
static std::once_flag init_flag;
std::call_once(init_flag, []() {
auto context = std::make_unique<llvm::LLVMContext>();
tsc = llvm::orc::ThreadSafeContext(std::move(context));
});
return tsc;
}

std::string llvmErrToString(const llvm::Error &err) {
std::string res;
llvm::raw_string_ostream os(res);
os << err;
return res;
};

struct CompiledKernel {
std::unique_ptr<llvm::orc::ExecutionSession> execution_session;
std::unique_ptr<llvm::DataLayout> data_layout;
std::unique_ptr<llvm::orc::MangleAndInterner> mangle;
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_layer;
std::unique_ptr<llvm::orc::IRCompileLayer> compiler_layer;
llvm::orc::JITDylib *dylib = nullptr;

CompiledKernel() = default;
CompiledKernel(CompiledKernel &&) = default;

~CompiledKernel() {
if (execution_session)
llvm::cantFail(execution_session->endSession());
}
};

std::vector<std::unique_ptr<CompiledKernel>> compiled_kernels;

static PyObject *loadBitcode(PyObject *self, PyObject *args) {
const char *name;
int shared;
PyObject *py_bytes;
int devId;

if (!PyArg_ParseTuple(args, "sSii", &name, &py_bytes, &shared, &devId)) {
std::cerr << "loadBitcode arg parse failed" << std::endl;
return NULL;
}

std::string kernel_name = name;
size_t binary_size = PyBytes_Size(py_bytes);
const char *binary_ptr = PyBytes_AsString(py_bytes);

llvm::LLVMContext context;
auto buf = llvm::MemoryBuffer::getMemBuffer(
llvm::StringRef(binary_ptr, binary_size));
auto mod = llvm::parseBitcodeFile(*buf, context);
if (!mod) {
std::cerr << "Failed to parse LLVM bitcode module" << std::endl;
return NULL;
}

if (getBoolEnv("MLIR_ENABLE_DUMP")) {
llvm::errs() << "********** Loaded Module (kernel_name=" << name
<< ") **********\n"
<< **mod << "\n";
}

auto init_err = llvm::InitializeNativeTarget();
if (init_err) {
std::cerr << "Failed to initialize native target." << std::endl;
return NULL;
}

llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();

auto self_epc =
llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create());

auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost();
if (!detect_host_res) {
std::cerr << "Failed to initialize JITTargetMachineBuilder: "
<< llvmErrToString(detect_host_res.takeError());
return NULL;
}
llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res);

auto data_layout_res = tmb.getDefaultDataLayoutForTarget();
if (!data_layout_res) {
std::cerr << "Failed to initialize data layout: "
<< llvmErrToString(data_layout_res.takeError());
return NULL;
}

CompiledKernel kernel;
kernel.execution_session =
std::make_unique<llvm::orc::ExecutionSession>(std::move(self_epc));
kernel.data_layout =
std::make_unique<llvm::DataLayout>(std::move(*data_layout_res));
kernel.mangle = std::make_unique<llvm::orc::MangleAndInterner>(
*kernel.execution_session, *kernel.data_layout);
kernel.object_layer = std::make_unique<llvm::orc::RTDyldObjectLinkingLayer>(
*kernel.execution_session,
[]() { return std::make_unique<llvm::SectionMemoryManager>(); });
kernel.compiler_layer = std::make_unique<llvm::orc::IRCompileLayer>(
*kernel.execution_session, *kernel.object_layer,
std::make_unique<llvm::orc::ConcurrentIRCompiler>(std::move(tmb)));

auto dylib_res = kernel.execution_session->createJITDylib("<main>");
if (!dylib_res) {
std::cerr << "Failed to create initialize JITDylib: "
<< llvmErrToString(dylib_res.takeError());
return NULL;
}

kernel.dylib = &(*dylib_res);
kernel.dylib->addGenerator(llvm::cantFail(
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
kernel.data_layout->getGlobalPrefix())));

// Compile module.
(**mod).setDataLayout(*kernel.data_layout);
llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext());
auto err = kernel.compiler_layer->add(*kernel.dylib, std::move(tsm));
if (err) {
std::cerr << "Cannot add LLVM module: " << llvmErrToString(err);
return NULL;
}

// Find kernel function pointer.
auto lookup_res =
kernel.execution_session->lookup({kernel.dylib}, (*kernel.mangle)(name));
if (!lookup_res) {
std::cerr << "Failed to find function " << std::string(name)
<< "\nError: " << llvmErrToString(lookup_res.takeError());
return NULL;
}
uint64_t fn_ptr = lookup_res->getAddress().getValue();

compiled_kernels.push_back(
std::make_unique<CompiledKernel>(std::move(kernel)));
auto *kernel_ptr = compiled_kernels.back().get();

return Py_BuildValue("(KKii)", reinterpret_cast<uint64_t>(kernel_ptr),
reinterpret_cast<uint64_t>(fn_ptr), 0, 0);
}

static PyObject *initContext(PyObject *self, PyObject *args) {
return Py_BuildValue("(K)", (uint64_t)0);
}
Expand All @@ -202,8 +34,6 @@ static PyObject *initDevices(PyObject *self, PyObject *args) {
}

static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBitcode, METH_VARARGS,
"Load provided SPV into ZE driver"},
{"get_device_properties", getDeviceProperties, METH_VARARGS,
"Get the properties for a given device"},
{NULL, NULL, 0, NULL} // sentinel
Expand Down
67 changes: 5 additions & 62 deletions third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,77 +2,20 @@
import hashlib
import tempfile
from pathlib import Path

from triton._C.libtriton import cpu

from triton.runtime.build import _build
from triton.runtime.cache import get_cache_manager
from triton.backends.driver import DriverBase
from triton.backends.compiler import GPUTarget

dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local")
llvm_root = os.getenv("LLVM_PATH", default="~/.triton/llvm")
llvm_root = os.path.expanduser(llvm_root)
llvm_dirs = os.listdir(llvm_root)
if len(llvm_dirs) == 1:
llvm_root = os.path.join(llvm_root, llvm_dirs[0])
include_dir = [
os.path.join(dirname, "include"),
os.path.join(llvm_root, "include"),
]
library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")]
library_dir = [os.path.join(dirname, "lib")]
libraries = [
"LLVMOrcJIT",
"LLVMPasses",
"LLVMX86CodeGen",
"LLVMX86AsmParser",
"LLVMX86Desc",
"LLVMX86Info",
"LLVMGlobalISel",
"LLVMSelectionDAG",
"LLVMHipStdPar",
"LLVMCoroutines",
"LLVMipo",
"LLVMFrontendOpenMP",
"LLVMInstrumentation",
"LLVMAsmPrinter",
"LLVMCodeGen",
"LLVMObjCARCOpts",
"LLVMLinker",
"LLVMVectorize",
"LLVMScalarOpts",
"LLVMInstCombine",
"LLVMFrontendOffloading",
"LLVMExecutionEngine",
"LLVMAggressiveInstCombine",
"LLVMTransformUtils",
"LLVMTarget",
"LLVMRuntimeDyld",
"LLVMJITLink",
"LLVMIRPrinter",
"LLVMBitWriter",
"LLVMAnalysis",
"LLVMProfileData",
"LLVMSymbolize",
"LLVMDebugInfoDWARF",
"LLVMObject",
"LLVMTextAPI",
"LLVMMCParser",
"LLVMMCDisassembler",
"LLVMMC",
"LLVMIRReader",
"LLVMCFGuard",
"LLVMBitReader",
"LLVMAsmParser",
"LLVMCore",
"LLVMBinaryFormat",
"LLVMOrcTargetProcess",
"LLVMTargetParser",
"LLVMRemarks",
"LLVMOrcShared",
"LLVMOption",
"LLVMDebugInfoCodeView",
"LLVMCodeGenTypes",
"LLVMBitstreamReader",
"LLVMSupport",
"LLVMDemangle",
"stdc++",
"z",
]
Expand Down Expand Up @@ -112,7 +55,7 @@ def __new__(cls):
def __init__(self):
dirname = os.path.dirname(os.path.realpath(__file__))
mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils")
self.load_binary = mod.load_binary
self.load_binary = cpu.utils.load_binary

def get_device_properties(self, *args):
return {"max_shared_mem": 0}
Expand Down
Loading

0 comments on commit 031be2f

Please sign in to comment.