diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index ca488c47a..1f74e49a3 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -853,9 +853,15 @@ def __init__( def forward( self, input_ids: torch.Tensor, - *args, - **kwargs, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): + if attention_mask is not None: + assert torch.equal(torch.ones_like(attention_mask), attention_mask), "attention_mask is expected to be only all ones" + if output_hidden_states is not None: + assert bool(self.text_encoder.config.output_hidden_states) == bool(output_hidden_states), "output_hidden_states is expected to be consistent with how it was compiled" + input_ids = input_ids.to(torch.long) # dummy generator uses long int for tracing inputs = (input_ids,)