From 2a5969ab37e192e5bb20f170b066446b670130cb Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Sun, 22 Sep 2024 18:06:37 -0300 Subject: [PATCH] Use number of rows from out-of-sample data in multivariate families --- bambi/families/multivariate.py | 43 ++++++++++++++++++++++++++++++---- tests/test_models.py | 8 +++++++ 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/bambi/families/multivariate.py b/bambi/families/multivariate.py index 41c463f0..6ecebbc8 100644 --- a/bambi/families/multivariate.py +++ b/bambi/families/multivariate.py @@ -36,9 +36,18 @@ def transform_coords(self, model, mean): return mean def posterior_predictive(self, model, posterior, **kwargs): - n = model.response_component.term.data.sum(1).astype(int) + data = kwargs["data"] + if data is None: + y = model.response_component.term.data + trials = model.response_component.term.data.sum(1).astype(int) + else: + y = response_evaluate_new_data(model, data).astype(int) + trials = y.sum(1).astype(int) + + # Prepend 'draw' and 'chain' dimensions + trials = trials[np.newaxis, np.newaxis, :] dont_reshape = ["n"] - return super().posterior_predictive(model, posterior, n=n, dont_reshape=dont_reshape) + return super().posterior_predictive(model, posterior, n=trials, dont_reshape=dont_reshape) def log_likelihood(self, model, posterior, data, **kwargs): if data is None: @@ -91,9 +100,35 @@ class DirichletMultinomial(MultivariateFamily): SUPPORTED_LINKS = {"a": ["log"]} def posterior_predictive(self, model, posterior, **kwargs): - n = model.response_component.term.data.sum(1).astype(int) + data = kwargs["data"] + if data is None: + y = model.response_component.term.data + trials = model.response_component.term.data.sum(1).astype(int) + else: + y = response_evaluate_new_data(model, data).astype(int) + trials = y.sum(1).astype(int) + + # Prepend 'draw' and 'chain' dimensions + trials = trials[np.newaxis, np.newaxis, :] dont_reshape = ["n"] - return super().posterior_predictive(model, posterior, n=n, dont_reshape=dont_reshape) + return super().posterior_predictive(model, posterior, n=trials, dont_reshape=dont_reshape) + + def log_likelihood(self, model, posterior, data, **kwargs): + if data is None: + y = model.response_component.term.data + trials = model.response_component.term.data.sum(1).astype(int) + else: + y = response_evaluate_new_data(model, data).astype(int) + trials = y.sum(1).astype(int) + + # Prepend 'draw' and 'chain' dimensions + y = y[np.newaxis, np.newaxis, :] + trials = trials[np.newaxis, np.newaxis, :] + + dont_reshape = ["n"] + return super().log_likelihood( + model, posterior, data=None, y=y, n=trials, dont_reshape=dont_reshape, **kwargs + ) def get_coords(self, response): name = get_aliased_name(response) + "_dim" diff --git a/tests/test_models.py b/tests/test_models.py index 0bc2e1db..b7b89b44 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1186,6 +1186,10 @@ def test_intercept_only(self, multinomial_data): idata = self.predict_oos(model, idata, data=model.data) self.assert_posterior_predictive(model, idata) + # Out of sample with different number of rows, see issue #845 + idata = self.predict_oos(model, idata, data=model.data.sample(frac=0.8, random_state=1211)) + self.assert_posterior_predictive(model, idata) + def test_numerical_predictors(self, multinomial_data): model = bmb.Model( "c(y1, y2, y3, y4) ~ treat + carry", multinomial_data, family="multinomial" @@ -1242,6 +1246,10 @@ def test_intercept_only(self, multinomial_data): idata = self.predict_oos(model, idata, model.data) self.assert_posterior_predictive(model, idata) + # Out of sample with different number of rows, see issue #845 + idata = self.predict_oos(model, idata, data=model.data.sample(frac=0.8, random_state=1211)) + self.assert_posterior_predictive(model, idata) + def test_predictor(self, multinomial_data): model = bmb.Model( "c(y1, y2, y3, y4) ~ 0 + treat", multinomial_data, family="dirichlet_multinomial"