Skip to content

Commit

Permalink
chore: Initialize shape key as non-empty string to validate no input …
Browse files Browse the repository at this point in the history
…tensor
  • Loading branch information
keehyuna committed Nov 14, 2024
1 parent 377248e commit 0a98180
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct TRTEngine : torch::CustomClassHolder {
at::cuda::CUDAStream caller_stream = c10::cuda::getDefaultCUDAStream();
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key;
std::string shape_key = "None";
bool cudagraphs_enabled = false;
bool use_pre_allocated_outputs = true;
std::vector<at::Tensor> pre_allocated_outputs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def create_output_tensors(self) -> List[torch.Tensor]:
outputs.append(output)
return outputs

def set_output_opt(self, enable: bool) -> None:
def set_pre_allocated_outputs(self, enable: bool) -> None:
self.use_pre_allocated_outputs = enable

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def setup_engine(self) -> None:
if self.engine is not None:
return
self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info())
self.set_output_opt(True)
self.set_pre_allocated_outputs(False)

def encode_metadata(self, metadata: Any) -> str:
metadata = copy.deepcopy(metadata)
Expand Down Expand Up @@ -272,7 +272,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
self.input_binding_names = state[2]
self.output_binding_names = state[3]

def set_output_opt(self, enable: bool) -> None:
def set_pre_allocated_outputs(self, enable: bool) -> None:
self.engine.use_pre_allocated_outputs = enable

def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
Expand Down

0 comments on commit 0a98180

Please sign in to comment.