Skip to content

Commit

Permalink
[IR] Support max/min functions (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 authored Oct 18, 2024
1 parent 17bfa47 commit eeed759
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 0 deletions.
19 changes: 19 additions & 0 deletions allo/ir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,6 +1674,25 @@ def build_Call(ctx, node):
node.args[0].dtype,
Int(32) if node.func.id == "int" else Float(32),
)

if node.func.id in {"min", "max"}:
stmts = build_stmts(ctx, node.args)
if isinstance(node.dtype, Float):
opcls = {
"min": arith_d.MinimumFOp,
"max": arith_d.MaximumFOp,
}.get(node.func.id)
elif isinstance(node.dtype, Int):
opcls = {
"min": arith_d.MinSIOp,
"max": arith_d.MaxSIOp,
}.get(node.func.id)
elif isinstance(node.dtype, UInt):
opcls = {
"min": arith_d.MinUIOp,
"max": arith_d.MaxUIOp,
}.get(node.func.id)
return opcls(stmts[0].result, stmts[1].result, ip=ctx.get_ip())
raise RuntimeError(f"Cannot resolve function `{node.func.id}`")

if (
Expand Down
8 changes: 8 additions & 0 deletions allo/ir/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,14 @@ def visit_Call(ctx, node):
new_args = visit_stmts(ctx, node.args)
node.shape = tuple()
node.dtype = float32 if node.func.id == "float" else int32
elif node.func.id in {"min", "max"}:
# Python-Builtin functions
assert (
len(node.args) == 2
), "Only support two arguments for `min` and `max`"
new_args = visit_stmts(ctx, node.args)
node.shape = new_args[0].shape
node.dtype = new_args[0].dtype
else:
raise RuntimeError(f"Unsupported function call {node.func.id}")
return node
Expand Down
3 changes: 3 additions & 0 deletions allo/ir/use_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def visit_Call(self, node):
if node.func.id in {"float", "int"}:
# Python-Builtin functions
return list(self.visit(node.args[0]))
if node.func.id in {"min", "max"}:
assert len(node.args) == 2, "min/max only support two arguments"
return list(self.visit(node.args[0])) + list(self.visit(node.args[1]))
raise RuntimeError(f"Unsupported function call {node.func.id}")

if obj.__module__.startswith("allo") and not obj.__module__.startswith(
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/allo/Dialect/Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class HLSCppVisitorBase {
arith::CmpIOp, arith::AddIOp, arith::SubIOp, arith::MulIOp,
arith::DivSIOp, arith::RemSIOp, arith::DivUIOp, arith::RemUIOp,
arith::MaxSIOp, arith::MinSIOp, arith::MaxUIOp, arith::MinUIOp,
arith::MaximumFOp, arith::MinimumFOp,
// Logical expressions.
arith::XOrIOp, arith::AndIOp, arith::OrIOp, arith::ShLIOp,
arith::ShRSIOp, arith::ShRUIOp, allo::GetIntBitOp, allo::SetIntBitOp,
Expand Down Expand Up @@ -185,6 +186,8 @@ class HLSCppVisitorBase {
HANDLE(arith::MinSIOp);
HANDLE(arith::MaxUIOp);
HANDLE(arith::MinUIOp);
HANDLE(arith::MaximumFOp);
HANDLE(arith::MinimumFOp);

// Bit operations.
HANDLE(arith::XOrIOp);
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Translation/EmitVivadoHLS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,12 @@ class ExprVisitor : public HLSCppVisitorBase<ExprVisitor, bool> {
bool visitOp(arith::MinUIOp op) {
return emitter.emitMaxMin(op, "min"), true;
}
bool visitOp(arith::MaximumFOp op) {
return emitter.emitMaxMin(op, "max"), true;
}
bool visitOp(arith::MinimumFOp op) {
return emitter.emitMaxMin(op, "min"), true;
}

/// Logical expressions.
bool visitOp(arith::XOrIOp op) { return emitter.emitBinary(op, "^"), true; }
Expand Down
27 changes: 27 additions & 0 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,5 +687,32 @@ def kernel(A: float32[10], B: float32[10]) -> (float32[10], float32[10]):
np.testing.assert_allclose(np_D, np_D_ref)


@pytest.mark.parametrize("T", [int8, int32, float32])
def test_minmax(T):
def kernel(A: T[10]) -> (T[2], T[2]):
min_val: T[2] = 0x3F3F3F3F
max_val: T[2] = -0x3F3F3F3F
for i in range(10):
min_val[0] = min(min_val[0], A[i])
max_val[0] = max(max_val[0], A[i])
return min_val, max_val

s = allo.customize(kernel)
print(s.module)
mod = s.build()
if T == int8:
np_A = np.random.randint(-64, 64, size=(10,)).astype(np.int8)
elif T == int32:
np_A = np.random.randint(-1000, 1000, size=(10,)).astype(np.int32)
elif T == float32:
np_A = np.random.random((10,)).astype(np.float32)
allo_min, allo_max = mod(np_A)
assert allo_min[0] == np.min(np_A)
assert allo_max[0] == np.max(np_A)
mod = s.build(target="vhls")
assert "min" in mod.hls_code
assert "max" in mod.hls_code


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit eeed759

Please sign in to comment.