Skip to content

Commit

Permalink
Add configuration and update Changelog
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Nov 5, 2023
1 parent eafb393 commit 4dc02c2
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 34 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

### New features

* Add configuration facilities to Bambi (#745)
* Interpet submodule now outputs informative messages when computing default values (#745)

### Maintenance and fixes

### Documentation
Expand Down
2 changes: 2 additions & 0 deletions bambi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pymc import math

from .backend import PyMCModel
from .config import config
from .data import clear_data_home, load_data
from .families import Family, Likelihood, Link
from .formula import Formula
Expand All @@ -23,6 +24,7 @@
"PyMCModel",
"Formula",
"clear_data_home",
"config",
"load_data",
"math",
]
Expand Down
51 changes: 51 additions & 0 deletions bambi/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
class Config:
"""Configuration variables for Bambi
It works with a pre-specified set of configuration variables and options for those.
When a user tries to set a configuration variable to a non-supported value, it raises an error.
"""

__FIELDS = {"INTERPRET_VERBOSE": (True, False)}

def __init__(self, config_dict: dict = None):
config_dict = {} if config_dict is None else config_dict
# When an option is not specified at instantiation time, it uses the first value specified
# in __FIELDS.
for field, choices in Config.__FIELDS.items():
if field in config_dict:
value = config_dict[field]
else:
value = choices[0]
self[field] = value

def __setitem__(self, key, value):
setattr(self, key, value)

def __setattr__(self, key, value):
if key in Config.__FIELDS:
if value in Config.__FIELDS[key]:
super().__setattr__(key, value)
else:
raise ValueError(
f"{value} is not a valid value for '{key}'"
f"Valid options are: {Config.__FIELDS[key]}"
)
else:
raise KeyError(f"'{key}' is not a valid configuration option")

def __getitem__(self, key):
return getattr(self, key)

def __str__(self): # pragma: no cover
lines = []
for field, choices in Config.__FIELDS.items():
lines.append(f"{field}: {self[field]} (available: {list(choices)})")
header = ["Bambi configuration"]
header.append("-" * len(header[0]))
return "\n".join(header + lines)

def __repr__(self): # pragma: no cover
return str(self)


config = Config()
10 changes: 9 additions & 1 deletion bambi/interpret/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from bambi.interpret.config import logger
import logging

from bambi.interpret.effects import comparisons, predictions, slopes
from bambi.interpret.plotting import plot_comparisons, plot_predictions, plot_slopes

Expand All @@ -11,3 +12,10 @@
"plot_predictions",
"plot_slopes",
]

logger = logging.getLogger("__bambi_interpret__")

if not logging.root.handlers:
logger.setLevel(logging.INFO)
if len(logger.handlers) == 0:
logger.addHandler(logging.StreamHandler())
19 changes: 0 additions & 19 deletions bambi/interpret/config.py

This file was deleted.

10 changes: 6 additions & 4 deletions bambi/interpret/logs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from bambi.interpret import logger
import logging

from bambi import config


def log_interpret_defaults(func):
Expand All @@ -10,11 +12,11 @@ def log_interpret_defaults(func):
or 'wrt' of 'comparisons' and 'slopes', as well as the 'conditional'
parameter of 'comparisons', 'predictions', and 'slopes'.
"""
interpret_logger = logger.get_logger("interpret")
logger = logging.getLogger("__bambi_interpret__")

def wrapper(*args, **kwargs):

if not logger.messages:
if not config["INTERPRET_VERBOSE"]:
return func(*args, **kwargs)

arg_name = None
Expand All @@ -39,7 +41,7 @@ def wrapper(*args, **kwargs):
covariate_name = args[0].name

if arg_name:
interpret_logger.info("Default computed for %s variable: %s", arg_name, covariate_name)
logger.info("Default computed for %s variable: %s", arg_name, covariate_name)

return func(*args, **kwargs)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

from bambi.config import Config


def test_config():
config = Config()

# Evaluate getters
assert config["INTERPRET_VERBOSE"] is True
assert config.INTERPRET_VERBOSE is True

config.INTERPRET_VERBOSE = False
assert config.INTERPRET_VERBOSE is False

# Evaluate setters
with pytest.raises(ValueError, match="anything is not a valid value for 'INTERPRET_VERBOSE'"):
config.INTERPRET_VERBOSE = "anything"

with pytest.raises(KeyError, match="'DOESNT_EXIST' is not a valid configuration option"):
config.DOESNT_EXIST = "anything"
16 changes: 6 additions & 10 deletions tests/test_interpret_messages.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import numpy as np
import bambi as bmb
import pandas as pd
import matplotlib.pyplot as plt
import pytest

import bambi as bmb
from bambi.interpret import plot_comparisons, plot_predictions, plot_slopes

bmb.interpret.logger.messages = True


@pytest.fixture(scope="module")
def mtcars():
Expand All @@ -22,7 +18,7 @@ def mtcars():

def test_predictions_list(mtcars, caplog):
model, idata = mtcars
caplog.set_level("INFO", logger="interpret")
caplog.set_level("INFO", logger="__bambi_interpret__")

# List of values with no unspecified covariates
conditional = ["hp", "drat", "am"]
Expand All @@ -39,7 +35,7 @@ def test_predictions_list(mtcars, caplog):

def test_predictions_list_unspecified(mtcars, caplog):
model, idata = mtcars
caplog.set_level("INFO", logger="interpret")
caplog.set_level("INFO", logger="__bambi_interpret__")

# List of values with unspecified covariates
conditional = ["hp", "drat"]
Expand All @@ -57,7 +53,7 @@ def test_predictions_list_unspecified(mtcars, caplog):

def test_predictions_dict_unspecified(mtcars, caplog):
model, idata = mtcars
caplog.set_level("INFO", logger="interpret")
caplog.set_level("INFO", logger="__bambi_interpret__")

# User passed values with unspecified covariates
conditional = {"hp": [110, 175], "am": [0, 1]}
Expand All @@ -79,7 +75,7 @@ def test_predictions_dict_unspecified(mtcars, caplog):

def test_comparisons_contrast_default(mtcars, caplog):
model, idata = mtcars
caplog.set_level("INFO", logger="interpret")
caplog.set_level("INFO", logger="__bambi_interpret__")

# List of values with no unspecified covariates
plot_comparisons(model, idata, "hp", conditional=None, average_by="am")
Expand All @@ -93,7 +89,7 @@ def test_comparisons_contrast_default(mtcars, caplog):

def test_slopes_wrt_default(mtcars, caplog):
model, idata = mtcars
caplog.set_level("INFO", logger="interpret")
caplog.set_level("INFO", logger="__bambi_interpret__")

# List of values with no unspecified covariates
plot_slopes(model, idata, "hp", conditional=None, average_by="am")
Expand Down

0 comments on commit 4dc02c2

Please sign in to comment.