Skip to content

Commit

Permalink
Bank Output Memories & Write to output as soon as PE is ready (#1586)
Browse files Browse the repository at this point in the history
* bank output memories

* rewrite frontend tests
  • Loading branch information
calebmkim committed Jul 6, 2023
1 parent 507201a commit d95f53e
Show file tree
Hide file tree
Showing 11 changed files with 449 additions and 129 deletions.
7 changes: 5 additions & 2 deletions frontends/systolic-lang/check-output.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@
for c in range(tl):
top[r][c] = json_data[f"t{c}"][r]

matmul_result = np.matmul(left, top).flatten()
matmul_result = np.matmul(left, top)

json_result = np.array(json_data["out_mem"])
res = []
for r in range(ll):
res.append(json_data[f"out_mem_{r}"])
json_result = np.array(res)

if np.array_equal(json_result, matmul_result):
print("Correct")
Expand Down
60 changes: 32 additions & 28 deletions frontends/systolic-lang/gen-systolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,17 @@ def instantiate_output_move(comp: cb.ComponentBuilder, row, col, cols):
Generates groups to move the final value from a PE into the output array.
"""
group_name = NAME_SCHEME["out mem move"].format(pe=f"pe_{row}_{col}")
idx = row * cols + col
pe = comp.get_cell(f"pe_{row}_{col}")
out = comp.get_cell(OUT_MEM)
out = comp.get_cell(OUT_MEM + f"_{row}")
with comp.static_group(group_name, 1):
out.addr0 = idx
out.addr0 = col
out.write_data = pe.out
out.write_en = 1


def gen_schedules(top_length, top_depth, left_length, left_depth):
"""
Generates 4 arrays that are the same size as the output (systolic) array
Generates 5 arrays that are the same size as the output (systolic) array
Each entry in the array has tuple [start, end) that indicates the cycles that
they are active
`update_sched` contains when to update the indices of the input memories and feed
Expand All @@ -191,19 +190,23 @@ def gen_schedules(top_length, top_depth, left_length, left_depth):
`pe_accum_sched` contains when to invoke PE and accumulate (bc the multipliers
are ready with an output)
`pe_move_sched` contains when to "move" the PE (i.e., pass data)
`pe_write_sched` contains when to "write" the PE value into memory (i.e., when
the PE is "finished")
"""
update_sched = np.zeros((left_length, top_length), dtype=object)
pe_fill_sched = np.zeros((left_length, top_length), dtype=object)
pe_accum_sched = np.zeros((left_length, top_length), dtype=object)
pe_move_sched = np.zeros((left_length, top_length), dtype=object)
pe_write_sched = np.zeros((left_length, top_length), dtype=object)
for row in range(0, left_length):
for col in range(0, top_length):
pos = row + col
update_sched[row][col] = (pos, pos + left_depth)
pe_fill_sched[row][col] = (pos + 1, pos + min(4, left_depth) + 1)
pe_accum_sched[row][col] = (pos + 5, pos + left_depth + 5)
pe_move_sched[row][col] = (pos + 1, pos + left_depth + 1)
return (update_sched, pe_fill_sched, pe_accum_sched, pe_move_sched)
pe_write_sched[row][col] = (pos + left_depth + 5, pos + left_depth + 6)
return (update_sched, pe_fill_sched, pe_accum_sched, pe_move_sched, pe_write_sched)


def accum_nec_ranges(nec_ranges, schedule):
Expand Down Expand Up @@ -377,6 +380,7 @@ def generate_control(
fill_sched,
accum_sched,
move_sched,
write_sched,
nec_ranges,
):
"""
Expand Down Expand Up @@ -449,13 +453,20 @@ def counter():
accum_sched[r][c][1],
[get_pe_invoke(r, c, top_length, left_length, 1)],
)
pe_writes = execute_if_between(
comp,
write_sched[r][c][0],
write_sched[r][c][1],
[py_ast.Enable(NAME_SCHEME["out mem move"].format(pe=f"pe_{r}_{c}"))],
)
pe_control = input_mem_updates + pe_fills + pe_moves + pe_accums + pe_writes
control_stmts.append(py_ast.StaticParComp(pe_control))
# providing metadata
tag = counter()
source_map[
tag
] = f"pe_{r}_{c} filling: [{fill_sched[r][c][0]},{fill_sched[r][c][1]}) \
accumulating: [{accum_sched[r][c][0]} {accum_sched[r][c][1]})"
pe_control = input_mem_updates + pe_fills + pe_moves + pe_accums
control_stmts.append(py_ast.StaticParComp(pe_control))
for start, end in nec_ranges:
# build the control stmts that assign correct values to
# idx_between_{start}_{end}_reg, which is what the if stmts above^ rely on
Expand All @@ -468,20 +479,11 @@ def counter():
# build the static repeat
# num repeats = (top_length - 1) + (left_length - 1) + (top_depth - 1) + 5 + 1
static_repeat = cb.static_repeat(
top_length + left_length + top_depth + 3, repeat_body
top_length + left_length + top_depth + 4, repeat_body
)

