Skip to content

Commit

Permalink
warn
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed May 25, 2024
1 parent d744499 commit 3885f8d
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,6 @@ def _from_pretrained(
if model_save_dir is None:
model_save_dir = new_model_save_dir


onnx_model = onnx.load(str(model_cache_path), load_external_data=False)
model_uses_external_data = check_model_uses_external_data(onnx_model)

Expand All @@ -527,29 +526,38 @@ def _from_pretrained(

override_dims = False

# we changed the meaning of past sequence length at some point ?
for dim in input_dims.keys():
if "past" in dim and input_dims[dim][2] == "past_sequence_length + sequence_length":
input_dims[dim][2] = "past_sequence_length"
override_dims = True

# Since v1.7.0 decoder with past models have fixed sequence length of 1
# To keep these models compatible we set this dimension to dynamic
if input_dims["input_ids"][1] == 1:
input_dims["input_ids"][1] = "sequence_length"
output_dims["logits"][1] = "sequence_length"
override_dims = True

# Since https://github.com/huggingface/optimum/pull/871/files
# changed axis notation/naming during export, we need to update the dims
for dim in input_dims.keys():
if "past" in dim and input_dims[dim][2] == "past_sequence_length + sequence_length":
input_dims[dim][2] = "past_sequence_length"
override_dims = True

if override_dims:
# this is kinda dangerous, warning the user is the least we can do
logger.warning(
"The ONNX model was probably exported with an older version of optimum. "
"We are updating the input/output dimensions and overwriting the model file "
"with new dimensions. This is necessary for the model to work correctly with "
"the current version of optimum. If you encounter any issues, please re-export "
"the model with the latest version of optimum for optimal performance."
)
onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims)
onnx.save(
onnx_model,
str(model_cache_path),
save_as_external_data=model_uses_external_data,
all_tensors_to_one_file=True,
location=model_cache_path.name + "_data",
size_threshold=0,
all_tensors_to_one_file=True,
convert_attribute=True,
size_threshold=0,
)

del onnx_model
Expand Down

0 comments on commit 3885f8d

Please sign in to comment.