Skip to content

Commit

Permalink
Rewrite while as for loop
Browse files Browse the repository at this point in the history
  • Loading branch information
BowenBao committed Feb 8, 2024
1 parent 34805b2 commit 93a070b
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,8 @@ def _add_torchscript_op_call(
@runtime_typing.checked
def fetch_opset_imports(self) -> Dict[str, int]:
# TODO: All local function domain versions are hardcoded as 1.
# TODO: Improve the efficiency of this approach.
# This function recursively fetches opset imports from subgraphs.
opset_imports: Dict[str, int] = {}
for subgraph in self._sub_torch_script_graphs.values():
opset_imports.update(subgraph.fetch_opset_imports())
Expand Down Expand Up @@ -940,9 +942,9 @@ def to_function_proto(self, opset_version: int, function_name: str) -> onnx.Func

unique_custom_domains = self.fetch_opset_imports()

while len(onnx_model.opset_import) > 0:
opsetid = onnx_model.opset_import.pop()
unique_custom_domains[opsetid.domain] = opsetid.version
for opset in onnx_model.opset_import:
unique_custom_domains[opset.domain] = opset.version
del onnx_model.opset_import[:]
onnx_model.opset_import.extend(
[
onnx.helper.make_opsetid(domain, version)
Expand Down Expand Up @@ -1038,9 +1040,9 @@ def to_model_proto(
# export opset_imports for nested functions, since it does not have access to
# them. We manually add them back and merge with existing opset_imports in the
# model proto.
while len(onnx_model.opset_import) > 0:
opsetid = onnx_model.opset_import.pop()
unique_custom_domains[opsetid.domain] = opsetid.version
for opset in onnx_model.opset_import:
unique_custom_domains[opset.domain] = opset.version
del onnx_model.opset_import[:]
onnx_model.opset_import.extend(
[
onnx.helper.make_opsetid(domain, version)
Expand Down

0 comments on commit 93a070b

Please sign in to comment.