Skip to content

Commit

Permalink
device setter
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 25, 2024
1 parent b610212 commit 9923084
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ def device(self) -> torch.device:
"""
return self._device

@device.setter
def device(self, **kwargs):
raise AttributeError("The device attribute is read-only, please use the `to` method to change the device.")

@property
def use_io_binding(self):
return check_io_binding(self.providers, self._use_io_binding)
Expand All @@ -317,13 +321,13 @@ def to(self, device: Union[torch.device, str, int]):
Returns:
`ORTModel`: the model placed on the requested device.
"""

device, provider_options = parse_device(device)

if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
return self

self.device = device
provider = get_provider_for_device(self.device)
provider = get_provider_for_device(device)
validate_provider_availability(provider) # raise error if the provider is not available

# IOBinding is only supported for CPU and CUDA Execution Providers.
Expand All @@ -339,6 +343,7 @@ def to(self, device: Union[torch.device, str, int]):

self.model.set_providers([provider], provider_options=[provider_options])
self.providers = self.model.get_providers()
self._device = device

return self

Expand Down

0 comments on commit 9923084

Please sign in to comment.