diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 9b1478dfd1f4..3055e91f8c0a 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -582,6 +582,24 @@ void init_triton_llvm(py::module &&m) { if (f.second) res.insert(f.first().str()); } + + // Likely something went wrong with the LLVM feature detection. + if (!res.size()) { + std::string triple = llvm::sys::getProcessTriple(); + // e.g. arm64-apple-darwin24.1.0 + // ^^^^^ + std::size_t pos = triple.find('-'); + if (pos == std::string::npos) { + return res; + } + + std::string arch = triple.substr(0, pos); + if (arch == "aarch64" || arch == "arm64") { + // Safe because NEON is a mandatory feature for aarch64. + res.insert("neon"); // For math tests + } + } + return res; }); } diff --git a/python/test/unit/cpu/test_math.py b/python/test/unit/cpu/test_math.py index 958913e7f9f1..1fd443db967a 100644 --- a/python/test/unit/cpu/test_math.py +++ b/python/test/unit/cpu/test_math.py @@ -5,10 +5,23 @@ import triton import triton.language as tl +from triton._C.libtriton import llvm from triton.language.extra import libdevice from itertools import chain, product +def get_native_vector_size_in_bits(): + """ + Returns the native vector size of the CPU. + Assuming x86 always uses "auto dispatch" with 512-bit vectors for Sleef. + """ + cpu_features = llvm.get_cpu_features() + # TODO support for arm sve w/ VLA + if "neon" in cpu_features: + return 128 + return 512 + + def is_interpreter(): return os.environ.get('TRITON_INTERPRET', '0') == '1' @@ -34,9 +47,13 @@ def check_num_vec_calls(meta, vec_lib, dtype_str, size, is_always_extern=False): # FP16 and BF16 are cast to FP32 for math ops elem_size = 8 if dtype_str == "float64" else 4 data_size = size * elem_size - if data_size > 64: - num_vec_calls = data_size // 64 - elif data_size >= 16: + + vec_size = get_native_vector_size_in_bits() / 8 # bytes + # 128-bit vector is the smallest supported by Sleef for both x86 and arm + smallest_vec_size = 128 / 8 # bytes + if data_size > vec_size: + num_vec_calls = data_size // vec_size + elif data_size >= smallest_vec_size: num_vec_calls = 1 else: num_vec_calls = 1 if is_always_extern else 0 diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 5aabcc051b91..c4b5e6ecd918 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -162,7 +162,8 @@ def make_tttcir(self, mod, metadata, opt): pm.enable_debug() cpu.passes.ttcpuir.add_optimize_masks(pm) passes.common.add_canonicalizer(pm) - convert_bf16_dot_product = self.cpu_arch == "aarch64" and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features + convert_bf16_dot_product = ((self.cpu_arch == "aarch64" or self.cpu_arch == "armv8") + and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features) if convert_bf16_dot_product: use_horizontal_sum = os.getenv("TRITON_CPU_DOT_PROD_HORIZ_SUM", "1") == "1" cpu.passes.ttcpuir.add_convert_dot_product(pm, use_horizontal_sum) @@ -215,7 +216,7 @@ def make_llir(self, src, metadata, options): VecLib.libmvec: {"avx512f"}, } if (vec_lib := options.get_vec_lib()) and vec_lib_requirements[vec_lib] & self.cpu_features: - cpu.passes.ttcpuir.add_math_to_vec_lib(pm, vec_lib) + cpu.passes.ttcpuir.add_math_to_vec_lib(pm, vec_lib, self.cpu_features) passes.convert.add_math_to_llvmir(pm) cpu.passes.ttcpuir.add_math_to_libm(pm) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h index 6e9892d00206..cc29821c580c 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -32,7 +32,8 @@ std::unique_ptr> createLowerMultiReductionPass(); std::unique_ptr> createAtomicOpsToLLVMPass(); std::unique_ptr> createDebugOpsToLLVMPass(); std::unique_ptr> -createMathToVecLibPass(VecLib lib = VecLib::Sleef); +createMathToVecLibPass(VecLib lib = VecLib::Sleef, + std::set cpu_features = {}); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUToLLVM/Passes.h.inc" diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp index b68d5a7473d0..2b1877c1c17b 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp @@ -56,14 +56,17 @@ template struct VecOpToFp32 : public OpRewritePattern { }; // Decompose vector operation to single-dimensional vector operations -// with a native AVX512 vector size. +// with a AVX512 for x86 or NEON for ARM. template struct DecomposeToNativeVecs : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + // CPU SIMD vector size in bits + size_t vec_bits; - DecomposeToNativeVecs(MLIRContext *context) - : OpRewritePattern(context) {} + DecomposeToNativeVecs(MLIRContext *context, + size_t native_vec_size_in_bits = 512) + : OpRewritePattern(context), vec_bits(native_vec_size_in_bits) {} LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); @@ -83,7 +86,7 @@ struct DecomposeToNativeVecs : public OpRewritePattern { // vector size. auto shape = vecTy.getShape(); SmallVector newShape(1, 1); - int64_t elemsPerVec = 512 / elemTy.getIntOrFloatBitWidth(); + int64_t elemsPerVec = vec_bits / elemTy.getIntOrFloatBitWidth(); for (int64_t i = shape.size() - 1; i >= 0; --i) { int64_t size = shape[i]; if (newShape.size() > 1) { @@ -330,9 +333,11 @@ struct ExternElementwiseOpConversion template void populatePatternsForOp(RewritePatternSet &patterns, - GetVecFnNameFn getVecFnName) { + GetVecFnNameFn getVecFnName, + size_t vec_size_in_bits = 512) { patterns.add>(patterns.getContext()); - patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext(), + vec_size_in_bits); patterns.add>(patterns.getContext(), getVecFnName); } @@ -340,8 +345,27 @@ void populatePatternsForOp(RewritePatternSet &patterns, struct MathToVecLibPass : public mlir::triton::cpu::impl::MathToVecLibBase { MathToVecLibPass() = default; + size_t vec_size_in_bits; - explicit MathToVecLibPass(VecLib lib) { this->lib = lib; } + explicit MathToVecLibPass(VecLib lib, std::set cpu_features) { + this->lib = lib; + update_vec_size(cpu_features); + } + + void update_vec_size(std::set &cpu_features) { + // TODO: + // Refactor this as an independent function. + // And improve this to support other x86 SIMD ISAs and also for arm SVE + // (VLA) + vec_size_in_bits = 512; + for (auto feature : cpu_features) { + // Arm NEON is fixed 128-bit SIMD ISA. + if (feature == "neon") { + vec_size_in_bits = 128; + break; + } + } + } void runOnOperation() override { Operation *op = getOperation(); @@ -356,20 +380,20 @@ struct MathToVecLibPass } case VecLib::Sleef: { populateCommonPatterns(patterns); - populatePatternsForOp(patterns, - SleefNameGenerator("expm1")); + populatePatternsForOp( + patterns, SleefNameGenerator("expm1"), vec_size_in_bits); populatePatternsForOp( - patterns, SleefNameGenerator("floor", /*ulp=*/0)); + patterns, SleefNameGenerator("floor", /*ulp=*/0), vec_size_in_bits); populatePatternsForOp( - patterns, SleefNameGenerator("sqrt", /*ulp=*/5)); + patterns, SleefNameGenerator("sqrt", /*ulp=*/5), vec_size_in_bits); populatePatternsForOp( - patterns, SleefNameGenerator("trunc", /*ulp=*/0)); + patterns, SleefNameGenerator("trunc", /*ulp=*/0), vec_size_in_bits); break; } } patterns.add>( - patterns.getContext()); + patterns.getContext(), vec_size_in_bits); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); @@ -379,26 +403,46 @@ struct MathToVecLibPass template void populateCommonPatterns(RewritePatternSet &patterns) const { - populatePatternsForOp(patterns, VecFnNameGenerator("acos")); - populatePatternsForOp(patterns, VecFnNameGenerator("acosh")); - populatePatternsForOp(patterns, VecFnNameGenerator("asin")); - populatePatternsForOp(patterns, VecFnNameGenerator("asinh")); - populatePatternsForOp(patterns, VecFnNameGenerator("atan")); - populatePatternsForOp(patterns, VecFnNameGenerator("atanh")); - populatePatternsForOp(patterns, VecFnNameGenerator("cbrt")); - populatePatternsForOp(patterns, VecFnNameGenerator("cos")); - populatePatternsForOp(patterns, VecFnNameGenerator("cosh")); - populatePatternsForOp(patterns, VecFnNameGenerator("erf")); - populatePatternsForOp(patterns, VecFnNameGenerator("exp")); - populatePatternsForOp(patterns, VecFnNameGenerator("exp2")); - populatePatternsForOp(patterns, VecFnNameGenerator("log")); - populatePatternsForOp(patterns, VecFnNameGenerator("log2")); - populatePatternsForOp(patterns, VecFnNameGenerator("log10")); - populatePatternsForOp(patterns, VecFnNameGenerator("log1p")); - populatePatternsForOp(patterns, VecFnNameGenerator("sin")); - populatePatternsForOp(patterns, VecFnNameGenerator("sinh")); - populatePatternsForOp(patterns, VecFnNameGenerator("tan")); - populatePatternsForOp(patterns, VecFnNameGenerator("tanh")); + populatePatternsForOp(patterns, VecFnNameGenerator("acos"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("acosh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("asin"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("asinh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("atan"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("atanh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("cbrt"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("cos"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("cosh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("erf"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("exp"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("exp2"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("log"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("log2"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("log10"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("log1p"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("sin"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("sinh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("tan"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("tanh"), + vec_size_in_bits); } }; @@ -408,8 +452,9 @@ namespace mlir { namespace triton { namespace cpu { -std::unique_ptr> createMathToVecLibPass(VecLib lib) { - return std::make_unique(lib); +std::unique_ptr> +createMathToVecLibPass(VecLib lib, std::set cpu_features) { + return std::make_unique(lib, cpu_features); } } // namespace cpu diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 3c3555d3c983..a412190bbcf8 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -145,8 +145,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); }); - m.def("add_math_to_vec_lib", [](mlir::PassManager &pm, cpu::VecLib lib) { - pm.addPass(mlir::triton::cpu::createMathToVecLibPass(lib)); + m.def("add_math_to_vec_lib", [](mlir::PassManager &pm, cpu::VecLib lib, + std::set cpu_features) { + pm.addPass(mlir::triton::cpu::createMathToVecLibPass(lib, cpu_features)); }); m.def("add_math_to_libm", [](mlir::PassManager &pm) { pm.addPass(mlir::createConvertMathToLibmPass());