Skip to content

Commit

Permalink
Utility libxsmm Python extension (triton-lang#17)
Browse files Browse the repository at this point in the history
Adds a python wrapper for a parallelized in-place copy function using libxsmm and OpenMP.
It is intended to be used for efficient tensor padding implementation.

The libxsmm path have to be specified through env variables:
  - XSMM_ROOT_DIR - path to libxsmm root dir with headers
  - XSMM_LIB_DIR - path to libxsmm.so location

libxsmm .so also has to be available during runtime execution e.g., exposed through LD_LIBRARY_PATH.
The XSMM python module can be built and installed using command:
  pip install -e ./third_party/cpu/python/
  • Loading branch information
adam-smnk authored and Devjiu committed Nov 13, 2024
1 parent c46ed63 commit 54e8dea
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 14 deletions.
43 changes: 29 additions & 14 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,12 @@
# ------------

import torch
import math

import triton
import triton.language as tl

import xsmm_py

BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 32
BLOCK_SIZE_K = 32
Expand All @@ -168,6 +169,7 @@
DYNAMIC_K_BLOCK = False
CACHE_PADDING = False
PREPROCESS_EXTERNAL = False
XSMM_PAD = False

@triton.jit
def matmul_kernel(
Expand Down Expand Up @@ -301,6 +303,8 @@ def matmul_kernel(
tl.store(c_ptrs, c)


a_scratch = torch.empty((), dtype=DATA_TYPE)
b_scratch = torch.empty((), dtype=DATA_TYPE)
def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
Expand All @@ -315,19 +319,30 @@ def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
# should be considered.
k_block = min(triton.next_power_of_2(K), 1024)

if K_DIM_PADDING or DYNAMIC_K_BLOCK:
padding_size = (math.ceil(K / k_block) * k_block) - K
if padding_size != 0:
a = torch.nn.functional.pad(a, (0, padding_size, 0, 0), mode='constant', value=0)
b = torch.nn.functional.pad(b, (0, 0, 0, padding_size), mode='constant', value=0)
K = a.shape[1]

# TODO: Check if padding is needed at all.
# Currently, cache padding is most useful together with dynamic K blocking
# to ensure that stride is non-power-of-two to improve cache behavior.
if CACHE_PADDING:
a = torch.nn.functional.pad(a, (0, 32, 0, 0), mode='constant', value=0)
b = torch.nn.functional.pad(b, (0, 32, 0, 0), mode='constant', value=0)
if XSMM_PAD:
padding_size = (((K + k_block - 1) // k_block) * k_block) - K
col_pad = 32 if CACHE_PADDING else 0
a_scratch.resize_(M, K + padding_size + col_pad)
b_scratch.resize_(K + padding_size, N + col_pad)
xsmm_py.fastZeroPad2D(a, a_scratch)
xsmm_py.fastZeroPad2D(b, b_scratch)
K = K + padding_size
a = a_scratch
b = b_scratch
else:
if K_DIM_PADDING or DYNAMIC_K_BLOCK:
padding_size = (((K + k_block - 1) // k_block) * k_block) - K
if padding_size != 0:
a = torch.nn.functional.pad(a, (0, padding_size, 0, 0), mode='constant', value=0)
b = torch.nn.functional.pad(b, (0, 0, 0, padding_size), mode='constant', value=0)
K = a.shape[1]

# TODO: Check if padding is needed at all.
# Currently, cache padding is most useful together with dynamic K blocking
# to ensure that stride is non-power-of-two to improve cache behavior.
if CACHE_PADDING:
a = torch.nn.functional.pad(a, (0, 32, 0, 0), mode='constant', value=0)
b = torch.nn.functional.pad(b, (0, 32, 0, 0), mode='constant', value=0)

#TODO: Currently masked load is not supported yet.
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (
Expand Down
21 changes: 21 additions & 0 deletions third_party/cpu/python/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from setuptools import setup, Extension
from torch.utils import cpp_extension
import os

xsmm_root = os.getenv("XSMM_ROOT_DIR")
xsmm_lib = os.getenv("XSMM_LIB_DIR")
print(f'Using LIBXSMM root: {xsmm_root}')
print(f'LIBXSMM lib location: {xsmm_lib}')

setup(name='xsmm_py',
ext_modules=[
cpp_extension.CppExtension('xsmm_py', ['xsmm_utils.cpp'],
include_dirs=[
f'{xsmm_root}/include',
f'{xsmm_root}/src/template'
],
library_dirs=[f'{xsmm_lib}'],
libraries=['xsmm', 'omp'],
extra_compile_args=['-fopenmp']
)],
cmdclass={'build_ext': cpp_extension.BuildExtension})
56 changes: 56 additions & 0 deletions third_party/cpu/python/xsmm_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include <torch/extension.h>

#include "libxsmm.h"
#include <omp.h>

#include <cstring>

void fastZeroPad2D(const at::Tensor &input, torch::Tensor &output) {
const auto inSizes = input.sizes();
const auto outSizes = output.sizes();
const auto byteSize = input.element_size();
assert(input.is_floating_point() && inSizes.size() == 2 ||
outSizes.size() == 2 && outSizes[0] >= inSizes[0] &&
outSizes[1] >= inSizes[1] && byteSize == output.element_size() &&
"Invalid fastZeroPad2D tensors");

libxsmm_datatype dtype =
byteSize == 4 ? LIBXSMM_DATATYPE_F32 : LIBXSMM_DATATYPE_BF16;
libxsmm_meltw_unary_shape shape;
// Fliped to libxsmm's column-major convention.
shape.m = inSizes[1];
shape.n = 1;
shape.ldi = inSizes[1];
shape.ldo = outSizes[1];
shape.in0_type = dtype;
shape.out_type = dtype;
shape.comp_type = dtype;
libxsmm_bitfield flags = LIBXSMM_MELTW_FLAG_UNARY_NONE;
libxsmm_meltwfunction_unary identityFn = libxsmm_dispatch_meltw_unary(
LIBXSMM_MELTW_TYPE_UNARY_IDENTITY, shape, flags);

void *baseIn = input.data_ptr();
void *outIn = output.data_ptr();
const int padRight = outSizes[1] - inSizes[1];

#pragma omp parallel for schedule(static)
for (int i = 0; i < inSizes[0]; ++i) {
libxsmm_meltw_unary_param param;
param.in.primary = baseIn + i * inSizes[1] * byteSize;
param.out.primary = outIn + i * outSizes[1] * byteSize;
identityFn(&param);
// Zero out right padding.
std::memset(outIn + i * outSizes[1] * byteSize + inSizes[1] * byteSize, 0,
byteSize * padRight);
}

// Zero out bottom padding.
#pragma omp parallel for schedule(static)
for (int i = inSizes[0]; i < outSizes[0]; ++i) {
std::memset(outIn + i * outSizes[1] * byteSize, 0, byteSize * outSizes[1]);
}
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fastZeroPad2D", &fastZeroPad2D, "Fast 2D tensor zero padding");
}

0 comments on commit 54e8dea

Please sign in to comment.