-
Notifications
You must be signed in to change notification settings - Fork 344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FIX] Code refactoring #1023
base: main
Are you sure you want to change the base?
[FIX] Code refactoring #1023
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
this is a very cool effort @elephaint. the new features look exciting (eg losses compatibility, model unification, ...) and there are a lot of bug fixes, too! |
Thanks! At this moment the issue is mostly that there is some performance regression; little bit of work to do there.... |
Performance tests - first picture is baseline (existing main repo), second picture is this PR. Conclusion
Model performanceModel performance 2Multivariate model performance |
…neuralforecast into fix/docs_and_refactoring
This is a large refactoring PR and open for discussion. The main goal of the PR is to unify API across different model types, and unify loss functions across different loss types.
Refactoring:
BaseWindows
,BaseMultivariate
andBaseRecurrent
intoBaseModel
, removing the need for separate classes and unifying model API across different model types. Instead, this PR introduces two model attributes, yielding four possible model options:RECURRENT
(True
/False
) andMULTIVARIATE
(True
/False
). We currently have a model for every combination except a recurrent multivariate model (e.g. a multivariate LSTM), however this is now relatively simple to add. In addition, this change allows to have models that can be recurrent or not, or multivariate or not on-the-fly, based on users' input. This also allows for easier modelling going forward.domain_map
functions.loss.domain_map
outside of models toBaseModel
TSMixer
,TSMixerx
andRMoK
tocommon.modules
Features:
DistributionLoss
now supports the use ofquantile
inpredict
, allowing for easy quantile retrieval for all DistributionLosses.GMM
,PMM
andNBMM
) now support learned weights for weighted mixture distribution outputs.quantile
inpredict
, allowing for easy quantile retrieval.ISQF
by adding softplus protection around some parameters instead of using.abs
Bug fixes:
MASE
loss now works.StudentT
increase default DoF to 3 to reduce unbound variance issues.eval: false
on the examples whilst not having any other tests, causing most models to effectively not being testedBreaking changes:
input_size
to be given.TCN
andDRNN
are now windows models, not recurrent models.Tests:
common._model_checks.py
that includes a model testing function.Todo: