Skip to content

Commit

Permalink
Merge branch 'master' into pr/1624
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Sep 7, 2024
2 parents f5ab183 + e2bf3fd commit e3c6d29
Show file tree
Hide file tree
Showing 48 changed files with 529 additions and 168 deletions.
1 change: 0 additions & 1 deletion .env

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/pypi_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.8
python-version: 3.11

- name: Install poetry
shell: bash
Expand Down
51 changes: 47 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,61 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[dev,github-actions,graph,mqf2]"
python -m pip install ".[dev,all_extras,github-actions]"
- name: Show dependencies
run: python -m pip list

- name: Run example notebooks
run: build_tools/run_examples.sh
shell: bash

pytest-nosoftdeps:
name: no-softdeps
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-13, windows-latest]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Setup macOS
if: runner.os == 'macOS'
run: |
brew install libomp # https://github.com/pytorch/pytorch/issues/20030
- name: Get full Python version
id: full-python-version
shell: bash
run: echo version=$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))") >> $GITHUB_OUTPUT

- name: Install dependencies
shell: bash
run: |
pip install ".[dev,github-actions]"
- name: Show dependencies
run: python -m pip list

- name: Run pytest
shell: bash
run: python -m pytest tests

pytest:
name: Run pytest
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-13] # add windows-2019 when poetry allows installation with `-f` flag
os: [ubuntu-latest, macos-13, windows-latest]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
Expand All @@ -60,7 +103,7 @@ jobs:
- name: Install dependencies
shell: bash
run: |
pip install ".[dev,github-actions,graph,mqf2]"
pip install ".[dev,all_extras,github-actions]"
- name: Show dependencies
run: python -m pip list
Expand Down Expand Up @@ -99,7 +142,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.8
python-version: 3.11

- name: Cache pip
uses: actions/cache@v2
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ celerybeat.pid
*.sage.py

# Environments
# .env
.env
.venv
env/
venv/
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-ast
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
rev: 7.1.1
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.10.1
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.9.0
rev: 24.8.0
hooks:
- id: black
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.0
rev: 1.8.7
hooks:
- id: nbqa-black
- id: nbqa-isort
Expand Down
4 changes: 2 additions & 2 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ version: 2
# reference: https://docs.readthedocs.io/en/stable/config-file/v2.html#sphinx
sphinx:
configuration: docs/source/conf.py
fail_on_warning: true
# fail_on_warning: true

# Build documentation with MkDocs
#mkdocs:
Expand All @@ -21,6 +21,6 @@ formats:

# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.8
version: 3.11
install:
- requirements: docs/requirements.txt
3 changes: 3 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ coverage:
project:
default:
threshold: 0.2%
patch:
default:
informational: true
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ nbconvert >=6.3.0
recommonmark >=0.7.1
pytorch-optimizer >=2.5.1
fastapi >0.80
cpflows
4 changes: 2 additions & 2 deletions docs/source/_templates/custom-module-template.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@
{% endblock %}

{% block modules %}
{% if modules %}
{% if all_modules %}
.. rubric:: Modules

.. autosummary::
:toctree:
:template: custom-module-template.rst
:recursive:
{% for item in modules %}
{% for item in all_modules %}
{{ item }}
{%- endfor %}
{% endif %}
Expand Down
4 changes: 4 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,7 @@ def setup(app: Sphinx):
intersphinx_mapping = {
"sklearn": ("https://scikit-learn.org/stable/", None),
}

