-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bandit dispatcher class #372
base: main
Are you sure you want to change the base?
Changes from all commits
6a71c73
a471312
d9691c9
3947528
3e31b7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
"""Bandit surrogates.""" | ||
|
||
from baybe.surrogates.bandits.base import MultiArmedBanditSurrogate | ||
from baybe.surrogates.bandits.beta_bernoulli import ( | ||
BetaBernoulliMultiArmedBanditSurrogate, | ||
) | ||
|
||
__all__ = [ | ||
"BetaBernoulliMultiArmedBanditSurrogate", | ||
"MultiArmedBanditSurrogate", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
"""Multi-armed bandit surrogates.""" | ||
|
||
from __future__ import annotations | ||
|
||
from attrs import define, field | ||
from attrs.validators import instance_of | ||
from pandas import DataFrame | ||
|
||
from baybe.objectives.base import Objective | ||
from baybe.objectives.single import SingleTargetObjective | ||
from baybe.priors.base import Prior | ||
from baybe.searchspace.core import SearchSpace | ||
from baybe.surrogates.bandits.beta_bernoulli import ( | ||
BetaBernoulliMultiArmedBanditSurrogate, | ||
) | ||
from baybe.surrogates.base import Surrogate | ||
from baybe.targets.binary import BinaryTarget | ||
|
||
|
||
def _get_bandit_class( | ||
searchspace: SearchSpace, objective: Objective | ||
) -> type[BetaBernoulliMultiArmedBanditSurrogate]: | ||
"""Retrieve the appropriate bandit class for the given modelling context.""" | ||
match searchspace, objective: | ||
case _, SingleTargetObjective(_target=BinaryTarget()): | ||
return BetaBernoulliMultiArmedBanditSurrogate | ||
case _: | ||
raise NotImplementedError( | ||
f"Currently, only a single target of type '{BinaryTarget.__name__}' " | ||
f"is supported." | ||
) | ||
|
||
|
||
@define | ||
class MultiArmedBanditSurrogate: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @AVHopp, @Scienfitz: here a first draft of what we've discussed. At this stage, a few comments:
What do you think so far? That said, we also don't need to rush with the decision, especially if we agree on the last point. Because than we can also add the dispatcher after the release. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. re 2: re 3: At the same time, I would argue the dispatcher offers Lastly, imo it would be safer to release the special MAB as private, instead of the other way round -> should we reverse the decision no deprecation required There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The more I think about it, the more I'd go away again from the strategy design pattern again and rather use a factory approach. This would reduce one level of nesting, make typing/auto-complete more straightforward, and require no additional machinery. Related to that your question about protocols: we already have all we need, I think: there is the Surrogate base class and the SurrogateProtocol, and all specialized bandits are confirm with them. The MAB class on the other hand is NOT a Surrogate (it does not inherit from it and it shouldn't, because I'd then have all stuff twice). And while it sort of complies with the Protocol, it's only because we trick it to do so by forwarding attribute access, causing the typing issues. So I'm really not yet convinced... While I might also overlook other issues with the alternative approach, it seems a bit clearer to me. It would mean to just replace the MAB class with, say, a factory function So anyway, I have no clear decision at the moment. Your choice to decide what to do 🙃 before you to, perhaps play around with it interactively and experience the feeling and typing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the status here now? Should we discuss this in one of our upcoming meetings? Seems like this touches an important, general point that we should align on. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either that or in a next baybathon. I think it's an important and general design decision we should be very clear about, in case we want to add it. But since it's not pressing, perhaps a baybathon with some other collected topics for it would be nice. Whiteboard sessions are more fun than teams meetings 🙃 |
||
"""A bandit surrogate class dispatching class. | ||
|
||
Follows the strategy design pattern to dispatch to the appropriate bandit model. | ||
""" | ||
|
||
prior: Prior = field(validator=instance_of(Prior)) | ||
"""The prior distribution assumed for each arm of the bandit.""" | ||
|
||
_bandit_model: Surrogate | None = field(init=False, default=None, eq=False) | ||
"""The specific bandit model to which is being dispatched.""" | ||
|
||
def fit( | ||
self, searchspace: SearchSpace, objective: Objective, measurements: DataFrame | ||
) -> None: | ||
"""Instantiate an appropriate bandit model and fit it to the data.""" | ||
cls = _get_bandit_class(searchspace, objective) | ||
self._bandit_model = cls(self.prior) | ||
self._bandit_model.fit(searchspace, objective, measurements) | ||
|
||
def __getattr__(self, name): | ||
# If the attribute is not found, try to get it from the bandit object | ||
return getattr(self._bandit_model, name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
included the searchspace in this private (so not affected by deprecation considerations) function why?