diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 22d0165bb2..8909f7b430 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -389,32 +389,31 @@ def forward( return BaseModelOutput(last_hidden_state=last_hidden_state) def compute_encoder_known_output_shapes(self, pixel_values: torch.FloatTensor) -> Dict[str, List[int]]: - if isinstance(self.normalized_config.config.image_size, int): + if self.normalized_config.model_type == "vit": # for vit models encoder_sequence_length = ( - self.normalized_config.config.image_size // self.normalized_config.config.patch_size + self.normalized_config.image_size // self.normalized_config.config.patch_size ) ** 2 + 1 # plus cls token - return { - "last_hidden_state": [ - pixel_values.shape[0], # batch_size - encoder_sequence_length, # encoder_sequence_length - self.normalized_config.config.hidden_size, # hidden_size - ] - } - else: - # for donut-swim models + elif self.normalized_config.config.model_type == "donut-swin": + # for donut-swin models encoder_sequence_length = ( - (self.normalized_config.config.image_size[0] // self.normalized_config.config.patch_size) - * (self.normalized_config.config.image_size[1] // self.normalized_config.config.patch_size) + self.normalized_config.config.image_size[0] + * self.normalized_config.config.image_size[1] // self.normalized_config.config.hidden_size ) - return { - "last_hidden_state": [ - pixel_values.shape[0], # batch_size - encoder_sequence_length, # encoder_sequence_length - self.normalized_config.config.hidden_size, # hidden_size - ] - } + else: + raise ValueError( + f"Unsupported encoder model type {self.normalized_config.config.model_type} for VisionEncoderDecoder." + "Please submit a PR to add support for this model type." + ) + + return { + "last_hidden_state": [ + pixel_values.shape[0], # batch_size + encoder_sequence_length, # encoder_sequence_length + self.normalized_config.config.hidden_size, # hidden_size + ] + } class ORTModelForConditionalGeneration(ORTModel, ABC):