diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index bef78a799..aedb3d821 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -390,9 +390,6 @@ def eval_function( # type: ignore[override] else: # Python constants are scalars return 0 - elif function.traceable: - # Trace the function call instead of adding the function as a node - return function.function(*args, **kwargs) # args/kwargs are TorchScriptTensor/python built-in based param_schemas = function.param_schemas() @@ -422,6 +419,10 @@ def eval_function( # type: ignore[override] value, float ): attributes[name] = (value,) + if function.traceable: + inputs = self._graph.preprocess_inputs(inputs) + # Trace the function call instead of adding the function as a node + return function.function(*inputs, **attributes) return self._graph.add_function_call(function, inputs, attributes) @@ -730,14 +731,7 @@ def _add_constant_to_graph(self, constant) -> torch.Value: value.setDebugName(_rename_intermediate_value(value.debugName())) return value - @runtime_typing.checked - def _add_torchscript_op_call( - self, - name: str, - onnx_inputs: Sequence[ValidInputType], - onnx_attributes: Mapping[str, ValidArgumentType], - n_outputs: int, - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: + def preprocess_inputs(self, onnx_inputs: Sequence[ValidInputType]) -> List[torch.Value]: unwrapped_inputs = _unwrap_tensors_to_torch_values(onnx_inputs) graph_inputs = [] assert isinstance(unwrapped_inputs, Sequence) @@ -761,6 +755,17 @@ def _add_torchscript_op_call( graph_inputs.append(self._add_constant_to_graph(input)) else: graph_inputs.append(input) + return graph_inputs + + @runtime_typing.checked + def _add_torchscript_op_call( + self, + name: str, + onnx_inputs: Sequence[ValidInputType], + onnx_attributes: Mapping[str, ValidArgumentType], + n_outputs: int, + ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: + graph_inputs = self.preprocess_inputs(onnx_inputs) for key, value in onnx_attributes.items(): assert not isinstance( value, TorchScriptTensor