Skip to content

Commit

Permalink
Add support for weighted responses (#761)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto authored Nov 28, 2023
1 parent 8dd25b1 commit 2d4b260
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

* Add configuration facilities to Bambi (#745)
* Interpet submodule now outputs informative messages when computing default values (#745)
* Bambi supports weighted responses (#761)

### Maintenance and fixes

Expand Down
22 changes: 19 additions & 3 deletions bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import pytensor.tensor as pt

from bambi.backend.utils import (
has_hyperprior,
get_distribution_from_prior,
get_distribution_from_likelihood,
get_distribution_from_prior,
get_linkinv,
has_hyperprior,
make_weighted_distribution,
GP_KERNELS,
)
from bambi.families.multivariate import MultivariateFamily, Multinomial, DirichletMultinomial
Expand Down Expand Up @@ -324,6 +325,21 @@ def build_response_distribution(self, kwargs, pymc_backend):
dist_rv = pm.Truncated(
self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims
)

# Handle weighted responses
elif self.term.is_weighted:
dims = kwargs.pop("dims", None)
data_matrix = kwargs.pop("observed")

# Get values of the response variable
observed = np.squeeze(data_matrix[:, 0])

# Get weights
weights = np.squeeze(data_matrix[:, 1])

# Get a weighted version of the response distribution
weighted_dist = make_weighted_distribution(distribution)
dist_rv = weighted_dist(self.name, weights, **kwargs, observed=observed, dims=dims)
else:
dist_rv = distribution(self.name, **kwargs)

Expand All @@ -345,7 +361,7 @@ def robustify_dims(self, pymc_backend, kwargs):
if isinstance(self.family, (Multinomial, DirichletMultinomial)):
return kwargs

if self.term.is_censored or self.term.is_truncated:
if self.term.is_censored or self.term.is_truncated or self.term.is_weighted:
return kwargs

dims, data = kwargs["dims"], kwargs["observed"]
Expand Down
98 changes: 98 additions & 0 deletions bambi/backend/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import inspect

from functools import partial

import pytensor.tensor as pt
import pymc as pm

Expand Down Expand Up @@ -81,3 +85,97 @@ def matern52(sigma, ell, input_dim=1):
"Matern32": {"fn": matern32, "params": ("sigma", "ell")},
"Matern52": {"fn": matern52, "params": ("sigma", "ell")},
}


def make_weighted_logp(dist: pm.Distribution):
"""Create a function to compute a weighted logp
Parameters
----------
dist : pm.Distribution
The PyMC distribution for which we want to get the weighted logp.
Returns
-------
A function that computes the weighted logp
"""

def logp(value, *dist_params, weights):
weights = pt.as_tensor_variable(weights)
return weights * pm.logp(dist.dist(*dist_params), value)

return logp


def get_dist_args(dist: pm.Distribution) -> list[str]:
"""Get the argument names of a PyMC distribution.
The argument names are the names of the parameters of the distribution.
Parameters
----------
dist : pm.Distribution
The PyMC distribution for which we want to extract the argument names.
Returns
-------
list[str]
The names of the arguments.
"""
# Get all args but the first one which is usually 'cls'
return inspect.getfullargspec(dist.dist).args[1:]


def create_cdist(dist: pm.Distribution):
def fun(*params):
*dist_params, size = params
return dist.dist(*dist_params, size=size)

return fun


# pylint: disable=bare-except
# pylint: disable=protected-access
def make_weighted_distribution(dist: pm.Distribution):
wlogp = make_weighted_logp(dist)
dist_args = get_dist_args(dist)

try:
dname = dist.rv_op._print_name[0]
except:
dname = "Dist"

cdist = create_cdist(dist)
class_name = f"Weighted{dname}"

class WeightedDistribution:
# We pass 'logp' to get the weighted logp, and we pass 'dist' to make sure
# the random draws are generated using the correct parameter values.
# Distribution.dist is the method that handles the parameters and with this approach
# we are sure that we use it.
def __new__(cls, name, weights, **kwargs):
# Get parameter values in the order required by the distribution as they are passed
# by position to `pm.CustomDist`
dist_params = [kwargs.pop(arg) for arg in dist_args if arg in kwargs]
return pm.CustomDist(
name,
*dist_params,
logp=partial(wlogp, weights=weights),
dist=cdist,
class_name=class_name,
**kwargs,
)

@classmethod
def dist(cls, **kwargs):
dist_params = [kwargs.pop(arg) for arg in dist_args if arg in kwargs]
weights = 1 if "weights" not in kwargs else kwargs.pop("weights")
return pm.CustomDist.dist(
*dist_params,
logp=partial(wlogp, weights=weights),
dist=cdist,
class_name=class_name,
**kwargs,
)

return WeightedDistribution
3 changes: 2 additions & 1 deletion bambi/terms/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from bambi.terms.base import BaseTerm

from bambi.terms.utils import is_censored_response, is_truncated_response
from bambi.terms.utils import is_censored_response, is_truncated_response, is_weighted_response


class ResponseTerm(BaseTerm):
Expand All @@ -11,6 +11,7 @@ def __init__(self, response, family):
self.family = family
self.is_censored = is_censored_response(self.term)
self.is_truncated = is_truncated_response(self.term)
self.is_weighted = is_weighted_response(self.term)

@property
def term(self):
Expand Down
10 changes: 10 additions & 0 deletions bambi/terms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,13 @@ def is_truncated_response(term):
if not is_call_component(component):
return False
return is_call_of_kind(component, "truncated")


def is_weighted_response(term):
"""Determines if a formulae term represents a weighted response"""
if not is_single_component(term):
return False
component = term.components[0] # get the first (and single) component
if not is_call_component(component):
return False
return is_call_of_kind(component, "weighted")
31 changes: 30 additions & 1 deletion bambi/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def censored(*args):


def truncated(x, lb=None, ub=None):
"""Construct array for truncated response
"""Construct array for a truncated response
Parameters
----------
Expand Down Expand Up @@ -127,6 +127,34 @@ def truncated(x, lb=None, ub=None):

truncated.__metadata__ = {"kind": "truncated"}


def weighted(x, weights):
"""Construct array for a weighted response
Parameters
----------
x : np.ndarray
The values of the truncated variable.
weights : np.ndarray
The weight of each value in `x`.
Returns
------
np.ndarray
Array of shape (n, 2). The first column contains the values of the `x` array and the second
contains the values of `weights`.
"""
x = np.asarray(x)
weights = np.asarray(weights)

if any(weights < 0):
raise ValueError("Weights must be positive.")

return np.column_stack([x, weights])


weighted.__metadata__ = {"kind": "weighted"}

# pylint: disable = invalid-name
@register_stateful_transform
class HSGP: # pylint: disable = too-many-instance-attributes
Expand Down Expand Up @@ -376,6 +404,7 @@ def get_distance(x):
"c": c,
"censored": censored,
"truncated": truncated,
"weighted": weighted,
"log": np.log,
"log2": np.log2,
"log10": np.log10,
Expand Down
10 changes: 10 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,3 +1334,13 @@ def test_predict_new_groups(data, formula, family, df_new, request):
model = bmb.Model(formula, data, family=family)
idata = model.fit(tune=100, draws=100)
model.predict(idata, data=df_new, sample_new_groups=True)


def test_weighted():
weights = 1 + np.random.poisson(lam=3, size=100)
y = np.random.exponential(scale=3, size=100)
data = pd.DataFrame({"w": weights, "y": y})
model = bmb.Model("weighted(y, w) ~ 1", data, family="exponential")
idata = model.fit(tune=TUNE, draws=DRAWS)
model.predict(idata, kind="pps")
model.predict(idata, kind="pps", data=data)
44 changes: 43 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import numpy as np
import pandas as pd
import pymc as pm

from bambi.utils import listify
from bambi.backend.pymc import probit, cloglog
from bambi.transformations import censored, truncated
from bambi.backend.utils import make_weighted_distribution
from bambi.transformations import censored, truncated, weighted


def test_listify():
Expand Down Expand Up @@ -100,3 +102,43 @@ def test_truncated():

with pytest.raises(ValueError, match="'ub' must be 0 or 1 dimensional."):
truncated(x, ub=np.column_stack([upper_arr, upper_arr]))


def test_weighted():
rng = np.random.default_rng(1234)
weights = 1 + rng.poisson(lam=3, size=100)
weights_wrong = rng.normal(size=100)
y = rng.exponential(scale=3, size=100)

out = weighted(y, weights)
assert out.shape == (100, 2)
assert (out[:, 0] == y).all()
assert (out[:, 1] == weights).all()

with pytest.raises(ValueError, match="Weights must be positive"):
weighted(y, weights_wrong)

# Draw function works and matches the non-weighted version
WeightedNormal = make_weighted_distribution(pm.Normal)
draws1 = pm.draw(WeightedNormal.dist(mu=0, sigma=1), draws=10, random_seed=1234)
draws2 = pm.draw(pm.Normal.dist(mu=0, sigma=1), draws=10, random_seed=1234)
assert np.allclose(draws1, draws2)

WeightedExponential = make_weighted_distribution(pm.Exponential)
draws1 = pm.draw(WeightedExponential.dist(lam=2.0), draws=10, random_seed=11)
draws2 = pm.draw(pm.Exponential.dist(lam=2.0), draws=10, random_seed=11)
assert np.allclose(draws1, draws2)

# Logp works and is propertly weighted
weights = np.array([0.5, 1.0, 3.2, 4.5, 1.0])
values = np.array([-2, -1, 0, 1.0, 2.0])
logp1 = pm.logp(WeightedNormal.dist(mu=0.5, sigma=0.3, weights=weights), value=values).eval()
logp2 = pm.logp(pm.Normal.dist(mu=0.5, sigma=0.3), value=values).eval()
assert np.allclose(logp1 / logp2, weights)

weights = np.array([1, 2.5, 2.5])
values = np.array([1, 1, 4.0])
logp1 = pm.logp(WeightedExponential.dist(lam=2, weights=weights), value=values).eval()
logp2 = pm.logp(pm.Exponential.dist(2), value=values).eval()

assert np.allclose(logp1 / logp2, weights)

0 comments on commit 2d4b260

Please sign in to comment.