Skip to content

Commit

Permalink
[Builder] Add index from/to Float/Fixed/UFixed type casting (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzDavid authored Nov 3, 2024
1 parent a84f31f commit d61322e
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 5 deletions.
19 changes: 14 additions & 5 deletions allo/ir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
RankedTensorType,
ShapedType,
IntegerType,
IndexType,
F32Type,
UnitAttr,
IntegerAttr,
Expand Down Expand Up @@ -356,6 +355,7 @@ def build_cast_op(ctx, op, src_type, res_type, shape=None):
if type(res_type) is type(src_type) and res_type == src_type:
return op

# Single-step type conversions
cast_map = {
# Index <-> UInt/Int
(Int, Index): arith_d.IndexCastOp,
Expand All @@ -367,9 +367,6 @@ def build_cast_op(ctx, op, src_type, res_type, shape=None):
(UInt, Float): arith_d.UIToFPOp,
(Float, Int): arith_d.FPToSIOp,
(Float, UInt): arith_d.FPToUIOp,
# FP to Index is not supported in MLIR
# (Float, Index): RuntimeError,
# (Index, Float): RuntimeError,
# Float <-> Fixed/UFixed
(Float, Fixed): allo_d.FloatToFixedOp,
(Float, UFixed): allo_d.FloatToFixedOp,
Expand All @@ -395,13 +392,25 @@ def build_cast_op(ctx, op, src_type, res_type, shape=None):
elif isinstance(src_type, Float) and isinstance(res_type, Index):
# FP to Index is not supported in MLIR
# we need to cast to UInt first, then cast to Index
op = arith_d.FPToUIOp(IndexType.get(), op.result, ip=ctx.get_ip())
op = arith_d.FPToUIOp(
IntegerType.get_signless(32), op.result, ip=ctx.get_ip()
)
opcls = arith_d.IndexCastOp # proceed to build cast to index
elif isinstance(src_type, Index) and isinstance(res_type, Float):
op = arith_d.IndexCastOp(
IntegerType.get_signless(32), op.result, ip=ctx.get_ip()
)
opcls = arith_d.SIToFPOp # proceed to build cast to float
elif isinstance(src_type, Index) and isinstance(res_type, (Fixed, UFixed)):
op = arith_d.IndexCastOp(
IntegerType.get_signless(32), op.result, ip=ctx.get_ip()
)
opcls = allo_d.IntToFixedOp # proceed to build cast to float
elif isinstance(src_type, (Fixed, UFixed)) and isinstance(res_type, Index):
op = allo_d.FixedToIntOp(
IntegerType.get_signless(32), op.result, ip=ctx.get_ip()
)
opcls = arith_d.IndexCastOp
elif isinstance(src_type, (Int, UInt)) and isinstance(res_type, (Int, UInt)):
if src_type.bits > res_type.bits:
opcls = arith_d.TruncIOp
Expand Down
63 changes: 63 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
uint1,
int32,
float32,
index,
)
import allo.ir.types as T

Expand Down Expand Up @@ -42,6 +43,68 @@ def kernel(a: int32) -> float32:
assert mod(1) == kernel(1)


def test_index_fixed_casting():
def test_one_cast(fixed):
def kernel(a: index) -> float32:
sum_val: fixed = a # casting
ret_val: float32 = 0.0
for step in range(10):
sum_val += step # casting
ret_val = sum_val
return ret_val

s = allo.customize(kernel)
mod = s.build()
assert mod(1) == kernel(1)

for i in range(1, 20):
test_one_cast(Fixed(32, i))
test_one_cast(UFixed(32, i))


def test_fixed_index_casting():

def test_one_cast(fixed):
def kernel(a: float32) -> int32:
a_fix: fixed = a
a_idx: index = a_fix
b: index = 2
ret: int32 = a_idx + b
return ret

s = allo.customize(kernel)
mod = s.build()
assert mod(2.0) == kernel(2.0)

for i in range(1, 20):
test_one_cast(Fixed(32, i))
test_one_cast(UFixed(32, i))


def test_index_float_casting():
def kernel(a: index) -> float32:
sum_val: float32 = a # casting
for step in range(10):
sum_val += step # casting
return sum_val

s = allo.customize(kernel)
mod = s.build()
assert mod(1) == kernel(1)


def test_float_index_casting():
def kernel(a: float32) -> int32:
a_idx: index = a # casting
b: index = 2
ret: int32 = a_idx + b
return ret

s = allo.customize(kernel)
mod = s.build()
assert mod(2.0) == kernel(2.0)


def test_large_bitwidth():
def kernel(a: Int(65536)[1], b: Int(345)[1], c: Int(65536)[1]):
c[0] = a[0] + b[0]
Expand Down

0 comments on commit d61322e

Please sign in to comment.