Skip to content

Commit

Permalink
fix past kv in old model
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed May 24, 2024
1 parent 6e86081 commit d744499
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,7 @@ def _from_pretrained(
if model_save_dir is None:
model_save_dir = new_model_save_dir

# Since v1.7.0 decoder with past models have fixed sequence length of 1
# To keep these models compatible we set this dimension to dynamic

onnx_model = onnx.load(str(model_cache_path), load_external_data=False)
model_uses_external_data = check_model_uses_external_data(onnx_model)

Expand All @@ -521,15 +520,28 @@ def _from_pretrained(
node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.input
}
output_dims = {
node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.output
}

override_dims = False

# we changed the meaning of past sequence length at some point ?
for dim in input_dims.keys():
if "past" in dim and input_dims[dim][2] == "past_sequence_length + sequence_length":
input_dims[dim][2] = "past_sequence_length"
override_dims = True

# Since v1.7.0 decoder with past models have fixed sequence length of 1
# To keep these models compatible we set this dimension to dynamic
if input_dims["input_ids"][1] == 1:
input_dims["input_ids"][1] = "sequence_length"
output_dims = {
node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.output
}
output_dims["logits"][1] = "sequence_length"
onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims)
override_dims = True

if override_dims:
onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims)
onnx.save(
onnx_model,
str(model_cache_path),
Expand All @@ -539,6 +551,7 @@ def _from_pretrained(
size_threshold=0,
convert_attribute=True,
)

del onnx_model

model = ORTModel.load_model(
Expand Down

0 comments on commit d744499

Please sign in to comment.