From ae25a8379e69b1655882a0f3b0d3b063ca2d3694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Mon, 1 Jul 2024 15:50:19 -0600 Subject: [PATCH] use assign argument if available in nn.Module.load_state_dict (#1032) --- nbs/common.base_model.ipynb | 5 ++++- neuralforecast/common/_base_model.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/nbs/common.base_model.ipynb b/nbs/common.base_model.ipynb index deb042caa..31d3da407 100644 --- a/nbs/common.base_model.ipynb +++ b/nbs/common.base_model.ipynb @@ -465,7 +465,10 @@ " content = torch.load(f, **kwargs)\n", " with _disable_torch_init():\n", " model = cls(**content['hyper_parameters']) \n", - " model.load_state_dict(content['state_dict'], strict=True, assign=True)\n", + " if \"assign\" in inspect.signature(model.load_state_dict).parameters:\n", + " model.load_state_dict(content[\"state_dict\"], strict=True, assign=True)\n", + " else: # pytorch<2.1\n", + " model.load_state_dict(content[\"state_dict\"], strict=True)\n", " return model" ] } diff --git a/neuralforecast/common/_base_model.py b/neuralforecast/common/_base_model.py index 6192be525..47587a66c 100644 --- a/neuralforecast/common/_base_model.py +++ b/neuralforecast/common/_base_model.py @@ -445,5 +445,8 @@ def load(cls, path, **kwargs): content = torch.load(f, **kwargs) with _disable_torch_init(): model = cls(**content["hyper_parameters"]) - model.load_state_dict(content["state_dict"], strict=True, assign=True) + if "assign" in inspect.signature(model.load_state_dict).parameters: + model.load_state_dict(content["state_dict"], strict=True, assign=True) + else: # pytorch<2.1 + model.load_state_dict(content["state_dict"], strict=True) return model