Skip to content

Commit

Permalink
[dataflow] Supplement, refine, and organize designs of unified systol…
Browse files Browse the repository at this point in the history
…ic array (#282)
  • Loading branch information
AdrianLiu00 authored Dec 25, 2024
1 parent 3f968a0 commit 63e83a7
Show file tree
Hide file tree
Showing 3 changed files with 413 additions and 25 deletions.
17 changes: 13 additions & 4 deletions allo/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def _build_top(s, stream_info):
with s.module.context, Location.unknown():
# create new func
func_type = FunctionType.get(input_types, [])
new_top = func_d.FuncOp(name="top", type=func_type, ip=InsertionPoint(top_func))
new_top = func_d.FuncOp(
name=s.top_func_name, type=func_type, ip=InsertionPoint(top_func)
)
new_top.add_entry_block()
return_op = func_d.ReturnOp([], ip=InsertionPoint(new_top.entry_block))
for op in top_func.entry_block.operations:
Expand Down Expand Up @@ -231,13 +233,19 @@ def wrapper(*args, **kwargs):
return actual_decorator


def customize(func):
def df_primitive_default(s):
df_pipeline(s.module, rewind=True)


def customize(func, opt_default=True):
global_vars = get_global_vars(func)
s = _customize(func, global_vars=global_vars)
stream_info = move_stream_to_interface(s)
s = _build_top(s, stream_info)

df_pipeline(s.module, rewind=True)
if opt_default:
df_primitive_default(s)

return s


Expand All @@ -248,6 +256,7 @@ def build(
project="top.prj",
configs=None,
wrap_io=True,
opt_default=True,
):
if target == "aie":
global_vars = get_global_vars(func)
Expand All @@ -257,7 +266,7 @@ def build(
mod.build()
return mod
# FPGA backend
s = customize(func)
s = customize(func, opt_default)
hls_mod = s.build(
target=target,
mode=mode,
Expand Down
2 changes: 0 additions & 2 deletions tests/dataflow/test_daisy_chain_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ def gemm(A: int16[M, K], B: int16[K, N], C: int16[M, N]):
fifo_A[i - 1, j].put(a)
with allo.meta_if(i < M):
fifo_B[i, j - 1].put(b)
with allo.meta_else():
pass

with allo.meta_if(i == 1):
packed_tmp: UInt(M * 16) = 0
Expand Down
Loading

0 comments on commit 63e83a7

Please sign in to comment.