Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jul 13, 2023
1 parent 6c12200 commit c522840
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 24 deletions.
19 changes: 13 additions & 6 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,19 @@ def main_export(

# Saving the additional components needed to perform inference.
model.scheduler.save_pretrained(output.joinpath("scheduler"))
if getattr(model, "feature_extractor", None) is not None:
model.feature_extractor.save_pretrained(output.joinpath("feature_extractor"))
if getattr(model, "tokenizer", None) is not None:
model.tokenizer.save_pretrained(output.joinpath("tokenizer"))
if getattr(model, "tokenizer_2", None) is not None:
model.tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))

feature_extractor = getattr(model, "feature_extractor", None)
if feature_extractor is not None:
feature_extractor.save_pretrained(output.joinpath("feature_extractor"))

tokenizer = getattr(model, "tokenizer", None)
if tokenizer is not None:
tokenizer.save_pretrained(output.joinpath("tokenizer"))

tokenizer_2 = getattr(model, "tokenizer_2", None)
if tokenizer_2 is not None:
tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))

model.save_config(output)

_, onnx_outputs = export_models(
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def _run_validation(
if isinstance(value, (list, tuple)):
value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.cpu().numpy() for tensor_name, pt_tensor in value.items()})
elif isinstance(value, (dict)):
elif isinstance(value, dict):
onnx_inputs.update({tensor_name: pt_tensor.cpu().numpy() for tensor_name, pt_tensor in value.items()})
else:
onnx_inputs[name] = value.cpu().numpy()
Expand Down
29 changes: 18 additions & 11 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,29 +678,29 @@ def inputs(self) -> Dict[str, Dict[int, str]]:

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
outputs = {
common_outputs = {
"text_embeds": {0: "batch_size", 1: "sequence_length"},
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
}
if self._normalized_config.output_hidden_states:
for i in range(self._normalized_config.num_layers + 1):
outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}

return outputs
return common_outputs


class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
outputs = {
common_outputs = {
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
"pooler_output": {0: "batch_size"},
}
if self._normalized_config.output_hidden_states:
for i in range(self._normalized_config.num_layers + 1):
outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}

return outputs
return common_outputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
Expand Down Expand Up @@ -734,18 +734,18 @@ class UNetOnnxConfig(VisionOnnxConfig):

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
inputs = {
common_inputs = {
"sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
"timestep": {0: "steps"},
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
}

# TODO : add text_image, image and image_embeds
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
inputs["text_embeds"] = {0: "batch_size"}
inputs["time_ids"] = {0: "batch_size"}
common_inputs["text_embeds"] = {0: "batch_size"}
common_inputs["time_ids"] = {0: "batch_size"}

return inputs
return common_inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -772,7 +772,14 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
return dummy_inputs

def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]:
return self.inputs
inputs = super().ordered_inputs(model=model)
# to fix mismatch between model forward signature and expected inputs
# a dictionnary of additional embeddings `added_cond_kwargs` is expected depending on config.addition_embed_type
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
inputs["text_embeds"] = self.inputs["text_embeds"]
inputs["time_ids"] = self.inputs["time_ids"]

return inputs


class VaeEncoderOnnxConfig(VisionOnnxConfig):
Expand Down
7 changes: 4 additions & 3 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ def _get_submodels_for_export_stable_diffusion(
vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
models_for_export["vae_decoder"] = vae_decoder

if getattr(pipeline, "text_encoder_2", None) is not None:
pipeline.text_encoder_2.config.output_hidden_states = True
models_for_export["text_encoder_2"] = pipeline.text_encoder_2
text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
if text_encoder_2 is not None:
text_encoder_2.config.output_hidden_states = True
models_for_export["text_encoder_2"] = text_encoder_2

return models_for_export

Expand Down
3 changes: 2 additions & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1491,9 +1491,10 @@ def get_model_from_task(
elif device is None:
device = torch.device("cpu")

# TODO : fix EulerDiscreteScheduler loading to enable for SD models
if (
version.parse(torch.__version__) >= version.parse("2.0")
and TasksManager._TASKS_TO_LIBRARY[task] != "diffusers"
and TasksManager._TASKS_TO_LIBRARY[task.replace("-with-past", "")] != "diffusers"
):
with device:
# Initialize directly in the requested device, to save allocation time. Especially useful for large
Expand Down
4 changes: 2 additions & 2 deletions optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def check_inputs(
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
pooled_prompt_embeds: Optional[np.ndarray] = None,
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand Down

0 comments on commit c522840

Please sign in to comment.