From 29b8e72a0b7f395a80f55842a6ee31dab7c8237d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Mon, 1 Jul 2024 15:50:45 -0600 Subject: [PATCH] suppress warning when saving hyperparameters in base auto (#1034) --- nbs/common.base_auto.ipynb | 7 ++++++- neuralforecast/common/_base_auto.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/nbs/common.base_auto.ipynb b/nbs/common.base_auto.ipynb index 516878c87..788c1d446 100644 --- a/nbs/common.base_auto.ipynb +++ b/nbs/common.base_auto.ipynb @@ -62,6 +62,7 @@ "outputs": [], "source": [ "#| export\n", + "import warnings\n", "from copy import deepcopy\n", "from os import cpu_count\n", "\n", @@ -169,7 +170,11 @@ " callbacks=None,\n", " ):\n", " super(BaseAuto, self).__init__()\n", - " self.save_hyperparameters() # Allows instantiation from a checkpoint from class\n", + " with warnings.catch_warnings(record=False):\n", + " warnings.filterwarnings('ignore')\n", + " # the following line issues a warning about the loss attribute being saved\n", + " # but we do want to save it\n", + " self.save_hyperparameters() # Allows instantiation from a checkpoint from class\n", "\n", " if backend == 'ray':\n", " if not isinstance(config, dict):\n", diff --git a/neuralforecast/common/_base_auto.py b/neuralforecast/common/_base_auto.py index f566a9e92..ba1bafede 100644 --- a/neuralforecast/common/_base_auto.py +++ b/neuralforecast/common/_base_auto.py @@ -4,6 +4,7 @@ __all__ = ['BaseAuto'] # %% ../../nbs/common.base_auto.ipynb 5 +import warnings from copy import deepcopy from os import cpu_count @@ -101,7 +102,11 @@ def __init__( callbacks=None, ): super(BaseAuto, self).__init__() - self.save_hyperparameters() # Allows instantiation from a checkpoint from class + with warnings.catch_warnings(record=False): + warnings.filterwarnings("ignore") + # the following line issues a warning about the loss attribute being saved + # but we do want to save it + self.save_hyperparameters() # Allows instantiation from a checkpoint from class if backend == "ray": if not isinstance(config, dict):