From b0ae885365208c67714ccdde3afad149b7240293 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 12 Nov 2024 13:46:05 +0100 Subject: [PATCH] Utility libxsmm Python extension (#17) Adds a python wrapper for a parallelized in-place copy function using libxsmm and OpenMP. It is intended to be used for efficient tensor padding implementation. The libxsmm path have to be specified through env variables: - XSMM_ROOT_DIR - path to libxsmm root dir with headers - XSMM_LIB_DIR - path to libxsmm.so location libxsmm .so also has to be available during runtime execution e.g., exposed through LD_LIBRARY_PATH. The XSMM python module can be built and installed using command: pip install -e ./third_party/cpu/python/ --- .../tutorials/03-matrix-multiplication-cpu.py | 14 ++--- third_party/cpu/python/setup.py | 21 +++++++ third_party/cpu/python/xsmm_utils.cpp | 56 +++++++++++++++++++ 3 files changed, 84 insertions(+), 7 deletions(-) create mode 100644 third_party/cpu/python/setup.py create mode 100644 third_party/cpu/python/xsmm_utils.cpp diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index 480577888704..702b6f4886cf 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -150,7 +150,6 @@ # ------------ import torch -import math import triton import triton.language as tl @@ -175,6 +174,7 @@ DYNAMIC_K_BLOCK = False CACHE_PADDING = False PREPROCESS_EXTERNAL = False +XSMM_PAD = False @triton.jit def pad_kernel(in_ptr, out_ptr, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, PADDING: tl.constexpr): @@ -316,12 +316,12 @@ def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, n pad_kernel[(K // BLOCK_SIZE_K, )](b, b_scratch, N, BLOCK_SIZE_K, BLOCK_SIZE_N, 32, num_threads=num_threads) b = b_scratch - # TODO: Check if padding is needed at all. - # Currently, cache padding is most useful together with dynamic K blocking - # to ensure that stride is non-power-of-two to improve cache behavior. - if CACHE_PADDING: - a = torch.nn.functional.pad(a, (0, 32, 0, 0), mode='constant', value=0) - b = torch.nn.functional.pad(b, (0, 32, 0, 0), mode='constant', value=0) + # TODO: Check if padding is needed at all. + # Currently, cache padding is most useful together with dynamic K blocking + # to ensure that stride is non-power-of-two to improve cache behavior. + if CACHE_PADDING: + a = torch.nn.functional.pad(a, (0, 32, 0, 0), mode='constant', value=0) + b = torch.nn.functional.pad(b, (0, 32, 0, 0), mode='constant', value=0) #TODO: Currently masked load is not supported yet. assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( diff --git a/third_party/cpu/python/setup.py b/third_party/cpu/python/setup.py new file mode 100644 index 000000000000..c3c963b92daf --- /dev/null +++ b/third_party/cpu/python/setup.py @@ -0,0 +1,21 @@ +from setuptools import setup, Extension +from torch.utils import cpp_extension +import os + +xsmm_root = os.getenv("XSMM_ROOT_DIR") +xsmm_lib = os.getenv("XSMM_LIB_DIR") +print(f'Using LIBXSMM root: {xsmm_root}') +print(f'LIBXSMM lib location: {xsmm_lib}') + +setup(name='xsmm_py', + ext_modules=[ + cpp_extension.CppExtension('xsmm_py', ['xsmm_utils.cpp'], + include_dirs=[ + f'{xsmm_root}/include', + f'{xsmm_root}/src/template' + ], + library_dirs=[f'{xsmm_lib}'], + libraries=['xsmm', 'omp'], + extra_compile_args=['-fopenmp'] + )], + cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/third_party/cpu/python/xsmm_utils.cpp b/third_party/cpu/python/xsmm_utils.cpp new file mode 100644 index 000000000000..a0e70c5ac15f --- /dev/null +++ b/third_party/cpu/python/xsmm_utils.cpp @@ -0,0 +1,56 @@ +#include + +#include "libxsmm.h" +#include + +#include + +void fastZeroPad2D(const at::Tensor &input, torch::Tensor &output) { + const auto inSizes = input.sizes(); + const auto outSizes = output.sizes(); + const auto byteSize = input.element_size(); + assert(input.is_floating_point() && inSizes.size() == 2 || + outSizes.size() == 2 && outSizes[0] >= inSizes[0] && + outSizes[1] >= inSizes[1] && byteSize == output.element_size() && + "Invalid fastZeroPad2D tensors"); + + libxsmm_datatype dtype = + byteSize == 4 ? LIBXSMM_DATATYPE_F32 : LIBXSMM_DATATYPE_BF16; + libxsmm_meltw_unary_shape shape; + // Fliped to libxsmm's column-major convention. + shape.m = inSizes[1]; + shape.n = 1; + shape.ldi = inSizes[1]; + shape.ldo = outSizes[1]; + shape.in0_type = dtype; + shape.out_type = dtype; + shape.comp_type = dtype; + libxsmm_bitfield flags = LIBXSMM_MELTW_FLAG_UNARY_NONE; + libxsmm_meltwfunction_unary identityFn = libxsmm_dispatch_meltw_unary( + LIBXSMM_MELTW_TYPE_UNARY_IDENTITY, shape, flags); + + void *baseIn = input.data_ptr(); + void *outIn = output.data_ptr(); + const int padRight = outSizes[1] - inSizes[1]; + +#pragma omp parallel for schedule(static) + for (int i = 0; i < inSizes[0]; ++i) { + libxsmm_meltw_unary_param param; + param.in.primary = baseIn + i * inSizes[1] * byteSize; + param.out.primary = outIn + i * outSizes[1] * byteSize; + identityFn(¶m); + // Zero out right padding. + std::memset(outIn + i * outSizes[1] * byteSize + inSizes[1] * byteSize, 0, + byteSize * padRight); + } + + // Zero out bottom padding. +#pragma omp parallel for schedule(static) + for (int i = inSizes[0]; i < outSizes[0]; ++i) { + std::memset(outIn + i * outSizes[1] * byteSize, 0, byteSize * outSizes[1]); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fastZeroPad2D", &fastZeroPad2D, "Fast 2D tensor zero padding"); +}