Skip to content

Commit

Permalink
[CPU] Add OpenMP launcher
Browse files Browse the repository at this point in the history
  • Loading branch information
minjang committed Jun 5, 2024
1 parent 6d74ccd commit c045887
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 12 deletions.
19 changes: 19 additions & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <algorithm>
#include <assert.h>
#include <cstdlib>
#include <optional>
#include <set>
#include <sstream>
#include <string>
Expand All @@ -29,6 +30,9 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {

inline const std::set<std::string> CACHE_NEUTRAL_ENV_VARS = {
"TRITON_REPRODUCER_PATH",
"TRITON_CPU_SINGLE_CORE",
"TRITON_CPU_MAX_THREADS",
"TRITON_CPU_OMP_DEBUG",
};

namespace tools {
Expand All @@ -52,6 +56,21 @@ inline std::string getStrEnv(const std::string &env) {
return result;
}

inline std::optional<int64_t> getIntEnv(const std::string &env) {
assertIsRecognized(env);
const char *cstr = std::getenv(env.c_str());
if (!cstr) {
return std::nullopt;
}

char *endptr;
long int result = std::strtol(cstr, &endptr, 10);
if (endptr == cstr) {
assert(false && "invalid integer");
}
return result;
}

// return value of a cache-invalidating boolean environment variable
inline bool getBoolEnv(const std::string &env) {
assertIsRecognized(env);
Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
cc_cmd += [f"-I{dir}" for dir in include_dirs]
# CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag.
if src.endswith(".cpp") or src.endswith(".cc"):
cc_cmd += ["-std=c++17"]
cc_cmd += ["-std=c++17", "-fopenmp"]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
Expand Down
21 changes: 15 additions & 6 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import triton
import triton.language as tl

BLOCK_SIZE = 1024
GPU_BLOCK_SIZE = 1024
CPU_BLOCK_SIZE = 128


@triton.jit
Expand Down Expand Up @@ -72,7 +73,7 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
# - Don't forget to pass meta-parameters as keywords arguments.
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=CPU_BLOCK_SIZE if is_cpu else GPU_BLOCK_SIZE)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
return output
Expand All @@ -93,9 +94,9 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
print(f'The maximum difference between torch-cpu and triton-cpu is '
f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}')

LINE_VALS = ['triton-cpu', 'torch-cpu']
LINE_NAMES = ['TritonCPU', 'TorchCPU']
LINE_STYLES = [('blue', '-'), ('green', '-')]
LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu']
LINE_NAMES = ['TritonCPU 1-core', 'TritonCPU', 'TorchCPU']
LINE_STYLES = [('blue', '-'), ('cyan', '-'), ('green', '-')]

if triton.runtime.driver.get_active_gpus():
triton.runtime.driver.set_active_to_gpu()
Expand Down Expand Up @@ -136,16 +137,22 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
ylabel='GB/s', # Label name for the y-axis.
plot_name=
# Name for the plot. Used also as a file name for saving the plot.
f'vector-add-performance (BLOCK_SIZE={BLOCK_SIZE})',
f'vector-add-performance (CPU_BLOCK_SIZE={CPU_BLOCK_SIZE}, GPU_BLOCK_SIZE={GPU_BLOCK_SIZE})',
args={}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(size, provider):
import os

device = 'cpu' if 'cpu' in provider else 'cuda'
x = torch.rand(size, device=device, dtype=torch.float32)
y = torch.rand(size, device=device, dtype=torch.float32)

if device == 'cpu':
triton.runtime.driver.set_active_to_cpu()
if 'single' in provider:
os.environ['TRITON_CPU_SINGLE_CORE'] = '1'
else:
os.unsetenv('TRITON_CPU_SINGLE_CORE')
else:
triton.runtime.driver.set_active_to_gpu()

Expand All @@ -156,6 +163,8 @@ def benchmark(size, provider):
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, False), quantiles=quantiles)
elif provider == 'torch-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True)
gbps = lambda ms: 12 * size / ms * 1e-6
Expand Down
52 changes: 47 additions & 5 deletions third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
include_dir = [
os.path.join(dirname, "include"),
os.path.join(llvm_root, "include"),
os.path.join(".", "include"),
]
library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")]
libraries = [
Expand Down Expand Up @@ -77,6 +78,8 @@
"z",
]

MINIMUM_OMP_CHUNK_SIZE = 10


def compile_module_from_src(src, name):
key = hashlib.md5(src.encode("utf-8")).hexdigest()
Expand Down Expand Up @@ -110,7 +113,6 @@ def __new__(cls):
return cls.instance

def __init__(self):
pass
dirname = os.path.dirname(os.path.realpath(__file__))
mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils")
self.load_binary = mod.load_binary
Expand Down Expand Up @@ -186,6 +188,10 @@ def format_of(ty):
#include <string>
#include <iostream>
#include <iomanip>
#include <omp.h>
#include <cmath>
#include "triton/Tools/Sys/GetEnv.hpp"
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
Expand Down Expand Up @@ -233,20 +239,56 @@ def format_of(ty):
return ptr_info;
}}
static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
// TODO: add OMP pragmas to run in parallel
static std::unique_ptr<uint32_t[][3]> get_all_grids(uint32_t gridX, uint32_t gridY, uint32_t gridZ) {{
std::unique_ptr<uint32_t[][3]> grids(new uint32_t[gridX * gridY * gridZ][3]);
// TODO: which order would be more effective for cache locality?
for (uint32_t z = 0; z < gridZ; ++z) {{
for (uint32_t y = 0; y < gridY; ++y) {{
for (uint32_t x = 0; x < gridX; ++x) {{
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z);
grids[z * gridY * gridX + y * gridX + x][0] = x;
grids[z * gridY * gridX + y * gridX + x][1] = y;
grids[z * gridY * gridX + y * gridX + x][2] = z;
}}
}}
}}
return grids;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
auto all_grids = get_all_grids(gridX, gridY, gridZ);
size_t N = gridX * gridY * gridZ;
if (mlir::triton::tools::getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{
if (mlir::triton::tools::getBoolEnv("TRITON_CPU_OMP_DEBUG"))
printf("Single core launcher\\n");
for (uint32_t i = 0; i < N; ++i) {{
const auto [x, y, z] = all_grids[i];
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z);
}}
return;
}}
// Use the default static scheduling with a simple chuck size policy.
std::optional<int> max_threads = mlir::triton::tools::getIntEnv("TRITON_CPU_MAX_THREADS");
if (max_threads.has_value())
max_threads = std::max(1, std::min(max_threads.value(), omp_get_max_threads()));
else
max_threads = omp_get_max_threads();
int chunk_size = std::ceil((double)N / (double)max_threads.value());
chunk_size = std::max(chunk_size, {MINIMUM_OMP_CHUNK_SIZE});
if (mlir::triton::tools::getBoolEnv("TRITON_CPU_OMP_DEBUG"))
printf("N: %zu, max_threads: %d, chunk_size: %zu\\n", N, max_threads, chunk_size);
#pragma omp parallel for schedule(static, chunk_size) num_threads(max_threads.value())
for (uint32_t i = 0; i < N; ++i) {{
const auto [x, y, z] = all_grids[i];
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z);
}}
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
Expand Down

0 comments on commit c045887

Please sign in to comment.