Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms committed Oct 23, 2024
1 parent d4b81dc commit 60dec6b
Showing 1 changed file with 16 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,6 @@ def eval_function( # type: ignore[override]
else:
# Python constants are scalars
return 0
elif function.traceable:
# Trace the function call instead of adding the function as a node
return function.function(*args, **kwargs)

# args/kwargs are TorchScriptTensor/python built-in based
param_schemas = function.param_schemas()
Expand Down Expand Up @@ -422,6 +419,10 @@ def eval_function( # type: ignore[override]
value, float
):
attributes[name] = (value,)
if function.traceable:
inputs = self._graph.preprocess_inputs(inputs)
# Trace the function call instead of adding the function as a node
return function.function(*inputs, **attributes)
return self._graph.add_function_call(function, inputs, attributes)


Expand Down Expand Up @@ -730,14 +731,7 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
value.setDebugName(_rename_intermediate_value(value.debugName()))
return value

@runtime_typing.checked
def _add_torchscript_op_call(
self,
name: str,
onnx_inputs: Sequence[ValidInputType],
onnx_attributes: Mapping[str, ValidArgumentType],
n_outputs: int,
) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]:
def preprocess_inputs(self, onnx_inputs: Sequence[ValidInputType]) -> List[torch.Value]:
unwrapped_inputs = _unwrap_tensors_to_torch_values(onnx_inputs)
graph_inputs = []
assert isinstance(unwrapped_inputs, Sequence)
Expand All @@ -761,6 +755,17 @@ def _add_torchscript_op_call(
graph_inputs.append(self._add_constant_to_graph(input))
else:
graph_inputs.append(input)
return graph_inputs

@runtime_typing.checked
def _add_torchscript_op_call(
self,
name: str,
onnx_inputs: Sequence[ValidInputType],
onnx_attributes: Mapping[str, ValidArgumentType],
n_outputs: int,
) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]:
graph_inputs = self.preprocess_inputs(onnx_inputs)
for key, value in onnx_attributes.items():
assert not isinstance(
value, TorchScriptTensor
Expand Down

0 comments on commit 60dec6b

Please sign in to comment.