Skip to content

Commit

Permalink
clean laplace results (#563)
Browse files Browse the repository at this point in the history
* clean laplace results

* disable pylint false alarm

* revert pylint disable and change inv for reciprocal

* fix docstring

* Remove `model` parameter from docstring.

It's not part of the method signature.

Co-authored-by: Tomás Capretto <[email protected]>
  • Loading branch information
aloctavodia and tomicapretto authored Aug 27, 2022
1 parent 505fbb3 commit e51a992
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
2 changes: 1 addition & 1 deletion bambi/backend/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def identity(x):


def inverse_squared(x):
return at.inv(at.sqrt(x))
return at.reciprocal(at.sqrt(x))


def arctan_2(x):
Expand Down
20 changes: 13 additions & 7 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class PyMCModel:
"cloglog": cloglog,
"identity": identity,
"inverse_squared": inverse_squared,
"inverse": at.inv,
"inverse": at.reciprocal,
"log": at.exp,
"logit": logit,
"probit": probit,
Expand Down Expand Up @@ -108,7 +108,7 @@ def run(
elif inference_method == "vi":
result = self._run_vi(**kwargs)
elif inference_method == "laplace":
result = self._run_laplace(draws)
result = self._run_laplace(draws, omit_offsets, include_mean)
else:
raise NotImplementedError(f"{inference_method} method has not been implemented")

Expand Down Expand Up @@ -349,10 +349,10 @@ def _run_mcmc(
f"``mcmc``, ``nuts_numpyro`` or ``nuts_blackjax``"
)

idata = self._clean_mcmc_results(idata, omit_offsets, include_mean)
idata = self._clean_results(idata, omit_offsets, include_mean)
return idata

def _clean_mcmc_results(self, idata, omit_offsets, include_mean):
def _clean_results(self, idata, omit_offsets, include_mean):
for group in idata.groups():
getattr(idata, group).attrs["modeling_interface"] = "bambi"
getattr(idata, group).attrs["modeling_interface_version"] = version.__version__
Expand Down Expand Up @@ -438,7 +438,7 @@ def _run_vi(self, **kwargs):
self.vi_approx = pm.fit(**kwargs)
return self.vi_approx

def _run_laplace(self, draws):
def _run_laplace(self, draws, omit_offsets, include_mean):
"""Fit a model using a Laplace approximation.
Mainly for pedagogical use, provides reasonable results for approximately
Expand All @@ -448,9 +448,13 @@ def _run_laplace(self, draws):
Parameters
----------
model: PyMC model
draws: int
The number of samples to draw from the posterior distribution.
omit_offsets: bool
Omits offset terms in the ``InferenceData`` object returned when the model includes
group specific effects.
include_mean: bool
Compute the posterior of the mean response.
Returns
-------
Expand All @@ -473,7 +477,9 @@ def _run_laplace(self, draws):

samples = np.random.multivariate_normal(modes, cov, size=draws)

return _posterior_samples_to_idata(samples, self.model)
idata = _posterior_samples_to_idata(samples, self.model)
idata = self._clean_results(idata, omit_offsets, include_mean)
return idata


def _posterior_samples_to_idata(samples, model):
Expand Down

0 comments on commit e51a992

Please sign in to comment.