Skip to content

Commit

Permalink
fix the image size value inferring
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Aug 10, 2023
1 parent 11470f6 commit 43c34dc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
7 changes: 7 additions & 0 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,13 @@ def outputs(self) -> List[str]:
return ["sample"]

def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs):
# For neuron, we use static shape for compiling the unet. Unlike `optimum`, we use the given `height` and `width` instead of the `sample_size`.
if self.height == self.width:
self._normalized_config.image_size = self.height
else:
raise ValueError(
"You need to input the same value for `self.height({self.height})` and `self.width({self.width})`."
)
dummy_inputs = super().generate_dummy_inputs(**kwargs)
dummy_inputs["timestep"] = dummy_inputs["timestep"].float()
dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
Expand All @@ -176,9 +174,9 @@ def __call__(
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 0. Height and width to unet (static shapes)
height = self.unet.config.neuron["static_height"] * self.vae_scale_factor
width = self.unet.config.neuron["static_width"] * self.vae_scale_factor

# 1. Check inputs. Raise error if not correct
self.check_inputs(
Expand Down

0 comments on commit 43c34dc

Please sign in to comment.