Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow custom optimizer #300

Merged
merged 3 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/pytorch_tabular/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,8 @@ class OptimizerConfig:
"""Optimizer and Learning Rate Scheduler configuration.
Args:
optimizer (str): Any of the standard optimizers from
[torch.optim](https://pytorch.org/docs/stable/optim.html#algorithms).
[torch.optim](https://pytorch.org/docs/stable/optim.html#algorithms) or provide full python path,
for example "torch_optimizer.RAdam".

optimizer_params (Dict): The parameters for the optimizer. If left blank, will use default
parameters.
Expand All @@ -675,7 +676,8 @@ class OptimizerConfig:
default="Adam",
metadata={
"help": "Any of the standard optimizers from"
" [torch.optim](https://pytorch.org/docs/stable/optim.html#algorithms)."
" [torch.optim](https://pytorch.org/docs/stable/optim.html#algorithms) or provide full python path,"
" for example 'torch_optimizer.RAdam'."
},
)
optimizer_params: Dict = field(
Expand Down
8 changes: 7 additions & 1 deletion src/pytorch_tabular/ssl_models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Author: Manu Joseph <[email protected]>
# For license information, see LICENSE.TXT
"""SSL Base Model."""
import importlib
import warnings
from abc import ABCMeta, abstractmethod
from typing import Dict, Optional
Expand Down Expand Up @@ -172,7 +173,12 @@ def configure_optimizers(self):
if self.custom_optimizer is None:
# Loading from the config
try:
self._optimizer = getattr(torch.optim, self.hparams.optimizer)
if "." in self.hparams.optimizer:
py_path, cls_name = self.hparams.optimizer.rsplit(".", 1)
module = importlib.import_module(py_path)
self._optimizer = getattr(module, cls_name)
else:
self._optimizer = getattr(torch.optim, self.hparams.optimizer)
opt = self._optimizer(
self.parameters(),
lr=self.hparams.learning_rate,
Expand Down
16 changes: 8 additions & 8 deletions src/pytorch_tabular/tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,10 @@ def prepare_model(
The length of the list should be equal to the number of metrics. Defaults to None.

optimizer (Optional[torch.optim.Optimizer], optional):
Custom optimizers which are a drop in replacements for standard PyToch optimizers.
Custom optimizers which are a drop in replacements for standard PyTorch optimizers.
This should be the Class and not the initialized object

optimizer_params (Optional[Dict], optional): The parmeters to initialize the custom optimizer.
optimizer_params (Optional[Dict], optional): The parameters to initialize the custom optimizer.

Returns:
BaseModel: The prepared model
Expand Down Expand Up @@ -631,9 +631,9 @@ def fit(

optimizer (Optional[torch.optim.Optimizer], optional):
Custom optimizers which are a drop in replacements for
standard PyToch optimizers. This should be the Class and not the initialized object
standard PyTorch optimizers. This should be the Class and not the initialized object

optimizer_params (Optional[Dict], optional): The parmeters to initialize the custom optimizer.
optimizer_params (Optional[Dict], optional): The parameters to initialize the custom optimizer.

train_sampler (Optional[torch.utils.data.Sampler], optional):
Custom PyTorch batch samplers which will be passed
Expand Down Expand Up @@ -717,9 +717,9 @@ def pretrain(
Defaults to None.

optimizer (Optional[torch.optim.Optimizer], optional): Custom optimizers which are a drop in replacements
for standard PyToch optimizers. This should be the Class and not the initialized object
for standard PyTorch optimizers. This should be the Class and not the initialized object

optimizer_params (Optional[Dict], optional): The parmeters to initialize the custom optimizer.
optimizer_params (Optional[Dict], optional): The parameters to initialize the custom optimizer.

max_epochs (Optional[int]): Overwrite maximum number of epochs to be run. Defaults to None.

Expand All @@ -731,7 +731,7 @@ def pretrain(
Defaults to None.

datamodule (Optional[TabularDatamodule], optional): The datamodule. If provided, will ignore the rest of the
parameters like train, test etc and use the datamodule. Defaults to None.
parameters like train, test etc. and use the datamodule. Defaults to None.

Returns:
pl.Trainer: The PyTorch Lightning Trainer instance
Expand Down Expand Up @@ -808,7 +808,7 @@ def create_finetune_model(

loss (Optional[torch.nn.Module], optional):
If provided, will be used as the loss function for the fine-tuning.
By Default it is MSELoss for regression and CrossEntropyLoss for classification.
By default, it is MSELoss for regression and CrossEntropyLoss for classification.

metrics (Optional[List[Callable]], optional): List of metrics (either callables or str) to be used for the
fine-tuning stage. If str, it should be one of the functional metrics implemented in
Expand Down
Loading