Skip to content

Commit

Permalink
extend tests and only translate generators
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Sep 17, 2024
1 parent 7f77b1c commit 4933c7c
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 68 deletions.
58 changes: 43 additions & 15 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
from .utils import (
ONNX_WEIGHTS_NAME,
get_provider_for_device,
np_to_pt,
np_to_pt_generators,
parse_device,
validate_provider_availability,
)
Expand Down Expand Up @@ -248,7 +248,10 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
for external_data_path in external_data_paths:
shutil.copyfile(external_data_path, model_save_path.parent / external_data_path.name)
# copy config
shutil.copyfile(model_path.parent / CONFIG_NAME, model_save_path.parent / CONFIG_NAME)
shutil.copyfile(
model_path.parent / self.sub_component_config_name,
model_save_path.parent / self.sub_component_config_name,
)

self.scheduler.save_pretrained(save_directory / "scheduler")

Expand Down Expand Up @@ -486,28 +489,28 @@ def __call__(self, *args, **kwargs):
device = self._execution_device

for i in range(len(args)):
args[i] = np_to_pt(args[i], device)
args[i] = np_to_pt_generators(args[i], device)

for k, v in kwargs.items():
kwargs[k] = np_to_pt(v, device)
kwargs[k] = np_to_pt_generators(v, device)

return self.auto_model_class.__call__(self, *args, **kwargs)


class ORTPipelinePart(ORTModelPart):
def __init__(self, session: ort.InferenceSession, parent_model: ORTPipeline):
config_path = Path(session._model_path).parent / "config.json"

if config_path.is_file():
self.config = FrozenDict(parent_model._dict_from_json_file(config_path))
else:
self.config = FrozenDict({})

super().__init__(session, parent_model)

config_path = Path(session._model_path).parent / "config.json"
config_dict = parent_model._dict_from_json_file(config_path) if config_path.is_file() else {}
self.config = FrozenDict(config_dict)

@property
def input_dtype(self):
# for backward compatibility and diffusion mixins (will be standardized in the future)
logger.warning(
"The `input_dtype` property is deprecated and will be removed in the next release. "
"Please use `input_dtypes` along with `TypeHelper` to get the `numpy` types."
)
return {name: TypeHelper.ort_type_to_numpy_type(ort_type) for name, ort_type in self.input_dtypes.items()}


Expand Down Expand Up @@ -593,7 +596,8 @@ def forward(self, sample: Union[np.ndarray, torch.Tensor], return_dict: bool = F

if "latent_sample" in model_outputs:
model_outputs["latents"] = model_outputs.pop("latent_sample")
elif "latent_parameters" in model_outputs:

if "latent_parameters" in model_outputs:
model_outputs["latent_dist"] = DiagonalGaussianDistribution(
parameters=model_outputs.pop("latent_parameters")
)
Expand Down Expand Up @@ -631,9 +635,32 @@ def forward(
class ORTVaeWrapper(ORTPipelinePart):
def __init__(self, vae_encoder: ORTModelVaeEncoder, vae_decoder: ORTModelVaeDecoder, parent_model: ORTPipeline):
super().__init__(vae_decoder.session, parent_model)
self.vae_encoder = vae_encoder
self.vae_decoder = vae_decoder

def encode(
self,
sample: Union[np.ndarray, torch.Tensor],
return_dict: bool = False,
):
return self.vae_encoder(sample, return_dict)

self.encode = vae_encoder.forward
self.decode = vae_decoder.forward
def decode(
self,
latent_sample: Union[np.ndarray, torch.Tensor],
generator: Optional[torch.Generator] = None,
return_dict: bool = False,
):
return self.vae_decoder(latent_sample, generator, return_dict)

def forward(
self,
sample: Union[np.ndarray, torch.Tensor],
generator: Optional[torch.Generator] = None,
return_dict: bool = False,
):
latent_sample = self.encode(sample).latent_dist.sample(generator=generator)
return self.decode(latent_sample, generator, return_dict)


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
Expand Down Expand Up @@ -930,6 +957,7 @@ class ORTPipelineForTask(ConfigMixin):
config_name = "model_index.json"

@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path, **kwargs) -> ORTPipeline:
load_config_kwargs = {
"force_download": kwargs.get("force_download", False),
Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
if check_if_transformers_greater("4.25.0"):
from transformers.generation import GenerationMixin
else:
from transformers.generation_utils import GenerationMixin
from transformers.generation_utils import GenerationMixin # type: ignore


if check_if_transformers_greater("4.43.0"):
Expand Down
16 changes: 6 additions & 10 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,20 +405,16 @@ def evaluation_loop(
return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=len(dataset))


def np_to_pt(np_object, device):
if isinstance(np_object, np.ndarray):
return torch.from_numpy(np_object)
elif isinstance(np_object, np.random.RandomState):
def np_to_pt_generators(np_object, device):
if isinstance(np_object, np.random.RandomState):
return torch.Generator(device=device).manual_seed(int(np_object.get_state()[1][0]))
elif isinstance(np_object, np.random.Generator):
return torch.Generator(device=device).manual_seed(int(np_object.bit_generator.state[1][0]))
elif isinstance(np_object, list) and isinstance(
np_object[0], (np.ndarray, np.random.RandomState, np.random.Generator)
):
return [np_to_pt(a, device) for a in np_object]
elif isinstance(np_object, list) and isinstance(np_object[0], (np.random.RandomState, np.random.Generator)):
return [np_to_pt_generators(a, device) for a in np_object]
elif isinstance(np_object, dict) and isinstance(
next(iter(np_object.values())), (np.ndarray, np.random.RandomState, np.random.Generator)
next(iter(np_object.values())), (np.random.RandomState, np.random.Generator)
):
return {k: np_to_pt(v, device) for k, v in np_object.items()}
return {k: np_to_pt_generators(v, device) for k, v in np_object.items()}
else:
return np_object
86 changes: 44 additions & 42 deletions tests/onnxruntime/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _generate_images(height=128, width=128, batch_size=1, channel=3, input_type=
"/in_paint/overture-creations-5sI6fQgYIuo.png"
).resize((width, height))
elif input_type == "np":
image = np.random.rand(channel, height, width)
image = np.random.rand(height, width, channel)
elif input_type == "pt":
image = torch.rand((channel, height, width))

Expand Down Expand Up @@ -115,17 +115,16 @@ def test_ort_pipeline_class_dispatch(self, model_arch: str):
def test_num_images_per_prompt(self, model_arch: str):
model_args = {"test_name": model_arch, "model_arch": model_arch}
self._setup(model_args)
pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch])
self.assertEqual(pipeline.vae_scale_factor, 2)
self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4)
self.assertEqual(pipeline.unet.config["in_channels"], 4)

