Skip to content

Commit

Permalink
corrected swin formula
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Aug 16, 2023
1 parent 25b36f0 commit eb55f96
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,32 +389,31 @@ def forward(
return BaseModelOutput(last_hidden_state=last_hidden_state)

def compute_encoder_known_output_shapes(self, pixel_values: torch.FloatTensor) -> Dict[str, List[int]]:
if isinstance(self.normalized_config.config.image_size, int):
if self.normalized_config.model_type == "vit":
# for vit models
encoder_sequence_length = (
self.normalized_config.config.image_size // self.normalized_config.config.patch_size
self.normalized_config.image_size // self.normalized_config.config.patch_size
) ** 2 + 1 # plus cls token
return {
"last_hidden_state": [
pixel_values.shape[0], # batch_size
encoder_sequence_length, # encoder_sequence_length
self.normalized_config.config.hidden_size, # hidden_size
]
}
else:
# for donut-swim models
elif self.normalized_config.config.model_type == "donut-swin":
# for donut-swin models
encoder_sequence_length = (
(self.normalized_config.config.image_size[0] // self.normalized_config.config.patch_size)
* (self.normalized_config.config.image_size[1] // self.normalized_config.config.patch_size)
self.normalized_config.config.image_size[0]
* self.normalized_config.config.image_size[1]
// self.normalized_config.config.hidden_size
)
return {
"last_hidden_state": [
pixel_values.shape[0], # batch_size
encoder_sequence_length, # encoder_sequence_length
self.normalized_config.config.hidden_size, # hidden_size
]
}
else:
raise ValueError(
f"Unsupported encoder model type {self.normalized_config.config.model_type} for VisionEncoderDecoder."
"Please submit a PR to add support for this model type."
)

return {
"last_hidden_state": [
pixel_values.shape[0], # batch_size
encoder_sequence_length, # encoder_sequence_length
self.normalized_config.config.hidden_size, # hidden_size
]
}


class ORTModelForConditionalGeneration(ORTModel, ABC):
Expand Down

0 comments on commit eb55f96

Please sign in to comment.