Skip to content

Commit

Permalink
[IR] Implement scalar as rank-0 memref (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 authored Oct 19, 2024
1 parent eeed759 commit f183d50
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 58 deletions.
5 changes: 1 addition & 4 deletions allo/ir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def build_Name(ctx, node, val=None):
buffer.op.result if isinstance(buffer, MockScalar) else buffer.result
)
if not ctx.enable_tensor:
affine_map = AffineMap.get(
dim_count=0, symbol_count=0, exprs=[AffineConstantExpr.get(0)]
)
affine_map = AffineMap.get(dim_count=0, symbol_count=0, exprs=[])
affine_attr = AffineMapAttr.get(affine_map)
store_op = affine_d.AffineStoreOp(
val.result, target, [], affine_attr, ip=ctx.get_ip()
Expand Down Expand Up @@ -1246,7 +1244,6 @@ def build_AnnAssign(ctx, node):
elif isinstance(node.dtype, Stream):
ctx.buffers[node.target.id] = rhs
else:
# TODO: figure out why zero-ranked cannot work
ctx.buffers[node.target.id] = MockScalar(
node.target.id,
node.dtype,
Expand Down
8 changes: 2 additions & 6 deletions allo/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
IntegerAttr,
FloatAttr,
StringAttr,
AffineConstantExpr,
AffineMap,
AffineMapAttr,
)
Expand Down Expand Up @@ -188,11 +187,10 @@ def __init__(self, name, dtype, ctx, value=None):
self.name = name
self.ctx = ctx
self.value = value
shape = (1,)
assert isinstance(dtype, AlloType), f"Expect AlloType, got {dtype}"
self.dtype = dtype
if not ctx.enable_tensor:
memref_type = MemRefType.get(shape, dtype.build())
memref_type = MemRefType.get((), dtype.build())
alloc_op = memref_d.AllocOp(memref_type, [], [], ip=ctx.get_ip())
alloc_op.attributes["name"] = StringAttr.get(name)
else:
Expand All @@ -203,9 +201,7 @@ def __init__(self, name, dtype, ctx, value=None):
def result(self):
# pylint: disable=no-else-return
if not self.ctx.enable_tensor:
affine_map = AffineMap.get(
dim_count=0, symbol_count=0, exprs=[AffineConstantExpr.get(0)]
)
affine_map = AffineMap.get(dim_count=0, symbol_count=0, exprs=[])
affine_attr = AffineMapAttr.get(affine_map)
load = affine_d.AffineLoadOp(
self.dtype.build(),
Expand Down
32 changes: 10 additions & 22 deletions mlir/lib/Translation/EmitIntelHLS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,8 @@ void ModuleEmitter::emitArrayDecl(Value array, bool isFunc, std::string name) {
os << " */";
} else {
emitValue(array, 0, false, name);
if (arrayType.getShape().size() == 1 && arrayType.getShape()[0] == 1) {
// do nothing;
} else {
for (auto &shape : arrayType.getShape())
os << "[" << shape << "]";
}
for (auto &shape : arrayType.getShape())
os << "[" << shape << "]";
}
} else { // tensor
emitValue(array, 0, false, name);
Expand Down Expand Up @@ -678,14 +674,10 @@ void ModuleEmitter::emitAffineLoad(AffineLoadOp op) {
AffineExprEmitter affineEmitter(state, affineMap.getNumDims(),
op.getMapOperands());
auto arrayType = memref.getType().cast<ShapedType>();
if (arrayType.getShape().size() == 1 && arrayType.getShape()[0] == 1) {
// do nothing;
} else {
for (auto index : affineMap.getResults()) {
os << "[";
affineEmitter.emitAffineExpr(index);
os << "]";
}
for (auto index : affineMap.getResults()) {
os << "[";
affineEmitter.emitAffineExpr(index);
os << "]";
}
os << ";";
emitInfoAndNewLine(op);
Expand All @@ -711,14 +703,10 @@ void ModuleEmitter::emitAffineStore(AffineStoreOp op) {
AffineExprEmitter affineEmitter(state, affineMap.getNumDims(),
op.getMapOperands());
auto arrayType = memref.getType().cast<ShapedType>();
if (arrayType.getShape().size() == 1 && arrayType.getShape()[0] == 1) {
// do nothing;
} else {
for (auto index : affineMap.getResults()) {
os << "[";
affineEmitter.emitAffineExpr(index);
os << "]";
}
for (auto index : affineMap.getResults()) {
os << "[";
affineEmitter.emitAffineExpr(index);
os << "]";
}
os << " = ";
emitValue(op.getValueToStore());
Expand Down
35 changes: 12 additions & 23 deletions mlir/lib/Translation/EmitVivadoHLS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -922,14 +922,10 @@ void ModuleEmitter::emitAffineLoad(AffineLoadOp op) {
emitValue(memref, 0, false, load_from_name); // comment
}
auto arrayType = memref.getType().cast<ShapedType>();
if (arrayType.getShape().size() == 1 && arrayType.getShape()[0] == 1) {
// do nothing;
} else {
for (auto index : affineMap.getResults()) {
os << "[";
affineEmitter.emitAffineExpr(index);
os << "]";
}
for (auto index : affineMap.getResults()) {
os << "[";
affineEmitter.emitAffineExpr(index);
os << "]";
}
os << ";";
emitInfoAndNewLine(op);
Expand Down Expand Up @@ -972,14 +968,10 @@ void ModuleEmitter::emitAffineStore(AffineStoreOp op) {
emitValue(memref, 0, false, store_to_name); // comment
}
auto arrayType = memref.getType().cast<ShapedType>();
if (arrayType.getShape().size() == 1 && arrayType.getShape()[0] == 1) {
// do nothing;
} else {
for (auto index : affineMap.getResults()) {
os << "[";
affineEmitter.emitAffineExpr(index);
os << "]";
}
for (auto index : affineMap.getResults()) {
os << "[";
affineEmitter.emitAffineExpr(index);
os << "]";
}
os << " = ";
emitValue(op.getValueToStore());
Expand Down Expand Up @@ -1762,12 +1754,8 @@ void ModuleEmitter::emitArrayDecl(Value array, bool isFunc, std::string name) {
os << " */";
} else {
emitValue(array, 0, false, name);
if (arrayType.getShape().size() == 1 && arrayType.getShape()[0] == 1) {
// do nothing;
} else {
for (auto &shape : arrayType.getShape())
os << "[" << shape << "]";
}
for (auto &shape : arrayType.getShape())
os << "[" << shape << "]";
}
} else { // tensor
emitValue(array, 0, false, name);
Expand Down Expand Up @@ -2127,7 +2115,8 @@ void ModuleEmitter::emitFunction(func::FuncOp func) {
unsigned idx = 0;
for (auto result : funcReturn.getOperands()) {
if (std::find(args.begin(), args.end(), result) == args.end()) {
os << ",\n";
if (func.getArguments().size() > 0)
os << ",\n";
indent();

// TODO: a known bug, cannot return a value twice, e.g. return %0, %0
Expand Down
15 changes: 15 additions & 0 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,5 +714,20 @@ def kernel(A: T[10]) -> (T[2], T[2]):
assert "max" in mod.hls_code


def test_scalar():
def kernel() -> int32:
a: int32 = 0
b: int32 = a + 1
return b

s = allo.customize(kernel)
print(s.module)
assert "%alloc[]" in str(s.module)
mod = s.build()
assert mod() == 1
mod = s.build(target="vhls")
assert "," not in mod.hls_code


if __name__ == "__main__":
pytest.main([__file__])
25 changes: 22 additions & 3 deletions tests/test_vhls.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ def compute(x: int32[N], y: int32[N]):


def test_pointer_generation():
def top(inst: bool[1], C: int32[3]):
flag: bool = inst[0]
if flag:
def top(inst: bool, C: int32[3]):
if inst:
C[0] = C[0] + 1

s = allo.customize(top)
Expand Down Expand Up @@ -166,5 +165,25 @@ def case1(C: int32) -> int32:
# Note: Should not expect it to run using csim! Need to generate correct binding for mutable scalars in PyBind.


def test_size1_array():
def top(A: int32[1]) -> int32[1]:
A[0] = A[0] + 1
return A

s = allo.customize(top)
mod = s.build()
np_A = np.array([1], dtype=np.int32)
np.testing.assert_allclose(mod(np_A), [2], rtol=1e-5)
print("Passed CPU simulation!")
mod = s.build(target="vitis_hls", mode="csim", project="test_size1_array.prj")
print(mod.hls_code)
assert "[1]" in mod.hls_code
if hls.is_available("vitis_hls"):
np_B = np.array([0], dtype=np.int32)
mod(np_A, np_B)
np.testing.assert_allclose(np_A, [2], rtol=1e-5)
print("Passed!")


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit f183d50

Please sign in to comment.