From 7a22fc7cd58ddf189f6302c2ed9e0ee8157828a7 Mon Sep 17 00:00:00 2001 From: minjang Date: Fri, 7 Jun 2024 01:15:01 -0700 Subject: [PATCH] Address the comments --- include/triton/Tools/Sys/GetEnv.hpp | 19 ---------- python/tutorials/01-vector-add.py | 38 ++++++++++++-------- third_party/cpu/backend/driver.py | 55 +++++++++++++++++++---------- 3 files changed, 60 insertions(+), 52 deletions(-) diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 08bbd0525b45..a01c4adbba64 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include @@ -30,9 +29,6 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { inline const std::set CACHE_NEUTRAL_ENV_VARS = { "TRITON_REPRODUCER_PATH", - "TRITON_CPU_SINGLE_CORE", - "TRITON_CPU_MAX_THREADS", - "TRITON_CPU_OMP_DEBUG", }; namespace tools { @@ -56,21 +52,6 @@ inline std::string getStrEnv(const std::string &env) { return result; } -inline std::optional getIntEnv(const std::string &env) { - assertIsRecognized(env); - const char *cstr = std::getenv(env.c_str()); - if (!cstr) { - return std::nullopt; - } - - char *endptr; - long int result = std::strtol(cstr, &endptr, 10); - if (endptr == cstr) { - assert(false && "invalid integer"); - } - return result; -} - // return value of a cache-invalidating boolean environment variable inline bool getBoolEnv(const std::string &env) { assertIsRecognized(env); diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index ab7790aeae41..ca2d71c2746d 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -24,7 +24,8 @@ import triton.language as tl GPU_BLOCK_SIZE = 1024 -CPU_BLOCK_SIZE = 128 +CPU_BLOCK_SIZE = 4096 +USE_GPU = True @triton.jit @@ -60,10 +61,11 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. # and (2) enqueue the above kernel with appropriate grid/block sizes: -def add(x: torch.Tensor, y: torch.Tensor, is_cpu): - # We need to preallocate the output. - output = torch.empty_like(x) - assert x.is_cpu == is_cpu and y.is_cpu == is_cpu and output.is_cpu == is_cpu +def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, is_cpu): + if output is None: + # We need to preallocate the output. + output = torch.empty_like(x) + assert x.is_cpu == is_cpu and y.is_cpu == is_cpu and output.is_cpu == is_cpu n_elements = output.numel() # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. @@ -88,22 +90,22 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu): x = torch.rand(size, device='cpu') y = torch.rand(size, device='cpu') output_torch_cpu = x + y -output_triton_cpu = add(x, y, is_cpu=True) +output_triton_cpu = add(x, y, None, is_cpu=True) print(output_torch_cpu) print(output_triton_cpu) print(f'The maximum difference between torch-cpu and triton-cpu is ' f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}') -LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu'] -LINE_NAMES = ['TritonCPU 1-core', 'TritonCPU', 'TorchCPU'] -LINE_STYLES = [('blue', '-'), ('cyan', '-'), ('green', '-')] +LINE_VALS = ['triton-cpu-single-prealloc', 'triton-cpu-prealloc', 'triton-cpu-single', 'triton-cpu', 'torch-cpu'] +LINE_NAMES = ['TritonCPU 1+pre', 'TritonCPU pre', 'TritonCPU 1', 'TritonCPU', 'TorchCPU'] +LINE_STYLES = [('blue', '--'), ('green', '--'), ('blue', '-'), ('green', '-'), ('cyan', '-')] -if triton.runtime.driver.get_active_gpus(): +if USE_GPU and triton.runtime.driver.get_active_gpus(): triton.runtime.driver.set_active_to_gpu() x = x.to('cuda') y = y.to('cuda') output_torch_gpu = x + y - output_triton_gpu = add(x, y, is_cpu=False) + output_triton_gpu = add(x, y, None, is_cpu=False) print(output_torch_gpu) print(output_triton_gpu) print(f'The maximum difference between torch-gpu and triton-gpu is ' @@ -160,13 +162,21 @@ def benchmark(size, provider): if provider == 'torch-gpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, False), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles) elif provider == 'torch-cpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, True), quantiles=quantiles, is_cpu=True) + elif provider == 'triton-cpu-prealloc': + # Note that we preallocate the output buffer here to only measure the kernel performance + # without a large chunk of memory allocation. + output = torch.empty_like(x) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True) + elif provider == 'triton-cpu-single-prealloc': + output = torch.empty_like(x) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, True), quantiles=quantiles, is_cpu=True) gbps = lambda ms: 12 * size / ms * 1e-6 return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 6eb0acc2926c..6278dbfc46b4 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -16,7 +16,6 @@ include_dir = [ os.path.join(dirname, "include"), os.path.join(llvm_root, "include"), - os.path.join(".", "include"), ] library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")] libraries = [ @@ -78,8 +77,6 @@ "z", ] -MINIMUM_OMP_CHUNK_SIZE = 10 - def compile_module_from_src(src, name): key = hashlib.md5(src.encode("utf-8")).hexdigest() @@ -184,18 +181,39 @@ def format_of(ty): # generate glue code src = f""" +#include +#include #include -#include -#include +#include #include +#include #include -#include - -#include "triton/Tools/Sys/GetEnv.hpp" +#include +#include +#include #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #include -#include + +inline bool getBoolEnv(const std::string &env) {{ + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) {{ return std::tolower(c); }}); + return str == "on" || str == "true" || str == "1"; +}} + +inline std::optional getIntEnv(const std::string &env) {{ + const char *cstr = std::getenv(env.c_str()); + if (!cstr) + return std::nullopt; + + char *endptr; + long int result = std::strtol(cstr, &endptr, 10); + if (endptr == cstr) + assert(false && "invalid integer"); + return result; +}} using kernel_ptr_t = void(*)({kernel_fn_arg_types}); @@ -255,12 +273,14 @@ def format_of(ty): }} static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + // TODO: Consider using omp collapse(3) clause for simplicity? auto all_grids = get_all_grids(gridX, gridY, gridZ); size_t N = gridX * gridY * gridZ; - if (mlir::triton::tools::getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{ - if (mlir::triton::tools::getBoolEnv("TRITON_CPU_OMP_DEBUG")) + if (getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{ + if (getBoolEnv("TRITON_CPU_OMP_DEBUG")) printf("Single core launcher\\n"); + for (uint32_t i = 0; i < N; ++i) {{ const auto [x, y, z] = all_grids[i]; (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); @@ -268,20 +288,17 @@ def format_of(ty): return; }} - // Use the default static scheduling with a simple chuck size policy. - std::optional max_threads = mlir::triton::tools::getIntEnv("TRITON_CPU_MAX_THREADS"); + std::optional max_threads = getIntEnv("TRITON_CPU_MAX_THREADS"); if (max_threads.has_value()) max_threads = std::max(1, std::min(max_threads.value(), omp_get_max_threads())); else max_threads = omp_get_max_threads(); - int chunk_size = std::ceil((double)N / (double)max_threads.value()); - chunk_size = std::max(chunk_size, {MINIMUM_OMP_CHUNK_SIZE}); - - if (mlir::triton::tools::getBoolEnv("TRITON_CPU_OMP_DEBUG")) - printf("N: %zu, max_threads: %d, chunk_size: %zu\\n", N, max_threads, chunk_size); + if (getBoolEnv("TRITON_CPU_OMP_DEBUG")) + printf("N: %zu, max_threads: %d\\n", N, max_threads.value()); -#pragma omp parallel for schedule(static, chunk_size) num_threads(max_threads.value()) + // For now, use the default chunk size, total iterations / max_threads. +#pragma omp parallel for schedule(static) num_threads(max_threads.value()) for (uint32_t i = 0; i < N; ++i) {{ const auto [x, y, z] = all_grids[i]; (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z);