From 99230847cec6520bad64dc2415cb85e561a63440 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 25 Jul 2024 13:56:14 +0200 Subject: [PATCH] device setter --- optimum/onnxruntime/modeling_ort.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index c7fa287a88..b24f74a6f5 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -297,6 +297,10 @@ def device(self) -> torch.device: """ return self._device + @device.setter + def device(self, **kwargs): + raise AttributeError("The device attribute is read-only, please use the `to` method to change the device.") + @property def use_io_binding(self): return check_io_binding(self.providers, self._use_io_binding) @@ -317,13 +321,13 @@ def to(self, device: Union[torch.device, str, int]): Returns: `ORTModel`: the model placed on the requested device. """ + device, provider_options = parse_device(device) if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider": return self - self.device = device - provider = get_provider_for_device(self.device) + provider = get_provider_for_device(device) validate_provider_availability(provider) # raise error if the provider is not available # IOBinding is only supported for CPU and CUDA Execution Providers. @@ -339,6 +343,7 @@ def to(self, device: Union[torch.device, str, int]): self.model.set_providers([provider], provider_options=[provider_options]) self.providers = self.model.get_providers() + self._device = device return self