Skip to content

Commit

Permalink
Builder: use register to make the queue nicer (#1681)
Browse files Browse the repository at this point in the history
  • Loading branch information
anshumanmohan authored Aug 21, 2023
1 parent 69deb30 commit 75abc97
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 50 deletions.
38 changes: 38 additions & 0 deletions calyx-py/calyx/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,44 @@ def sub_store_in_reg(self, left, right, cellname, width, ans_reg=None):
sub_group.done = ans_reg.done
return sub_group, ans_reg

def eq_store_in_reg(self, left, right, cellname, width, ans_reg=None):
"""Adds wiring into component `self` to compute `left` == `right`
and store it in `ans_reg`.
1. Within component `self`, creates a group called `cellname`_group.
2. Within `group`, create a cell `cellname` that computes equality.
3. Puts the values of `left` and `right` into `cell`.
4. Then puts the answer of the computation into `ans_reg`.
4. Returns the equality group and the register.
"""
eq_cell = self.eq(width, cellname)
ans_reg = ans_reg or self.reg(f"reg_{cellname}", 1)
with self.group(f"{cellname}_group") as eq_group:
eq_cell.left = left
eq_cell.right = right
ans_reg.write_en = 1
ans_reg.in_ = eq_cell.out
eq_group.done = ans_reg.done
return eq_group, ans_reg

def neq_store_in_reg(self, left, right, cellname, width, ans_reg=None):
"""Adds wiring into component `self` to compute `left` != `right`
and store it in `ans_reg`.
1. Within component `self`, creates a group called `cellname`_group.
2. Within `group`, create a cell `cellname` that computes inequality.
3. Puts the values of `left` and `right` into `cell`.
4. Then puts the answer of the computation into `ans_reg`.
4. Returns the inequality group and the register.
"""
neq_cell = self.neq(width, cellname)
ans_reg = ans_reg or self.reg(f"reg_{cellname}", 1)
with self.group(f"{cellname}_group") as neq_group:
neq_cell.left = left
neq_cell.right = right
ans_reg.write_en = 1
ans_reg.in_ = neq_cell.out
neq_group.done = ans_reg.done
return neq_group, ans_reg


@dataclass(frozen=True)
class CellAndGroup:
Expand Down
89 changes: 39 additions & 50 deletions calyx-py/calyx/queue_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,6 @@
ANS_MEM_LEN = 10


def insert_raise_err_if_i_eq_max_cmds(prog):
"""Inserts a the component `raise_err_if_i_eq_MAX_CMDS` into the program.
It has:
- one input, `i`.
- one ref register, `err`.
If `i` equals MAX_CMDS, it raises the `err` flag.
"""
raise_err_if_i_eq_max_cmds: cb.ComponentBuilder = prog.component(
"raise_err_if_i_eq_MAX_CMDS"
)
i = raise_err_if_i_eq_max_cmds.input("i", 32)
err = raise_err_if_i_eq_max_cmds.reg("err", 1, is_ref=True)

i_eq_max_cmds = raise_err_if_i_eq_max_cmds.eq_use(i, MAX_CMDS, 32)
raise_err = raise_err_if_i_eq_max_cmds.reg_store(err, 1, "raise_err")

raise_err_if_i_eq_max_cmds.control += [
cb.if_with(
i_eq_max_cmds,
raise_err,
)
]

return raise_err_if_i_eq_max_cmds


def insert_main(prog, queue):
"""Inserts the component `main` into the program.
This will be used to `invoke` the component `queue` and feed it a list of commands.
Expand Down Expand Up @@ -67,15 +39,12 @@ def insert_main(prog, queue):

# The two components we'll use:
queue = main.cell("myqueue", queue)
raise_err_if_i_eq_max_cmds = main.cell(
"raise_err_if_i_eq_MAX_CMDS", insert_raise_err_if_i_eq_max_cmds(prog)
)

# We will use the `invoke` method to call the `queue` component.
# The queue component takes two inputs by reference and one input directly.
# The two `ref` inputs:
err = main.reg("err", 1) # A flag to indicate an error
ans = main.reg("ans", 32) # A memory to hold the answer of a pop
ans = main.reg("ans", 32) # A memory to hold the answer of a pop or peek

# We will set up a while loop that runs over the command list, relaying
# the commands to the `queue` component.
Expand All @@ -88,7 +57,6 @@ def insert_main(prog, queue):

incr_i = main.incr(i, 32) # i++
incr_j = main.incr(j, 32) # j++
err_eq_0 = main.eq_use(err.out, 0, 1) # is `err` flag down?
cmd_le_1 = main.le_use(cmd.out, 1, 2) # cmd <= 1

read_cmd = main.mem_read_seq_d1(commands, i.out, "read_cmd_phase1")
Expand All @@ -100,36 +68,57 @@ def insert_main(prog, queue):
)
write_ans = main.mem_store_seq_d1(ans_mem, j.out, ans.out, "write_ans")

loop_goes_on = main.reg(
"loop_goes_on", 1
) # A flag to indicate whether the loop should continue
update_err_is_down, _ = main.eq_store_in_reg(
err.out,
0,
"err_is_down",
1,
loop_goes_on
# Does the `err` flag say that the loop should continue?
)
update_i_neq_15, _ = main.neq_store_in_reg(
i.out,
cb.const(32, 15),
"i_neq_15",
32,
loop_goes_on
# Does the `i` index say that the loop should continue?
)

main.control += [
cb.while_with(
err_eq_0, # Run while the `err` flag is down
update_err_is_down,
cb.while_(
loop_goes_on.out, # Run while the `err` flag is down
[
read_cmd,
write_cmd_to_reg,
# `cmd := commands[i]`
write_cmd_to_reg, # `cmd := commands[i]`
read_value,
write_value_to_reg,
# `value := values[i]`
write_value_to_reg, # `value := values[i]`
cb.invoke( # Invoke the queue.
queue,
in_cmd=cmd.out,
in_value=value.out,
ref_ans=ans,
ref_err=err,
),
cb.if_with( # If it was a pop or a peek, write ans to the answer list
cmd_le_1,
[ # AM: I'd like to have an additional check hereL
# if err flag comes back raised,
# we do not perform this write_ans or this incr_j
write_ans,
incr_j,
update_err_is_down, # Does `err` say that the loop should be broken?
cb.if_(
loop_goes_on.out, # If the loop is not meant to be broken...
[
cb.if_with(
cmd_le_1, # If the command was a pop or peek,
[
write_ans, # Write the answer to the answer list
incr_j, # And increment the answer index.
],
),
incr_i, # Increment the command index
update_i_neq_15, # Did this increment make us need to break?
],
),
incr_i, # Increment the command index
cb.invoke( # If i = MAX_CMDS, raise error flag
raise_err_if_i_eq_max_cmds, in_i=i.out, ref_err=err
), # AM: hella hacky
],
),
]

0 comments on commit 75abc97

Please sign in to comment.