suppress_warnings = [
"autosummary.import_cycle",
]
5 changes: 1 addition & 4 deletions docs/source/tutorials/ar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"os.chdir(\"../../..\")"
"warnings.filterwarnings(\"ignore\")"
]
},
{
Expand Down
21 changes: 12 additions & 9 deletions docs/source/tutorials/building.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,9 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"os.chdir(\"../../..\")"
"warnings.filterwarnings(\"ignore\")"
]
},
{
Expand Down Expand Up @@ -1034,16 +1031,19 @@
{
"cell_type": "raw",
"metadata": {
"raw_mimetype": "text/restructuredtext"
"raw_mimetype": "text/restructuredtext",
"vscode": {
"languageId": "raw"
}
},
"source": [
"While not required, to give the user transparancy over these additional hyperparameters, it is worth passing them explicitly instead of implicitly in ``**kwargs``\n",
"\n",
"They are described in detail in the :py:class:`~pytorch_forecasting.models.base_model.BaseModel`. \n",
"They are described in detail in the :py:class:`~pytorch_forecasting.models.base_model.BaseModel`.\n",
"\n",
".. automethod:: pytorch_forecasting.models.base_model.BaseModel.__init__\n",
" :noindex:\n",
" \n",
"\n",
"You can simply copy this docstring into your model implementation:"
]
},
Expand Down Expand Up @@ -2238,15 +2238,18 @@
{
"cell_type": "raw",
"metadata": {
"raw_mimetype": "text/restructuredtext"
"raw_mimetype": "text/restructuredtext",
"vscode": {
"languageId": "raw"
}
},
"source": [
"Now that we have established the basics, we can move on to more advanced use cases, e.g. how can we make use of covariates - static and continuous alike. We can leverage the :py:class:`~pytorch_forecasting.models.base_model.BaseModelWithCovariates` for this. The difference to the :py:class:`~pytorch_forecasting.models.base_model.BaseModel` is a :py:meth:`~pytorch_forecasting.models.base_model.BaseModelWithCovariates.from_dataset` method that pre-defines hyperparameters for architectures with covariates.\n",
"\n",
".. autoclass:: pytorch_forecasting.models.base_model.BaseModelWithCovariates\n",
" :noindex:\n",
" :members: from_dataset\n",
" \n",
"\n",
"\n",
"Here is a from the BaseModelWithCovariates docstring to copy:"
]
Expand Down
5 changes: 1 addition & 4 deletions docs/source/tutorials/deepar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"os.chdir(\"../../..\")"
"warnings.filterwarnings(\"ignore\")"
]
},
{
Expand Down
5 changes: 1 addition & 4 deletions docs/source/tutorials/nhits.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"os.chdir(\"../../..\")"
"warnings.filterwarnings(\"ignore\")"
]
},
{
Expand Down
16 changes: 8 additions & 8 deletions docs/source/tutorials/stallion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,9 @@
},
"outputs": [],
"source": [
"import os\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\") # avoid printing out absolute paths\n",
"\n",
"os.chdir(\"../../..\")"
"warnings.filterwarnings(\"ignore\") # avoid printing out absolute paths"
]
},
{
Expand Down Expand Up @@ -1557,16 +1554,19 @@
{
"cell_type": "raw",
"metadata": {
"raw_mimetype": "text/restructuredtext"
"raw_mimetype": "text/restructuredtext",
"vscode": {
"languageId": "raw"
}
},
"source": [
"Hyperparamter tuning with [optuna](https://optuna.org/) is directly build into pytorch-forecasting. For example, we can use the \n",
"Hyperparamter tuning with [optuna](https://optuna.org/) is directly build into pytorch-forecasting. For example, we can use the\n",
":py:func:`~pytorch_forecasting.models.temporal_fusion_transformer.tuning.optimize_hyperparameters` function to optimize the TFT's hyperparameters.\n",
"\n",
".. code-block:: python\n",
"\n",
" import pickle\n",
" \n",
"\n",
" from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters\n",
"\n",
" # create study\n",
Expand Down Expand Up @@ -1917,7 +1917,7 @@
"source": [
"# calcualte metric by which to display\n",
"predictions = best_tft.predict(val_dataloader, return_y=True)\n",
"mean_losses = SMAPE(reduction=\"none\")(predictions.output, predictions.y).mean(1)\n",
"mean_losses = SMAPE(reduction=\"none\").loss(predictions.output, predictions.y[0]).mean(1)\n",
"indices = mean_losses.argsort(descending=True) # sort losses\n",
"for idx in range(10): # plot 10 examples\n",
" best_tft.plot_prediction(\n",
Expand Down
8 changes: 1 addition & 7 deletions examples/ar.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
from pathlib import Path
import pickle
import warnings

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner import Tuner
import numpy as np
import pandas as pd
from pandas.core.common import SettingWithCopyWarning
import torch

from pytorch_forecasting import EncoderNormalizer, GroupNormalizer, TimeSeriesDataSet
from pytorch_forecasting import GroupNormalizer, TimeSeriesDataSet
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_forecasting.data.examples import generate_ar_data
from pytorch_forecasting.metrics import NormalDistributionLoss
from pytorch_forecasting.models.deepar import DeepAR
from pytorch_forecasting.utils import profile

warnings.simplefilter("error", category=SettingWithCopyWarning)

Expand Down
2 changes: 0 additions & 2 deletions examples/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.tuner import Tuner
import pandas as pd
from sklearn.preprocessing import scale

from pytorch_forecasting import NBeats, TimeSeriesDataSet
from pytorch_forecasting.data import NaNLabelEncoder
Expand Down
7 changes: 1 addition & 6 deletions examples/stallion.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
from pathlib import Path
import pickle
import warnings

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner import Tuner
import numpy as np
import pandas as pd
from pandas.core.common import SettingWithCopyWarning
import torch

from pytorch_forecasting import GroupNormalizer, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data.examples import get_stallion_data
from pytorch_forecasting.metrics import MAE, RMSE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
from pytorch_forecasting.utils import profile

warnings.simplefilter("error", category=SettingWithCopyWarning)

Expand Down
Loading

0 comments on commit e3c6d29

Please sign in to comment.