Skip to content

Commit

Permalink
match diffusers numpy input
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Sep 16, 2024
1 parent 0869f1c commit b70b641
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
4 changes: 2 additions & 2 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ class ORTDiffusionPipeline(ConfigMixin):

@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
def from_pretrained(cls, pretrained_model_or_path, **kwargs) -> ORTPipeline:
load_config_kwargs = {
"force_download": kwargs.get("force_download", False),
"resume_download": kwargs.get("resume_download", None),
Expand Down Expand Up @@ -953,7 +953,7 @@ class ORTPipelineForTask(ConfigMixin):
config_name = "model_index.json"

@classmethod
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
def from_pretrained(cls, pretrained_model_or_path, **kwargs) -> ORTPipeline:
load_config_kwargs = {
"force_download": kwargs.get("force_download", False),
"resume_download": kwargs.get("resume_download", None),
Expand Down
23 changes: 11 additions & 12 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,19 +407,18 @@ def evaluation_loop(

def np_to_pt(np_object, device):
if isinstance(np_object, np.ndarray):
if np_object.ndim == 4:
return torch.from_numpy(np_object).permute(0, 3, 1, 2)
elif np_object.ndim == 3:
return torch.from_numpy(np_object).permute(2, 0, 1)
else:
return torch.from_numpy(np_object)
elif isinstance(np_object, list) and isinstance(np_object[0], np.ndarray):
return [np_to_pt(a, device) for a in np_object]
elif isinstance(np_object, dict) and isinstance(next(iter(np_object.values())), np.ndarray):
return {k: np_to_pt(v, device) for k, v in np_object.items()}
return torch.from_numpy(np_object)
elif isinstance(np_object, np.random.RandomState):
return torch.Generator(device=device).manual_seed(int(np_object.get_state()[1][0]))
elif isinstance(np_object, list) and isinstance(np_object[0], np.random.RandomState):
return [torch.Generator(device=device).manual_seed(int(a.get_state()[1][0])) for a in np_object]
elif isinstance(np_object, np.random.Generator):
return torch.Generator(device=device).manual_seed(int(np_object.bit_generator.state[1][0]))
elif isinstance(np_object, list) and isinstance(
np_object[0], (np.ndarray, np.random.RandomState, np.random.Generator)
):
return [np_to_pt(a, device) for a in np_object]
elif isinstance(np_object, dict) and isinstance(
next(iter(np_object.values())), (np.ndarray, np.random.RandomState, np.random.Generator)
):
return {k: np_to_pt(v, device) for k, v in np_object.items()}
else:
return np_object
12 changes: 5 additions & 7 deletions tests/onnxruntime/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _generate_images(height=128, width=128, batch_size=1, channel=3, input_type=
"/in_paint/overture-creations-5sI6fQgYIuo.png"
).resize((width, height))
elif input_type == "np":
image = np.random.rand(height, width, channel)
image = np.random.rand(channel, height, width)
elif input_type == "pt":
image = torch.rand((channel, height, width))

Expand Down Expand Up @@ -461,10 +461,9 @@ def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str):
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)

pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider)
self.assertEqual(pipeline.device.type, "cuda")

outputs = pipeline(**inputs).images
# Verify model devices
self.assertEqual(pipeline.device.type.lower(), "cuda")
# Verify model outptus
self.assertIsInstance(outputs, np.ndarray)
self.assertEqual(outputs.shape, (batch_size, height, width, 3))

Expand Down Expand Up @@ -650,9 +649,8 @@ def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str):
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)

pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider)
self.assertEqual(pipeline.device, "cuda")

outputs = pipeline(**inputs).images
# Verify model devices
self.assertEqual(pipeline.device.type.lower(), "cuda")
# Verify model outptus
self.assertIsInstance(outputs, np.ndarray)
self.assertEqual(outputs.shape, (batch_size, height, width, 3))

0 comments on commit b70b641

Please sign in to comment.