From 4933c7ce1378516b2895ea75312bb5791e428bfa Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 17 Sep 2024 14:19:45 +0200 Subject: [PATCH] extend tests and only translate generators --- optimum/onnxruntime/modeling_diffusion.py | 58 +++++++++++---- optimum/onnxruntime/modeling_seq2seq.py | 2 +- optimum/onnxruntime/utils.py | 16 ++--- tests/onnxruntime/test_diffusion.py | 86 ++++++++++++----------- 4 files changed, 94 insertions(+), 68 deletions(-) diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index e91142dfa8..4d52dd9cd9 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -67,7 +67,7 @@ from .utils import ( ONNX_WEIGHTS_NAME, get_provider_for_device, - np_to_pt, + np_to_pt_generators, parse_device, validate_provider_availability, ) @@ -248,7 +248,10 @@ def _save_pretrained(self, save_directory: Union[str, Path]): for external_data_path in external_data_paths: shutil.copyfile(external_data_path, model_save_path.parent / external_data_path.name) # copy config - shutil.copyfile(model_path.parent / CONFIG_NAME, model_save_path.parent / CONFIG_NAME) + shutil.copyfile( + model_path.parent / self.sub_component_config_name, + model_save_path.parent / self.sub_component_config_name, + ) self.scheduler.save_pretrained(save_directory / "scheduler") @@ -486,28 +489,28 @@ def __call__(self, *args, **kwargs): device = self._execution_device for i in range(len(args)): - args[i] = np_to_pt(args[i], device) + args[i] = np_to_pt_generators(args[i], device) for k, v in kwargs.items(): - kwargs[k] = np_to_pt(v, device) + kwargs[k] = np_to_pt_generators(v, device) return self.auto_model_class.__call__(self, *args, **kwargs) class ORTPipelinePart(ORTModelPart): def __init__(self, session: ort.InferenceSession, parent_model: ORTPipeline): - config_path = Path(session._model_path).parent / "config.json" - - if config_path.is_file(): - self.config = FrozenDict(parent_model._dict_from_json_file(config_path)) - else: - self.config = FrozenDict({}) - super().__init__(session, parent_model) + config_path = Path(session._model_path).parent / "config.json" + config_dict = parent_model._dict_from_json_file(config_path) if config_path.is_file() else {} + self.config = FrozenDict(config_dict) + @property def input_dtype(self): - # for backward compatibility and diffusion mixins (will be standardized in the future) + logger.warning( + "The `input_dtype` property is deprecated and will be removed in the next release. " + "Please use `input_dtypes` along with `TypeHelper` to get the `numpy` types." + ) return {name: TypeHelper.ort_type_to_numpy_type(ort_type) for name, ort_type in self.input_dtypes.items()} @@ -593,7 +596,8 @@ def forward(self, sample: Union[np.ndarray, torch.Tensor], return_dict: bool = F if "latent_sample" in model_outputs: model_outputs["latents"] = model_outputs.pop("latent_sample") - elif "latent_parameters" in model_outputs: + + if "latent_parameters" in model_outputs: model_outputs["latent_dist"] = DiagonalGaussianDistribution( parameters=model_outputs.pop("latent_parameters") ) @@ -631,9 +635,32 @@ def forward( class ORTVaeWrapper(ORTPipelinePart): def __init__(self, vae_encoder: ORTModelVaeEncoder, vae_decoder: ORTModelVaeDecoder, parent_model: ORTPipeline): super().__init__(vae_decoder.session, parent_model) + self.vae_encoder = vae_encoder + self.vae_decoder = vae_decoder + + def encode( + self, + sample: Union[np.ndarray, torch.Tensor], + return_dict: bool = False, + ): + return self.vae_encoder(sample, return_dict) - self.encode = vae_encoder.forward - self.decode = vae_decoder.forward + def decode( + self, + latent_sample: Union[np.ndarray, torch.Tensor], + generator: Optional[torch.Generator] = None, + return_dict: bool = False, + ): + return self.vae_decoder(latent_sample, generator, return_dict) + + def forward( + self, + sample: Union[np.ndarray, torch.Tensor], + generator: Optional[torch.Generator] = None, + return_dict: bool = False, + ): + latent_sample = self.encode(sample).latent_dist.sample(generator=generator) + return self.decode(latent_sample, generator, return_dict) @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) @@ -930,6 +957,7 @@ class ORTPipelineForTask(ConfigMixin): config_name = "model_index.json" @classmethod + @validate_hf_hub_args def from_pretrained(cls, pretrained_model_or_path, **kwargs) -> ORTPipeline: load_config_kwargs = { "force_download": kwargs.get("force_download", False), diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 3cecadafe3..30f042dcc3 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -68,7 +68,7 @@ if check_if_transformers_greater("4.25.0"): from transformers.generation import GenerationMixin else: - from transformers.generation_utils import GenerationMixin + from transformers.generation_utils import GenerationMixin # type: ignore if check_if_transformers_greater("4.43.0"): diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 1da49a65a2..128e2406f1 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -405,20 +405,16 @@ def evaluation_loop( return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=len(dataset)) -def np_to_pt(np_object, device): - if isinstance(np_object, np.ndarray): - return torch.from_numpy(np_object) - elif isinstance(np_object, np.random.RandomState): +def np_to_pt_generators(np_object, device): + if isinstance(np_object, np.random.RandomState): return torch.Generator(device=device).manual_seed(int(np_object.get_state()[1][0])) elif isinstance(np_object, np.random.Generator): return torch.Generator(device=device).manual_seed(int(np_object.bit_generator.state[1][0])) - elif isinstance(np_object, list) and isinstance( - np_object[0], (np.ndarray, np.random.RandomState, np.random.Generator) - ): - return [np_to_pt(a, device) for a in np_object] + elif isinstance(np_object, list) and isinstance(np_object[0], (np.random.RandomState, np.random.Generator)): + return [np_to_pt_generators(a, device) for a in np_object] elif isinstance(np_object, dict) and isinstance( - next(iter(np_object.values())), (np.ndarray, np.random.RandomState, np.random.Generator) + next(iter(np_object.values())), (np.random.RandomState, np.random.Generator) ): - return {k: np_to_pt(v, device) for k, v in np_object.items()} + return {k: np_to_pt_generators(v, device) for k, v in np_object.items()} else: return np_object diff --git a/tests/onnxruntime/test_diffusion.py b/tests/onnxruntime/test_diffusion.py index cdcee7f613..7bb6128878 100644 --- a/tests/onnxruntime/test_diffusion.py +++ b/tests/onnxruntime/test_diffusion.py @@ -62,7 +62,7 @@ def _generate_images(height=128, width=128, batch_size=1, channel=3, input_type= "/in_paint/overture-creations-5sI6fQgYIuo.png" ).resize((width, height)) elif input_type == "np": - image = np.random.rand(channel, height, width) + image = np.random.rand(height, width, channel) elif input_type == "pt": image = torch.rand((channel, height, width)) @@ -115,17 +115,16 @@ def test_ort_pipeline_class_dispatch(self, model_arch: str): def test_num_images_per_prompt(self, model_arch: str): model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertEqual(pipeline.vae_scale_factor, 2) - self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) - self.assertEqual(pipeline.unet.config["in_channels"], 4) - height, width, batch_size = 64, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - for num_images in [1, 3]: - outputs = pipeline(**inputs, num_images_per_prompt=num_images).images - self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + for batch_size in [1, 3]: + for height in [64, 128]: + for width in [64, 128]: + for num_images_per_prompt in [1, 3]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images + self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3)) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers @@ -184,17 +183,21 @@ def __call__(self, *args, **kwargs) -> None: def test_shape(self, model_arch: str): model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) - height, width, batch_size = 128, 64, 1 + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + height, width, batch_size = 128, 64, 1 inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - for output_type in ["np", "pil", "latent"]: + for output_type in ["pil", "np", "pt", "latent"]: inputs["output_type"] = output_type outputs = pipeline(**inputs).images if output_type == "pil": self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) elif output_type == "np": self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + elif output_type == "pt": + self.assertEqual(outputs.shape, (batch_size, 3, height, width)) else: self.assertEqual( outputs.shape, @@ -334,16 +337,14 @@ def test_num_images_per_prompt(self, model_arch: str): self._setup(model_args) pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertEqual(pipeline.vae_scale_factor, 2) - self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) - self.assertEqual(pipeline.unet.config["in_channels"], 4) - batch_size, height = 1, 32 - for width in [64, 32]: - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - for num_images in [1, 3]: - outputs = pipeline(**inputs, num_images_per_prompt=num_images).images - self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + for batch_size in [1, 3]: + for height in [64, 128]: + for width in [64, 128]: + for num_images_per_prompt in [1, 3]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images + self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3)) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers @@ -383,18 +384,21 @@ def test_shape(self, model_arch: str): self._setup(model_args) pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - height, width, batch_size = 32, 64, 1 - for input_type in ["np", "pil", "pt"]: + height, width, batch_size = 128, 64, 1 + + for input_type in ["pil", "np", "pt"]: inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type) - for output_type in ["np", "pil", "latent"]: + for output_type in ["pil", "np", "pt", "latent"]: inputs["output_type"] = output_type outputs = pipeline(**inputs).images if output_type == "pil": self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) elif output_type == "np": self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + elif output_type == "pt": + self.assertEqual(outputs.shape, (batch_size, 3, height, width)) else: self.assertEqual( outputs.shape, @@ -477,17 +481,14 @@ class ORTPipelineForInpaintingTest(ORTModelTestMixin): TASK = "inpainting" def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil"): - assert batch_size == 1, "Inpainting models only support batch_size=1" - assert input_type == "pil", "Inpainting models only support input_type='pil'" - inputs = _generate_prompts(batch_size=batch_size) inputs["image"] = _generate_images( height=height, width=width, batch_size=batch_size, channel=channel, input_type=input_type - )[0] + ) inputs["mask_image"] = _generate_images( - height=height, width=width, batch_size=batch_size, channel=channel, input_type=input_type - )[0] + height=height, width=width, batch_size=batch_size, channel=1, input_type=input_type + ) inputs["strength"] = 0.75 inputs["height"] = height @@ -522,16 +523,14 @@ def test_num_images_per_prompt(self, model_arch: str): self._setup(model_args) pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertEqual(pipeline.vae_scale_factor, 2) - self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) - self.assertEqual(pipeline.unet.config["in_channels"], 4) - batch_size, height = 1, 32 - for width in [64, 32]: - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - for num_images in [1, 3]: - outputs = pipeline(**inputs, num_images_per_prompt=num_images).images - self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + for batch_size in [1, 3]: + for height in [64, 128]: + for width in [64, 128]: + for num_images_per_prompt in [1, 3]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images + self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3)) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers @@ -571,18 +570,21 @@ def test_shape(self, model_arch: str): self._setup(model_args) pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - height, width, batch_size = 32, 64, 1 - for input_type in ["pil"]: + height, width, batch_size = 128, 64, 1 + + for input_type in ["pil", "np", "pt"]: inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type) - for output_type in ["np", "pil", "latent"]: + for output_type in ["pil", "np", "pt", "latent"]: inputs["output_type"] = output_type outputs = pipeline(**inputs).images if output_type == "pil": self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) elif output_type == "np": self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + elif output_type == "pt": + self.assertEqual(outputs.shape, (batch_size, 3, height, width)) else: self.assertEqual( outputs.shape,