From 93a070b0d47e587269c49c87477f84bd816c2de0 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 8 Feb 2024 23:46:16 +0000 Subject: [PATCH] Rewrite while as for loop --- .../function_libs/torch_lib/graph_building.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index dc0af0362..057e1a922 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -734,6 +734,8 @@ def _add_torchscript_op_call( @runtime_typing.checked def fetch_opset_imports(self) -> Dict[str, int]: # TODO: All local function domain versions are hardcoded as 1. + # TODO: Improve the efficiency of this approach. + # This function recursively fetches opset imports from subgraphs. opset_imports: Dict[str, int] = {} for subgraph in self._sub_torch_script_graphs.values(): opset_imports.update(subgraph.fetch_opset_imports()) @@ -940,9 +942,9 @@ def to_function_proto(self, opset_version: int, function_name: str) -> onnx.Func unique_custom_domains = self.fetch_opset_imports() - while len(onnx_model.opset_import) > 0: - opsetid = onnx_model.opset_import.pop() - unique_custom_domains[opsetid.domain] = opsetid.version + for opset in onnx_model.opset_import: + unique_custom_domains[opset.domain] = opset.version + del onnx_model.opset_import[:] onnx_model.opset_import.extend( [ onnx.helper.make_opsetid(domain, version) @@ -1038,9 +1040,9 @@ def to_model_proto( # export opset_imports for nested functions, since it does not have access to # them. We manually add them back and merge with existing opset_imports in the # model proto. - while len(onnx_model.opset_import) > 0: - opsetid = onnx_model.opset_import.pop() - unique_custom_domains[opsetid.domain] = opsetid.version + for opset in onnx_model.opset_import: + unique_custom_domains[opset.domain] = opset.version + del onnx_model.opset_import[:] onnx_model.opset_import.extend( [ onnx.helper.make_opsetid(domain, version)