forked from triton-lang/triton-cpu
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Utility libxsmm Python extension (triton-lang#17)
Adds a python wrapper for a parallelized in-place copy function using libxsmm and OpenMP. It is intended to be used for efficient tensor padding implementation. The libxsmm path have to be specified through env variables: - XSMM_ROOT_DIR - path to libxsmm root dir with headers - XSMM_LIB_DIR - path to libxsmm.so location libxsmm .so also has to be available during runtime execution e.g., exposed through LD_LIBRARY_PATH. The XSMM python module can be built and installed using command: pip install -e ./third_party/cpu/python/
- Loading branch information
Showing
3 changed files
with
84 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from setuptools import setup, Extension | ||
from torch.utils import cpp_extension | ||
import os | ||
|
||
xsmm_root = os.getenv("XSMM_ROOT_DIR") | ||
xsmm_lib = os.getenv("XSMM_LIB_DIR") | ||
print(f'Using LIBXSMM root: {xsmm_root}') | ||
print(f'LIBXSMM lib location: {xsmm_lib}') | ||
|
||
setup(name='xsmm_py', | ||
ext_modules=[ | ||
cpp_extension.CppExtension('xsmm_py', ['xsmm_utils.cpp'], | ||
include_dirs=[ | ||
f'{xsmm_root}/include', | ||
f'{xsmm_root}/src/template' | ||
], | ||
library_dirs=[f'{xsmm_lib}'], | ||
libraries=['xsmm', 'omp'], | ||
extra_compile_args=['-fopenmp'] | ||
)], | ||
cmdclass={'build_ext': cpp_extension.BuildExtension}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#include <torch/extension.h> | ||
|
||
#include "libxsmm.h" | ||
#include <omp.h> | ||
|
||
#include <cstring> | ||
|
||
void fastZeroPad2D(const at::Tensor &input, torch::Tensor &output) { | ||
const auto inSizes = input.sizes(); | ||
const auto outSizes = output.sizes(); | ||
const auto byteSize = input.element_size(); | ||
assert(input.is_floating_point() && inSizes.size() == 2 || | ||
outSizes.size() == 2 && outSizes[0] >= inSizes[0] && | ||
outSizes[1] >= inSizes[1] && byteSize == output.element_size() && | ||
"Invalid fastZeroPad2D tensors"); | ||
|
||
libxsmm_datatype dtype = | ||
byteSize == 4 ? LIBXSMM_DATATYPE_F32 : LIBXSMM_DATATYPE_BF16; | ||
libxsmm_meltw_unary_shape shape; | ||
// Fliped to libxsmm's column-major convention. | ||
shape.m = inSizes[1]; | ||
shape.n = 1; | ||
shape.ldi = inSizes[1]; | ||
shape.ldo = outSizes[1]; | ||
shape.in0_type = dtype; | ||
shape.out_type = dtype; | ||
shape.comp_type = dtype; | ||
libxsmm_bitfield flags = LIBXSMM_MELTW_FLAG_UNARY_NONE; | ||
libxsmm_meltwfunction_unary identityFn = libxsmm_dispatch_meltw_unary( | ||
LIBXSMM_MELTW_TYPE_UNARY_IDENTITY, shape, flags); | ||
|
||
void *baseIn = input.data_ptr(); | ||
void *outIn = output.data_ptr(); | ||
const int padRight = outSizes[1] - inSizes[1]; | ||
|
||
#pragma omp parallel for schedule(static) | ||
for (int i = 0; i < inSizes[0]; ++i) { | ||
libxsmm_meltw_unary_param param; | ||
param.in.primary = baseIn + i * inSizes[1] * byteSize; | ||
param.out.primary = outIn + i * outSizes[1] * byteSize; | ||
identityFn(¶m); | ||
// Zero out right padding. | ||
std::memset(outIn + i * outSizes[1] * byteSize + inSizes[1] * byteSize, 0, | ||
byteSize * padRight); | ||
} | ||
|
||
// Zero out bottom padding. | ||
#pragma omp parallel for schedule(static) | ||
for (int i = inSizes[0]; i < outSizes[0]; ++i) { | ||
std::memset(outIn + i * outSizes[1] * byteSize, 0, byteSize * outSizes[1]); | ||
} | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("fastZeroPad2D", &fastZeroPad2D, "Fast 2D tensor zero padding"); | ||
} |