control.append(static_repeat)

# Move all the results into output memory
mover_groups = []
for row in range(left_length):
for col in range(top_length):
mover_groups.append(
py_ast.Enable(NAME_SCHEME["out mem move"].format(pe=f"pe_{row}_{col}"))
)

control.append(py_ast.StaticSeqComp(mover_groups))
return py_ast.StaticSeqComp(stmts=control), source_map


Expand All @@ -500,14 +502,15 @@ def create_systolic_array(
f"{top_length}x{top_depth} and {left_depth}x{left_length}"
)

(update_sched, fill_sched, accum_sched, move_sched) = gen_schedules(
(update_sched, fill_sched, accum_sched, move_sched, write_sched) = gen_schedules(
top_length, top_depth, left_length, left_depth
)
nec_ranges = set()
accum_nec_ranges(nec_ranges, update_sched)
accum_nec_ranges(nec_ranges, fill_sched)
accum_nec_ranges(nec_ranges, accum_sched)
accum_nec_ranges(nec_ranges, move_sched)
accum_nec_ranges(nec_ranges, write_sched)

main = prog.component("main")

Expand All @@ -524,15 +527,15 @@ def create_systolic_array(
instantiate_memory(main, "left", col, left_depth)

# Instantiate output memory
total_size = left_length * top_length
out_idx_size = bits_needed(total_size)
main.mem_d1(
OUT_MEM,
BITWIDTH,
total_size,
out_idx_size,
is_external=True,
)
out_idx_size = bits_needed(top_length)
for i in range(left_length):
main.mem_d1(
OUT_MEM + f"_{i}",
BITWIDTH,
top_length,
out_idx_size,
is_external=True,
)

# Instantiate all the PEs
for row in range(left_length):
Expand All @@ -545,7 +548,7 @@ def create_systolic_array(
# Instantiate output movement structure
instantiate_output_move(main, row, col, top_length)

iter_limit = top_length + left_length + top_depth + 3
iter_limit = top_length + left_length + top_depth + 4
iter_idx_size = bits_needed(iter_limit)
# instantiate groups that initialize idx to 0 and increment it
instantiate_idx_groups(main, iter_idx_size, iter_limit)
Expand All @@ -566,6 +569,7 @@ def create_systolic_array(
fill_sched,
accum_sched,
move_sched,
write_sched,
nec_ranges,
)
main.control = control
Expand Down
8 changes: 5 additions & 3 deletions tests/correctness/systolic/output/array-2-3-4.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cycles": 21,
"cycles": 14,
"memories": {
"l0": [
62,
Expand All @@ -11,11 +11,13 @@
28,
61
],
"out_mem": [
"out_mem_0": [
5304,
5634,
8244,
1030,
1030
],
"out_mem_1": [
8518,
8879,
11617,
Expand Down
13 changes: 11 additions & 2 deletions tests/correctness/systolic/output/array-2-3-4.systolic.data
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,21 @@
"width": 32
}
},
"out_mem": {
"out_mem_0": {
"data": [
0,
0,
0,
0,
0
],
"format": {
"is_signed": false,
"numeric_type": "bitnum",
"width": 32
}
},
"out_mem_1": {
"data": [
0,
0,
0,
Expand Down
32 changes: 23 additions & 9 deletions tests/correctness/systolic/output/array-8.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cycles": 92,
"cycles": 29,
"memories": {
"l0": [
26,
Expand Down Expand Up @@ -81,63 +81,77 @@
70,
91
],
"out_mem": [
"out_mem_0": [
15082,
17066,
25978,
13905,
17367,
27929,
17607,
13732,
13732
],
"out_mem_1": [
9378,
10449,
15741,
10148,
12877,
18998,
9314,
9333,
9333
],
"out_mem_2": [
15735,
12897,
24104,
16444,
16455,
29104,
17296,
15490,
15490
],
"out_mem_3": [
22450,
26165,
32194,
24043,
23784,
33638,
26276,
24976,
24976
],
"out_mem_4": [
15650,
19069,
21323,
13406,
19967,
24453,
17448,
14934,
14934
],
"out_mem_5": [
18516,
22029,
30577,
17767,
20837,
35265,
21524,
14972,
14972
],
"out_mem_6": [
13426,
16673,
19948,
13367,
15650,
23464,
18419,
10693,
10693
],
"out_mem_7": [
15791,
22708,
22926,
Expand Down
Loading

0 comments on commit d95f53e

Please sign in to comment.