Skip to content

Commit

Permalink
Add num_threads option to control threading per kernel invocation. (#170
Browse files Browse the repository at this point in the history
)

Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich authored Oct 28, 2024
1 parent 51427ed commit a7d1412
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 46 deletions.
5 changes: 0 additions & 5 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,13 @@ def add_tiled_with_st_threshold(x: torch.Tensor, y: torch.Tensor, output):
args={}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(size, provider):
import os

device = 'cpu' if 'cpu' in provider else 'cuda'
x = torch.rand(size, device=device, dtype=torch.float32)
y = torch.rand(size, device=device, dtype=torch.float32)

if device == 'cpu':
triton.runtime.driver.set_active_to_cpu()
if 'single' in provider:
os.environ['TRITON_CPU_SINGLE_CORE'] = '1'
else:
os.unsetenv('TRITON_CPU_SINGLE_CORE')
else:
triton.runtime.driver.set_active_to_gpu()
output = torch.empty_like(x)
Expand Down
10 changes: 3 additions & 7 deletions python/tutorials/02-fused-softmax-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.


def softmax(x, y=None):
def softmax(x, y=None, num_threads=0):
n_rows, n_cols = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
Expand All @@ -126,6 +126,7 @@ def softmax(x, y=None):
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
num_threads=num_threads,
)
return y

Expand Down Expand Up @@ -190,7 +191,6 @@ def softmax(x, y=None):
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
import os

# Currently compilation time is very long. Let's show the progress.
print(f"Running {provider} with {M} x {N}...")
Expand All @@ -201,10 +201,6 @@ def benchmark(M, N, provider):
if device == 'cpu':
y = torch.empty_like(x)
triton.runtime.driver.set_active_to_cpu()
if 'single' in provider:
os.environ['TRITON_CPU_SINGLE_CORE'] = '1'
else:
os.unsetenv('TRITON_CPU_SINGLE_CORE')
else:
y = None
triton.runtime.driver.set_active_to_gpu()
Expand All @@ -218,7 +214,7 @@ def benchmark(M, N, provider):
compiled = torch.compile(naive_softmax)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x), quantiles=quantiles)
if provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y, num_threads=1), quantiles=quantiles)
if provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles)
if provider == 'triton-gpu':
Expand Down
10 changes: 3 additions & 7 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def matmul_kernel(
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.


def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
Expand All @@ -272,6 +272,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
c.stride(0), c.stride(1), #
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, #
GROUP_SIZE_M=GROUP_SIZE_M, #
num_threads=num_threads, #
)
return c

Expand Down Expand Up @@ -359,7 +360,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
args={}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(M, N, K, provider):
import os

device = 'cpu' if 'cpu' in provider else 'cuda'
a = torch.randn((M, K), device=device, dtype=torch.float32)
Expand All @@ -368,10 +368,6 @@ def benchmark(M, N, K, provider):
if device == 'cpu':
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
triton.runtime.driver.set_active_to_cpu()
if 'single' in provider:
os.environ['TRITON_CPU_SINGLE_CORE'] = '1'
else:
os.unsetenv('TRITON_CPU_SINGLE_CORE')
else:
c = None
triton.runtime.driver.set_active_to_gpu()
Expand All @@ -387,7 +383,7 @@ def benchmark(M, N, K, provider):
compiled = torch.compile(torch.matmul)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles)
elif provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c, num_threads=1), quantiles=quantiles)
elif provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles)
perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3)
Expand Down
12 changes: 7 additions & 5 deletions python/tutorials/matrix-vector-multiplication-bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def gemv(
weight: torch.Tensor,
x: torch.Tensor,
output: torch.Tensor,
num_threads=0,
):
assert weight.shape[1] == x.shape[0], "Incompatible dimensions"
assert weight.is_contiguous() and x.is_contiguous(), "Input and weight must be contiguous"
Expand All @@ -69,7 +70,8 @@ def gemv(
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), )

gemv_kernel[grid](output, weight, x, M, N, weight.stride(0), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N)
gemv_kernel[grid](output, weight, x, M, N, weight.stride(0), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N,
num_threads=num_threads)

return output

Expand Down Expand Up @@ -148,7 +150,6 @@ def gemv(
args={}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(M, N, provider):
import os

device = 'cpu' if 'cpu' in provider else 'cuda'
weight = torch.randn((M, N), device=device, dtype=torch.bfloat16)
Expand All @@ -157,11 +158,11 @@ def benchmark(M, N, provider):
if device == 'cpu':
output = torch.empty((M), device=x.device, dtype=x.dtype)
triton.runtime.driver.set_active_to_cpu()
num_threads = 0
if 'single' in provider:
os.environ['TRITON_CPU_SINGLE_CORE'] = '1'
num_threads = 1
torch.set_num_threads(1)
else:
os.unsetenv('TRITON_CPU_SINGLE_CORE')
torch.set_num_threads(default_num_threads)
else:
output = None
Expand All @@ -178,7 +179,8 @@ def benchmark(M, N, provider):
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_matmul(weight, x, out=output),
quantiles=quantiles)
elif 'triton-cpu' in provider:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output, num_threads=num_threads),
quantiles=quantiles)

