Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom LR scheduler not supported #191

Open
ErikHartman opened this issue Aug 1, 2024 · 1 comment
Open

Custom LR scheduler not supported #191

ErikHartman opened this issue Aug 1, 2024 · 1 comment

Comments

@ErikHartman
Copy link

Bug description
Trying to set a custom learning rate scheduler with the set_lr_scheduler_class throws an NotImplemented error.

To Reproduce

from torch.optim.lr_scheduler import ReduceLROnPlateau
from your_package import LR_SchedulerInterface  # Replace with the actual import path

class CustomReduceLROnPlateau(LR_SchedulerInterface):
    def __init__(self, optimizer, num_warmup_steps, num_training_steps, patience=10, factor=0.1, mode='min', threshold=1e-4, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8, **kwargs):
        super().__init__()
        self.scheduler = ReduceLROnPlateau(
            optimizer,
            mode=mode,
            factor=factor,
            patience=patience,
            threshold=threshold,
            threshold_mode=threshold_mode,
            cooldown=cooldown,
            min_lr=min_lr,
            eps=eps,
            **kwargs
        )

    def step(self, metrics, epoch=None):
        self.scheduler.step(metrics, epoch)

    def get_last_lr(self):
        return self.scheduler.optimizer.param_groups[0]['lr']

from peptdeep.model.generic_property_prediction import (
    ModelInterface_for_Generic_AASeq_Regression,
)
from peptdeep.model.generic_property_prediction import (
    Model_for_Generic_AASeq_Regression_Transformer,
)

transformer = ModelInterface_for_Generic_AASeq_Regression(
    model_class=Model_for_Generic_AASeq_Regression_Transformer
)
transformer.target_column_to_train = 'normlogintensity'
transformer.target_column_to_predict = 'transformer_predictions'
transformer.train(data_train, warmup_epoch=10, epoch=50, verbose=True)

Expected behavior
No error.

Version (please complete the following information):

  • Installation Type: pip
  • peptdeep version 1.2.1

Additional context
I see in the source code that this isn't implemented. Would be nice if it was.

@mo-sameh
Copy link
Collaborator

Since LR_SchedulerInterface is just an interface, you don't need to call super().init in the implementation. The NotImplementedError occurs because it tries to initialize an interface class.

Additionally, two other changes are needed:

  1. ModelInterface expects the step function to have a positional argument named "loss," not "metrics."
  2. get_last_lr should return a list to handle cases with multiple parameter groups.

Sorry about the docstring incorrectly stating that the return value should be a float; it will be updated. This version of your scheduler should work fine:

class CustomReduceLROnPlateau(LR_SchedulerInterface):
        def __init__(self, optimizer, num_warmup_steps, num_training_steps, patience=10, factor=0.1, mode='min', threshold=1e-4, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8, **kwargs):
            self.scheduler = ReduceLROnPlateau(
                optimizer,
                mode=mode,
                factor=factor,
                patience=patience,
                threshold=threshold,
                threshold_mode=threshold_mode,
                cooldown=cooldown,
                min_lr=min_lr,
                eps=eps,
                **kwargs
            )
    
        def step(self, loss, epoch=None):
            self.scheduler.step(loss, epoch)
    
        def get_last_lr(self):
            return [self.scheduler.optimizer.param_groups[0]['lr']]

For a very similar implementation, you can check out the scheduler used in AlphaDIA at Alphadia transfer-learning

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants