Skip to content

Commit

Permalink
Add more libdevice lowerings (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
int3 authored Aug 9, 2024
1 parent 9035f4b commit 5662d4c
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 3 deletions.
10 changes: 9 additions & 1 deletion python/src/ir.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <pybind11/functional.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Expand Down Expand Up @@ -1425,6 +1425,10 @@ void init_triton_ir(py::module &&m) {
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::Exp2Op>(val);
})
.def("create_expm1",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::ExpM1Op>(val);
})
.def("create_cos",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::CosOp>(val);
Expand Down Expand Up @@ -1477,6 +1481,10 @@ void init_triton_ir(py::module &&m) {
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::LogOp>(val);
})
.def("create_log1p",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::Log1pOp>(val);
})
.def("create_log2",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::Log2Op>(val);
Expand Down
8 changes: 6 additions & 2 deletions python/test/unit/cpu/test_libdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@ def is_cpu():

@pytest.mark.parametrize("dtype_str", float_dtypes)
@pytest.mark.parametrize("math_fn", [
"acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "cos", "cosh", "erf", "exp", "exp2", "log", "log2",
"log10", "sin", "sinh", "tan", "tanh"
"acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "cos", "cosh", "erf", "exp", "exp2", "expm1", "floor",
"log", "log1p", "log2", "log10", "rsqrt", "sin", "sinh", "sqrt", "tan", "tanh"
])
@pytest.mark.parametrize("size", [1, 4, 16, 64])
def test_libdevice(dtype_str, math_fn, size, device):
if not is_cpu():
pytest.skip("This test is CPU-specific")

if dtype_str == "bfloat16":
if math_fn == "floor" or math_fn == "rsqrt":
pytest.skip("libgcc < 13 does not define __truncsfbf2, which this op needs")

@triton.jit
def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr):
idxs = tl.arange(0, BLOCK_SIZE)
Expand Down
25 changes: 25 additions & 0 deletions python/triton/language/extra/cpu/libdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ def exp2(arg0, _builder=None):
return core.tensor(_builder.create_exp2(arg0.handle), arg0.type)


@core.extern
def expm1(arg0, _builder=None):
return core.tensor(_builder.create_expm1(arg0.handle), arg0.type)


@core.extern
def floor(arg0, _builder=None):
return core.tensor(_builder.create_floor(arg0.handle), arg0.type)


@core.extern
def log(arg0, _builder=None):
return core.tensor(_builder.create_log(arg0.handle), arg0.type)
Expand All @@ -76,11 +86,26 @@ def log10(arg0, _builder=None):
return core.tensor(_builder.create_log10(arg0.handle), arg0.type)


@core.extern
def log1p(arg0, _builder=None):
return core.tensor(_builder.create_log1p(arg0.handle), arg0.type)


@core.extern
def sin(arg0, _builder=None):
return core.tensor(_builder.create_sin(arg0.handle), arg0.type)


@core.extern
def rsqrt(arg0, _builder=None):
return core.tensor(_builder.create_rsqrt(arg0.handle), arg0.type)


@core.extern
def sqrt(arg0, _builder=None):
return core.tensor(_builder.create_sqrt(arg0.handle), arg0.type)


@core.extern
def sinh(arg0, _builder=None):
return core.tensor(_builder.create_sinh(arg0.handle), arg0.type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,11 @@ struct ConvertElementwiseOps
patterns.add<OpTypeConversion<math::AbsIOp>>(typeConverter, context);
patterns.add<OpTypeConversion<math::ExpOp>>(typeConverter, context);
patterns.add<OpTypeConversion<math::Exp2Op>>(typeConverter, context);
patterns.add<OpTypeConversion<math::ExpM1Op>>(typeConverter, context);
patterns.add<OpTypeConversion<math::LogOp>>(typeConverter, context);
patterns.add<OpTypeConversion<math::Log2Op>>(typeConverter, context);
patterns.add<OpTypeConversion<math::Log10Op>>(typeConverter, context);
patterns.add<OpTypeConversion<math::Log1pOp>>(typeConverter, context);
patterns.add<OpTypeConversion<math::SinOp>>(typeConverter, context);
patterns.add<OpTypeConversion<math::SinhOp>>(typeConverter, context);
patterns.add<OpTypeConversion<math::CosOp>>(typeConverter, context);
Expand Down

0 comments on commit 5662d4c

Please sign in to comment.