diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index dc0af0362..057e1a922 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -734,6 +734,8 @@ def _add_torchscript_op_call( @runtime_typing.checked def fetch_opset_imports(self) -> Dict[str, int]: # TODO: All local function domain versions are hardcoded as 1. + # TODO: Improve the efficiency of this approach. + # This function recursively fetches opset imports from subgraphs. opset_imports: Dict[str, int] = {} for subgraph in self._sub_torch_script_graphs.values(): opset_imports.update(subgraph.fetch_opset_imports()) @@ -940,9 +942,9 @@ def to_function_proto(self, opset_version: int, function_name: str) -> onnx.Func unique_custom_domains = self.fetch_opset_imports() - while len(onnx_model.opset_import) > 0: - opsetid = onnx_model.opset_import.pop() - unique_custom_domains[opsetid.domain] = opsetid.version + for opset in onnx_model.opset_import: + unique_custom_domains[opset.domain] = opset.version + del onnx_model.opset_import[:] onnx_model.opset_import.extend( [ onnx.helper.make_opsetid(domain, version) @@ -1038,9 +1040,9 @@ def to_model_proto( # export opset_imports for nested functions, since it does not have access to # them. We manually add them back and merge with existing opset_imports in the # model proto. - while len(onnx_model.opset_import) > 0: - opsetid = onnx_model.opset_import.pop() - unique_custom_domains[opsetid.domain] = opsetid.version + for opset in onnx_model.opset_import: + unique_custom_domains[opset.domain] = opset.version + del onnx_model.opset_import[:] onnx_model.opset_import.extend( [ onnx.helper.make_opsetid(domain, version)