diff --git a/nbs/common.base_model.ipynb b/nbs/common.base_model.ipynb index deb042caa..31d3da407 100644 --- a/nbs/common.base_model.ipynb +++ b/nbs/common.base_model.ipynb @@ -465,7 +465,10 @@ " content = torch.load(f, **kwargs)\n", " with _disable_torch_init():\n", " model = cls(**content['hyper_parameters']) \n", - " model.load_state_dict(content['state_dict'], strict=True, assign=True)\n", + " if \"assign\" in inspect.signature(model.load_state_dict).parameters:\n", + " model.load_state_dict(content[\"state_dict\"], strict=True, assign=True)\n", + " else: # pytorch<2.1\n", + " model.load_state_dict(content[\"state_dict\"], strict=True)\n", " return model" ] } diff --git a/neuralforecast/common/_base_model.py b/neuralforecast/common/_base_model.py index 6192be525..47587a66c 100644 --- a/neuralforecast/common/_base_model.py +++ b/neuralforecast/common/_base_model.py @@ -445,5 +445,8 @@ def load(cls, path, **kwargs): content = torch.load(f, **kwargs) with _disable_torch_init(): model = cls(**content["hyper_parameters"]) - model.load_state_dict(content["state_dict"], strict=True, assign=True) + if "assign" in inspect.signature(model.load_state_dict).parameters: + model.load_state_dict(content["state_dict"], strict=True, assign=True) + else: # pytorch<2.1 + model.load_state_dict(content["state_dict"], strict=True) return model