diff --git a/allo/ir/builder.py b/allo/ir/builder.py index 1a288a66..7bc7f768 100644 --- a/allo/ir/builder.py +++ b/allo/ir/builder.py @@ -128,7 +128,10 @@ def build_shaped_type(ctx, dtype, shape, layout=None): def build_array(ctx, dtype, shape): if not ctx.enable_tensor: memref_type = MemRefType.get(shape, dtype.build()) - return memref_d.AllocOp(memref_type, [], [], ip=ctx.get_ip()) + alloc_op = memref_d.AllocOp(memref_type, [], [], ip=ctx.get_ip()) + if isinstance(dtype, UInt): + alloc_op.attributes["unsigned"] = UnitAttr.get() + return alloc_op return tensor_d.EmptyOp(shape, dtype.build(), ip=ctx.get_ip()) @staticmethod @@ -690,10 +693,14 @@ def build_general_binop(ctx, node, lhs, rhs): if isinstance(node.op, (ast.LShift, ast.RShift)) and isinstance( node.dtype, (Fixed, UFixed) ): - return opcls[ty_cls]( + op = opcls[ty_cls]( node.dtype.build(), lhs.result, rhs.result, ip=ctx.get_ip() ) - return opcls[ty_cls](lhs.result, rhs.result, ip=ctx.get_ip()) + else: + op = opcls[ty_cls](lhs.result, rhs.result, ip=ctx.get_ip()) + if isinstance(node.dtype, UInt): + op.attributes["unsigned"] = UnitAttr.get() + return op @staticmethod def build_UnaryOp(ctx, node): @@ -1122,6 +1129,8 @@ def build_memory_access(ctx, node, val=None, idx=0): affine_attr, ip=ctx.get_ip(), ) + if isinstance(node.value.dtype, UInt): + op.attributes["unsigned"] = UnitAttr.get() else: # ast.Store op = affine_d.AffineStoreOp( val.results[idx], value.result, ivs, affine_attr, ip=ctx.get_ip() @@ -1131,6 +1140,8 @@ def build_memory_access(ctx, node, val=None, idx=0): if isinstance(node.ctx, ast.Load): # pylint: disable=redefined-variable-type op = memref_d.LoadOp(value.result, new_indices, ip=ctx.get_ip()) + if isinstance(node.value.dtype, UInt): + op.attributes["unsigned"] = UnitAttr.get() else: # ast.Store op = memref_d.StoreOp( val.result, diff --git a/mlir/lib/Translation/EmitVivadoHLS.cpp b/mlir/lib/Translation/EmitVivadoHLS.cpp index f21e6bf3..fd56deb8 100644 --- a/mlir/lib/Translation/EmitVivadoHLS.cpp +++ b/mlir/lib/Translation/EmitVivadoHLS.cpp @@ -1768,7 +1768,9 @@ void ModuleEmitter::emitBitcast(arith::BitcastOp op) { template void ModuleEmitter::emitCast(CastOpType op) { indent(); - emitValue(op.getResult()); + Value result = op.getResult(); + fixUnsignedType(result, op->hasAttr("unsigned")); + emitValue(result); os << " = "; emitValue(op.getOperand()); os << ";"; diff --git a/tests/test_types.py b/tests/test_types.py index ac1287f0..0c50a3a0 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -43,6 +43,21 @@ def kernel(a: int32) -> float32: assert mod(1) == kernel(1) +def test_uint(): + def casting(): + buf1: UInt(17)[16, 16] = 0 + buf2: float32[16, 16] + + for i, j in allo.grid(16, 16): + buf2[i, j] = float(buf1[i, j] + buf1[j, i]) + + s = allo.customize(casting) + mod = s.build(target="vhls") + code = mod.hls_code + assert "ap_uint<17>" in code and "ap_uint<18>" in code + assert "ap_int<" not in code + + def test_index_fixed_casting(): def test_one_cast(fixed): def kernel(a: index) -> float32: