diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index aa318d523..057e1a922 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -731,6 +731,21 @@ def _add_torchscript_op_call( tensor.name = _rename_intermediate_value(tensor.name) return tensors + @runtime_typing.checked + def fetch_opset_imports(self) -> Dict[str, int]: + # TODO: All local function domain versions are hardcoded as 1. + # TODO: Improve the efficiency of this approach. + # This function recursively fetches opset imports from subgraphs. + opset_imports: Dict[str, int] = {} + for subgraph in self._sub_torch_script_graphs.values(): + opset_imports.update(subgraph.fetch_opset_imports()) + if subgraph.domain_name is not None: + opset_imports[subgraph.domain_name] = 1 + for function in self._function_store.values(): + opset_imports.update(function.function_ir.get_opset_import()) + opset_imports[function.function_ir.domain] = 1 + return opset_imports + @runtime_typing.checked def fetch_function_proto_dict( self, opset_version: int @@ -925,6 +940,18 @@ def to_function_proto(self, opset_version: int, function_name: str) -> onnx.Func onnx_model = onnx.load_from_string(proto) + unique_custom_domains = self.fetch_opset_imports() + + for opset in onnx_model.opset_import: + unique_custom_domains[opset.domain] = opset.version + del onnx_model.opset_import[:] + onnx_model.opset_import.extend( + [ + onnx.helper.make_opsetid(domain, version) + for domain, version in unique_custom_domains.items() + ] + ) + # Dissect the model proto and transform to function proto. domain = self.domain_name if domain is None: @@ -947,11 +974,7 @@ def to_model_proto( function_proto_dict: Mapping[ Tuple[str, str], onnx.FunctionProto ] = self.fetch_function_proto_dict(opset_version) - unique_custom_domains: Dict[str, int] = {} - - for function_proto in function_proto_dict.values(): - # TODO(BowenBao): All local function domain versions are hardcoded as 1. - unique_custom_domains[function_proto.domain] = 1 + unique_custom_domains = self.fetch_opset_imports() initializers_size = sum( _tensor_rawdata_size(tensor) for tensor in self.initializers.values() @@ -1017,20 +1040,13 @@ def to_model_proto( # export opset_imports for nested functions, since it does not have access to # them. We manually add them back and merge with existing opset_imports in the # model proto. - while len(onnx_model.opset_import) > 0: - opsetid = onnx_model.opset_import.pop() - unique_custom_domains[opsetid.domain] = opsetid.version + for opset in onnx_model.opset_import: + unique_custom_domains[opset.domain] = opset.version + del onnx_model.opset_import[:] onnx_model.opset_import.extend( [ onnx.helper.make_opsetid(domain, version) for domain, version in unique_custom_domains.items() ] ) - # Include the library shared opset domain - # TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed - onnx_model.opset_import.append( - onnx.helper.make_opsetid( - common_ops.common_opset.domain, common_ops.common_opset.version - ) - ) return onnx_model