From e7ead0ce16afb0fd383b60d098af7616d9488d02 Mon Sep 17 00:00:00 2001 From: calebmkim <55243755+calebmkim@users.noreply.github.com> Date: Tue, 7 Nov 2023 13:08:45 -0500 Subject: [PATCH] Systolic Generation Optimization (#1760) * better systolic array * cleaner code * small chnage * metadata format * remove getters * flake8 * sort for testing * plz work * does this work * simple sort * another try * plzzzz workkk --- .../systolic-lang/gen_array_component.py | 308 +++--------------- .../systolic-lang/systolic_scheduling.py | 240 ++++++++++++++ tests/frontend/systolic/array-1.expect | 64 +--- 3 files changed, 309 insertions(+), 303 deletions(-) create mode 100644 frontends/systolic-lang/systolic_scheduling.py diff --git a/frontends/systolic-lang/gen_array_component.py b/frontends/systolic-lang/gen_array_component.py index a3f9070f9a..76f09ae398 100644 --- a/frontends/systolic-lang/gen_array_component.py +++ b/frontends/systolic-lang/gen_array_component.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 -import numpy as np from gen_pe import pe, PE_NAME, BITWIDTH import calyx.builder as cb from calyx import py_ast from calyx.utils import bits_needed from systolic_arg_parser import SystolicConfiguration +from systolic_scheduling import gen_schedules # Global constant for the current bitwidth. DEPTH = "depth" @@ -33,44 +33,6 @@ } -class CalyxAdd: - """ - A class that represents addition in Calyx between a port and a constant - """ - - def __init__(self, port, const): - self.port = port - self.const = const - - def __eq__(self, other): - if type(other) != CalyxAdd: - return False - return ( - cb.ExprBuilder.unwrap(self.port) == cb.ExprBuilder.unwrap(other.port) - and self.const == other.const - ) - - def __hash__(self): - return hash(self.const) - - def __str__(self): - return ( - str(cb.ExprBuilder.unwrap(self.port).item.id.name) - + "_plus_" - + str(self.const) - ) - - def implement_add(self, comp: cb.ComponentBuilder) -> str: - """ - Implements the `CalyxAdd` by creating an adder that adds the two values - """ - if comp.try_get_cell(str(self)) is None: - add = comp.add(BITWIDTH, str(self)) - with comp.continuous: - add.left = self.port - add.right = self.const - - def add_systolic_output_params(comp: cb.ComponentBuilder, row_num, addr_width): """ Add output arguments to systolic array component `comp` for row `row_num`. @@ -161,21 +123,17 @@ def instantiate_data_move( from the `write` register of the PE at (row, col) to the read register of the PEs at (row+1, col) and (row, col+1) """ - name = f"pe_{row}_{col}" - if not right_edge: - group_name = NAME_SCHEME["register move right"].format(pe=name) src_reg = comp.get_cell(f"left_{row}_{col}") dst_reg = comp.get_cell(f"left_{row}_{col + 1}") - with comp.static_group(group_name, 1): + with comp.continuous: dst_reg.in_ = src_reg.out dst_reg.write_en = 1 if not down_edge: - group_name = NAME_SCHEME["register move down"].format(pe=name) src_reg = comp.get_cell(f"top_{row}_{col}") dst_reg = comp.get_cell(f"top_{row + 1}_{col}") - with comp.static_group(group_name, 1): + with comp.continuous: dst_reg.in_ = src_reg.out dst_reg.write_en = 1 @@ -211,19 +169,6 @@ def get_memory_updates(row, col): return mover_enables -def get_pe_moves(r, c, top_length, left_length): - """ - Gets the PE moves for the PE at (r,c) - """ - pe_moves = [] - if r < left_length - 1: - pe_moves.append(NAME_SCHEME["register move down"].format(pe=f"pe_{r}_{c}")) - if c < top_length - 1: - pe_moves.append(NAME_SCHEME["register move right"].format(pe=f"pe_{r}_{c}")) - pe_enables = [py_ast.Enable(name) for name in pe_moves] - return pe_enables - - def get_pe_invoke(r, c, mul_ready): """ gets the PE invokes for the PE at (r,c). mul_ready signals whether 1 or 0 @@ -239,7 +184,7 @@ def get_pe_invoke(r, c, mul_ready): ), ( "mul_ready", - py_ast.ConstantPort(1, mul_ready), + mul_ready, ), ], out_connects=[], @@ -292,158 +237,25 @@ def instantiate_idx_groups(comp: cb.ComponentBuilder, config: SystolicConfigurat lt_iter_limit.right = iter_limit.out -def instantiate_calyx_adds(comp, nec_ranges) -> list: - """ - Instantiates the CalyxAdd objects to adders and actual groups that perform the - specified add. - Returns a list of all the group names that we created. - """ - for lo, hi in nec_ranges: - if type(lo) == CalyxAdd: - lo.implement_add(comp) - if type(hi) == CalyxAdd: - hi.implement_add(comp) - - -def check_idx_lower_bound(comp: cb.ComponentBuilder, lo): - """ - Creates assignments to test if idx >= lo - """ - if type(lo) == CalyxAdd: - lo_value = comp.get_cell(str(lo)).port("out") - else: - lo_value = lo - idx = comp.get_cell("idx") - index_ge = f"index_ge_{lo}" - ge = comp.ge(BITWIDTH, index_ge) - with comp.continuous: - ge.left = idx.out - ge.right = lo_value - - -def check_idx_upper_bound(comp: cb.ComponentBuilder, hi): - """ - Creates assignments to test if idx < hi - """ - if type(hi) == CalyxAdd: - hi_value = comp.get_cell(str(hi)).port("out") - else: - hi_value = hi - idx = comp.get_cell("idx") - index_lt = f"index_lt_{hi}" - lt = comp.lt(BITWIDTH, index_lt) - with comp.continuous: - lt.left = idx.out - lt.right = hi_value - - -def check_idx_between(comp: cb.ComponentBuilder, lo, hi) -> list: - """ - Creates assignments to check whether idx is between [lo, hi). - That is, whether lo <= idx < hi. - """ - # This is the name of the combinational cell that checks the condition - idx_between_str = f"idx_between_{lo}_{hi}_comb" - lt = comp.get_cell(f"index_lt_{hi}") - # if lo == 0, then only need to check if reg < hi - if type(lo) == int and lo == 0: - # In this case, the `wire` cell is the cell checking the condition. - wire = comp.wire(idx_between_str, 1) - with comp.continuous: - wire.in_ = lt.out - # need to check if reg >= lo and reg < hi - else: - ge = comp.get_cell(f"index_ge_{lo}") - # In this case, the `and` cell is the cell checking the condition. - and_ = comp.and_(1, idx_between_str) - with comp.continuous: - and_.right = lt.out - and_.left = ge.out - - -def accum_nec_ranges(nec_ranges, schedule): - """ - Essentially creates a set that contains all of the idx ranges that - we need to check for (e.g., [1,3) [2,4)] in order to realize - the schedule - - nec_ranges is a set of tuples. - schedule is either a 2d array or 1d array with tuple (start,end) entries. - Adds all intervals (start,end) in schedule to nec_ranges if the it's - not already in nec_ranges. - """ - if schedule.ndim == 1: - for r in schedule: - nec_ranges.add(r) - elif schedule.ndim == 2: - for r in schedule: - for c in r: - nec_ranges.add(c) - else: - raise Exception("accum_nec_ranges expects only 1d or 2d arrays") - return nec_ranges - - -def gen_schedules( - config: SystolicConfiguration, - comp: cb.ComponentBuilder, -): +def execute_if_between(comp: cb.ComponentBuilder, start, end, body): """ - Generates 4 arrays that are the same size as the output (systolic) array - Each entry in the array has tuple [start, end) that indicates the cycles that - they are active - `update_sched` contains when to update the indices of the input memories and feed - them into the systolic array - `pe_fill_sched` contains when to invoke PE but not accumulate (bc the multipliers - are not ready with an output yet) - `pe_accum_sched` contains when to invoke PE and accumulate (bc the multipliers - are ready with an output) - `pe_move_sched` contains when to "move" the PE (i.e., pass data) - `pe_write_sched` contains when to "write" the PE value into the output ports - (e.g., this.r0_valid) + body is a list of control stmts + if body is empty, return an empty list + otherwise, builds an if stmt that executes body in parallel if + idx is between start and end """ - - def depth_plus_const(const: int): - """ - Returns depth + const. If config.static, then this is an int. - Otherwise, we need to perform a Calyx addition to figure this out. - """ - if config.static: - # return an int - return config.get_contraction_dimension() + const - else: - # return a CalyxAdd object, whose value is determined after generation - depth_port = comp.this().depth - return CalyxAdd(depth_port, const) - - left_length, top_length = config.left_length, config.top_length - - schedules = {} - update_sched = np.zeros((left_length, top_length), dtype=object) - pe_fill_sched = np.zeros((left_length, top_length), dtype=object) - pe_accum_sched = np.zeros((left_length, top_length), dtype=object) - pe_move_sched = np.zeros((left_length, top_length), dtype=object) - pe_write_sched = np.zeros((left_length, top_length), dtype=object) - for row in range(0, left_length): - for col in range(0, top_length): - pos = row + col - update_sched[row][col] = (pos, depth_plus_const(pos)) - pe_fill_sched[row][col] = (pos + 1, pos + 5) - pe_accum_sched[row][col] = (pos + 5, depth_plus_const(pos + 5)) - pe_move_sched[row][col] = (pos + 1, depth_plus_const(pos + 1)) - pe_write_sched[row][col] = ( - depth_plus_const(pos + 5), - depth_plus_const(pos + 6), - ) - schedules["update_sched"] = update_sched - schedules["fill_sched"] = pe_fill_sched - schedules["accum_sched"] = pe_accum_sched - schedules["move_sched"] = pe_move_sched - schedules["write_sched"] = pe_write_sched - return schedules + if not body: + return [] + if_cell = comp.get_cell(f"idx_between_{start}_{end}_comb") + return [ + cb.static_if( + if_cell.out, + py_ast.StaticParComp(body), + ) + ] -def execute_if_between(comp: cb.ComponentBuilder, start, end, body): +def execute_if_eq(comp: cb.ComponentBuilder, val, body): """ body is a list of control stmts if body is empty, return an empty list @@ -452,7 +264,7 @@ def execute_if_between(comp: cb.ComponentBuilder, start, end, body): """ if not body: return [] - if_cell = comp.get_cell(f"idx_between_{start}_{end}_comb") + if_cell = comp.get_cell(f"index_eq_{val}") return [ cb.static_if( if_cell.out, @@ -462,7 +274,7 @@ def execute_if_between(comp: cb.ComponentBuilder, start, end, body): def generate_control( - comp: cb.ComponentBuilder, config: SystolicConfiguration, schedules + comp: cb.ComponentBuilder, config: SystolicConfiguration, schedule ): """ Logically, control performs the following actions: @@ -500,51 +312,45 @@ def counter(): while_body_stmts = [py_ast.Enable("incr_idx")] for r in range(left_length): for c in range(top_length): - # build 4 if stmts for the 4 schedules that we need to account for + # Execute input_mem_updates = execute_if_between( comp, - schedules["update_sched"][r][c][0], - schedules["update_sched"][r][c][1], + schedule.mappings["update_sched"][r][c].i1, + schedule.mappings["update_sched"][r][c].i2, get_memory_updates(r, c), ) - pe_fills = execute_if_between( - comp, - schedules["fill_sched"][r][c][0], - schedules["fill_sched"][r][c][1], - [get_pe_invoke(r, c, 0)], - ) - pe_moves = execute_if_between( - comp, - schedules["move_sched"][r][c][0], - schedules["move_sched"][r][c][1], - get_pe_moves(r, c, top_length, left_length), + pe_accum_thresh = schedule.mappings["pe_accum_cond"][r][c].i1 + pe_accum_cond = py_ast.CompPort( + py_ast.CompVar(f"index_ge_{pe_accum_thresh}"), "out" ) - pe_accums = execute_if_between( + pe_executions = execute_if_between( comp, - schedules["accum_sched"][r][c][0], - schedules["accum_sched"][r][c][1], - [get_pe_invoke(r, c, 1)], + schedule.mappings["pe_sched"][r][c].i1, + schedule.mappings["pe_sched"][r][c].i2, + [get_pe_invoke(r, c, pe_accum_cond)], ) - output_writes = execute_if_between( + output_writes = execute_if_eq( comp, - schedules["write_sched"][r][c][0], - schedules["write_sched"][r][c][1], + schedule.mappings["pe_write_sched"][r][c].i1, [py_ast.Enable(NAME_SCHEME["out write"].format(pe=f"pe_{r}_{c}"))], ) - pe_control = ( - input_mem_updates + pe_fills + pe_moves + pe_accums + output_writes + while_body_stmts.append( + py_ast.StaticParComp(input_mem_updates + pe_executions + output_writes) ) - while_body_stmts.append(py_ast.StaticParComp(pe_control)) # providing metadata tag = counter() + boundary_fill_sched = "" + if r == 0 or c == 0: + boundary_fill_sched = f"Feeding Boundary PE: \ +[{schedule.mappings['update_sched'][r][c].i1},\ +{schedule.mappings['update_sched'][r][c].i2}) || " source_map[ tag - ] = f"pe_{r}_{c} filling: [{schedules['fill_sched'][r][c][0]},\ -{schedules['fill_sched'][r][c][1]}), \ -accumulating: [{schedules['accum_sched'][r][c][0]} \ -{schedules['accum_sched'][r][c][1]}), \ -writing: [{schedules['write_sched'][r][c][0]} \ -{schedules['write_sched'][r][c][1]})" + ] = f"pe_{r}_{c}: \ +{boundary_fill_sched}\ +Invoking PE: [{schedule.mappings['pe_sched'][r][c].i1}, \ +{schedule.mappings['pe_sched'][r][c].i2}) || \ +Writing PE Result: {schedule.mappings['pe_write_sched'][r][c].i1}" while_body = py_ast.StaticParComp(while_body_stmts) @@ -575,24 +381,16 @@ def create_systolic_array(prog: cb.Builder, config: SystolicConfiguration): # initialize the iteration limit to top_length + left_length + depth + 4 init_iter_limit(computational_unit, depth_port, config) - schedules = gen_schedules(config, computational_unit) - nec_ranges = set() - for sched in schedules.values(): - accum_nec_ranges(nec_ranges, sched) - instantiate_calyx_adds(computational_unit, nec_ranges) + # Generate the Schedule + schedule = gen_schedules(config, computational_unit) # instantiate groups that handles the idx variables instantiate_idx_groups(computational_unit, config) - list1, list2 = zip(*nec_ranges) - nec_ranges_beg = set(list1) - nec_ranges_end = set(list2) - for val in nec_ranges_beg: - check_idx_lower_bound(computational_unit, val) - for val in nec_ranges_end: - check_idx_upper_bound(computational_unit, val) - for start, end in nec_ranges: - # create the assignments that help determine if idx is in between - check_idx_between(computational_unit, start, end) + + # Generate the hardware For the schedule + schedule.build_hardware( + computational_unit, idx_reg=computational_unit.get_cell("idx") + ) for row in range(config.left_length): for col in range(config.top_length): @@ -629,6 +427,6 @@ def create_systolic_array(prog: cb.Builder, config: SystolicConfiguration): instantiate_output_move(computational_unit, row, col) # Generate the control and set the source map - control, source_map = generate_control(computational_unit, config, schedules) + control, source_map = generate_control(computational_unit, config, schedule) computational_unit.control = control prog.program.meta = source_map diff --git a/frontends/systolic-lang/systolic_scheduling.py b/frontends/systolic-lang/systolic_scheduling.py new file mode 100644 index 0000000000..9e36e1d74d --- /dev/null +++ b/frontends/systolic-lang/systolic_scheduling.py @@ -0,0 +1,240 @@ +import calyx.builder as cb +from gen_pe import BITWIDTH +from enum import Enum +from systolic_arg_parser import SystolicConfiguration +import numpy as np + + +class CalyxAdd: + """ + A class that represents addition in Calyx between a port and a constant + """ + + def __init__(self, port, const): + self.port = port + self.const = const + + def __eq__(self, other): + if type(other) != CalyxAdd: + return False + return ( + cb.ExprBuilder.unwrap(self.port) == cb.ExprBuilder.unwrap(other.port) + and self.const == other.const + ) + + def __hash__(self): + return hash(self.const) + + def __str__(self): + return ( + str(cb.ExprBuilder.unwrap(self.port).item.id.name) + + "_plus_" + + str(self.const) + ) + + def implement_add(self, comp: cb.ComponentBuilder) -> str: + """ + Implements the `CalyxAdd` by creating an adder that adds the two values + """ + if comp.try_get_cell(str(self)) is None: + add = comp.add(BITWIDTH, str(self)) + with comp.continuous: + add.left = self.port + add.right = self.const + + +class ScheduleType(Enum): + GE = 1 + LT = 2 + EQ = 3 + INTERVAL = 4 + + +class ScheduleInstance: + def __init__(self, type: ScheduleType, i1, i2=None): + self.type = type + self.i1 = i1 + self.i2 = i2 + if type == ScheduleType.INTERVAL and self.i2 is None: + raise Exception("INTERVAL type must specify beginning and end") + + def __lt__(self, other): + return (self.type, self.i1, self.i2) < (other.type, other.i1, other.i2) + + +class Schedule: + def __init__(self): + # XXX(Caleb): self.instances could be a set, but I'm running into annoying + # ordering errors on tests. Python dictionaries are luckily ordered. + self.instances = {} + self.mappings = {} + + def add_instances(self, name, schedule_instances): + """ """ + self.mappings[name] = schedule_instances + for schedule_instance in schedule_instances.flatten(): + self.instances[schedule_instance] = None + + def __instantiate_calyx_adds(self, comp) -> list: + """ """ + for schedule_instance in self.instances.keys(): + if type(schedule_instance.i1) == CalyxAdd: + schedule_instance.i1.implement_add(comp) + if type(schedule_instance.i2) == CalyxAdd: + schedule_instance.i2.implement_add(comp) + + def __check_idx_eq(self, comp: cb.ComponentBuilder, idx_reg: cb.CellBuilder, eq): + """ + Creates assignments to test if idx >= lo + """ + if type(eq) == CalyxAdd: + eq_value = comp.get_cell(str(eq)).port("out") + else: + eq_value = eq + eq = comp.eq(BITWIDTH, f"index_eq_{eq}") + with comp.continuous: + eq.left = idx_reg.out + eq.right = eq_value + + def __check_idx_lower_bound( + self, comp: cb.ComponentBuilder, idx_reg: cb.CellBuilder, lo + ): + """ + Creates assignments to test if idx >= lo + """ + if type(lo) == int and lo == 0: + return + if type(lo) == CalyxAdd: + lo_value = comp.get_cell(str(lo)).port("out") + else: + lo_value = lo + ge = comp.ge(BITWIDTH, f"index_ge_{lo}") + with comp.continuous: + ge.left = idx_reg.out + ge.right = lo_value + + def __check_idx_upper_bound( + self, comp: cb.ComponentBuilder, idx_reg: cb.CellBuilder, hi + ): + """ + Creates assignments to test if idx < hi + """ + if type(hi) == CalyxAdd: + hi_value = comp.get_cell(str(hi)).port("out") + else: + hi_value = hi + lt = comp.lt(BITWIDTH, f"index_lt_{hi}") + with comp.continuous: + lt.left = idx_reg.out + lt.right = hi_value + + def __check_idx_between(self, comp: cb.ComponentBuilder, lo, hi) -> list: + """ + Creates assignments to check whether idx is between [lo, hi). + That is, whether lo <= idx < hi. + IMPORTANT: Assumes the lt and gt cells ahve already been created + """ + # This is the name of the combinational cell that checks the condition + idx_between_str = f"idx_between_{lo}_{hi}_comb" + lt = comp.get_cell(f"index_lt_{hi}") + # if lo == 0, then only need to check if reg < hi + if type(lo) == int and lo == 0: + # In this case, the `wire` cell is the cell checking the condition. + wire = comp.wire(idx_between_str, 1) + with comp.continuous: + wire.in_ = lt.out + # need to check if reg >= lo and reg < hi + else: + ge = comp.get_cell(f"index_ge_{lo}") + # In this case, the `and` cell is the cell checking the condition. + and_ = comp.and_(1, idx_between_str) + with comp.continuous: + and_.right = lt.out + and_.left = ge.out + + def build_hardware(self, comp: cb.ComponentBuilder, idx_reg: cb.CellBuilder): + """ """ + # instantiate groups that handles the idx variables + # Dictionary to keep consistent ordering. + ge_ranges = {} + lt_ranges = {} + eq_ranges = {} + interval_ranges = {} + for schedule_instance in self.instances.keys(): + sched_type = schedule_instance.type + if sched_type == ScheduleType.GE: + ge_ranges[schedule_instance.i1] = None + elif sched_type == ScheduleType.LT: + lt_ranges[schedule_instance.i1] = None + elif sched_type == ScheduleType.EQ: + eq_ranges[schedule_instance.i1] = None + elif sched_type == ScheduleType.INTERVAL: + ge_ranges[schedule_instance.i1] = None + lt_ranges[schedule_instance.i2] = None + interval_ranges[(schedule_instance.i1, schedule_instance.i2)] = None + self.__instantiate_calyx_adds(comp) + # Need to sort for testing purposes + for val in eq_ranges: + self.__check_idx_eq(comp, idx_reg, val) + for val in ge_ranges: + self.__check_idx_lower_bound(comp, idx_reg, val) + for val in lt_ranges: + self.__check_idx_upper_bound(comp, idx_reg, val) + for start, end in interval_ranges: + self.__check_idx_between(comp, start, end) + + +def gen_schedules( + config: SystolicConfiguration, + comp: cb.ComponentBuilder, +): + """ + Generates 4 arrays that are the same size as the output (systolic) array + Each entry in the array has tuple [start, end) that indicates the cycles that + they are active + `update_sched` contains when to update the indices of the input memories and feed + them into the systolic array + `pe_sched` contains when to invoke PE + `pe_accum_cond` contains when to allow the PEs to accumulate (bc the multipliers + are ready with an output) + `pe_write_sched` contains when to "write" the PE value into the output ports + (e.g., this.r0_valid) + """ + + def depth_plus_const(const: int): + """ + Returns depth + const. If config.static, then this is an int. + Otherwise, we need to perform a Calyx addition to figure this out. + """ + if config.static: + # return an int + return config.get_contraction_dimension() + const + else: + # return a CalyxAdd object, whose value is determined after generation + depth_port = comp.this().depth + return CalyxAdd(depth_port, const) + + left_length, top_length = config.left_length, config.top_length + update_sched = np.zeros((left_length, top_length), dtype=object) + pe_sched = np.zeros((left_length, top_length), dtype=object) + pe_accum_cond = np.zeros((left_length, top_length), dtype=object) + pe_write_sched = np.zeros((left_length, top_length), dtype=object) + for row in range(0, left_length): + for col in range(0, top_length): + pos = row + col + update_sched[row][col] = ScheduleInstance( + ScheduleType.INTERVAL, pos, depth_plus_const(pos) + ) + pe_sched[row][col] = ScheduleInstance( + ScheduleType.INTERVAL, pos + 1, depth_plus_const(pos + 5) + ) + pe_accum_cond[row][col] = ScheduleInstance(ScheduleType.GE, pos + 5) + pe_write_sched[row][col] = ScheduleInstance( + ScheduleType.EQ, depth_plus_const(pos + 5) + ) + schedule = Schedule() + schedule.add_instances("update_sched", update_sched) + schedule.add_instances("pe_sched", pe_sched) + schedule.add_instances("pe_accum_cond", pe_accum_cond) + schedule.add_instances("pe_write_sched", pe_write_sched) + return schedule diff --git a/tests/frontend/systolic/array-1.expect b/tests/frontend/systolic/array-1.expect index 1c7b2f7036..a0a53c6bf7 100644 --- a/tests/frontend/systolic/array-1.expect +++ b/tests/frontend/systolic/array-1.expect @@ -31,27 +31,18 @@ component systolic_array_comp(depth: 32, t0_read_data: 32, l0_read_data: 32) -> cells { iter_limit = std_reg(32); iter_limit_add = std_add(32); - depth_plus_5 = std_add(32); - depth_plus_0 = std_add(32); - depth_plus_1 = std_add(32); - depth_plus_6 = std_add(32); idx = std_reg(32); idx_add = std_add(32); lt_iter_limit = std_lt(32); - index_ge_0 = std_ge(32); + depth_plus_0 = std_add(32); + depth_plus_5 = std_add(32); + index_eq_depth_plus_5 = std_eq(32); index_ge_1 = std_ge(32); - index_ge_depth_plus_5 = std_ge(32); index_ge_5 = std_ge(32); index_lt_depth_plus_0 = std_lt(32); - index_lt_depth_plus_1 = std_lt(32); - index_lt_5 = std_lt(32); index_lt_depth_plus_5 = std_lt(32); - index_lt_depth_plus_6 = std_lt(32); - idx_between_5_depth_plus_5_comb = std_and(1); - idx_between_1_5_comb = std_and(1); idx_between_0_depth_plus_0_comb = std_wire(1); - idx_between_1_depth_plus_1_comb = std_and(1); - idx_between_depth_plus_5_depth_plus_6_comb = std_and(1); + idx_between_1_depth_plus_5_comb = std_and(1); pe_0_0 = mac_pe(); top_0_0 = std_reg(32); left_0_0 = std_reg(32); @@ -65,14 +56,6 @@ component systolic_array_comp(depth: 32, t0_read_data: 32, l0_read_data: 32) -> iter_limit.in = iter_limit_add.out; iter_limit.write_en = 1'd1; } - depth_plus_5.left = depth; - depth_plus_5.right = 32'd5; - depth_plus_0.left = depth; - depth_plus_0.right = 32'd0; - depth_plus_1.left = depth; - depth_plus_1.right = 32'd1; - depth_plus_6.left = depth; - depth_plus_6.right = 32'd6; static<1> group init_idx { idx.in = 32'd0; idx.write_en = 1'd1; @@ -85,33 +68,23 @@ component systolic_array_comp(depth: 32, t0_read_data: 32, l0_read_data: 32) -> } lt_iter_limit.left = idx.out; lt_iter_limit.right = iter_limit.out; - index_ge_0.left = idx.out; - index_ge_0.right = 32'd0; + depth_plus_0.left = depth; + depth_plus_0.right = 32'd0; + depth_plus_5.left = depth; + depth_plus_5.right = 32'd5; + index_eq_depth_plus_5.left = idx.out; + index_eq_depth_plus_5.right = depth_plus_5.out; index_ge_1.left = idx.out; index_ge_1.right = 32'd1; - index_ge_depth_plus_5.left = idx.out; - index_ge_depth_plus_5.right = depth_plus_5.out; index_ge_5.left = idx.out; index_ge_5.right = 32'd5; index_lt_depth_plus_0.left = idx.out; index_lt_depth_plus_0.right = depth_plus_0.out; - index_lt_depth_plus_1.left = idx.out; - index_lt_depth_plus_1.right = depth_plus_1.out; - index_lt_5.left = idx.out; - index_lt_5.right = 32'd5; index_lt_depth_plus_5.left = idx.out; index_lt_depth_plus_5.right = depth_plus_5.out; - index_lt_depth_plus_6.left = idx.out; - index_lt_depth_plus_6.right = depth_plus_6.out; - idx_between_5_depth_plus_5_comb.right = index_lt_depth_plus_5.out; - idx_between_5_depth_plus_5_comb.left = index_ge_5.out; - idx_between_1_5_comb.right = index_lt_5.out; - idx_between_1_5_comb.left = index_ge_1.out; idx_between_0_depth_plus_0_comb.in = index_lt_depth_plus_0.out; - idx_between_1_depth_plus_1_comb.right = index_lt_depth_plus_1.out; - idx_between_1_depth_plus_1_comb.left = index_ge_1.out; - idx_between_depth_plus_5_depth_plus_6_comb.right = index_lt_depth_plus_6.out; - idx_between_depth_plus_5_depth_plus_6_comb.left = index_ge_depth_plus_5.out; + idx_between_1_depth_plus_5_comb.right = index_lt_depth_plus_5.out; + idx_between_1_depth_plus_5_comb.left = index_ge_1.out; idx_minus_0.left = idx.out; idx_minus_0.right = 32'd0; idx_minus_0_res.in = idx_minus_0.out; @@ -147,17 +120,12 @@ component systolic_array_comp(depth: 32, t0_read_data: 32, l0_read_data: 32) -> t0_move; } } - static if idx_between_1_5_comb.out { - static par { - static invoke pe_0_0(top=top_0_0.out, left=left_0_0.out, mul_ready=1'd0)(); - } - } - static if idx_between_5_depth_plus_5_comb.out { + static if idx_between_1_depth_plus_5_comb.out { static par { - static invoke pe_0_0(top=top_0_0.out, left=left_0_0.out, mul_ready=1'd1)(); + static invoke pe_0_0(top=top_0_0.out, left=left_0_0.out, mul_ready=index_ge_5.out)(); } } - static if idx_between_depth_plus_5_depth_plus_6_comb.out { + static if index_eq_depth_plus_5.out { static par { pe_0_0_out_write; } @@ -228,5 +196,5 @@ component main() -> () { } } metadata #{ -0: pe_0_0 filling: [1,5), accumulating: [5 depth_plus_5), writing: [depth_plus_5 depth_plus_6) +0: pe_0_0: Feeding Boundary PE: [0,depth_plus_0) || Invoking PE: [1, depth_plus_5) || Writing PE Result: depth_plus_5 }#