diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 472382c4f5..80ec5fdda8 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -1486,8 +1486,8 @@ def forward( flattened_patches=flattened_patches, attention_mask=attention_mask, ) - # work around for Unexpected input data type (tensor(float)) for attention_mask - # Need to be looked into + + #decoder requires torch.LongTensor for attention_mask attention_mask = attention_mask.to(torch.int64) # Decode