diff --git a/CHANGELOG.md b/CHANGELOG.md index 50c4ca3d7..23b2ca225 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ ### New features +* Add configuration facilities to Bambi (#745) +* Interpet submodule now outputs informative messages when computing default values (#745) + ### Maintenance and fixes ### Documentation diff --git a/bambi/__init__.py b/bambi/__init__.py index c8fe79692..660d6724c 100644 --- a/bambi/__init__.py +++ b/bambi/__init__.py @@ -5,6 +5,7 @@ from pymc import math from .backend import PyMCModel +from .config import config from .data import clear_data_home, load_data from .families import Family, Likelihood, Link from .formula import Formula @@ -23,6 +24,7 @@ "PyMCModel", "Formula", "clear_data_home", + "config", "load_data", "math", ] diff --git a/bambi/config.py b/bambi/config.py new file mode 100644 index 000000000..aa982115d --- /dev/null +++ b/bambi/config.py @@ -0,0 +1,51 @@ +class Config: + """Configuration variables for Bambi + + It works with a pre-specified set of configuration variables and options for those. + When a user tries to set a configuration variable to a non-supported value, it raises an error. + """ + + __FIELDS = {"INTERPRET_VERBOSE": (True, False)} + + def __init__(self, config_dict: dict = None): + config_dict = {} if config_dict is None else config_dict + # When an option is not specified at instantiation time, it uses the first value specified + # in __FIELDS. + for field, choices in Config.__FIELDS.items(): + if field in config_dict: + value = config_dict[field] + else: + value = choices[0] + self[field] = value + + def __setitem__(self, key, value): + setattr(self, key, value) + + def __setattr__(self, key, value): + if key in Config.__FIELDS: + if value in Config.__FIELDS[key]: + super().__setattr__(key, value) + else: + raise ValueError( + f"{value} is not a valid value for '{key}'" + f"Valid options are: {Config.__FIELDS[key]}" + ) + else: + raise KeyError(f"'{key}' is not a valid configuration option") + + def __getitem__(self, key): + return getattr(self, key) + + def __str__(self): # pragma: no cover + lines = [] + for field, choices in Config.__FIELDS.items(): + lines.append(f"{field}: {self[field]} (available: {list(choices)})") + header = ["Bambi configuration"] + header.append("-" * len(header[0])) + return "\n".join(header + lines) + + def __repr__(self): # pragma: no cover + return str(self) + + +config = Config() diff --git a/bambi/interpret/__init__.py b/bambi/interpret/__init__.py index 126f288ec..b487826e0 100644 --- a/bambi/interpret/__init__.py +++ b/bambi/interpret/__init__.py @@ -1,4 +1,5 @@ -from bambi.interpret.config import logger +import logging + from bambi.interpret.effects import comparisons, predictions, slopes from bambi.interpret.plotting import plot_comparisons, plot_predictions, plot_slopes @@ -11,3 +12,10 @@ "plot_predictions", "plot_slopes", ] + +logger = logging.getLogger("__bambi_interpret__") + +if not logging.root.handlers: + logger.setLevel(logging.INFO) + if len(logger.handlers) == 0: + logger.addHandler(logging.StreamHandler()) diff --git a/bambi/interpret/config.py b/bambi/interpret/config.py deleted file mode 100644 index 7a8b2fe98..000000000 --- a/bambi/interpret/config.py +++ /dev/null @@ -1,19 +0,0 @@ -import logging - - -class InterpretLogger: - def __init__(self, messages=False): - self.messages = messages - - def get_logger(self, name=None): - _log = logging.getLogger(name) - - if not logging.root.handlers: - _log.setLevel(logging.INFO) - if len(_log.handlers) == 0: - _log.addHandler(logging.StreamHandler()) - - return _log - - -logger = InterpretLogger() diff --git a/bambi/interpret/logs.py b/bambi/interpret/logs.py index 01d7443fe..0e2e026b3 100644 --- a/bambi/interpret/logs.py +++ b/bambi/interpret/logs.py @@ -1,4 +1,6 @@ -from bambi.interpret import logger +import logging + +from bambi import config def log_interpret_defaults(func): @@ -10,11 +12,11 @@ def log_interpret_defaults(func): or 'wrt' of 'comparisons' and 'slopes', as well as the 'conditional' parameter of 'comparisons', 'predictions', and 'slopes'. """ - interpret_logger = logger.get_logger("interpret") + logger = logging.getLogger("__bambi_interpret__") def wrapper(*args, **kwargs): - if not logger.messages: + if not config["INTERPRET_VERBOSE"]: return func(*args, **kwargs) arg_name = None @@ -39,7 +41,7 @@ def wrapper(*args, **kwargs): covariate_name = args[0].name if arg_name: - interpret_logger.info("Default computed for %s variable: %s", arg_name, covariate_name) + logger.info("Default computed for %s variable: %s", arg_name, covariate_name) return func(*args, **kwargs) diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 000000000..0fb339980 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,21 @@ +import pytest + +from bambi.config import Config + + +def test_config(): + config = Config() + + # Evaluate getters + assert config["INTERPRET_VERBOSE"] is True + assert config.INTERPRET_VERBOSE is True + + config.INTERPRET_VERBOSE = False + assert config.INTERPRET_VERBOSE is False + + # Evaluate setters + with pytest.raises(ValueError, match="anything is not a valid value for 'INTERPRET_VERBOSE'"): + config.INTERPRET_VERBOSE = "anything" + + with pytest.raises(KeyError, match="'DOESNT_EXIST' is not a valid configuration option"): + config.DOESNT_EXIST = "anything" \ No newline at end of file diff --git a/tests/test_interpret_messages.py b/tests/test_interpret_messages.py index 904174f74..b57bfc6dd 100644 --- a/tests/test_interpret_messages.py +++ b/tests/test_interpret_messages.py @@ -1,13 +1,9 @@ -import numpy as np +import bambi as bmb import pandas as pd -import matplotlib.pyplot as plt import pytest -import bambi as bmb from bambi.interpret import plot_comparisons, plot_predictions, plot_slopes -bmb.interpret.logger.messages = True - @pytest.fixture(scope="module") def mtcars(): @@ -22,7 +18,7 @@ def mtcars(): def test_predictions_list(mtcars, caplog): model, idata = mtcars - caplog.set_level("INFO", logger="interpret") + caplog.set_level("INFO", logger="__bambi_interpret__") # List of values with no unspecified covariates conditional = ["hp", "drat", "am"] @@ -39,7 +35,7 @@ def test_predictions_list(mtcars, caplog): def test_predictions_list_unspecified(mtcars, caplog): model, idata = mtcars - caplog.set_level("INFO", logger="interpret") + caplog.set_level("INFO", logger="__bambi_interpret__") # List of values with unspecified covariates conditional = ["hp", "drat"] @@ -57,7 +53,7 @@ def test_predictions_list_unspecified(mtcars, caplog): def test_predictions_dict_unspecified(mtcars, caplog): model, idata = mtcars - caplog.set_level("INFO", logger="interpret") + caplog.set_level("INFO", logger="__bambi_interpret__") # User passed values with unspecified covariates conditional = {"hp": [110, 175], "am": [0, 1]} @@ -79,7 +75,7 @@ def test_predictions_dict_unspecified(mtcars, caplog): def test_comparisons_contrast_default(mtcars, caplog): model, idata = mtcars - caplog.set_level("INFO", logger="interpret") + caplog.set_level("INFO", logger="__bambi_interpret__") # List of values with no unspecified covariates plot_comparisons(model, idata, "hp", conditional=None, average_by="am") @@ -93,7 +89,7 @@ def test_comparisons_contrast_default(mtcars, caplog): def test_slopes_wrt_default(mtcars, caplog): model, idata = mtcars - caplog.set_level("INFO", logger="interpret") + caplog.set_level("INFO", logger="__bambi_interpret__") # List of values with no unspecified covariates plot_slopes(model, idata, "hp", conditional=None, average_by="am")