Skip to content

Commit

Permalink
fix device attribution
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 25, 2024
1 parent 0cb6be7 commit 88831a5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
8 changes: 7 additions & 1 deletion optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,14 @@ def to(self, device: Union[torch.device, str, int]):
Returns:
`ORTModel`: the model placed on the requested device.
"""

device, provider_options = parse_device(device)
provider = get_provider_for_device(device)
validate_provider_availability(provider) # raise error if the provider is not available
self.device = device

if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
return self

self.vae_decoder.session.set_providers([provider], provider_options=[provider_options])
self.text_encoder.session.set_providers([provider], provider_options=[provider_options])
self.unet.session.set_providers([provider], provider_options=[provider_options])
Expand All @@ -464,6 +468,8 @@ def to(self, device: Union[torch.device, str, int]):
self.vae_encoder.session.set_providers([provider], provider_options=[provider_options])

self.providers = self.vae_decoder.session.get_providers()
self._device = device

return self

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,12 +1124,13 @@ def to(self, device: Union[torch.device, str, int]):
provider = get_provider_for_device(device)
validate_provider_availability(provider) # raise error if the provider is not available

self.device = device
self.encoder.session.set_providers([provider], provider_options=[provider_options])
self.decoder.session.set_providers([provider], provider_options=[provider_options])
if self.decoder_with_past is not None:
self.decoder_with_past.session.set_providers([provider], provider_options=[provider_options])

self.providers = self.encoder.session.get_providers()
self._device = device

return self

Expand Down

0 comments on commit 88831a5

Please sign in to comment.