Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bandit dispatcher class #372

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion baybe/surrogates/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""BayBE surrogates."""

from baybe.surrogates.bandit import BetaBernoulliMultiArmedBanditSurrogate
from baybe.surrogates.bandits.base import MultiArmedBanditSurrogate
from baybe.surrogates.bandits.beta_bernoulli import (
BetaBernoulliMultiArmedBanditSurrogate,
)
from baybe.surrogates.custom import CustomONNXSurrogate, register_custom_architecture
from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate
from baybe.surrogates.linear import BayesianLinearSurrogate
Expand All @@ -15,6 +18,7 @@
"CustomONNXSurrogate",
"GaussianProcessSurrogate",
"MeanPredictionSurrogate",
"MultiArmedBanditSurrogate",
"NGBoostSurrogate",
"RandomForestSurrogate",
]
11 changes: 11 additions & 0 deletions baybe/surrogates/bandits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Bandit surrogates."""

from baybe.surrogates.bandits.base import MultiArmedBanditSurrogate
from baybe.surrogates.bandits.beta_bernoulli import (
BetaBernoulliMultiArmedBanditSurrogate,
)

__all__ = [
"BetaBernoulliMultiArmedBanditSurrogate",
"MultiArmedBanditSurrogate",
]
57 changes: 57 additions & 0 deletions baybe/surrogates/bandits/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Multi-armed bandit surrogates."""

from __future__ import annotations

from attrs import define, field
from attrs.validators import instance_of
from pandas import DataFrame

from baybe.objectives.base import Objective
from baybe.objectives.single import SingleTargetObjective
from baybe.priors.base import Prior
from baybe.searchspace.core import SearchSpace
from baybe.surrogates.bandits.beta_bernoulli import (
BetaBernoulliMultiArmedBanditSurrogate,
)
from baybe.surrogates.base import Surrogate
from baybe.targets.binary import BinaryTarget


def _get_bandit_class(
searchspace: SearchSpace, objective: Objective
) -> type[BetaBernoulliMultiArmedBanditSurrogate]:
"""Retrieve the appropriate bandit class for the given modelling context."""
match searchspace, objective:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

included the searchspace in this private (so not affected by deprecation considerations) function why?

case _, SingleTargetObjective(_target=BinaryTarget()):
return BetaBernoulliMultiArmedBanditSurrogate
case _:
raise NotImplementedError(
f"Currently, only a single target of type '{BinaryTarget.__name__}' "
f"is supported."
)


@define
class MultiArmedBanditSurrogate:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @AVHopp, @Scienfitz: here a first draft of what we've discussed. At this stage, a few comments:

  • In principle, this setup works already. However, I'm not yet 100% convinced if this is the best way to go. We should not rush this decision, in particular since the same question (i.e. how to dispatch to more specialized implementations) will soon also arise for our other surrogates, e.g. ApproximateGPs, SparseGPs, ...
  • While it works, it does have the disadvantage that you basically loose all autocomplete when you invoke the dispatcher, which is not great. Perhaps we can improve upon this somehow by making it inherit from a (possibly new) base class for all bandit models, where we define the interface? The problem is however, that we cannot inherit from Surrogate nor from SurrogateProtocol, since the attribute forwarding will get very ugly otherwise
  • Overall, looking at the code, if we decide to keep it or a similar version, I don't want to expose only the dispatcher. To me, this is more like a utility in the sense that a user has no clue what to use can simply throw their specifications in and will get a reasonable model choice in return. However, there is no reason not to allow also full control by additionally exposing the specialized classes.

What do you think so far? That said, we also don't need to rush with the decision, especially if we agree on the last point. Because than we can also add the dispatcher after the release.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re 2:
would it be possible to create a bandit protocol and inherit from it for that purpose? or does the autocomplete not work with it if its just protocol?

re 3:
I see your argument but it would somehow end up in this situation of us having lots of models laying around. so when I type from baybe.surrogates import or just search the docs etc about what models are available, I would get lots of models and one model that looks like a base class due to the naming. Even though its not a base class in our case Im always annoyed if packages also expose all the small detailed objects - it will perhaps confuse less senior users.

At the same time, I would argue the dispatcher offers full control. Eg, for users who know about the beta/bernulli/binary/bandit connection. They simply set of stuff with beta prior and binary target.

Lastly, imo it would be safer to release the special MAB as private, instead of the other way round -> should we reverse the decision no deprecation required

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more I think about it, the more I'd go away again from the strategy design pattern again and rather use a factory approach. This would reduce one level of nesting, make typing/auto-complete more straightforward, and require no additional machinery. Related to that your question about protocols: we already have all we need, I think: there is the Surrogate base class and the SurrogateProtocol, and all specialized bandits are confirm with them. The MAB class on the other hand is NOT a Surrogate (it does not inherit from it and it shouldn't, because I'd then have all stuff twice). And while it sort of complies with the Protocol, it's only because we trick it to do so by forwarding attribute access, causing the typing issues. So I'm really not yet convinced...

While I might also overlook other issues with the alternative approach, it seems a bit clearer to me. It would mean to just replace the MAB class with, say, a factory function make_bandit that returns an object of type Surrogate. Done. The downsides are: people need explicitly call the factory, and they need to specify the domain specs upfront and later again. But when regarded only as a utility for automatic model selection as alternative to manually specifying the model, this is completely OK.

So anyway, I have no clear decision at the moment. Your choice to decide what to do 🙃 before you to, perhaps play around with it interactively and experience the feeling and typing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the status here now? Should we discuss this in one of our upcoming meetings? Seems like this touches an important, general point that we should align on.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either that or in a next baybathon. I think it's an important and general design decision we should be very clear about, in case we want to add it. But since it's not pressing, perhaps a baybathon with some other collected topics for it would be nice. Whiteboard sessions are more fun than teams meetings 🙃

"""A bandit surrogate class dispatching class.

