Skip to content
This repository has been archived by the owner on May 5, 2024. It is now read-only.

[NO MERGE] hack support memref::subview #10

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions openhls/compiler/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def parfor(**kwargs):

def wrapper(body):
for args in itertools.product(*kwargs):
print(f"{args=}")
idx = tuple(i for arg, i in args)
pe_idx = extend_idx(idx)
state.state.update_current_pe_idx(pe_idx=pe_idx)
Expand Down
22 changes: 12 additions & 10 deletions openhls/compiler/state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

import networkx as nx
from contextlib import contextmanager
from threading import RLock

from openhls.config import VAL_PREFIX, DTYPE, DEBUG, INCLUDE_AUX_DEPS
from openhls.util import extend_idx
Expand All @@ -17,7 +17,6 @@
class State:
_var_count = 0
_op_call_count = 0
op_graph = nx.MultiDiGraph()
cst_map = {}
cst_count = 0
_pe_idx = (0,)
Expand All @@ -26,12 +25,17 @@ class State:
pe_idx_to_most_recent_op_id = {}
op_id_to_pe_idx = {}
pe_deps = set()
rlock = None

def __init__(self, output_file):
self.op_graph.add_nodes_from(
[INPUT_ARG, MEMREF_ARG, GLOBAL_MEMREF_ARG, CONSTANT]
)
self.output_file = output_file
self.rlock = RLock()

@contextmanager
def with_rlock(self):
self.rlock.acquire()
yield
self.rlock.release()

def incr_var(self):
self._var_count += 1
Expand Down Expand Up @@ -70,12 +74,10 @@ def add_op_res(self, v, op):
self.val_source[v] = op

def maybe_add_op(self, op):
if op not in self.op_graph.nodes:
self.op_graph.add_node(op)
pass

def add_edge(self, op, arg, out_v):
val_source = self.get_arg_src(arg)
self.op_graph.add_edge(val_source, op, input=arg, output=out_v, id=op.op_id)
pass

def update_most_recent_pe_idx(self, pe_idx, op):
self.pe_idx_to_most_recent_op_id[pe_idx] = op.op_id
Expand Down
14 changes: 12 additions & 2 deletions openhls/ir/memref.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from dataclasses import dataclass
from typing import Tuple

Expand Down Expand Up @@ -81,9 +82,18 @@ def reduce_add(self):
def reduce_max(self):
return ReduceMax(list(self.registers.flatten()))

def alias(self, other_memref):
def alias(self, other_memref, offsets=None, sizes=None, strides=None):
assert isinstance(other_memref, MemRef)
self.registers = other_memref.registers
if offsets is not None and sizes is not None and strides is not None:
subview = []
for o, si, st in zip(offsets, sizes, strides):
subview.append(slice(o, o + si, st))
print("subview", subview, file=sys.stderr)
print("before subview", self.registers.shape, file=sys.stderr)
self.registers = other_memref.registers[tuple(subview)]
print("aftier subview", self.registers.shape, file=sys.stderr)
else:
self.registers = other_memref.registers


class GlobalMemRef:
Expand Down
1 change: 1 addition & 0 deletions openhls/ir/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def parse_mlir_module(module_str):
value_float = float(str(value).split(":")[0])
csts[res_val] = value_float
else:
vals.add(res_val)
vals.update(set(args))

start_time = reg_start_time.findall(line)
Expand Down
4 changes: 2 additions & 2 deletions openhls/rtl/emit_verilog.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def make_pe_always(fsm, pe, op_datas: list[Op], vals, input_wires, ip_res_val_ma
not_latches = set()
for op in op_datas:
if DEBUG:
tree_conds.append(f"\n\t// {op.emit()} start")
tree_conds.append(f"\n\t// {op.emit()} start"[:100])
ip = getattr(pe, op.type.value, None)
args = op.args
start_time = op.attrs["start_time"]
Expand Down Expand Up @@ -103,7 +103,7 @@ def make_pe_always(fsm, pe, op_datas: list[Op], vals, input_wires, ip_res_val_ma
raise NotImplementedError(str(op))

if DEBUG:
tree_conds.append(f"\t// {op.emit()} end\n")
tree_conds.append(f"\t// {op.emit()} end\n"[:100])

return make_always_tree(tree_conds, not_latches)

Expand Down
35 changes: 22 additions & 13 deletions openhls_translate/EmitHLSPy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class ModuleEmitter : public OpenHLSEmitterBase {
void emitLoad(memref::LoadOp op);
void emitStore(memref::StoreOp op);
void emitMemCpy(memref::CopyOp op);
void emitMemSubview(memref::SubViewOp op);
void emitGlobal(memref::GlobalOp op);
void emitGetGlobal(memref::GetGlobalOp op);
void emitTensorStore(memref::TensorStoreOp op);
Expand Down Expand Up @@ -420,6 +421,7 @@ class StmtVisitor : public HLSVisitorBase<StmtVisitor, bool> {
bool visitOp(memref::StoreOp op) { return emitter.emitStore(op), true; }
bool visitOp(memref::DeallocOp op) { return true; }
bool visitOp(memref::CopyOp op) { return emitter.emitMemCpy(op), true; }
bool visitOp(memref::SubViewOp op) { return emitter.emitMemSubview(op), true; }
bool visitOp(memref::GlobalOp op) { return emitter.emitGlobal(op), true; }
bool visitOp(memref::GetGlobalOp op) {
return emitter.emitGetGlobal(op), true;
Expand Down Expand Up @@ -1169,33 +1171,40 @@ void ModuleEmitter::emitStore(memref::StoreOp op) {
}

void ModuleEmitter::emitMemCpy(memref::CopyOp op) {
// indent() << "memcpy(";
indent() << "";
// emitValue(op.target());
// os << " = ";
emitValue(op.target());
os << ".alias(";
emitValue(op.getSource());
os << ")";
// os << ", ";
os << "\n";
}

// auto type = op.target().getType().cast<MemRefType>();
// os << type.getNumElements() << " * sizeof(" << getTypeName(op.target())
// << "))";
// os << "\n";
void ModuleEmitter::emitMemSubview(memref::SubViewOp op) {
indent() << "";
emitArrayDecl(op.getResult());
os << "\n";
indent() << "";
emitValue(op.result());
os << ".alias(";
emitValue(op.getSource());
os << ", offsets=" << op.getStaticOffsets();
os << ", sizes=" << op.getStaticSizes();
os << ", strides=" << op.getStaticStrides();
os << ")";
os << "\n";
}

void ModuleEmitter::emitGlobal(memref::GlobalOp op) {
auto initial_val = op.initial_value();
auto elem = initial_val->dyn_cast<DenseFPElementsAttr>();
os << op.sym_name().str() << " = np.array([";
for (const auto &item : elem.getValues<FloatAttr>())
os << item.getValueAsDouble() << ", ";
os << "]).reshape(";

os << op.sym_name().str() << " = np.full((";
for (const auto &item : elem.getType().getShape())
os << item << ", ";
os << "), ";
for (const auto &item : elem.getValues<FloatAttr>()) {
os << item.getValueAsDouble();
break;
}
os << ")\n";
}

Expand Down
3 changes: 2 additions & 1 deletion openhls_translate/Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class HLSVisitorBase {
// Memref-related statements.
memref::AllocOp, memref::AllocaOp, memref::LoadOp, memref::StoreOp,
memref::GlobalOp, memref::GetGlobalOp,
memref::DeallocOp, memref::CopyOp, memref::TensorStoreOp,
memref::DeallocOp, memref::CopyOp, memref::SubViewOp, memref::TensorStoreOp,
tensor::ReshapeOp, memref::ReshapeOp, memref::CollapseShapeOp,
memref::ExpandShapeOp, memref::ReinterpretCastOp,
bufferization::ToMemrefOp, bufferization::ToTensorOp,
Expand Down Expand Up @@ -132,6 +132,7 @@ class HLSVisitorBase {
HANDLE(memref::GetGlobalOp);
HANDLE(memref::DeallocOp);
HANDLE(memref::CopyOp);
HANDLE(memref::SubViewOp);
HANDLE(memref::TensorStoreOp);
HANDLE(tensor::ReshapeOp);
HANDLE(memref::ReshapeOp);
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
requires = [
"setuptools>=42",
"wheel",
"cmake==3.21",
"cmake>=3.24",
# MLIR build depends.
"ninja",
"numpy==1.23.1",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ numpy
networkx
astor
jinja2
cocotb==1.6.2
cocotb
matplotlib
xeda
16 changes: 3 additions & 13 deletions scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ if [ ! -f "${OPENHLS_DIR}"/build/llvm/CMakeCache.txt ]; then
-DCMAKE_BUILD_TYPE=DEBUG \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DLLVM_TARGETS_TO_BUILD=host \
-DPython3_FIND_VIRTUALENV=ONLY \
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-S "${OPENHLS_DIR}"/externals/llvm-project/llvm \
-B "${OPENHLS_DIR}"/build/llvm
Expand Down Expand Up @@ -137,7 +138,7 @@ if [ ! -f "${OPENHLS_DIR}"/build/flopoco_converter/CMakeCache.txt ]; then
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DLLVM_TARGETS_TO_BUILD=host \
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-S "${OPENHLS_DIR}"/flopoco_convert_ext \
-S "${OPENHLS_DIR}"/extensions/flopoco_convert_ext \
-B "${OPENHLS_DIR}"/build/flopoco_converter
fi

Expand All @@ -154,15 +155,4 @@ if [ ! -f "${OPENHLS_DIR}"/build/ghdl/bin/ghdl ]; then
mkdir -p "${OPENHLS_DIR}"/build/ghdl
tar -xvf ghdl-gha-ubuntu-20.04-llvm.tgz -C "${OPENHLS_DIR}"/build/ghdl
fi
fi


# TODO
#PYBIND11_DIR=${PREFIX}/lib/python3.10/site-packages/pybind11/share/cmake/
#PYBIND11_DIR=$(python -c "import pybind11; print(pybind11.get_cmake_dir())")
#-DPYTHON_LIBRARY="/Users/mlevental/miniforge3/envs/openhls/lib/libpython3.10.dylib" -DPYTHON_INCLUDE_DIR="/Users/mlevental/miniforge3/envs/openhls/include/python3.10" \

# -DPYTHON_INCLUDE_DIR="$(python -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())")" \
# -DPYTHON_LIBRARY="$(python -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR'))")" \

#-Dpybind11_DIR=/home/mlevental/miniconda3/envs/openhls/lib/python3.10/site-packages/pybind11/share/cmake/pybind11 -DPython_EXECUTABLE=/home/mlevental/miniconda3/envs/openhls/bin/python
fi
9 changes: 1 addition & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,6 @@ def build_torch_mlir(base_cmake_args):
)


def install_torch_mlir_from_wheel():
torch_mlir_wheel = get_latest_torch_mlir()
subprocess.check_call(
[sys.executable, "-m", "pip", "install", torch_mlir_wheel],
cwd=CWD,
)


def build_circt(base_cmake_args):
circt_dir = os.path.join(EXTERNALS, "circt")
circt_build_dir = os.path.join(ROOT_BUILD_DIR, "circt")
Expand Down Expand Up @@ -168,6 +160,7 @@ def build_openhls_translate(base_cmake_args):
f'-DMLIR_DIR={os.path.join(LLVM_BUILD_DIR, "lib", "cmake", "mlir")}',
f'-DLLVM_DIR={os.path.join(LLVM_BUILD_DIR, "lib", "cmake", "llvm")}',
"-DMLIR_ENABLE_BINDINGS_PYTHON=ON",
"-DLLVM_ENABLE_ABI_BREAKING_CHECKS=OFF"
f"-Dpybind11_DIR={pybind11.get_cmake_dir()}",
]
run_cmake(openhls_dir, cmake_args, openhls_build_dir, target="openhls_translate")
Expand Down
Loading