Skip to content

Commit

Permalink
use assign argument if available in nn.Module.load_state_dict (#1032)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Jul 1, 2024
1 parent b534ddf commit ae25a83
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,10 @@
" content = torch.load(f, **kwargs)\n",
" with _disable_torch_init():\n",
" model = cls(**content['hyper_parameters']) \n",
" model.load_state_dict(content['state_dict'], strict=True, assign=True)\n",
" if \"assign\" in inspect.signature(model.load_state_dict).parameters:\n",
" model.load_state_dict(content[\"state_dict\"], strict=True, assign=True)\n",
" else: # pytorch<2.1\n",
" model.load_state_dict(content[\"state_dict\"], strict=True)\n",
" return model"
]
}
Expand Down
5 changes: 4 additions & 1 deletion neuralforecast/common/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,5 +445,8 @@ def load(cls, path, **kwargs):
content = torch.load(f, **kwargs)
with _disable_torch_init():
model = cls(**content["hyper_parameters"])
model.load_state_dict(content["state_dict"], strict=True, assign=True)
if "assign" in inspect.signature(model.load_state_dict).parameters:
model.load_state_dict(content["state_dict"], strict=True, assign=True)
else: # pytorch<2.1
model.load_state_dict(content["state_dict"], strict=True)
return model

0 comments on commit ae25a83

Please sign in to comment.