height, width, batch_size = 64, 64, 1
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch])

for num_images in [1, 3]:
outputs = pipeline(**inputs, num_images_per_prompt=num_images).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))
for batch_size in [1, 3]:
for height in [64, 128]:
for width in [64, 128]:
for num_images_per_prompt in [1, 3]:
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_diffusers
Expand Down Expand Up @@ -184,17 +183,21 @@ def __call__(self, *args, **kwargs) -> None:
def test_shape(self, model_arch: str):
model_args = {"test_name": model_arch, "model_arch": model_arch}
self._setup(model_args)
height, width, batch_size = 128, 64, 1

pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch])

height, width, batch_size = 128, 64, 1
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)

for output_type in ["np", "pil", "latent"]:
for output_type in ["pil", "np", "pt", "latent"]:
inputs["output_type"] = output_type
outputs = pipeline(**inputs).images
if output_type == "pil":
self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width))
elif output_type == "np":
self.assertEqual(outputs.shape, (batch_size, height, width, 3))
elif output_type == "pt":
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
else:
self.assertEqual(
outputs.shape,
Expand Down Expand Up @@ -334,16 +337,14 @@ def test_num_images_per_prompt(self, model_arch: str):
self._setup(model_args)

pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch])
self.assertEqual(pipeline.vae_scale_factor, 2)
self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4)
self.assertEqual(pipeline.unet.config["in_channels"], 4)

batch_size, height = 1, 32
for width in [64, 32]:
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
for num_images in [1, 3]:
outputs = pipeline(**inputs, num_images_per_prompt=num_images).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))
for batch_size in [1, 3]:
for height in [64, 128]:
for width in [64, 128]:
for num_images_per_prompt in [1, 3]:
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_diffusers
Expand Down Expand Up @@ -383,18 +384,21 @@ def test_shape(self, model_arch: str):
self._setup(model_args)

pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch])
height, width, batch_size = 32, 64, 1

for input_type in ["np", "pil", "pt"]:
height, width, batch_size = 128, 64, 1

for input_type in ["pil", "np", "pt"]:
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type)

for output_type in ["np", "pil", "latent"]:
for output_type in ["pil", "np", "pt", "latent"]:
inputs["output_type"] = output_type
outputs = pipeline(**inputs).images
if output_type == "pil":
self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width))
elif output_type == "np":
self.assertEqual(outputs.shape, (batch_size, height, width, 3))
elif output_type == "pt":
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
else:
self.assertEqual(
outputs.shape,
Expand Down Expand Up @@ -477,17 +481,14 @@ class ORTPipelineForInpaintingTest(ORTModelTestMixin):
TASK = "inpainting"

def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil"):
assert batch_size == 1, "Inpainting models only support batch_size=1"
assert input_type == "pil", "Inpainting models only support input_type='pil'"

inputs = _generate_prompts(batch_size=batch_size)

inputs["image"] = _generate_images(
height=height, width=width, batch_size=batch_size, channel=channel, input_type=input_type
)[0]
)
inputs["mask_image"] = _generate_images(
height=height, width=width, batch_size=batch_size, channel=channel, input_type=input_type
)[0]
height=height, width=width, batch_size=batch_size, channel=1, input_type=input_type
)

inputs["strength"] = 0.75
inputs["height"] = height
Expand Down Expand Up @@ -522,16 +523,14 @@ def test_num_images_per_prompt(self, model_arch: str):
self._setup(model_args)

pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch])
self.assertEqual(pipeline.vae_scale_factor, 2)
self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4)
self.assertEqual(pipeline.unet.config["in_channels"], 4)

batch_size, height = 1, 32
for width in [64, 32]:
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
for num_images in [1, 3]:
outputs = pipeline(**inputs, num_images_per_prompt=num_images).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))
for batch_size in [1, 3]:
for height in [64, 128]:
for width in [64, 128]:
for num_images_per_prompt in [1, 3]:
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_diffusers
Expand Down Expand Up @@ -571,18 +570,21 @@ def test_shape(self, model_arch: str):
self._setup(model_args)

pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch])
height, width, batch_size = 32, 64, 1

for input_type in ["pil"]:
height, width, batch_size = 128, 64, 1

for input_type in ["pil", "np", "pt"]:
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type)

for output_type in ["np", "pil", "latent"]:
for output_type in ["pil", "np", "pt", "latent"]:
inputs["output_type"] = output_type
outputs = pipeline(**inputs).images
if output_type == "pil":
self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width))
elif output_type == "np":
self.assertEqual(outputs.shape, (batch_size, height, width, 3))
elif output_type == "pt":
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
else:
self.assertEqual(
outputs.shape,
Expand Down

0 comments on commit 4933c7c

Please sign in to comment.