Skip to content

Commit

Permalink
stage
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Aug 20, 2024
1 parent ca69d91 commit 61a9fad
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 90 deletions.
5 changes: 5 additions & 0 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,11 @@ def export_models(
failed_models = []
total_compilation_time = 0
compile_configs = {}
models_and_neuron_configs.pop("text_encoder")
models_and_neuron_configs.pop("text_encoder_2")
# models_and_neuron_configs.pop('unet')
models_and_neuron_configs.pop('vae_encoder')
models_and_neuron_configs.pop('vae_decoder')
for i, model_name in enumerate(models_and_neuron_configs.keys()):
logger.info(f"***** Compiling {model_name} *****")
submodel, sub_neuron_config = models_and_neuron_configs[model_name]
Expand Down
14 changes: 10 additions & 4 deletions optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,15 +1058,15 @@ def forward(
inputs = (sample, timestep, encoder_hidden_states)
if timestep_cond is not None:
inputs = inputs + (timestep_cond,)
if added_cond_kwargs is not None:
text_embeds = added_cond_kwargs.pop("text_embeds", None)
time_ids = added_cond_kwargs.pop("time_ids", None)
inputs = inputs + (text_embeds, time_ids)
if mid_block_additional_residual is not None:
inputs = inputs + (mid_block_additional_residual,)
if down_block_additional_residuals is not None:
for idx in range(len(down_block_additional_residuals)):
inputs = inputs + (down_block_additional_residuals[idx],)
if added_cond_kwargs:
text_embeds = added_cond_kwargs.pop("text_embeds", None)
time_ids = added_cond_kwargs.pop("time_ids", None)
inputs = inputs + (text_embeds, time_ids)

outputs = self.model(*inputs)
return outputs
Expand Down Expand Up @@ -1139,9 +1139,15 @@ def forward(
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
guess_mode: bool = False,
added_cond_kwargs: Optional[Dict] = None,
return_dict: bool = True,
) -> Union["ControlNetOutput", Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
timestep = timestep.expand((sample.shape[0],)).to(torch.long)
inputs = (sample, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale)
if added_cond_kwargs:
text_embeds = added_cond_kwargs.pop("text_embeds", None)
time_ids = added_cond_kwargs.pop("time_ids", None)
inputs += (text_embeds, time_ids)
outputs = self.model(*inputs)

if guess_mode:
Expand Down
Loading

0 comments on commit 61a9fad

Please sign in to comment.