From e51a99215dcc44ce056bbdaf7338b78924fd1a1c Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Sat, 27 Aug 2022 12:35:49 -0300 Subject: [PATCH] clean laplace results (#563) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * clean laplace results * disable pylint false alarm * revert pylint disable and change inv for reciprocal * fix docstring * Remove `model` parameter from docstring. It's not part of the method signature. Co-authored-by: Tomás Capretto --- bambi/backend/links.py | 2 +- bambi/backend/pymc.py | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/bambi/backend/links.py b/bambi/backend/links.py index f347aeaf1..759fae619 100644 --- a/bambi/backend/links.py +++ b/bambi/backend/links.py @@ -36,7 +36,7 @@ def identity(x): def inverse_squared(x): - return at.inv(at.sqrt(x)) + return at.reciprocal(at.sqrt(x)) def arctan_2(x): diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 36fb4760e..bd7709d12 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -23,7 +23,7 @@ class PyMCModel: "cloglog": cloglog, "identity": identity, "inverse_squared": inverse_squared, - "inverse": at.inv, + "inverse": at.reciprocal, "log": at.exp, "logit": logit, "probit": probit, @@ -108,7 +108,7 @@ def run( elif inference_method == "vi": result = self._run_vi(**kwargs) elif inference_method == "laplace": - result = self._run_laplace(draws) + result = self._run_laplace(draws, omit_offsets, include_mean) else: raise NotImplementedError(f"{inference_method} method has not been implemented") @@ -349,10 +349,10 @@ def _run_mcmc( f"``mcmc``, ``nuts_numpyro`` or ``nuts_blackjax``" ) - idata = self._clean_mcmc_results(idata, omit_offsets, include_mean) + idata = self._clean_results(idata, omit_offsets, include_mean) return idata - def _clean_mcmc_results(self, idata, omit_offsets, include_mean): + def _clean_results(self, idata, omit_offsets, include_mean): for group in idata.groups(): getattr(idata, group).attrs["modeling_interface"] = "bambi" getattr(idata, group).attrs["modeling_interface_version"] = version.__version__ @@ -438,7 +438,7 @@ def _run_vi(self, **kwargs): self.vi_approx = pm.fit(**kwargs) return self.vi_approx - def _run_laplace(self, draws): + def _run_laplace(self, draws, omit_offsets, include_mean): """Fit a model using a Laplace approximation. Mainly for pedagogical use, provides reasonable results for approximately @@ -448,9 +448,13 @@ def _run_laplace(self, draws): Parameters ---------- - model: PyMC model draws: int The number of samples to draw from the posterior distribution. + omit_offsets: bool + Omits offset terms in the ``InferenceData`` object returned when the model includes + group specific effects. + include_mean: bool + Compute the posterior of the mean response. Returns ------- @@ -473,7 +477,9 @@ def _run_laplace(self, draws): samples = np.random.multivariate_normal(modes, cov, size=draws) - return _posterior_samples_to_idata(samples, self.model) + idata = _posterior_samples_to_idata(samples, self.model) + idata = self._clean_results(idata, omit_offsets, include_mean) + return idata def _posterior_samples_to_idata(samples, model):