Skip to content

Commit

Permalink
Correctly construct nested function opset_imports for TorchScriptGrap…
Browse files Browse the repository at this point in the history
…h serialization (#1272)

Fixes #1263
  • Loading branch information
BowenBao authored Feb 9, 2024
1 parent f9cf9ee commit 9923238
Showing 1 changed file with 31 additions and 15 deletions.
46 changes: 31 additions & 15 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,21 @@ def _add_torchscript_op_call(
tensor.name = _rename_intermediate_value(tensor.name)
return tensors

@runtime_typing.checked
def fetch_opset_imports(self) -> Dict[str, int]:
# TODO: All local function domain versions are hardcoded as 1.
# TODO: Improve the efficiency of this approach.
# This function recursively fetches opset imports from subgraphs.
opset_imports: Dict[str, int] = {}
for subgraph in self._sub_torch_script_graphs.values():
opset_imports.update(subgraph.fetch_opset_imports())
if subgraph.domain_name is not None:
opset_imports[subgraph.domain_name] = 1
for function in self._function_store.values():
opset_imports.update(function.function_ir.get_opset_import())
opset_imports[function.function_ir.domain] = 1
return opset_imports

@runtime_typing.checked
def fetch_function_proto_dict(
self, opset_version: int
Expand Down Expand Up @@ -925,6 +940,18 @@ def to_function_proto(self, opset_version: int, function_name: str) -> onnx.Func

onnx_model = onnx.load_from_string(proto)

unique_custom_domains = self.fetch_opset_imports()

for opset in onnx_model.opset_import:
unique_custom_domains[opset.domain] = opset.version
del onnx_model.opset_import[:]
onnx_model.opset_import.extend(
[
onnx.helper.make_opsetid(domain, version)
for domain, version in unique_custom_domains.items()
]
)

# Dissect the model proto and transform to function proto.
domain = self.domain_name
if domain is None:
Expand All @@ -947,11 +974,7 @@ def to_model_proto(
function_proto_dict: Mapping[
Tuple[str, str], onnx.FunctionProto
] = self.fetch_function_proto_dict(opset_version)
unique_custom_domains: Dict[str, int] = {}

for function_proto in function_proto_dict.values():
# TODO(BowenBao): All local function domain versions are hardcoded as 1.
unique_custom_domains[function_proto.domain] = 1
unique_custom_domains = self.fetch_opset_imports()

initializers_size = sum(
_tensor_rawdata_size(tensor) for tensor in self.initializers.values()
Expand Down Expand Up @@ -1017,20 +1040,13 @@ def to_model_proto(
# export opset_imports for nested functions, since it does not have access to
# them. We manually add them back and merge with existing opset_imports in the
# model proto.
while len(onnx_model.opset_import) > 0:
opsetid = onnx_model.opset_import.pop()
unique_custom_domains[opsetid.domain] = opsetid.version
for opset in onnx_model.opset_import:
unique_custom_domains[opset.domain] = opset.version
del onnx_model.opset_import[:]
onnx_model.opset_import.extend(
[
onnx.helper.make_opsetid(domain, version)
for domain, version in unique_custom_domains.items()
]
)
# Include the library shared opset domain
# TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed
onnx_model.opset_import.append(
onnx.helper.make_opsetid(
common_ops.common_opset.domain, common_ops.common_opset.version
)
)
return onnx_model

0 comments on commit 9923238

Please sign in to comment.