diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 8e654ce617..e3150da05e 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -509,7 +509,6 @@ def _from_pretrained( if model_save_dir is None: model_save_dir = new_model_save_dir - onnx_model = onnx.load(str(model_cache_path), load_external_data=False) model_uses_external_data = check_model_uses_external_data(onnx_model) @@ -527,12 +526,6 @@ def _from_pretrained( override_dims = False - # we changed the meaning of past sequence length at some point ? - for dim in input_dims.keys(): - if "past" in dim and input_dims[dim][2] == "past_sequence_length + sequence_length": - input_dims[dim][2] = "past_sequence_length" - override_dims = True - # Since v1.7.0 decoder with past models have fixed sequence length of 1 # To keep these models compatible we set this dimension to dynamic if input_dims["input_ids"][1] == 1: @@ -540,16 +533,31 @@ def _from_pretrained( output_dims["logits"][1] = "sequence_length" override_dims = True + # Since https://github.com/huggingface/optimum/pull/871/files + # changed axis notation/naming during export, we need to update the dims + for dim in input_dims.keys(): + if "past" in dim and input_dims[dim][2] == "past_sequence_length + sequence_length": + input_dims[dim][2] = "past_sequence_length" + override_dims = True + if override_dims: + # this is kinda dangerous, warning the user is the least we can do + logger.warning( + "The ONNX model was probably exported with an older version of optimum. " + "We are updating the input/output dimensions and overwriting the model file " + "with new dimensions. This is necessary for the model to work correctly with " + "the current version of optimum. If you encounter any issues, please re-export " + "the model with the latest version of optimum for optimal performance." + ) onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims) onnx.save( onnx_model, str(model_cache_path), save_as_external_data=model_uses_external_data, - all_tensors_to_one_file=True, location=model_cache_path.name + "_data", - size_threshold=0, + all_tensors_to_one_file=True, convert_attribute=True, + size_threshold=0, ) del onnx_model