diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 455236126b..8e654ce617 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -509,8 +509,7 @@ def _from_pretrained( if model_save_dir is None: model_save_dir = new_model_save_dir - # Since v1.7.0 decoder with past models have fixed sequence length of 1 - # To keep these models compatible we set this dimension to dynamic + onnx_model = onnx.load(str(model_cache_path), load_external_data=False) model_uses_external_data = check_model_uses_external_data(onnx_model) @@ -521,15 +520,28 @@ def _from_pretrained( node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] for node in onnx_model.graph.input } + output_dims = { + node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] + for node in onnx_model.graph.output + } + + override_dims = False + + # we changed the meaning of past sequence length at some point ? + for dim in input_dims.keys(): + if "past" in dim and input_dims[dim][2] == "past_sequence_length + sequence_length": + input_dims[dim][2] = "past_sequence_length" + override_dims = True + + # Since v1.7.0 decoder with past models have fixed sequence length of 1 + # To keep these models compatible we set this dimension to dynamic if input_dims["input_ids"][1] == 1: input_dims["input_ids"][1] = "sequence_length" - output_dims = { - node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] - for node in onnx_model.graph.output - } output_dims["logits"][1] = "sequence_length" - onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims) + override_dims = True + if override_dims: + onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims) onnx.save( onnx_model, str(model_cache_path), @@ -539,6 +551,7 @@ def _from_pretrained( size_threshold=0, convert_attribute=True, ) + del onnx_model model = ORTModel.load_model(