diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 340bccd6289a3..9d2b3fc5eb8f7 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -108,7 +108,9 @@ def _reverse_map(d: Dict[Any, Enum]): return {v.value: k for k, v in d.items()} -MetaType = Union[FakeTensor, int, torch.SymInt, bool, torch.SymBool, ep.CustomObjArgument] +MetaType = Union[ + FakeTensor, int, torch.SymInt, bool, torch.SymBool, ep.CustomObjArgument +] ST_DELIMITER = ";" @@ -126,7 +128,7 @@ def _reverse_map(d: Dict[Any, Enum]): torch.complex64: ScalarType.COMPLEXFLOAT, torch.complex128: ScalarType.COMPLEXDOUBLE, torch.bool: ScalarType.BOOL, - torch.bfloat16: ScalarType.BFLOAT16 + torch.bfloat16: ScalarType.BFLOAT16, } @@ -212,7 +214,9 @@ def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt: if s.node.hint is None: return SymInt.create(as_expr=SymExpr(str(s))) else: - return SymInt.create(as_expr=SymExpr(str(s), hint=SymExprHint.create(as_int=s.node.hint))) + return SymInt.create( + as_expr=SymExpr(str(s), hint=SymExprHint.create(as_int=s.node.hint)) + ) else: raise SerializeError( f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`" @@ -252,16 +256,22 @@ def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta: def _reduce_fake_tensor(fake_tensor: FakeTensor): is_parameter = isinstance(fake_tensor, torch.nn.Parameter) tensor_meta = serialize_tensor_meta(fake_tensor) - tensor_meta_bytes = json.dumps(_dataclass_to_dict(tensor_meta), cls=EnumEncoder).encode("utf-8") + tensor_meta_bytes = json.dumps( + _dataclass_to_dict(tensor_meta), cls=EnumEncoder + ).encode("utf-8") return _reconstruct_fake_tensor, (tensor_meta_bytes, is_parameter) -def _reconstruct_fake_tensor(serialized_tensor_meta: bytes, is_parameter: bool) -> FakeTensor: +def _reconstruct_fake_tensor( + serialized_tensor_meta: bytes, is_parameter: bool +) -> FakeTensor: # Deserialize the bytes into a TensorMeta json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8")) tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta) # Find the current fake mode - assert _CURRENT_DESERIALIZER is not None, "Need access to current deserializer state" + assert ( + _CURRENT_DESERIALIZER is not None + ), "Need access to current deserializer state" fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta) if is_parameter: fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment] @@ -269,7 +279,9 @@ def _reconstruct_fake_tensor(serialized_tensor_meta: bytes, is_parameter: bool) def serialize_torch_artifact(artifact: Dict[str, Any]) -> bytes: - assert FakeTensor not in copyreg.dispatch_table, "Refusing to stomp on existing FakeTensor reducer" + assert ( + FakeTensor not in copyreg.dispatch_table + ), "Refusing to stomp on existing FakeTensor reducer" try: copyreg.pickle(FakeTensor, _reduce_fake_tensor) buffer = io.BytesIO() @@ -305,9 +317,7 @@ def _sympy_int_to_int(val: sympy.Expr): return -math.inf if isinstance(val, sympy.Integer): return int(val) - raise RuntimeError( - "Export constraints cannot be non-integer expressions" - ) + raise RuntimeError("Export constraints cannot be non-integer expressions") def _int_to_sympy_int(val) -> sympy.Expr: @@ -370,7 +380,7 @@ class GraphModuleSerializer: def __init__( self, graph_signature: ep.ExportGraphSignature, - module_call_graph: List[ep.ModuleCallEntry] + module_call_graph: List[ep.ModuleCallEntry], ): self.graph_state = GraphState() self.graph_signature = graph_signature @@ -388,17 +398,23 @@ def save_graph_state(self): def handle_placeholder(self, node: torch.fx.Node): assert node.op == "placeholder" - if isinstance(node.meta['val'], torch.Tensor): + if isinstance(node.meta["val"], torch.Tensor): graph_input = Argument.create(as_tensor=TensorArgument(name=node.name)) - self.graph_state.tensor_values[node.name] = serialize_tensor_meta(node.meta["val"]) - elif isinstance(node.meta['val'], torch.SymInt): + self.graph_state.tensor_values[node.name] = serialize_tensor_meta( + node.meta["val"] + ) + elif isinstance(node.meta["val"], torch.SymInt): raise AssertionError("SymInt graph input is not implemented yet.") - elif isinstance(node.meta['val'], (int, bool, str, float, type(None))): - graph_input = self.serialize_input(node.meta['val']) - elif isinstance(node.meta['val'], ep.CustomObjArgument): + elif isinstance(node.meta["val"], (int, bool, str, float, type(None))): + graph_input = self.serialize_input(node.meta["val"]) + elif isinstance(node.meta["val"], ep.CustomObjArgument): class_fqn = node.meta["val"].class_fqn - graph_input = Argument.create(as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn)) - self.graph_state.custom_obj_values[node.name] = self.serialize_script_obj_meta(node.meta["val"]) + graph_input = Argument.create( + as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn) + ) + self.graph_state.custom_obj_values[node.name] = ( + self.serialize_script_obj_meta(node.meta["val"]) + ) else: raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}") self.graph_state.inputs.append(graph_input) @@ -439,7 +455,11 @@ def handle_call_function(self, node: torch.fx.Node): ex_node = Node( target=self.serialize_operator(node.target), inputs=self.serialize_sym_op_inputs(node.target, node.args), - outputs=[Argument.create(as_sym_int=self.serialize_sym_int_output(node.name, meta_val))], + outputs=[ + Argument.create( + as_sym_int=self.serialize_sym_int_output(node.name, meta_val) + ) + ], metadata=self.serialize_metadata(node), ) elif node.target in _SYM_BOOL_OPS: @@ -448,7 +468,11 @@ def handle_call_function(self, node: torch.fx.Node): ex_node = Node( target=self.serialize_operator(node.target), inputs=self.serialize_sym_op_inputs(node.target, node.args), - outputs=[Argument.create(as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val))], + outputs=[ + Argument.create( + as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val) + ) + ], metadata=self.serialize_metadata(node), ) elif isinstance(node.target, torch._ops.OpOverload): @@ -480,6 +504,7 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: ret["stack_trace"] = stack_trace if nn_module_stack := node.meta.get("nn_module_stack"): + def export_nn_module_stack(val): assert isinstance(val, tuple) and len(val) == 2 path, ty = val @@ -503,18 +528,22 @@ def export_nn_module_stack(val): # Serialize to "key,orig_path,type_str" nn_module_list = [ - f"{k},{export_nn_module_stack(v)}" - for k, v in nn_module_stack.items() + f"{k},{export_nn_module_stack(v)}" for k, v in nn_module_stack.items() ] ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list) if source_fn_st := node.meta.get("source_fn_stack"): - source_fn_list = [f"{source_fn[0]},{self.serialize_operator(source_fn[1])}" for source_fn in source_fn_st] + source_fn_list = [ + f"{source_fn[0]},{self.serialize_operator(source_fn[1])}" + for source_fn in source_fn_st + ] ret["source_fn_stack"] = ST_DELIMITER.join(source_fn_list) return ret - def serialize_script_obj_meta(self, script_obj_meta: ep.CustomObjArgument) -> CustomObjArgument: + def serialize_script_obj_meta( + self, script_obj_meta: ep.CustomObjArgument + ) -> CustomObjArgument: return CustomObjArgument( name=script_obj_meta.name, class_fqn=script_obj_meta.class_fqn, @@ -555,7 +584,6 @@ def serialize_inputs( # with default values pass - return serialized_args def serialize_hoo_inputs(self, args, kwargs) -> List[NamedArgument]: @@ -566,28 +594,32 @@ def serialize_hoo_inputs(self, args, kwargs) -> List[NamedArgument]: NamedArgument( name="", arg=self.serialize_input(a), - ) for a in args + ) + for a in args ] - inputs.extend([ - NamedArgument( - name=name, - arg=self.serialize_input(a) - ) for name, a in kwargs.items() - ]) + inputs.extend( + [ + NamedArgument(name=name, arg=self.serialize_input(a)) + for name, a in kwargs.items() + ] + ) return inputs def is_sym_int_arg(self, arg) -> bool: return isinstance(arg, int) or ( - isinstance(arg, torch.fx.Node) and arg.name in self.graph_state.sym_int_values + isinstance(arg, torch.fx.Node) + and arg.name in self.graph_state.sym_int_values ) def is_sym_bool_arg(self, arg) -> bool: return isinstance(arg, bool) or ( - isinstance(arg, torch.fx.Node) and arg.name in self.graph_state.sym_bool_values + isinstance(arg, torch.fx.Node) + and arg.name in self.graph_state.sym_bool_values ) def serialize_input(self, arg) -> Argument: import torch._inductor.ir as inductor_ir + inductor_tensor_buffers = ( inductor_ir.Buffer, inductor_ir.ReinterpretView, @@ -599,20 +631,34 @@ def serialize_input(self, arg) -> Argument: attr = getattr(arg.graph.owning_module, arg.target) if isinstance(attr, torch.Tensor): - raise SerializeError("getattr nodes containing tensors should not appear in the graph") + raise SerializeError( + "getattr nodes containing tensors should not appear in the graph" + ) elif isinstance(attr, torch.fx.GraphModule): with self.save_graph_state(): graph = self.serialize_graph(attr) - return Argument.create(as_graph=GraphArgument(name=arg.target, graph=graph)) + return Argument.create( + as_graph=GraphArgument(name=arg.target, graph=graph) + ) else: - raise SerializeError(f"Unsupported getattr attribute {arg.target} with type: {type(attr)}") + raise SerializeError( + f"Unsupported getattr attribute {arg.target} with type: {type(attr)}" + ) elif self.is_sym_int_arg(arg): - return Argument.create(as_sym_int=SymIntArgument.create(as_name=arg.name)) + return Argument.create( + as_sym_int=SymIntArgument.create(as_name=arg.name) + ) elif self.is_sym_bool_arg(arg): - return Argument.create(as_sym_bool=SymBoolArgument.create(as_name=arg.name)) + return Argument.create( + as_sym_bool=SymBoolArgument.create(as_name=arg.name) + ) else: if isinstance(arg.meta["val"], ep.CustomObjArgument): - return Argument.create(as_custom_obj=CustomObjArgument(name=arg.name, class_fqn=arg.meta["val"].class_fqn)) + return Argument.create( + as_custom_obj=CustomObjArgument( + name=arg.name, class_fqn=arg.meta["val"].class_fqn + ) + ) return Argument.create(as_tensor=TensorArgument(name=arg.name)) elif isinstance(arg, inductor_tensor_buffers): # Other branches are for arguments in fx node. @@ -679,7 +725,9 @@ def serialize_input(self, arg) -> Argument: arguments = [] for a in arg: if a.op == "get_attr": - raise SerializeError("getattr nodes containing tensors should not appear in the graph") + raise SerializeError( + "getattr nodes containing tensors should not appear in the graph" + ) arguments.append(TensorArgument(name=a.name)) return Argument.create(as_tensors=arguments) elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg): @@ -693,6 +741,7 @@ def serialize_optional_tensor_args(a): ) else: raise SerializeError(f"Unsupported list/tuple argument: {a}") + return Argument.create( as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) ) @@ -701,7 +750,9 @@ def serialize_optional_tensor_args(a): return Argument.create( as_tensors=[TensorArgument(name=a.get_name()) for a in arg], ) - elif all(isinstance(a, (*inductor_tensor_buffers, type(None))) for a in arg): + elif all( + isinstance(a, (*inductor_tensor_buffers, type(None))) for a in arg + ): # list of inductor buffers as optional tensors def serialize_optional_tensor_args(a): if a is None: @@ -712,23 +763,28 @@ def serialize_optional_tensor_args(a): ) else: raise SerializeError(f"Unsupported list/tuple argument: {a}") + return Argument.create( as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) ) else: - raise SerializeError(f"Unsupported list/tuple argument type: {[type(a) for a in arg]}") + raise SerializeError( + f"Unsupported list/tuple argument type: {[type(a) for a in arg]}" + ) elif isinstance(arg, torch.dtype): return Argument.create(as_scalar_type=_TORCH_TO_SERIALIZE_DTYPE[arg]) elif isinstance(arg, torch.device): return Argument.create(as_device=Device(type=arg.type, index=arg.index)) elif isinstance(arg, torch.memory_format): - return Argument.create(as_memory_format=_TORCH_TO_SERIALIZE_MEMORY_FORMAT[arg]) + return Argument.create( + as_memory_format=_TORCH_TO_SERIALIZE_MEMORY_FORMAT[arg] + ) elif isinstance(arg, torch.layout): return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg]) elif isinstance(arg, torch._C.ScriptObject): if not ( - arg._has_method("__getstate__") and # type: ignore[attr-defined] - arg._has_method("__setstate__") # type: ignore[attr-defined] + arg._has_method("__getstate__") # type: ignore[attr-defined] + and arg._has_method("__setstate__") # type: ignore[attr-defined] ): raise SerializeError( f"Unable to serialize custom class {arg}. Please define " @@ -741,7 +797,9 @@ def serialize_optional_tensor_args(a): custom_obj_name = f"_custom_obj_{len(self.custom_objs)}" self.custom_objs[custom_obj_name] = arg class_fqn = arg._type().qualified_name() # type: ignore[attr-defined] - return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn)) + return Argument.create( + as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn) + ) elif isinstance(arg, torch._ops.OpOverload): return Argument.create(as_operator=self.serialize_operator(arg)) else: @@ -765,9 +823,7 @@ def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument: def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: if spec.kind == ep.InputKind.USER_INPUT: return InputSpec.create( - user_input=UserInputSpec( - arg=self.serialize_argument_spec(spec.arg) - ) + user_input=UserInputSpec(arg=self.serialize_argument_spec(spec.arg)) ) elif spec.kind == ep.InputKind.PARAMETER: assert spec.target is not None @@ -803,7 +859,9 @@ def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: assert isinstance(spec.arg, ep.CustomObjArgument) return InputSpec.create( custom_obj=InputToCustomObjSpec( - arg=CustomObjArgument(name=spec.arg.name, class_fqn=spec.arg.class_fqn), + arg=CustomObjArgument( + name=spec.arg.name, class_fqn=spec.arg.class_fqn + ), custom_obj_name=spec.target, ) ) @@ -820,16 +878,12 @@ def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec: if spec.kind == ep.OutputKind.USER_OUTPUT: return OutputSpec.create( - user_output=UserOutputSpec( - arg=self.serialize_argument_spec(spec.arg) - ) + user_output=UserOutputSpec(arg=self.serialize_argument_spec(spec.arg)) ) elif spec.kind == ep.OutputKind.LOSS_OUTPUT: assert isinstance(spec.arg, ep.TensorArgument) return OutputSpec.create( - loss_output=LossOutputSpec( - arg=TensorArgument(name=spec.arg.name) - ) + loss_output=LossOutputSpec(arg=TensorArgument(name=spec.arg.name)) ) elif spec.kind == ep.OutputKind.BUFFER_MUTATION: assert spec.target is not None @@ -891,24 +945,39 @@ def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument: elif isinstance(x, ep.ConstantArgument): return self.serialize_input(x.value) elif isinstance(x, ep.CustomObjArgument): - return Argument.create(as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn)) + return Argument.create( + as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn) + ) else: raise AssertionError("TODO") - def serialize_module_call_signature(self, module_call_signature: ep.ModuleCallSignature) -> ModuleCallSignature: + def serialize_module_call_signature( + self, module_call_signature: ep.ModuleCallSignature + ) -> ModuleCallSignature: return ModuleCallSignature( - inputs=[self.serialize_argument_spec(x) for x in module_call_signature.inputs], - outputs=[self.serialize_argument_spec(x) for x in module_call_signature.outputs], + inputs=[ + self.serialize_argument_spec(x) for x in module_call_signature.inputs + ], + outputs=[ + self.serialize_argument_spec(x) for x in module_call_signature.outputs + ], in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION), out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION), ) - def serialize_module_call_graph(self, module_call_graph: List[ep.ModuleCallEntry]) -> List[ModuleCallEntry]: + def serialize_module_call_graph( + self, module_call_graph: List[ep.ModuleCallEntry] + ) -> List[ModuleCallEntry]: return [ ModuleCallEntry( fqn=entry.fqn, - signature=self.serialize_module_call_signature(entry.signature) if entry.signature else None, - ) for entry in module_call_graph + signature=( + self.serialize_module_call_signature(entry.signature) + if entry.signature + else None + ), + ) + for entry in module_call_graph ] def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: @@ -930,7 +999,9 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: mostly reuse the names coming from FX. This function computes a mapping from the FX representation to our representation, preserving the names. """ - assert node.op == "call_function" and isinstance(node.target, torch._ops.OpOverload) + assert node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ) assert isinstance(node.target, torch._ops.OpOverload) returns = node.target._schema.returns @@ -967,7 +1038,9 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: output_arguments = [] for idx, (meta, return_schema) in enumerate(zip(meta_val, returns)): if meta is None: - assert isinstance(return_schema.real_type, (torch.OptionalType, torch.TensorType)) + assert isinstance( + return_schema.real_type, (torch.OptionalType, torch.TensorType) + ) # When the return type is annoated as Tensor type, the op can also return an # undefined Tensor which will be implicitly converted to None in Python. output_arguments.append(Argument.create(as_none=())) @@ -1008,7 +1081,9 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: ) output_arguments.append(self.serialize_output(name, meta)) else: - raise ValueError(f"Unhandled output type {type(meta)} from node {node.format_node()}") + raise ValueError( + f"Unhandled output type {type(meta)} from node {node.format_node()}" + ) return output_arguments @@ -1076,13 +1151,19 @@ def serialize_output(self, name: str, meta_val: Any) -> Argument: return Argument.create(as_none=()) if isinstance(meta_val, torch.Tensor): # e.g "-> Tensor" - return Argument.create(as_tensor=self.serialize_tensor_output(name, meta_val)) + return Argument.create( + as_tensor=self.serialize_tensor_output(name, meta_val) + ) elif isinstance(meta_val, (int, torch.SymInt)): # e.g "-> SymInt" - return Argument.create(as_sym_int=self.serialize_sym_int_output(name, meta_val)) + return Argument.create( + as_sym_int=self.serialize_sym_int_output(name, meta_val) + ) elif isinstance(meta_val, torch.SymBool): # e.g "-> SymBool" - return Argument.create(as_sym_bool=self.serialize_sym_bool_output(name, meta_val)) + return Argument.create( + as_sym_bool=self.serialize_sym_bool_output(name, meta_val) + ) # list outputs should've been handled earlier raise SerializeError(f"Unable to serialize output {meta_val}") @@ -1092,7 +1173,9 @@ def _handle_getitem_users(self, node: torch.fx.Node) -> List[TensorArgument]: idx_to_name = {} for user in node.users: - assert user.target is operator.getitem, f"User node {user} of {node} is incorrect" + assert ( + user.target is operator.getitem + ), f"User node {user} of {node} is incorrect" idx_to_name[user.args[1]] = user.name for idx, _ in enumerate(meta_val): @@ -1116,7 +1199,9 @@ def serialize_graph(self, graph_module: torch.fx.GraphModule) -> Graph: try: getattr(self, f"handle_{node.op}")(node) except Exception as e: - raise SerializeError(f"Failed serializing node {node} in graph: {node.format_node()}") from e + raise SerializeError( + f"Failed serializing node {node} in graph: {node.format_node()}" + ) from e return Graph( inputs=self.graph_state.inputs, @@ -1155,11 +1240,12 @@ def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram: exported_program._validate() gm_serializer = GraphModuleSerializer( - exported_program.graph_signature, - exported_program.module_call_graph + exported_program.graph_signature, exported_program.module_call_graph ) serialized_graph_module = gm_serializer.serialize(exported_program.graph_module) - serialized_range_constraints = serialize_range_constraints(exported_program.range_constraints) + serialized_range_constraints = serialize_range_constraints( + exported_program.range_constraints + ) # TODO: Directly serialize exported_program.constants once # CustomClassHolders get stored in the ExportedProgram rather than in @@ -1210,7 +1296,12 @@ def __init__(self): @contextmanager def save_graph_module(self) -> Iterator[None]: - saved = self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta + saved = ( + self.graph, + self.module, + self.serialized_name_to_node, + self.serialized_name_to_meta, + ) self.graph = torch.fx.Graph() self.module = torch.nn.Module() self.serialized_name_to_node = {} @@ -1218,10 +1309,17 @@ def save_graph_module(self) -> Iterator[None]: try: yield finally: - self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta = saved + ( + self.graph, + self.module, + self.serialized_name_to_node, + self.serialized_name_to_meta, + ) = saved def deserialize_operator(self, serialized_target: str): - if serialized_target.startswith("_operator"): # TODO(zhxchen17) Follow up on this. + if serialized_target.startswith( + "_operator" + ): # TODO(zhxchen17) Follow up on this. module = operator serialized_target_names = serialized_target.split(".")[1:] elif serialized_target.startswith("torch"): @@ -1258,7 +1356,9 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: # Here we force symbols corresponding to SymInts to be at least integers. # Otherwise some expressions that the shape env would otherwise evaluate to False, # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. - sym = sym.subs({s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols}) + sym = sym.subs( + {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols} + ) if isinstance(sym, sympy.Symbol): self.symbol_name_to_symbol[val.expr_str] = sym if hint is not None: @@ -1325,7 +1425,9 @@ def deserialize_tensor_meta( ), ) - def deserialize_script_obj_meta(self, script_obj_meta: CustomObjArgument) -> ep.CustomObjArgument: + def deserialize_script_obj_meta( + self, script_obj_meta: CustomObjArgument + ) -> ep.CustomObjArgument: return ep.CustomObjArgument( name=script_obj_meta.name, class_fqn=script_obj_meta.class_fqn, @@ -1351,10 +1453,14 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: self.serialized_name_to_meta[name] = self.deserialize_sym_int(sym_int_value) for name, sym_bool_value in serialized_graph.sym_bool_values.items(): - self.serialized_name_to_meta[name] = self.deserialize_sym_bool(sym_bool_value) + self.serialized_name_to_meta[name] = self.deserialize_sym_bool( + sym_bool_value + ) for name, script_obj_meta in serialized_graph.custom_obj_values.items(): - self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta(script_obj_meta) + self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta( + script_obj_meta + ) # Inputs: convert to placeholder nodes in FX. for i, input_ in enumerate(serialized_graph.inputs): @@ -1362,7 +1468,13 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: node_name = input_.value.name placeholder_node = self.graph.placeholder(node_name) self.sync_fx_node(node_name, placeholder_node) - elif input_.type in ("as_int", "as_float", "as_bool", "as_none", "as_string"): + elif input_.type in ( + "as_int", + "as_float", + "as_bool", + "as_none", + "as_string", + ): node_name = f"arg{i}" placeholder_node = self.graph.placeholder(node_name) placeholder_node.meta["val"] = self.deserialize_input(input_) @@ -1376,7 +1488,9 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: self.deserialize_node(serialized_node, target) except Exception as e: - raise SerializeError(f"Failed deserializing node {serialized_node}") from e + raise SerializeError( + f"Failed deserializing node {serialized_node}" + ) from e # Outputs: convert to a single `output` node. outputs = [] @@ -1417,7 +1531,8 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None: # HOPs don't have schema yet, just check the output lengths and as_tensor attribute name = ( serialized_node.outputs[0].as_tensor.name - if len(serialized_node.outputs) == 1 and hasattr(serialized_node.outputs[0], "as_tensor") + if len(serialized_node.outputs) == 1 + and hasattr(serialized_node.outputs[0], "as_tensor") else None ) fx_node = self.graph.create_node( @@ -1436,10 +1551,14 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None: else None # FX will generate a name for us. ) args, kwargs = self.deserialize_inputs(target, serialized_node) - fx_node = self.graph.create_node("call_function", target, args, kwargs, name) + fx_node = self.graph.create_node( + "call_function", target, args, kwargs, name + ) self.deserialize_outputs(serialized_node, fx_node) else: - raise SerializeError(f"Unsupported target type for node {serialized_node}: {target}") + raise SerializeError( + f"Unsupported target type for node {serialized_node}: {target}" + ) fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) @@ -1448,7 +1567,7 @@ def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec: return ep.InputSpec( kind=ep.InputKind.USER_INPUT, arg=self.deserialize_argument_spec(i.user_input.arg), - target=None + target=None, ) elif i.type == "parameter": return ep.InputSpec( @@ -1472,7 +1591,9 @@ def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec: elif i.type == "custom_obj": return ep.InputSpec( kind=ep.InputKind.CUSTOM_OBJ, - arg=ep.CustomObjArgument(name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn), + arg=ep.CustomObjArgument( + name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn + ), target=i.custom_obj.custom_obj_name, ) if i.type == "token": @@ -1501,25 +1622,25 @@ def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec: return ep.OutputSpec( kind=ep.OutputKind.BUFFER_MUTATION, arg=ep.TensorArgument(name=o.buffer_mutation.arg.name), - target=o.buffer_mutation.buffer_name + target=o.buffer_mutation.buffer_name, ) elif o.type == "gradient_to_parameter": return ep.OutputSpec( kind=ep.OutputKind.GRADIENT_TO_PARAMETER, arg=ep.TensorArgument(name=o.gradient_to_parameter.arg.name), - target=o.gradient_to_parameter.parameter_name + target=o.gradient_to_parameter.parameter_name, ) elif o.type == "gradient_to_user_input": return ep.OutputSpec( kind=ep.OutputKind.GRADIENT_TO_USER_INPUT, arg=ep.TensorArgument(name=o.gradient_to_user_input.arg.name), - target=o.gradient_to_user_input.user_input_name + target=o.gradient_to_user_input.user_input_name, ) elif o.type == "user_input_mutation": return ep.OutputSpec( kind=ep.OutputKind.USER_INPUT_MUTATION, arg=ep.TensorArgument(name=o.user_input_mutation.arg.name), - target=o.user_input_mutation.user_input_name + target=o.user_input_mutation.user_input_name, ) elif o.type == "token": return ep.OutputSpec( @@ -1533,7 +1654,7 @@ def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec: def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature: return ep.ExportGraphSignature( input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs], - output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs] + output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs], ) def deserialize( @@ -1554,14 +1675,22 @@ def deserialize( shape_env=self.shape_env, ) self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {} - self.symbol_name_to_range = {} if symbol_name_to_range is None else symbol_name_to_range - self.signature = self.deserialize_signature(serialized_graph_module.signature) + self.symbol_name_to_range = ( + {} if symbol_name_to_range is None else symbol_name_to_range + ) + self.signature = self.deserialize_signature( + serialized_graph_module.signature + ) self.constants = deserialize_torch_artifact(constants) self.deserialize_graph(serialized_graph_module.graph) - module_call_graph = self.deserialize_module_call_graph(serialized_graph_module.module_call_graph) + module_call_graph = self.deserialize_module_call_graph( + serialized_graph_module.module_call_graph + ) return GraphModuleDeserializer.Result( - graph_module=ep._create_graph_module_for_export(self.module, self.graph), + graph_module=ep._create_graph_module_for_export( + self.module, self.graph + ), signature=self.signature, module_call_graph=module_call_graph, names_to_symbols=self.symbol_name_to_symbol, @@ -1584,12 +1713,15 @@ def deserialize_sym_op_inputs(self, inputs): def deserialize_inputs(self, target: torch._ops.OpOverload, serialized_node: Node): schema_args = target._schema.arguments actual_args = { - input.name: self.deserialize_input(input.arg) for input in serialized_node.inputs + input.name: self.deserialize_input(input.arg) + for input in serialized_node.inputs } args = [] kwargs = {} for schema_arg in schema_args: - is_positional = not schema_arg.has_default_value() and not schema_arg.kwarg_only + is_positional = ( + not schema_arg.has_default_value() and not schema_arg.kwarg_only + ) if is_positional: args.append(actual_args[schema_arg.name]) else: @@ -1664,6 +1796,7 @@ def deserialize_input(self, inp: Argument) -> Any: elif typ_ in ("as_sym_ints", "as_sym_bools"): return [self.deserialize_sym_argument(arg) for arg in value] elif typ_ == "as_optional_tensors": + def deserialize_optional_tensor_args(a): if a.type == "as_none": return None @@ -1671,6 +1804,7 @@ def deserialize_optional_tensor_args(a): return self.serialized_name_to_node[a.value.name] else: raise SerializeError(f"Unhandled argument {inp}") + return list(map(deserialize_optional_tensor_args, value)) else: raise SerializeError(f"Unhandled argument {inp}") @@ -1710,25 +1844,33 @@ def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): ): self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node) return - elif ( - len(serialized_node.outputs) == 1 and - isinstance(serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument)) + elif len(serialized_node.outputs) == 1 and isinstance( + serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument) ): self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) return self.deserialize_multiple_outputs(serialized_node, fx_node) - def deserialize_multiple_outputs(self, serialized_node: Node, fx_node: torch.fx.Node) -> None: + def deserialize_multiple_outputs( + self, serialized_node: Node, fx_node: torch.fx.Node + ) -> None: deserialized_metadata = self.deserialize_metadata(serialized_node.metadata) - def generate_getitem(meta_val, fx_node: torch.fx.Node, arg: Union[TensorArgument, SymIntArgument], idx: int): + def generate_getitem( + meta_val, + fx_node: torch.fx.Node, + arg: Union[TensorArgument, SymIntArgument], + idx: int, + ): if isinstance(arg, TensorArgument): name = arg.name elif isinstance(arg, SymIntArgument): name = arg.as_name else: - raise AssertionError(f"generate_getitem got unknown argument type {type(arg)}") + raise AssertionError( + f"generate_getitem got unknown argument type {type(arg)}" + ) individual_output = self.graph.create_node( "call_function", operator.getitem, @@ -1756,7 +1898,7 @@ def generate_getitems(meta_val, fx_node: torch.fx.Node, args): meta_val.append([]) generate_getitems(meta_val[-1], list_output, arg) list_output.meta.update(deserialized_metadata) - list_output.meta['val'] = meta_val[-1] + list_output.meta["val"] = meta_val[-1] else: raise NotImplementedError(f"Unimplemented node output type: {arg}") @@ -1806,6 +1948,7 @@ def deserialize_meta_func(serialized_target: str): # Originally serialized to "key,orig_path,type_str" def import_nn_module_stack(key, path, ty): return key, (path, ty) + nn_module_stack = dict( import_nn_module_stack(*item.split(",")) for item in nn_module_stack_str.split(ST_DELIMITER) @@ -1829,20 +1972,33 @@ def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec: else: return ep.ConstantArgument(value=self.deserialize_input(x)) - def deserialize_module_call_signature(self, module_call_signature: ModuleCallSignature) -> ep.ModuleCallSignature: + def deserialize_module_call_signature( + self, module_call_signature: ModuleCallSignature + ) -> ep.ModuleCallSignature: return ep.ModuleCallSignature( - inputs=[self.deserialize_argument_spec(x) for x in module_call_signature.inputs], - outputs=[self.deserialize_argument_spec(x) for x in module_call_signature.outputs], + inputs=[ + self.deserialize_argument_spec(x) for x in module_call_signature.inputs + ], + outputs=[ + self.deserialize_argument_spec(x) for x in module_call_signature.outputs + ], in_spec=treespec_loads(module_call_signature.in_spec), out_spec=treespec_loads(module_call_signature.out_spec), ) - def deserialize_module_call_graph(self, module_call_graph: List[ModuleCallEntry]) -> List[ep.ModuleCallEntry]: + def deserialize_module_call_graph( + self, module_call_graph: List[ModuleCallEntry] + ) -> List[ep.ModuleCallEntry]: return [ ep.ModuleCallEntry( fqn=entry.fqn, - signature=self.deserialize_module_call_signature(entry.signature) if entry.signature else None, - ) for entry in module_call_graph + signature=( + self.deserialize_module_call_signature(entry.signature) + if entry.signature + else None + ), + ) + for entry in module_call_graph ] @@ -1882,25 +2038,27 @@ def deserialize( ) symbol_name_to_range = { - k: symbolic_shapes.ValueRanges(_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val)) + k: symbolic_shapes.ValueRanges( + _int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val) + ) for k, v in exported_program.range_constraints.items() } - res = ( - GraphModuleDeserializer() - .deserialize( - exported_program.graph_module, - state_dict, - constants, - symbol_name_to_range, - ) + res = GraphModuleDeserializer().deserialize( + exported_program.graph_module, + state_dict, + constants, + symbol_name_to_range, ) range_constraints = self.deserialize_range_constraints( - symbol_name_to_range, res.names_to_symbols, + symbol_name_to_range, + res.names_to_symbols, ) model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version self._validate_model_opset_version(model_opset_version) - upgrader = GraphModuleOpUpgrader(self.expected_opset_version, model_opset_version) + upgrader = GraphModuleOpUpgrader( + self.expected_opset_version, model_opset_version + ) exported_program = ep.ExportedProgram( root=res.graph_module, @@ -1915,7 +2073,9 @@ def deserialize( ) return upgrader.upgrade(exported_program) - def _validate_model_opset_version(self, model_opset_version: Optional[Dict[str, int]]): + def _validate_model_opset_version( + self, model_opset_version: Optional[Dict[str, int]] + ): """Compare model_opset_version with expected_opset_version and raise error if we can't resolve the version difference. E.g., model_opset_version = {"aten": 3, "custom": 4} @@ -1934,14 +2094,16 @@ def _validate_model_opset_version(self, model_opset_version: Optional[Dict[str, """ if not model_opset_version: raise RuntimeError("Serialized model should have opset version.") - common_namespaces = {key for key in model_opset_version if key in self.expected_opset_version} + common_namespaces = { + key for key in model_opset_version if key in self.expected_opset_version + } for namespace in common_namespaces: - assert ( - isinstance(model_version := model_opset_version[namespace], int) + assert isinstance( + model_version := model_opset_version[namespace], int ), f"model_opset_version value should be int, got {model_opset_version[namespace]}" - assert ( - isinstance(compiler_version := self.expected_opset_version[namespace], int) + assert isinstance( + compiler_version := self.expected_opset_version[namespace], int ), f"expected_opset_version value should be int, got {self.expected_opset_version[namespace]}" # TODO(larryliu0820): Add support for upgrader & downgrader @@ -1953,7 +2115,10 @@ def _validate_model_opset_version(self, model_opset_version: Optional[Dict[str, for namespace in model_opset_version: if namespace in common_namespaces: continue - log.warning("Compiler doesn't have a version table for op namespace: {ns}. ", extra={"ns": namespace}) + log.warning( + "Compiler doesn't have a version table for op namespace: {ns}. ", + extra={"ns": namespace}, + ) class EnumEncoder(json.JSONEncoder): @@ -1961,7 +2126,7 @@ def default(self, obj): if isinstance(obj, Enum): return obj.value if isinstance(obj, bytes): - return base64.b64encode(obj).decode('utf-8') + return base64.b64encode(obj).decode("utf-8") return super().default(obj) @@ -1988,19 +2153,17 @@ def serialize( exported_program: ep.ExportedProgram, opset_version: Optional[Dict[str, int]] = None, ) -> SerializedArtifact: - serialized_program = ( - ExportedProgramSerializer(opset_version).serialize(exported_program) + serialized_program = ExportedProgramSerializer(opset_version).serialize( + exported_program ) assert isinstance(serialized_program.exported_program, ExportedProgram) json_program = json.dumps( _dataclass_to_dict(serialized_program.exported_program), cls=EnumEncoder ) - json_bytes = json_program.encode('utf-8') + json_bytes = json_program.encode("utf-8") artifact = SerializedArtifact( - json_bytes, - serialized_program.state_dict, - serialized_program.constants + json_bytes, serialized_program.state_dict, serialized_program.constants ) return artifact @@ -2033,16 +2196,10 @@ def _dict_to_dataclass(cls, data): if len(data) == 0: return data d_type = typing.get_args(cls)[0] - return [ - _dict_to_dataclass(d_type, d) - for d in data - ] + return [_dict_to_dataclass(d_type, d) for d in data] elif isinstance(data, dict): v_type = typing.get_args(cls)[1] - return { - k: _dict_to_dataclass(v_type, v) - for k, v in data.items() - } + return {k: _dict_to_dataclass(v_type, v) for k, v in data.items()} return data @@ -2051,20 +2208,19 @@ def deserialize( expected_opset_version: Optional[Dict[str, int]] = None, ) -> ep.ExportedProgram: assert isinstance(artifact.exported_program, bytes) - exported_program_str = artifact.exported_program.decode('utf-8') + exported_program_str = artifact.exported_program.decode("utf-8") exported_program_dict = json.loads(exported_program_str) - serialized_exported_program = _dict_to_dataclass(ExportedProgram, exported_program_dict) - return ( - ExportedProgramDeserializer(expected_opset_version) - .deserialize( - serialized_exported_program, - artifact.state_dict, - artifact.constants - ) + serialized_exported_program = _dict_to_dataclass( + ExportedProgram, exported_program_dict + ) + return ExportedProgramDeserializer(expected_opset_version).deserialize( + serialized_exported_program, artifact.state_dict, artifact.constants ) -def _canonicalize_graph(sorted_inputs, sorted_outputs, graph) -> Tuple[Graph, Dict[str, str]]: +def _canonicalize_graph( + sorted_inputs, sorted_outputs, graph +) -> Tuple[Graph, Dict[str, str]]: def _get_argument(a: Argument): if a.type == "as_none": return None @@ -2156,13 +2312,15 @@ def get_name(a) -> Optional[str]: raise AssertionError(f"Unknown argument type: {a}") for i in sorted_inputs: + def add_input(a): if s := get_name(a): graph_inputs.add(s) - for_args(add_input , i) + for_args(add_input, i) for idx, node in enumerate(nodes): + def add_def(a): if s := get_name(a): assert s not in def_table @@ -2174,6 +2332,7 @@ def add_def(a): edges[idx] = Edges([], 0) for idx, user in enumerate(nodes): + def add_edge(a): if s := get_name(a): if s not in def_table: @@ -2205,6 +2364,7 @@ def get_ranks(i): ranks = [] for_args(lambda x: ranks.append(get_rank(x)), i) return ranks + node = nodes[idx] args_rank = [(a.name, get_ranks(a.arg)) for a in node.inputs] heapq.heappush(candidates, (node.target, args_rank, idx)) @@ -2295,8 +2455,12 @@ def replace_use(a): # Stage 4: Aggregate values. sorted_tensor_values = dict(sorted(graph.tensor_values.items(), key=lambda x: x[0])) - sorted_sym_int_values = dict(sorted(graph.sym_int_values.items(), key=lambda x: x[0])) - sorted_sym_bool_values = dict(sorted(graph.sym_bool_values.items(), key=lambda x: x[0])) + sorted_sym_int_values = dict( + sorted(graph.sym_int_values.items(), key=lambda x: x[0]) + ) + sorted_sym_bool_values = dict( + sorted(graph.sym_bool_values.items(), key=lambda x: x[0]) + ) # Stage 5: Recurse in subgraphs. counter = 0 @@ -2305,9 +2469,7 @@ def replace_use(a): a = i.arg if a.type == "as_graph": a.as_graph.graph = _canonicalize_graph( - a.as_graph.graph.inputs, - a.as_graph.graph.outputs, - a.as_graph.graph + a.as_graph.graph.inputs, a.as_graph.graph.outputs, a.as_graph.graph ) a.as_graph.name = f"_g{counter}" counter += 1 @@ -2392,13 +2554,19 @@ def rank_output(out) -> Tuple[int, Optional[str], int]: else: raise AssertionError(f"Unknown output type: {spec}") - sorted_ins = sorted(enumerate(zip(graph.inputs, signature.input_specs)), key=rank_input) + sorted_ins = sorted( + enumerate(zip(graph.inputs, signature.input_specs)), key=rank_input + ) sorted_inputs, input_specs = zip(*(i for idx, i in sorted_ins)) # type: ignore[assignment] - sorted_outs = sorted(enumerate(zip(graph.outputs, signature.output_specs)), key=rank_output) + sorted_outs = sorted( + enumerate(zip(graph.outputs, signature.output_specs)), key=rank_output + ) sorted_outputs, output_specs = zip(*(i for idx, i in sorted_outs)) # type: ignore[assignment] - sorted_graph, replace_table = _canonicalize_graph(sorted_inputs, sorted_outputs, graph) + sorted_graph, replace_table = _canonicalize_graph( + sorted_inputs, sorted_outputs, graph + ) def replace_input(inp): assert isinstance(spec, InputSpec) @@ -2415,7 +2583,13 @@ def replace_input(inp): pass else: raise AssertionError(f"Unknown sym_int type: {s}") - elif arg.type in ("as_none", "as_int", "as_float", "as_string", "as_custom_obj"): + elif arg.type in ( + "as_none", + "as_int", + "as_float", + "as_string", + "as_custom_obj", + ): return else: raise AssertionError(f"Unknown input type: {arg}")