diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index bd4d94c5f6b4..6c9f35827d73 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -48,6 +48,8 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): # CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag. if src.endswith(".cpp") or src.endswith(".cc"): cc_cmd += ["-std=c++17", "-fopenmp"] + if src.endswith(".s"): + cc_cmd += ["-gdwarf-5"] ret = subprocess.check_call(cc_cmd) if ret == 0: return so diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index c3f11334750a..0a98532eceba 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -42,7 +42,7 @@ def supports_target(target: GPUTarget): def __init__(self, target: tuple) -> None: super().__init__(target) - self.binary_ext = "bc" + self.binary_ext = "asm" def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} @@ -138,22 +138,14 @@ def make_llir(src, metadata, options): return ret @staticmethod - def make_bc(src, metadata, options): - if os.environ.get("TRITON_CPU_ASM_DUMP", "0") == "1": - from triton.runtime.cache import get_cache_manager - - asm = llvm.translate_to_host_asm(src, options.enable_fp_fusion) - fn_cache_manager = get_cache_manager(metadata['hash']) - fn_cache_manager.put(asm, f"{metadata['name']}.asm") - - ret = llvm.translate_to_bc(src) - return ret + def make_asm(src, metadata, options): + return llvm.translate_to_host_asm(src, options.enable_fp_fusion) def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) - stages["bc"] = lambda src, metadata: self.make_bc(src, metadata, options) + stages["asm"] = lambda src, metadata: self.make_asm(src, metadata, options) @functools.lru_cache() def hash(self): diff --git a/third_party/cpu/backend/driver.cpp b/third_party/cpu/backend/driver.cpp deleted file mode 100644 index babff3dfdebe..000000000000 --- a/third_party/cpu/backend/driver.cpp +++ /dev/null @@ -1,224 +0,0 @@ -//===- driver.cpp ---------------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "llvm/Bitcode/BitcodeReader.h" -#include "llvm/ExecutionEngine/Orc/CompileUtils.h" -#include "llvm/ExecutionEngine/Orc/Core.h" -#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" -#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" -#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" -#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" -#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" -#include "llvm/ExecutionEngine/SectionMemoryManager.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/TargetSelect.h" - -#include -#include -#include -#include -#include -#include -#include - -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION -#include - -static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { - int device_id; - if (!PyArg_ParseTuple(args, "i", &device_id)) - return NULL; - - return Py_BuildValue("{s:i}", "max_shared_mem", 0); -} - -bool getBoolEnv(const std::string &env) { - const char *s = std::getenv(env.c_str()); - std::string str(s ? s : ""); - std::transform(str.begin(), str.end(), str.begin(), - [](unsigned char c) { return std::tolower(c); }); - return (str == "on" || str == "true" || str == "1"); -} - -llvm::orc::ThreadSafeContext &getThreadSafeContext() { - static llvm::orc::ThreadSafeContext tsc; - static std::once_flag init_flag; - std::call_once(init_flag, []() { - auto context = std::make_unique(); - tsc = llvm::orc::ThreadSafeContext(std::move(context)); - }); - return tsc; -} - -std::string llvmErrToString(const llvm::Error &err) { - std::string res; - llvm::raw_string_ostream os(res); - os << err; - return res; -}; - -struct CompiledKernel { - std::unique_ptr execution_session; - std::unique_ptr data_layout; - std::unique_ptr mangle; - std::unique_ptr object_layer; - std::unique_ptr compiler_layer; - llvm::orc::JITDylib *dylib = nullptr; - - CompiledKernel() = default; - CompiledKernel(CompiledKernel &&) = default; - - ~CompiledKernel() { - if (execution_session) - llvm::cantFail(execution_session->endSession()); - } -}; - -std::vector> compiled_kernels; - -static PyObject *loadBitcode(PyObject *self, PyObject *args) { - const char *name; - int shared; - PyObject *py_bytes; - int devId; - - if (!PyArg_ParseTuple(args, "sSii", &name, &py_bytes, &shared, &devId)) { - std::cerr << "loadBitcode arg parse failed" << std::endl; - return NULL; - } - - std::string kernel_name = name; - size_t binary_size = PyBytes_Size(py_bytes); - const char *binary_ptr = PyBytes_AsString(py_bytes); - - llvm::LLVMContext context; - auto buf = llvm::MemoryBuffer::getMemBuffer( - llvm::StringRef(binary_ptr, binary_size)); - auto mod = llvm::parseBitcodeFile(*buf, context); - if (!mod) { - std::cerr << "Failed to parse LLVM bitcode module" << std::endl; - return NULL; - } - - if (getBoolEnv("MLIR_ENABLE_DUMP")) { - llvm::errs() << "********** Loaded Module (kernel_name=" << name - << ") **********\n" - << **mod << "\n"; - } - - auto init_err = llvm::InitializeNativeTarget(); - if (init_err) { - std::cerr << "Failed to initialize native target." << std::endl; - return NULL; - } - - llvm::InitializeNativeTargetAsmPrinter(); - llvm::InitializeNativeTargetAsmParser(); - - auto self_epc = - llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create()); - - auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost(); - if (!detect_host_res) { - std::cerr << "Failed to initialize JITTargetMachineBuilder: " - << llvmErrToString(detect_host_res.takeError()); - return NULL; - } - llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res); - - auto data_layout_res = tmb.getDefaultDataLayoutForTarget(); - if (!data_layout_res) { - std::cerr << "Failed to initialize data layout: " - << llvmErrToString(data_layout_res.takeError()); - return NULL; - } - - CompiledKernel kernel; - kernel.execution_session = - std::make_unique(std::move(self_epc)); - kernel.data_layout = - std::make_unique(std::move(*data_layout_res)); - kernel.mangle = std::make_unique( - *kernel.execution_session, *kernel.data_layout); - kernel.object_layer = std::make_unique( - *kernel.execution_session, - []() { return std::make_unique(); }); - kernel.compiler_layer = std::make_unique( - *kernel.execution_session, *kernel.object_layer, - std::make_unique(std::move(tmb))); - - auto dylib_res = kernel.execution_session->createJITDylib("
"); - if (!dylib_res) { - std::cerr << "Failed to create initialize JITDylib: " - << llvmErrToString(dylib_res.takeError()); - return NULL; - } - - kernel.dylib = &(*dylib_res); - kernel.dylib->addGenerator(llvm::cantFail( - llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( - kernel.data_layout->getGlobalPrefix()))); - - // Compile module. - (**mod).setDataLayout(*kernel.data_layout); - llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext()); - auto err = kernel.compiler_layer->add(*kernel.dylib, std::move(tsm)); - if (err) { - std::cerr << "Cannot add LLVM module: " << llvmErrToString(err); - return NULL; - } - - // Find kernel function pointer. - auto lookup_res = - kernel.execution_session->lookup({kernel.dylib}, (*kernel.mangle)(name)); - if (!lookup_res) { - std::cerr << "Failed to find function " << std::string(name) - << "\nError: " << llvmErrToString(lookup_res.takeError()); - return NULL; - } - uint64_t fn_ptr = lookup_res->getAddress().getValue(); - - compiled_kernels.push_back( - std::make_unique(std::move(kernel))); - auto *kernel_ptr = compiled_kernels.back().get(); - - return Py_BuildValue("(KKii)", reinterpret_cast(kernel_ptr), - reinterpret_cast(fn_ptr), 0, 0); -} - -static PyObject *initContext(PyObject *self, PyObject *args) { - return Py_BuildValue("(K)", (uint64_t)0); -} - -static PyObject *initDevices(PyObject *self, PyObject *args) { - return Py_BuildValue("(i)", 1); -} - -static PyMethodDef ModuleMethods[] = { - {"load_binary", loadBitcode, METH_VARARGS, - "Load provided SPV into ZE driver"}, - {"get_device_properties", getDeviceProperties, METH_VARARGS, - "Get the properties for a given device"}, - {NULL, NULL, 0, NULL} // sentinel -}; - -static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cpu_utils", - NULL, // documentation - -1, // size - ModuleMethods}; - -PyMODINIT_FUNC PyInit_cpu_utils(void) { - PyObject *m = PyModule_Create(&ModuleDef); - if (m == NULL) { - return NULL; - } - PyModule_AddFunctions(m, ModuleMethods); - return m; -} diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 1018f64d5b35..126e41d3bd7a 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -8,74 +8,9 @@ from triton.backends.compiler import GPUTarget dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") -llvm_root = os.getenv("LLVM_PATH", default="~/.triton/llvm") -llvm_root = os.path.expanduser(llvm_root) -llvm_dirs = os.listdir(llvm_root) -if len(llvm_dirs) == 1: - llvm_root = os.path.join(llvm_root, llvm_dirs[0]) -include_dir = [ - os.path.join(dirname, "include"), - os.path.join(llvm_root, "include"), -] -library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")] -libraries = [ - "LLVMOrcJIT", - "LLVMPasses", - "LLVMX86CodeGen", - "LLVMX86AsmParser", - "LLVMX86Desc", - "LLVMX86Info", - "LLVMGlobalISel", - "LLVMSelectionDAG", - "LLVMHipStdPar", - "LLVMCoroutines", - "LLVMipo", - "LLVMFrontendOpenMP", - "LLVMInstrumentation", - "LLVMAsmPrinter", - "LLVMCodeGen", - "LLVMObjCARCOpts", - "LLVMLinker", - "LLVMVectorize", - "LLVMScalarOpts", - "LLVMInstCombine", - "LLVMFrontendOffloading", - "LLVMExecutionEngine", - "LLVMAggressiveInstCombine", - "LLVMTransformUtils", - "LLVMTarget", - "LLVMRuntimeDyld", - "LLVMJITLink", - "LLVMIRPrinter", - "LLVMBitWriter", - "LLVMAnalysis", - "LLVMProfileData", - "LLVMSymbolize", - "LLVMDebugInfoDWARF", - "LLVMObject", - "LLVMTextAPI", - "LLVMMCParser", - "LLVMMCDisassembler", - "LLVMMC", - "LLVMIRReader", - "LLVMCFGuard", - "LLVMBitReader", - "LLVMAsmParser", - "LLVMCore", - "LLVMBinaryFormat", - "LLVMOrcTargetProcess", - "LLVMTargetParser", - "LLVMRemarks", - "LLVMOrcShared", - "LLVMOption", - "LLVMDebugInfoCodeView", - "LLVMCodeGenTypes", - "LLVMBitstreamReader", - "LLVMSupport", - "LLVMDemangle", - "stdc++", - "z", -] +include_dir = [os.path.join(dirname, "include")] +library_dir = [os.path.join(dirname, "lib")] +libraries = ["stdc++"] def compile_module_from_src(src, name): @@ -110,9 +45,26 @@ def __new__(cls): return cls.instance def __init__(self): - dirname = os.path.dirname(os.path.realpath(__file__)) - mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils") - self.load_binary = mod.load_binary + pass + + def load_binary(self, name, src, shared_mem, device): + # src actually holds asm text, compile to a shared library. + key = hashlib.md5(src).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + asm_path = os.path.join(tmpdir, "kernel.s") + Path(asm_path).write_bytes(src) + Path("kernel.s").write_bytes(src) + so = _build(name, asm_path, tmpdir, library_dir, include_dir, ["gcc", "m"]) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import ctypes + lib = ctypes.cdll.LoadLibrary(cache_path) + fn_ptr = getattr(lib, name) + fn_ptr_as_void_p = ctypes.cast(fn_ptr, ctypes.c_void_p).value + return (fn_ptr, fn_ptr_as_void_p, 0, 0) def get_device_properties(self, *args): return {"max_shared_mem": 0}