Skip to content

Commit

Permalink
Address the comments
Browse files Browse the repository at this point in the history
  • Loading branch information
minjang committed Jun 7, 2024
1 parent c045887 commit 7a22fc7
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 52 deletions.
19 changes: 0 additions & 19 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <algorithm>
#include <assert.h>
#include <cstdlib>
#include <optional>
#include <set>
#include <sstream>
#include <string>
Expand All @@ -30,9 +29,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {

inline const std::set<std::string> CACHE_NEUTRAL_ENV_VARS = {
"TRITON_REPRODUCER_PATH",
"TRITON_CPU_SINGLE_CORE",
"TRITON_CPU_MAX_THREADS",
"TRITON_CPU_OMP_DEBUG",
};

namespace tools {
Expand All @@ -56,21 +52,6 @@ inline std::string getStrEnv(const std::string &env) {
return result;
}

inline std::optional<int64_t> getIntEnv(const std::string &env) {
assertIsRecognized(env);
const char *cstr = std::getenv(env.c_str());
if (!cstr) {
return std::nullopt;
}

char *endptr;
long int result = std::strtol(cstr, &endptr, 10);
if (endptr == cstr) {
assert(false && "invalid integer");
}
return result;
}

// return value of a cache-invalidating boolean environment variable
inline bool getBoolEnv(const std::string &env) {
assertIsRecognized(env);
Expand Down
38 changes: 24 additions & 14 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import triton.language as tl

GPU_BLOCK_SIZE = 1024
CPU_BLOCK_SIZE = 128
CPU_BLOCK_SIZE = 4096
USE_GPU = True


@triton.jit
Expand Down Expand Up @@ -60,10 +61,11 @@ def add_kernel(x_ptr, # *Pointer* to first input vector.
# and (2) enqueue the above kernel with appropriate grid/block sizes:


def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.is_cpu == is_cpu and y.is_cpu == is_cpu and output.is_cpu == is_cpu
def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, is_cpu):
if output is None:
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.is_cpu == is_cpu and y.is_cpu == is_cpu and output.is_cpu == is_cpu
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
Expand All @@ -88,22 +90,22 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
x = torch.rand(size, device='cpu')
y = torch.rand(size, device='cpu')
output_torch_cpu = x + y
output_triton_cpu = add(x, y, is_cpu=True)
output_triton_cpu = add(x, y, None, is_cpu=True)
print(output_torch_cpu)
print(output_triton_cpu)
print(f'The maximum difference between torch-cpu and triton-cpu is '
f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}')

LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu']
LINE_NAMES = ['TritonCPU 1-core', 'TritonCPU', 'TorchCPU']
LINE_STYLES = [('blue', '-'), ('cyan', '-'), ('green', '-')]
LINE_VALS = ['triton-cpu-single-prealloc', 'triton-cpu-prealloc', 'triton-cpu-single', 'triton-cpu', 'torch-cpu']
LINE_NAMES = ['TritonCPU 1+pre', 'TritonCPU pre', 'TritonCPU 1', 'TritonCPU', 'TorchCPU']
LINE_STYLES = [('blue', '--'), ('green', '--'), ('blue', '-'), ('green', '-'), ('cyan', '-')]

if triton.runtime.driver.get_active_gpus():
if USE_GPU and triton.runtime.driver.get_active_gpus():
triton.runtime.driver.set_active_to_gpu()
x = x.to('cuda')
y = y.to('cuda')
output_torch_gpu = x + y
output_triton_gpu = add(x, y, is_cpu=False)
output_triton_gpu = add(x, y, None, is_cpu=False)
print(output_torch_gpu)
print(output_triton_gpu)
print(f'The maximum difference between torch-gpu and triton-gpu is '
Expand Down Expand Up @@ -160,13 +162,21 @@ def benchmark(size, provider):
if provider == 'torch-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
elif provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, False), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles)
elif provider == 'torch-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, True), quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu-prealloc':
# Note that we preallocate the output buffer here to only measure the kernel performance
# without a large chunk of memory allocation.
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu-single-prealloc':
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, True), quantiles=quantiles, is_cpu=True)
gbps = lambda ms: 12 * size / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down
55 changes: 36 additions & 19 deletions third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
include_dir = [
os.path.join(dirname, "include"),
os.path.join(llvm_root, "include"),
os.path.join(".", "include"),
]
library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")]
libraries = [
Expand Down Expand Up @@ -78,8 +77,6 @@
"z",
]

MINIMUM_OMP_CHUNK_SIZE = 10


def compile_module_from_src(src, name):
key = hashlib.md5(src.encode("utf-8")).hexdigest()
Expand Down Expand Up @@ -184,18 +181,39 @@ def format_of(ty):

# generate glue code
src = f"""
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <string>
#include <iostream>
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <omp.h>
#include <cmath>
#include "triton/Tools/Sys/GetEnv.hpp"
#include <optional>
#include <stdio.h>
#include <string>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
#include <stdio.h>
inline bool getBoolEnv(const std::string &env) {{
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) {{ return std::tolower(c); }});
return str == "on" || str == "true" || str == "1";
}}
inline std::optional<int64_t> getIntEnv(const std::string &env) {{
const char *cstr = std::getenv(env.c_str());
if (!cstr)
return std::nullopt;
char *endptr;
long int result = std::strtol(cstr, &endptr, 10);
if (endptr == cstr)
assert(false && "invalid integer");
return result;
}}
using kernel_ptr_t = void(*)({kernel_fn_arg_types});
Expand Down Expand Up @@ -255,33 +273,32 @@ def format_of(ty):
}}
static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
// TODO: Consider using omp collapse(3) clause for simplicity?
auto all_grids = get_all_grids(gridX, gridY, gridZ);
size_t N = gridX * gridY * gridZ;
if (mlir::triton::tools::getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{
if (mlir::triton::tools::getBoolEnv("TRITON_CPU_OMP_DEBUG"))
if (getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{
if (getBoolEnv("TRITON_CPU_OMP_DEBUG"))
printf("Single core launcher\\n");
for (uint32_t i = 0; i < N; ++i) {{
const auto [x, y, z] = all_grids[i];
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z);
}}
return;
}}
// Use the default static scheduling with a simple chuck size policy.
std::optional<int> max_threads = mlir::triton::tools::getIntEnv("TRITON_CPU_MAX_THREADS");
std::optional<int> max_threads = getIntEnv("TRITON_CPU_MAX_THREADS");
if (max_threads.has_value())
max_threads = std::max(1, std::min(max_threads.value(), omp_get_max_threads()));
else
max_threads = omp_get_max_threads();
int chunk_size = std::ceil((double)N / (double)max_threads.value());
chunk_size = std::max(chunk_size, {MINIMUM_OMP_CHUNK_SIZE});
if (mlir::triton::tools::getBoolEnv("TRITON_CPU_OMP_DEBUG"))
printf("N: %zu, max_threads: %d, chunk_size: %zu\\n", N, max_threads, chunk_size);
if (getBoolEnv("TRITON_CPU_OMP_DEBUG"))
printf("N: %zu, max_threads: %d\\n", N, max_threads.value());
#pragma omp parallel for schedule(static, chunk_size) num_threads(max_threads.value())
// For now, use the default chunk size, total iterations / max_threads.
#pragma omp parallel for schedule(static) num_threads(max_threads.value())
for (uint32_t i = 0; i < N; ++i) {{
const auto [x, y, z] = all_grids[i];
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z);
Expand Down

0 comments on commit 7a22fc7

Please sign in to comment.