Follows the strategy design pattern to dispatch to the appropriate bandit model.
"""

prior: Prior = field(validator=instance_of(Prior))
"""The prior distribution assumed for each arm of the bandit."""

_bandit_model: Surrogate | None = field(init=False, default=None, eq=False)
"""The specific bandit model to which is being dispatched."""

def fit(
self, searchspace: SearchSpace, objective: Objective, measurements: DataFrame
) -> None:
"""Instantiate an appropriate bandit model and fit it to the data."""
cls = _get_bandit_class(searchspace, objective)
self._bandit_model = cls(self.prior)
self._bandit_model.fit(searchspace, objective, measurements)

def __getattr__(self, name):
# If the attribute is not found, try to get it from the bandit object
return getattr(self._bandit_model, name)
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Multi-armed bandit surrogates."""
"""A Beta-Bernoulli multi-armed bandit surrogate."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, ClassVar

from attrs import define, field
from attrs.validators import instance_of

from baybe.exceptions import IncompatibleSearchSpaceError, ModelNotTrainedError
from baybe.parameters.categorical import CategoricalParameter
Expand All @@ -28,7 +29,9 @@ class BetaBernoulliMultiArmedBanditSurrogate(Surrogate):
supports_transfer_learning: ClassVar[bool] = False
# See base class.

prior: BetaPrior = field(factory=lambda: BetaPrior(1, 1))
prior: BetaPrior = field(
factory=lambda: BetaPrior(1, 1), validator=instance_of(BetaPrior)
)
"""The beta prior for the win rates of the bandit arms. Uniform by default."""

# TODO: type should be `torch.Tensor | None` but is currently
Expand Down
5 changes: 3 additions & 2 deletions examples/Multi_Armed_Bandit/bernoulli_multi_armed_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
from baybe.acquisition import PosteriorStandardDeviation, qThompsonSampling
from baybe.acquisition.base import AcquisitionFunction
from baybe.parameters import CategoricalParameter
from baybe.priors.basic import BetaPrior
from baybe.recommenders import (
BotorchRecommender,
RandomRecommender,
TwoPhaseMetaRecommender,
)
from baybe.surrogates import BetaBernoulliMultiArmedBanditSurrogate
from baybe.surrogates import MultiArmedBanditSurrogate
from baybe.targets import BinaryTarget

### An Imaginary Use Case
Expand Down Expand Up @@ -97,7 +98,7 @@ def draw_arm(self, arm_index: int) -> bool:
# For bandits, one-hot parameter encoding is required:
encoding="OHE",
)
surrogate = BetaBernoulliMultiArmedBanditSurrogate()
surrogate = MultiArmedBanditSurrogate(BetaPrior(1, 1))


# For each simulation, we report the trajectory of earned rewards and the estimated win rates of the bandit arms:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_iterations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)
from baybe.recommenders.pure.nonpredictive.base import NonPredictiveRecommender
from baybe.searchspace import SearchSpaceType
from baybe.surrogates.bandit import BetaBernoulliMultiArmedBanditSurrogate
from baybe.surrogates.bandits import BetaBernoulliMultiArmedBanditSurrogate
from baybe.surrogates.base import IndependentGaussianSurrogate, Surrogate
from baybe.surrogates.custom import CustomONNXSurrogate
from baybe.surrogates.gaussian_process.presets import (
Expand Down
Loading