Skip to content

Commit

Permalink
Improve Systolic Array Generator (#1582)
Browse files Browse the repository at this point in the history
* have example for reference, starting to hack

* pipelined systolic array

* changes test cases

* deleted unnecesary file

* rewrote tests

* fixed other errors

* python lint

* metadata and some small stuff

* removed space

* fixed small changes

* tested correctness

* deleted unnecessary file

* added test cases
  • Loading branch information
calebmkim authored Jul 4, 2023
1 parent fb9dc99 commit 507201a
Show file tree
Hide file tree
Showing 27 changed files with 2,197 additions and 629 deletions.
59 changes: 53 additions & 6 deletions calyx-py/calyx/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def __init__(self):
self.import_("primitives/core.futil")
self._index: Dict[str, ComponentBuilder] = {}

def component(self, name: str, cells=None) -> ComponentBuilder:
def component(self, name: str, cells=None, latency=None) -> ComponentBuilder:
"""Create a new component builder."""
cells = cells or []
comp_builder = ComponentBuilder(self, name, cells)
comp_builder = ComponentBuilder(self, name, cells, latency)
self.program.components.append(comp_builder.component)
self._index[name] = comp_builder
return comp_builder
Expand All @@ -49,7 +49,11 @@ class ComponentBuilder:
"""Builds Calyx components definitions."""

def __init__(
self, prog: Builder, name: str, cells: Optional[List[ast.Cell]] = None
self,
prog: Builder,
name: str,
cells: Optional[List[ast.Cell]] = None,
latency: Optional[int] = None,
):
"""Contructs a new component in the current program. If `cells` is
provided, the component will be initialized with those cells."""
Expand All @@ -61,6 +65,7 @@ def __init__(
outputs=[],
structs=cells,
controls=ast.Empty(),
latency=latency,
)
self.index: Dict[str, Union[GroupBuilder, CellBuilder]] = {}
for cell in cells:
Expand Down Expand Up @@ -114,6 +119,14 @@ def get_cell(self, name: str) -> CellBuilder:
f"Known cells: {list(map(lambda c: c.id.name, self.component.cells))}"
)

def try_get_cell(self, name: str) -> CellBuilder:
"""Tries to get a cell builder by name. If cannot find it, return None"""
out = self.index.get(name)
if out and isinstance(out, CellBuilder):
return out
else:
return None

def get_group(self, name: str) -> GroupBuilder:
"""Retrieve a group builder by name."""
out = self.index.get(name)
Expand All @@ -137,7 +150,17 @@ def group(self, name: str, static_delay: Optional[int] = None) -> GroupBuilder:
def comb_group(self, name: str) -> GroupBuilder:
"""Create a new combinational group with the given name."""
group = ast.CombGroup(ast.CompVar(name), connections=[])
assert group not in self.component.wires, f"comb group '{name}' already exists"
assert group not in self.component.wires, f"group '{name}' already exists"

self.component.wires.append(group)
builder = GroupBuilder(group, self)
self.index[name] = builder
return builder

def static_group(self, name: str, latency: int) -> GroupBuilder:
"""Create a new combinational group with the given name."""
group = ast.StaticGroup(ast.CompVar(name), connections=[], latency=latency)
assert group not in self.component.wires, f"group '{name}' already exists"

self.component.wires.append(group)
builder = GroupBuilder(group, self)
Expand Down Expand Up @@ -253,6 +276,15 @@ def le(self, name: str, size: int, signed=False):
self.prog.import_("primitives/binary_operators.futil")
return self.cell(name, ast.Stdlib.op("le", size, signed))

def and_(self, name: str, size: int) -> CellBuilder:
"""Generate a StdAnd cell."""
return self.cell(name, ast.Stdlib.op("and", size, False))

def pipelined_mult(self, name: str) -> CellBuilder:
"""Generate a pipelined multiplier."""
self.prog.import_("primitives/pipelined.futil")
return self.cell(name, ast.Stdlib.pipelined_mult())


def as_control(obj):
"""Convert a Python object into a control statement.
Expand Down Expand Up @@ -307,13 +339,18 @@ def while_(port: ExprBuilder, cond: Optional[GroupBuilder], body) -> ast.While:
return ast.While(port.expr, cg, as_control(body))


def static_repeat(num_repeats: int, body) -> ast.StaticRepeat:
"""Build a `static repeat` control statement."""
return ast.StaticRepeat(num_repeats, as_control(body))


def if_(
port: ExprBuilder,
cond: Optional[GroupBuilder],
body,
else_body=None,
) -> ast.If:
"""Build an `if` control statement."""
"""Build an `static if` control statement."""
else_body = ast.Empty() if else_body is None else else_body

if cond:
Expand All @@ -326,6 +363,16 @@ def if_(
return ast.If(port.expr, cg, as_control(body), as_control(else_body))


def static_if(
port: ExprBuilder,
body,
else_body=None,
) -> ast.If:
"""Build an `if` control statement."""
else_body = ast.Empty() if else_body is None else else_body
return ast.StaticIf(port.expr, as_control(body), as_control(else_body))


def invoke(cell: CellBuilder, **kwargs) -> ast.Invoke:
"""Build an `invoke` control statement.
Expand Down Expand Up @@ -654,7 +701,7 @@ def infer_width(expr):
return inst.args[0]
elif port_name == "write_en":
return 1
elif prim in ("std_add", "std_lt", "std_eq"):
elif prim in ("std_add", "std_lt", "std_le", "std_ge", "std_gt", "std_eq"):
if port_name == "left" or port_name == "right":
return inst.args[0]
elif prim == "std_mem_d1" or prim == "seq_mem_d1":
Expand Down
101 changes: 100 additions & 1 deletion calyx-py/calyx/py_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Component:
wires: list[Structure]
cells: list[Cell]
controls: Control
latency: Optional[int]

def __init__(
self,
Expand All @@ -58,11 +59,13 @@ def __init__(
outputs: list[PortDef],
structs: list[Structure],
controls: Control,
latency: Optional[int] = None,
):
self.inputs = inputs
self.outputs = outputs
self.name = name
self.controls = controls
self.latency = latency

# Partition cells and wires.
def is_cell(x):
Expand All @@ -82,7 +85,10 @@ def get_cell(self, name: str) -> Cell:
def doc(self) -> str:
ins = ", ".join([s.doc() for s in self.inputs])
outs = ", ".join([s.doc() for s in self.outputs])
signature = f"component {self.name}({ins}) -> ({outs})"
latency_annotation = (
f"static<{self.latency}> " if self.latency is not None else ""
)
signature = f"{latency_annotation}component {self.name}({ins}) -> ({outs})"
cells = block("cells", [c.doc() for c in self.cells])
wires = block("wires", [w.doc() for w in self.wires])
controls = block("control", [self.controls.doc()])
Expand Down Expand Up @@ -219,6 +225,19 @@ def doc(self) -> str:
)


@dataclass
class StaticGroup(Structure):
id: CompVar
connections: list[Connect]
latency: int

def doc(self) -> str:
return block(
f"static<{self.latency}> group {self.id.doc()}",
[c.doc() for c in self.connections],
)


@dataclass
class CompInst(Emittable):
id: str
Expand Down Expand Up @@ -309,6 +328,14 @@ def doc(self) -> str:
return block("seq", [s.doc() for s in self.stmts])


@dataclass
class StaticSeqComp(Control):
stmts: list[Control]

def doc(self) -> str:
return block("static seq", [s.doc() for s in self.stmts])


@dataclass
class ParComp(Control):
stmts: list[Control]
Expand All @@ -317,6 +344,14 @@ def doc(self) -> str:
return block("par", [s.doc() for s in self.stmts])


@dataclass
class StaticParComp(Control):
stmts: list[Control]

def doc(self) -> str:
return block("static par", [s.doc() for s in self.stmts])


@dataclass
class Invoke(Control):
id: CompVar
Expand Down Expand Up @@ -356,6 +391,40 @@ def with_attr(self, key: str, value: int) -> Invoke:
return self


@dataclass
class StaticInvoke(Control):
id: CompVar
in_connects: List[Tuple[str, Port]]
out_connects: List[Tuple[str, Port]]
ref_cells: List[Tuple[str, CompVar]] = field(default_factory=list)
attributes: List[Tuple[str, int]] = field(default_factory=list)

def doc(self) -> str:
inv = f"static invoke {self.id.doc()}"

# Add attributes if present
if len(self.attributes) > 0:
attrs = " ".join([f"@{tag}({val})" for tag, val in self.attributes])
inv = f"{attrs} {inv}"

# Add ref cells if present
if len(self.ref_cells) > 0:
rcs = ", ".join([f"{n}={arg.doc()}" for (n, arg) in self.ref_cells])
inv += f"[{rcs}]"

# Inputs and outputs
in_defs = ", ".join([f"{p}={a.doc()}" for p, a in self.in_connects])
out_defs = ", ".join([f"{p}={a.doc()}" for p, a in self.out_connects])
inv += f"({in_defs})({out_defs})"
inv += ";"

return inv

def with_attr(self, key: str, value: int) -> Invoke:
self.attributes.append((key, value))
return self


@dataclass
class While(Control):
port: Port
Expand All @@ -370,6 +439,16 @@ def doc(self) -> str:
return block(cond, self.body.doc(), sep="")


@dataclass
class StaticRepeat(Control):
num_repeats: int
body: Control

def doc(self) -> str:
cond = f"static repeat {self.num_repeats}"
return block(cond, self.body.doc(), sep="")


@dataclass
class Empty(Control):
def doc(self) -> str:
Expand All @@ -396,6 +475,22 @@ def doc(self) -> str:
return block(cond, true_branch, sep="") + false_branch


@dataclass
class StaticIf(Control):
port: Port
true_branch: Control
false_branch: Control = field(default_factory=Empty)

def doc(self) -> str:
cond = f"static if {self.port.doc()}"
true_branch = self.true_branch.doc()
if isinstance(self.false_branch, Empty):
false_branch = ""
else:
false_branch = block(" else", self.false_branch.doc(), sep="")
return block(cond, true_branch, sep="") + false_branch


# Standard Library
# XXX: This is a funky way to build the standard library. Maybe we can have a
# better "theory of standard library" to figure out what the right way to do
Expand Down Expand Up @@ -532,3 +627,7 @@ def fixed_point_op(
return CompInst(
f'std_fp_{"s" if signed else ""}{op}', [width, int_width, frac_width]
)

@staticmethod
def pipelined_mult():
return CompInst(f"pipelined_mult", [])
59 changes: 59 additions & 0 deletions frontends/systolic-lang/check-output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np
import argparse
import json


if __name__ == "__main__":
"""
This is a script to help you know whether the Calyx's systolic array
generator is giving you the correct answers.
How to use this script: run Calyx's systolic array generator and get an
output json. Then run this script on the output json, and this script
will check the answers against numpy's matrix multiplication implementation.
Command line arguments are (no json support yet):
-tl -td -ll -ld are the same as the systolic array arguments.
-j which is the path to the json you want to check
"""
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("file", nargs="?", type=str)
parser.add_argument("-tl", "--top-length", type=int)
parser.add_argument("-td", "--top-depth", type=int)
parser.add_argument("-ll", "--left-length", type=int)
parser.add_argument("-ld", "--left-depth", type=int)
parser.add_argument("-j", "--json-file", type=str)

args = parser.parse_args()

tl = args.top_length
td = args.top_depth
ll = args.left_length
ld = args.left_depth
json_file = args.json_file

assert td == ld, f"Cannot multiply matrices: " f"{tl}x{td} and {ld}x{ll}"

left = np.zeros((ll, ld), dtype="i")
top = np.zeros((td, tl), dtype="i")
json_data = json.load(open(json_file))["memories"]

for r in range(ll):
for c in range(ld):
left[r][c] = json_data[f"l{r}"][c]

for r in range(td):
for c in range(tl):
top[r][c] = json_data[f"t{c}"][r]

matmul_result = np.matmul(left, top).flatten()

json_result = np.array(json_data["out_mem"])

if np.array_equal(json_result, matmul_result):
print("Correct")
else:
print("Incorrect\n. Should have been:\n")
print(matmul_result)
print("\nBut got:\n")
print(json_result)
Loading

0 comments on commit 507201a

Please sign in to comment.