Skip to content

Commit

Permalink
Utility libxsmm Python extension (triton-lang#17)
Browse files Browse the repository at this point in the history
Adds a python wrapper for a parallelized in-place copy function using libxsmm and OpenMP.
It is intended to be used for efficient tensor padding implementation.

The libxsmm path have to be specified through env variables:
  - XSMM_ROOT_DIR - path to libxsmm root dir with headers
  - XSMM_LIB_DIR - path to libxsmm.so location

libxsmm .so also has to be available during runtime execution e.g., exposed through LD_LIBRARY_PATH.
The XSMM python module can be built and installed using command:
  pip install -e ./third_party/cpu/python/
  • Loading branch information
adam-smnk authored and Devjiu committed Jan 20, 2025
1 parent ad183bb commit b0ae885
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 7 deletions.
14 changes: 7 additions & 7 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@
# ------------

import torch
import math

import triton
import triton.language as tl
Expand All @@ -175,6 +174,7 @@
DYNAMIC_K_BLOCK = False
CACHE_PADDING = False
PREPROCESS_EXTERNAL = False
XSMM_PAD = False

@triton.jit
def pad_kernel(in_ptr, out_ptr, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, PADDING: tl.constexpr):
Expand Down Expand Up @@ -316,12 +316,12 @@ def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, n
pad_kernel[(K // BLOCK_SIZE_K, )](b, b_scratch, N, BLOCK_SIZE_K, BLOCK_SIZE_N, 32, num_threads=num_threads)
b = b_scratch

# TODO: Check if padding is needed at all.
# Currently, cache padding is most useful together with dynamic K blocking
# to ensure that stride is non-power-of-two to improve cache behavior.
if CACHE_PADDING:
a = torch.nn.functional.pad(a, (0, 32, 0, 0), mode='constant', value=0)
b = torch.nn.functional.pad(b, (0, 32, 0, 0), mode='constant', value=0)
# TODO: Check if padding is needed at all.
# Currently, cache padding is most useful together with dynamic K blocking
# to ensure that stride is non-power-of-two to improve cache behavior.
if CACHE_PADDING:
a = torch.nn.functional.pad(a, (0, 32, 0, 0), mode='constant', value=0)
b = torch.nn.functional.pad(b, (0, 32, 0, 0), mode='constant', value=0)

#TODO: Currently masked load is not supported yet.
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (
Expand Down
21 changes: 21 additions & 0 deletions third_party/cpu/python/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from setuptools import setup, Extension
from torch.utils import cpp_extension
import os

xsmm_root = os.getenv("XSMM_ROOT_DIR")
xsmm_lib = os.getenv("XSMM_LIB_DIR")
print(f'Using LIBXSMM root: {xsmm_root}')
print(f'LIBXSMM lib location: {xsmm_lib}')

setup(name='xsmm_py',
ext_modules=[
cpp_extension.CppExtension('xsmm_py', ['xsmm_utils.cpp'],
include_dirs=[
f'{xsmm_root}/include',
f'{xsmm_root}/src/template'
],
library_dirs=[f'{xsmm_lib}'],
libraries=['xsmm', 'omp'],
extra_compile_args=['-fopenmp']
)],
cmdclass={'build_ext': cpp_extension.BuildExtension})
56 changes: 56 additions & 0 deletions third_party/cpu/python/xsmm_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include <torch/extension.h>

#include "libxsmm.h"
#include <omp.h>

#include <cstring>

void fastZeroPad2D(const at::Tensor &input, torch::Tensor &output) {
const auto inSizes = input.sizes();
const auto outSizes = output.sizes();
const auto byteSize = input.element_size();
assert(input.is_floating_point() && inSizes.size() == 2 ||
outSizes.size() == 2 && outSizes[0] >= inSizes[0] &&
outSizes[1] >= inSizes[1] && byteSize == output.element_size() &&
"Invalid fastZeroPad2D tensors");

libxsmm_datatype dtype =
byteSize == 4 ? LIBXSMM_DATATYPE_F32 : LIBXSMM_DATATYPE_BF16;
libxsmm_meltw_unary_shape shape;
// Fliped to libxsmm's column-major convention.
shape.m = inSizes[1];
shape.n = 1;
shape.ldi = inSizes[1];
shape.ldo = outSizes[1];
shape.in0_type = dtype;
shape.out_type = dtype;
shape.comp_type = dtype;
libxsmm_bitfield flags = LIBXSMM_MELTW_FLAG_UNARY_NONE;
libxsmm_meltwfunction_unary identityFn = libxsmm_dispatch_meltw_unary(
LIBXSMM_MELTW_TYPE_UNARY_IDENTITY, shape, flags);

void *baseIn = input.data_ptr();
void *outIn = output.data_ptr();
const int padRight = outSizes[1] - inSizes[1];

#pragma omp parallel for schedule(static)
for (int i = 0; i < inSizes[0]; ++i) {
libxsmm_meltw_unary_param param;
param.in.primary = baseIn + i * inSizes[1] * byteSize;
param.out.primary = outIn + i * outSizes[1] * byteSize;
identityFn(&param);
// Zero out right padding.
std::memset(outIn + i * outSizes[1] * byteSize + inSizes[1] * byteSize, 0,
byteSize * padRight);
}

// Zero out bottom padding.
#pragma omp parallel for schedule(static)
for (int i = inSizes[0]; i < outSizes[0]; ++i) {
std::memset(outIn + i * outSizes[1] * byteSize, 0, byteSize * outSizes[1]);
}
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fastZeroPad2D", &fastZeroPad2D, "Fast 2D tensor zero padding");
}

0 comments on commit b0ae885

Please sign in to comment.