perf = lambda ms: 2 * M * N * 1e-9 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
Expand Down
12 changes: 5 additions & 7 deletions python/tutorials/matrix-vector-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def gemv(
weight: torch.Tensor,
x: torch.Tensor,
output: torch.Tensor,
num_threads=0,
):
assert weight.shape[1] == x.shape[0], "Incompatible dimensions"
assert weight.is_contiguous() and x.is_contiguous(), "Input and weight must be contiguous"
Expand All @@ -68,7 +69,8 @@ def gemv(
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), )

gemv_kernel[grid](output, weight, x, M, N, weight.stride(0), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N)
gemv_kernel[grid](output, weight, x, M, N, weight.stride(0), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N,
num_threads=num_threads)

return output

Expand Down Expand Up @@ -146,7 +148,6 @@ def gemv(
args={'M': 4096}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(M, N, provider):
import os

device = 'cpu' if 'cpu' in provider else 'cuda'
weight = torch.randn((M, N), device=device, dtype=torch.float32)
Expand All @@ -155,10 +156,6 @@ def benchmark(M, N, provider):
if device == 'cpu':
output = torch.empty((M), device=x.device, dtype=x.dtype)
triton.runtime.driver.set_active_to_cpu()
if 'single' in provider:
os.environ['TRITON_CPU_SINGLE_CORE'] = '1'
else:
os.unsetenv('TRITON_CPU_SINGLE_CORE')

if 'transpose' in provider:
weight = torch.transpose(weight, 0, 1)
Expand Down Expand Up @@ -190,7 +187,8 @@ def benchmark(M, N, provider):
weight = torch.nn.Linear(N, M, bias=False, device=weight.device, dtype=weight.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: weight.forward(x), quantiles=quantiles)
elif provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output, num_threads=1),
quantiles=quantiles)
elif provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles)
elif provider == 'triton-cpu-linear':
Expand Down
5 changes: 5 additions & 0 deletions third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,14 @@ class CPUOptions:
# GPU-specific options are used in several places.
# For now, we just provide dummy values.
backend_name: str = "cpu"
# These options provide compatibility with GPU kernel calls.
# All of them are ignored.
num_warps: int = 0
num_stages: int = 0
num_ctas: int = 0
# Max number of threads to be used for a kernel call.
# Zero value is used to utilize all available CPU cores.
num_threads: int = 0
cluster_dims: tuple = (1, 1, 1)
extern_libs: dict = None
debug: bool = False
Expand Down
27 changes: 12 additions & 15 deletions third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def format_of(ty):
return grids;
}}
static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_threads, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
// TODO: Consider using omp collapse(3) clause for simplicity?
size_t N = gridX * gridY * gridZ;
if (N == 1) {{
Expand All @@ -238,28 +238,19 @@ def format_of(ty):
}}
auto all_grids = get_all_grids(gridX, gridY, gridZ);
if (getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{
if (getBoolEnv("TRITON_CPU_OMP_DEBUG"))
printf("Single core launcher\\n");
int max_threads = (num_threads > 0) ? num_threads : omp_get_max_threads();
// Don't pay OMP overhead price when a single thread is used.
if (max_threads == 1) {{
for (size_t i = 0; i < N; ++i) {{
const auto [x, y, z] = all_grids[i];
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z, gridX, gridY, gridZ);
}}
return;
}}
std::optional<int> max_threads = getIntEnv("TRITON_CPU_MAX_THREADS");
if (max_threads.has_value())
max_threads = std::max(1, std::min(max_threads.value(), omp_get_max_threads()));
else
max_threads = omp_get_max_threads();
if (getBoolEnv("TRITON_CPU_OMP_DEBUG"))
printf("N: %zu, max_threads: %d\\n", N, max_threads.value());
// For now, use the default chunk size, total iterations / max_threads.
#pragma omp parallel for schedule(static) num_threads(max_threads.value())
#pragma omp parallel for schedule(static) num_threads(max_threads)
for (size_t i = 0; i < N; ++i) {{
const auto [x, y, z] = all_grids[i];
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z, gridX, gridY, gridZ);
Expand All @@ -285,6 +276,12 @@ def format_of(ty):
void *pStream = PyLong_AsVoidPtr(py_obj_stream);
kernel_ptr_t kernel_ptr = reinterpret_cast<kernel_ptr_t>(pKrnl);
// Extract num_threads metadata.
int num_threads = 0;
PyObject *num_threads_attr = PyObject_GetAttrString(kernel_metadata, "num_threads");
if (num_threads_attr && PyLong_Check(num_threads_attr))
num_threads = PyLong_AsLong(num_threads_attr);
// extract launch metadata
if (launch_enter_hook != Py_None){{
PyObject* args = Py_BuildValue("(O)", launch_metadata);
Expand All @@ -295,7 +292,7 @@ def format_of(ty):
}}
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
run_omp_kernels(gridX, gridY, gridZ, kernel_ptr {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''});
run_omp_kernels(gridX, gridY, gridZ, num_threads, kernel_ptr {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''});
if(launch_exit_hook != Py_None){{
PyObject* args = Py_BuildValue("(O)", launch_metadata);
Expand Down

0 comments on commit a7d1412

Please sign in to comment.