diff --git a/aemcmc/__init__.py b/aemcmc/__init__.py index 4d52a61..1a60207 100644 --- a/aemcmc/__init__.py +++ b/aemcmc/__init__.py @@ -1,3 +1,7 @@ from . import _version __version__ = _version.get_versions()["version"] + +# Register rewrite databases +import aemcmc.conjugates +import aemcmc.gibbs diff --git a/aemcmc/basic.py b/aemcmc/basic.py new file mode 100644 index 0000000..92d5be9 --- /dev/null +++ b/aemcmc/basic.py @@ -0,0 +1,123 @@ +from typing import Dict, Tuple + +from aesara.graph.basic import Variable +from aesara.graph.fg import FunctionGraph +from aesara.tensor.random.utils import RandomStream +from aesara.tensor.var import TensorVariable + +from aemcmc.opt import ( + SamplerTracker, + construct_ir_fgraph, + expand_subsumptions, + sampler_rewrites_db, +) + + +def construct_sampler( + obs_rvs_to_values: Dict[TensorVariable, TensorVariable], srng: RandomStream +) -> Tuple[ + Dict[TensorVariable, TensorVariable], + Dict[Variable, Variable], + Dict[TensorVariable, TensorVariable], +]: + r"""Eagerly construct a sampler for a given set of observed variables and their observations. + + Parameters + ========== + obs_rvs_to_values + A ``dict`` of variables that maps stochastic elements + (e.g. `RandomVariable`\s) to symbolic `Variable`\s representing their + observed values. + + Returns + ======= + A ``dict`` that maps each random variable to its sampler step and + any updates generated by the sampler steps. + """ + + fgraph, obs_rvs_to_values, memo, new_to_old_rvs = construct_ir_fgraph( + obs_rvs_to_values + ) + + fgraph.attach_feature(SamplerTracker(srng)) + + _ = sampler_rewrites_db.query("+basic").optimize(fgraph) + + random_vars = tuple(rv for rv in fgraph.outputs if rv not in obs_rvs_to_values) + + discovered_samplers = fgraph.sampler_mappings.rvs_to_samplers + + rvs_to_init_vals = {rv: rv.clone() for rv in random_vars} + posterior_sample_steps = rvs_to_init_vals.copy() + # Replace occurrences of observed variables with their observed values + posterior_sample_steps.update(obs_rvs_to_values) + + # TODO FIXME: Get/extract `Scan`-generated updates + posterior_updates: Dict[Variable, Variable] = {} + + rvs_without_samplers = set() + + for rv in fgraph.outputs: + + if rv in obs_rvs_to_values: + continue + + rv_steps = discovered_samplers.get(rv) + + if not rv_steps: + rvs_without_samplers.add(rv) + continue + + # TODO FIXME: Just choosing one for now, but we should consider them all. + step_desc, step, updates = rv_steps.pop() + + # Expand subsumed `DimShuffle`d inputs to `Elemwise`s + if updates: + update_keys, update_values = zip(*updates.items()) + else: + update_keys, update_values = tuple(), tuple() + + sfgraph = FunctionGraph( + outputs=(step,) + tuple(update_keys) + tuple(update_values), + clone=False, + copy_inputs=False, + copy_orphans=False, + ) + + # Update the other sampled random variables in this step's graph + sfgraph.replace_all(list(posterior_sample_steps.items()), import_missing=True) + + expand_subsumptions.optimize(sfgraph) + + step = sfgraph.outputs[0] + + # Update the other sampled random variables in this step's graph + # (step,) = clone_replace([step], replace=posterior_sample_steps) + + posterior_sample_steps[rv] = step + + if updates: + keys_offset = len(update_keys) + 1 + update_keys = sfgraph.outputs[1:keys_offset] + update_values = sfgraph.outputs[keys_offset:] + updates = dict(zip(update_keys, update_values)) + posterior_updates.update(updates) + + if rvs_without_samplers: + # TODO: Assign NUTS to these + raise NotImplementedError( + f"Could not find a posterior samplers for {rvs_without_samplers}" + ) + + # TODO: Track/handle "auxiliary/augmentation" variables introduced by sample + # steps? + + return ( + { + new_to_old_rvs[rv]: step + for rv, step in posterior_sample_steps.items() + if rv not in obs_rvs_to_values + }, + posterior_updates, + {new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items()}, + ) diff --git a/aemcmc/conjugates.py b/aemcmc/conjugates.py index 5272369..a767927 100644 --- a/aemcmc/conjugates.py +++ b/aemcmc/conjugates.py @@ -1,12 +1,16 @@ import aesara.tensor as at -from aeppl.opt import NoCallbackEquilibriumDB -from aesara.graph.kanren import KanrenRelationSub +from aesara.graph.opt import in2out, local_optimizer +from aesara.graph.optdb import LocalGroupDB +from aesara.graph.unify import eval_if_etuple +from aesara.tensor.random.basic import BinomialRV from etuples import etuple, etuplize -from kanren import eq, lall +from kanren import eq, lall, run from unification import var +from aemcmc.opt import sampler_finder_db -def beta_binomial_conjugateo(observed_rv_expr, posterior_expr): + +def beta_binomial_conjugateo(observed_val, observed_rv_expr, posterior_expr): r"""Produce a goal that represents the application of Bayes theorem for a beta prior with a binomial observation model. @@ -22,15 +26,15 @@ def beta_binomial_conjugateo(observed_rv_expr, posterior_expr): Parameters ---------- + observed_val + The observed value. observed_rv_expr - A tuple that contains expressions that represent the observed variable - and it observed value respectively. + An expression that represents the observed variable. posterior_exp An expression that represents the posterior distribution of the latent variable. """ - # Beta-binomial observation model alpha_lv, beta_lv = var(), var() p_rng_lv = var() @@ -42,12 +46,10 @@ def beta_binomial_conjugateo(observed_rv_expr, posterior_expr): n_lv = var() Y_et = etuple(etuplize(at.random.binomial), var(), var(), var(), n_lv, p_et) - y_lv = var() # observation - # Posterior distribution for p - new_alpha_et = etuple(etuplize(at.add), alpha_lv, y_lv) + new_alpha_et = etuple(etuplize(at.add), alpha_lv, observed_val) new_beta_et = etuple( - etuplize(at.sub), etuple(etuplize(at.add), beta_lv, n_lv), y_lv + etuplize(at.sub), etuple(etuplize(at.add), beta_lv, n_lv), observed_val ) p_posterior_et = etuple( etuplize(at.random.beta), @@ -59,13 +61,47 @@ def beta_binomial_conjugateo(observed_rv_expr, posterior_expr): ) return lall( - eq(observed_rv_expr[0], Y_et), - eq(observed_rv_expr[1], y_lv), + eq(observed_rv_expr, Y_et), eq(posterior_expr, p_posterior_et), ) -conjugates_db = NoCallbackEquilibriumDB() -conjugates_db.register( - "beta_binomial", KanrenRelationSub(beta_binomial_conjugateo), -5, "basic" +@local_optimizer([BinomialRV]) +def local_beta_binomial_posterior(fgraph, node): + + sampler_mappings = getattr(fgraph, "sampler_mappings", None) + + rv_var = node.outputs[1] + key = ("local_beta_binomial_posterior", rv_var) + + if sampler_mappings is None or key in sampler_mappings.rvs_seen: + return None # pragma: no cover + + q = var() + + rv_et = etuplize(rv_var) + + res = run(None, q, beta_binomial_conjugateo(rv_var, rv_et, q)) + res = next(res, None) + + if res is None: + return None # pragma: no cover + + beta_rv = rv_et[-1].evaled_obj + beta_posterior = eval_if_etuple(res) + + sampler_mappings.rvs_to_samplers.setdefault(beta_rv, []).append( + ("local_beta_binomial_posterior", beta_posterior, None) + ) + sampler_mappings.rvs_seen.add(key) + + return rv_var.owner.outputs + + +conjugates_db = LocalGroupDB(apply_all_opts=True) +conjugates_db.name = "conjugates_db" +conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic") + +sampler_finder_db.register( + "conjugates", in2out(conjugates_db.query("+basic"), name="gibbs"), "basic" ) diff --git a/aemcmc/dists.py b/aemcmc/dists.py index 8352082..ecf5386 100644 --- a/aemcmc/dists.py +++ b/aemcmc/dists.py @@ -53,6 +53,7 @@ def multivariate_normal_rue2005(rng, b, Q): w = at.slinalg.solve_triangular(L, b, lower=True) u = at.slinalg.solve_triangular(L.T, w, lower=False) z = rng.standard_normal(size=L.shape[0]) + z.owner.outputs[0].name = "z_rng" v = at.slinalg.solve_triangular(L.T, z, lower=False) return u + v @@ -135,6 +136,7 @@ def multivariate_normal_cong2017( A_inv = 1 / A a_rows = A.shape[0] z = rng.standard_normal(size=a_rows + omega.shape[0]) + z.owner.outputs[0].name = "z_rng" y1 = at.sqrt(A_inv) * z[:a_rows] y2 = (1 / at.sqrt(omega)) * z[a_rows:] Ainv_phi = A_inv[:, None] * phi.T diff --git a/aemcmc/gibbs.py b/aemcmc/gibbs.py index 86444a6..1de8e4f 100644 --- a/aemcmc/gibbs.py +++ b/aemcmc/gibbs.py @@ -1,18 +1,17 @@ -from typing import Dict, List, Mapping, Tuple, Union +from typing import List, Mapping, Optional, Tuple import aesara import aesara.tensor as at -from aesara.graph import optimize_graph from aesara.graph.basic import Variable -from aesara.graph.opt import EquilibriumOptimizer +from aesara.graph.opt import in2out +from aesara.graph.optdb import LocalGroupDB from aesara.graph.unify import eval_if_etuple from aesara.ifelse import ifelse from aesara.tensor.math import Dot from aesara.tensor.random import RandomStream -from aesara.tensor.random.opt import local_dimshuffle_rv_lift +from aesara.tensor.random.basic import BernoulliRV, NegBinomialRV, NormalRV from aesara.tensor.var import TensorVariable from etuples import etuple, etuplize -from etuples.core import ExpressionTuple from unification import unify, var from aemcmc.dists import ( @@ -20,21 +19,13 @@ multivariate_normal_rue2005, polyagamma, ) +from aemcmc.opt import sampler_finder, sampler_finder_db - -def canonicalize_and_tuplize(graph: TensorVariable) -> ExpressionTuple: - """Canonicalize and etuple-ize a graph.""" - graph_opt = optimize_graph( - graph, - custom_opt=EquilibriumOptimizer( - [local_dimshuffle_rv_lift], max_use_ratio=aesara.config.optdb__max_use_ratio - ), - ) - graph_et = etuplize(graph_opt) - return graph_et +gibbs_db = LocalGroupDB(apply_all_opts=True) +gibbs_db.name = "gibbs_db" -def update_beta_low_dimension( +def normal_regression_overdetermined_posterior( srng: RandomStream, omega: TensorVariable, lmbdatau_inv: TensorVariable, @@ -57,7 +48,7 @@ def update_beta_low_dimension( return multivariate_normal_rue2005(srng, X.T @ (omega * z), Q) -def update_beta_high_dimension( +def normal_regression_underdetermined_posterior( srng: RandomStream, omega: TensorVariable, lmbdatau_inv: TensorVariable, @@ -74,7 +65,7 @@ def update_beta_high_dimension( return multivariate_normal_cong2017(srng, lmbdatau_inv, omega, X, z) -def update_beta( +def normal_regression_posterior( srng: RandomStream, omega: TensorVariable, lmbdatau_inv: TensorVariable, @@ -100,7 +91,7 @@ def update_beta( \begin{align*} \left( \beta \mid Z = z \right) &\sim - \operatorname{N}\left( A^{-1} X^{\top} \Omega, A^{-1} \right) \\ + \operatorname{N}\left( A^{-1} X^{\top} \Omega z, A^{-1} \right) \\ A &= X^{\top} X + \Lambda^{-1}_{*} \\ \Lambda_{*} &= \tau^2 \Lambda \end{align*} @@ -131,8 +122,8 @@ def update_beta( """ return ifelse( X.shape[1] > X.shape[0], - update_beta_high_dimension(srng, omega, lmbdatau_inv, X, z), - update_beta_low_dimension(srng, omega, lmbdatau_inv, X, z), + normal_regression_underdetermined_posterior(srng, omega, lmbdatau_inv, X, z), + normal_regression_overdetermined_posterior(srng, omega, lmbdatau_inv, X, z), ) @@ -148,26 +139,9 @@ def update_beta( ) -def horseshoe_model(srng: TensorVariable) -> TensorVariable: - """Horseshoe shrinkage prior [1]_. - - References - ---------- - .. [1]: Carvalho, C. M., Polson, N. G., & Scott, J. G. (2010). - The horseshoe estimator for sparse signals. - Biometrika, 97(2), 465-480. - - """ - size = at.scalar("size", dtype="int32") - tau_rv = srng.halfcauchy(0, 1, size=1) - lmbda_rv = srng.halfcauchy(0, 1, size=size) - beta_rv = srng.normal(0, tau_rv * lmbda_rv, size=size) - return beta_rv - - def horseshoe_match(graph: TensorVariable) -> Tuple[TensorVariable, TensorVariable]: - graph_et = canonicalize_and_tuplize(graph) + graph_et = etuplize(graph) s = unify(graph_et, horseshoe_pattern) if s is False: @@ -192,10 +166,10 @@ def horseshoe_match(graph: TensorVariable) -> Tuple[TensorVariable, TensorVariab + "in your model is not half-Cauchy distributed." ) - if halfcauchy_1.type.shape == (1,): + if halfcauchy_1.type.ndim == 0 or all(s == 1 for s in halfcauchy_1.type.shape): lmbda_rv = halfcauchy_2 tau_rv = halfcauchy_1 - elif halfcauchy_2.type.shape == (1,): + elif halfcauchy_2.type.ndim == 0 or all(s == 1 for s in halfcauchy_2.type.shape): lmbda_rv = halfcauchy_1 tau_rv = halfcauchy_2 else: @@ -207,7 +181,7 @@ def horseshoe_match(graph: TensorVariable) -> Tuple[TensorVariable, TensorVariab return (lmbda_rv, tau_rv) -def horseshoe_step( +def horseshoe_posterior( srng: RandomStream, beta: TensorVariable, sigma: TensorVariable, @@ -269,6 +243,51 @@ def horseshoe_step( return lmbda_inv_new, tau_inv_new +@sampler_finder([NormalRV]) +def normal_horseshoe_finder(fgraph, node, srng): + r"""Find and construct a Gibbs sampler for the normal-Horseshoe model. + + The implementation follows the sampler described in [1]_. It is designed to + sample efficiently from the following model: + + .. math:: + + \begin{align*} + \beta_i &\sim \operatorname{N}(0, \lambda_i^2 \tau^2) \\ + \lambda_i &\sim \operatorname{C}^{+}(0, 1) \\ + \tau &\sim \operatorname{C}^{+}(0, 1) + \end{align*} + + References + ---------- + .. [1] Makalic, Enes & Schmidt, Daniel. (2015). A Simple Sampler for the + Horseshoe Estimator. 10.1109/LSP.2015.2503725. + + """ + + rv_var = node.outputs[1] + + try: + lambda_rv, tau_rv = horseshoe_match(node.outputs[1]) + except ValueError: # pragma: no cover + return None + + lambda_posterior, tau_posterior = horseshoe_posterior( + srng, rv_var, 1.0, lambda_rv, tau_rv + ) + + if lambda_rv.name: + lambda_posterior.name = f"{lambda_rv.name}_posterior" + + if tau_rv.name: + tau_posterior.name = f"{tau_rv.name}_posterior" + + return [(lambda_rv, lambda_posterior, None), (tau_rv, tau_posterior, None)] + + +gibbs_db.register("normal_horseshoe", normal_horseshoe_finder, "basic") + + X_lv = var() beta_lv = var() neg_one_lv = var() @@ -283,11 +302,9 @@ def horseshoe_step( gamma_pattern = etuple(etuplize(at.random.gamma), var(), var(), var(), a_lv, b_lv) -def gamma_match( - graph: TensorVariable, -) -> Tuple[TensorVariable, TensorVariable]: - graph_opt = optimize_graph(graph) - graph_et = etuplize(graph_opt) +def gamma_match(graph: TensorVariable) -> Tuple[TensorVariable, TensorVariable]: + + graph_et = etuplize(graph) s = unify(graph_et, gamma_pattern) if s is False: raise ValueError("Not a gamma prior.") @@ -307,8 +324,8 @@ def gamma_match( def nbinom_sigmoid_dot_match( graph: TensorVariable, ) -> Tuple[TensorVariable, TensorVariable, TensorVariable]: - graph_opt = optimize_graph(graph) - graph_et = etuplize(graph_opt) + + graph_et = etuplize(graph) s = unify(graph_et, nbinom_sigmoid_dot_pattern) if s is False: raise ValueError("Not a negative binomial regression.") @@ -326,29 +343,6 @@ def nbinom_sigmoid_dot_match( return X, h, beta_rv -def nbinom_horseshoe_model(srng: RandomStream) -> TensorVariable: - """Negative binomial regression model with a horseshoe shrinkage prior.""" - X = at.matrix("X") - h = at.scalar("h") - - beta_rv = horseshoe_model(srng) - eta = X @ beta_rv - p = at.sigmoid(-eta) - Y_rv = srng.nbinom(h, p) - - return Y_rv - - -def nbinom_horseshoe_match( - Y_rv: TensorVariable, -) -> Tuple[ - TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable -]: - X, h, beta_rv = nbinom_sigmoid_dot_match(Y_rv) - lmbda_rv, tau_rv = horseshoe_match(beta_rv) - return h, X, beta_rv, lmbda_rv, tau_rv - - def sample_CRT( srng: RandomStream, y: TensorVariable, h: TensorVariable ) -> Tuple[TensorVariable, Mapping[Variable, Variable]]: @@ -385,14 +379,14 @@ def single_sample_CRT(y_i: TensorVariable, h: TensorVariable): return res, updates -def h_step( +def nbinom_dispersion_posterior( srng: RandomStream, h_last: TensorVariable, p: TensorVariable, a: TensorVariable, b: TensorVariable, y: TensorVariable, -) -> Tuple[TensorVariable, Mapping[Variable, Variable]]: +) -> Tuple[TensorVariable, Optional[Mapping[Variable, Variable]]]: r"""Sample the conditional posterior for the dispersion parameter under a negative-binomial and gamma prior. In other words, this draws a sample from :math:`h \mid Y = y` per @@ -404,7 +398,7 @@ def h_step( h &\sim \operatorname{Gamma}(a, b) \end{align*} - where `y` is a sample from :math:`y \sim Y`. + where :math:`\operatorname{NB}` is a negative-binomial distribution. The conditional posterior sample step is derived from the following decomposition: @@ -416,7 +410,8 @@ def h_step( where :math:`\operatorname{Log}` is the logarithmic distribution. Under a gamma prior, :math:`h` is conjugate to :math:`l`. We draw samples from - :math:`l` according to :math:`l \sim \operatorname{CRT(y, h)}`. + :math:`l` according to :math:`l \sim \operatorname{CRT(y, h)}`, where + :math:`y` is a sample from :math:`y \sim Y`. The resulting posterior is @@ -426,40 +421,42 @@ def h_step( \left(h \mid Y = y\right) \sim \operatorname{Gamma}\left(a + \sum_{i=1}^N l_i, \frac{1}{1/b + \sum_{i=1}^N \log(1 - p_i)} \right) \end{gather*} + Parameters + ---------- + srng + The random number generator from which samples are drawn. + h_last + The previous sample value of :math:`h`. + p + The success probability parameter in the negative-binomial distribution of :math:`Y`. + a + The shape parameter in the :math:`\operatorname{Gamma}` prior on :math:`h`. + b + The rate parameter in the :math:`\operatorname{Gamma}` prior on :math:`h`. + y + A sample from :math:`Y`. + + Returns + ------- + A sample from the posterior :math:`h \mid y`. References ---------- - .. [1] Zhou, Mingyuan, Lingbo Li, David Dunson, and Lawrence Carin. 2012. “Lognormal and Gamma Mixed Negative Binomial Regression.” Proceedings of the International Conference on Machine Learning. International Conference on Machine Learning 2012: 1343–50. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4180062/. + .. [1] Zhou, Mingyuan, Lingbo Li, David Dunson, and Lawrence Carin. 2012. + “Lognormal and Gamma Mixed Negative Binomial Regression.” + Proceedings of the International Conference on Machine Learning. + 2012: 1343–50. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4180062/. """ Ls, updates = sample_CRT(srng, y, h_last) L_sum = Ls.sum(axis=-1) h = srng.gamma(a + L_sum, at.reciprocal(b) - at.sum(at.log(1 - p), axis=-1)) - h.name = f"{h_last.name or 'h'}-posterior" + h.name = f"{h_last.name or 'h'}_posterior" return h, updates -def nbinom_horseshoe_with_dispersion_match( - Y_rv: TensorVariable, -) -> Tuple[ - TensorVariable, - TensorVariable, - TensorVariable, - TensorVariable, - TensorVariable, - TensorVariable, - TensorVariable, -]: - X, h_rv, beta_rv = nbinom_sigmoid_dot_match(Y_rv) - lmbda_rv, tau_rv = horseshoe_match(beta_rv) - a, b = gamma_match(h_rv) - return X, beta_rv, lmbda_rv, tau_rv, h_rv, a, b - - -def nbinom_horseshoe_gibbs( - srng: RandomStream, Y_rv: TensorVariable, y: TensorVariable, num_samples: int -) -> Tuple[Union[TensorVariable, List[TensorVariable]], Dict]: - r"""Build a Gibbs sampler for the negative binomial regression with a horseshoe prior. +def nbinom_normal_posterior(srng, beta, beta_std, X, h, y): + r"""Produce a Gibbs sample step for a negative binomial logistic-link regression with a normal prior. The implementation follows the sampler described in [1]_. It is designed to sample efficiently from the following negative binomial regression model: @@ -467,37 +464,38 @@ def nbinom_horseshoe_gibbs( .. math:: \begin{align*} - Y_i &\sim \operatorname{NB}\left(h, p_i\right) \\ - p_i &= \frac{\exp(\psi_i)}{1 + \exp(\psi_i)} \\ - \psi_i &= x_i^\top \beta \\ - \beta_j &\sim \operatorname{N}(0, \lambda_j^2 \tau^2) \\ - \lambda_j &\sim \operatorname{HalfCauchy}(0, 1) \\ - \tau &\sim \operatorname{HalfCauchy}(0, 1) + Y &\sim \operatorname{NB}\left(h, p\right) \\ + p &= \frac{\exp(\psi)}{1 + \exp(\psi)} \\ + \psi &= X^\top \beta \\ + \beta &\sim \operatorname{N}(0, \lambda^2) \\ \end{align*} + where :math:`\operatorname{NB}` is a negative-binomial distribution. Parameters ---------- - srng: symbolic random number generator - The random number generating object to be used during sampling. - Y_rv - Model graph. - y: TensorVariable - The observed count data. - n_samples: TensorVariable - A tensor describing the number of posterior samples to generate. + srng + The random number generator from which samples are drawn. + beta + The current/previous value of the regression parameter :math:`beta`. + beta_std + The std. dev. of the regression parameter :math:`beta`. + X + The regression matrix. + h + The :math:`h` parameter in the negative-binomial distribution of :math:`Y`. + y + A sample from the observation distribution :math:`y \sim Y`. Returns ------- - (outputs, updates): tuple - A symbolic description of the sampling result to be used to - compile a sampling function. + A sample from the posterior :math:`\beta \mid y`. Notes ----- The :math:`z` expression in Section 2.2 of [1]_ seems to omit division by the Polya-Gamma auxilliary variables whereas [2]_ and [3]_ - explicitely include it. We found that including the division results in + explicitly include it. We found that including the division results in accurate posterior samples for the regression coefficients. It is also worth noting that the :math:`\sigma^2` parameter is not sampled directly in the negative binomial regression problem and thus set to 1 [2]_. @@ -513,153 +511,83 @@ def nbinom_horseshoe_gibbs( 2019 September ; 14(3): 829–855. doi:10.1214/18-ba1132. """ + # This effectively introduces a new term, `w`, to the model. + # TODO: We could/should generate a graph that uses this scale-mixture + # "expanded" form and find/create the posteriors from there + w = srng.gen(polyagamma, y + h, X @ beta) + z = 0.5 * (y - h) / w - def nbinom_horseshoe_step( - beta: TensorVariable, - lmbda: TensorVariable, - tau: TensorVariable, - y: TensorVariable, - X: TensorVariable, - h: TensorVariable, - ) -> Tuple[TensorVariable, TensorVariable, TensorVariable]: - """Complete one full update of the Gibbs sampler and return the new state - of the posterior conditional parameters. - - Parameters - ---------- - beta: Tensorvariable - Coefficients (other than intercept) of the regression model. - lmbda - Inverse of the local shrinkage parameter of the horseshoe prior. - tau - Inverse of the global shrinkage parameters of the horseshoe prior. - y: TensorVariable - The observed count data. - X: TensorVariable - The covariate matrix. - h: TensorVariable - The "number of successes" parameter of the negative binomial disribution - used to model the data. - - """ - xb = X @ beta - w = srng.gen(polyagamma, y + h, xb) - z = 0.5 * (y - h) / w - - lmbda_inv = 1.0 / lmbda - tau_inv = 1.0 / tau - beta_new = update_beta(srng, w, lmbda_inv * tau_inv, X, z) - - lmbda_inv_new, tau_inv_new = horseshoe_step( - srng, beta_new, 1.0, lmbda_inv, tau_inv - ) - return beta_new, 1.0 / lmbda_inv_new, 1.0 / tau_inv_new - - h, X, beta_rv, lmbda_rv, tau_rv = nbinom_horseshoe_match(Y_rv) + tau_beta = at.reciprocal(beta_std) - outputs, updates = aesara.scan( - nbinom_horseshoe_step, - outputs_info=[beta_rv, lmbda_rv, tau_rv], - non_sequences=[y, X, h], - n_steps=num_samples, - strict=True, - ) + beta_posterior = normal_regression_posterior(srng, w, tau_beta, X, z) - return outputs, updates + if beta.name: + beta_posterior.name = f"{beta.name}_posterior" + return beta_posterior -def nbinom_horseshoe_gibbs_with_dispersion( - srng: RandomStream, - Y_rv: TensorVariable, - y: TensorVariable, - num_samples: TensorVariable, -) -> Tuple[Union[TensorVariable, List[TensorVariable]], Mapping[Variable, Variable]]: - r"""Build a Gibbs sampler for the negative binomial regression with a horseshoe prior and gamma prior dispersion. - This is a direct extension of `nbinom_horseshoe_gibbs_with_dispersion` that - adds a gamma prior assumption to the :math:`h` parameter in the - negative-binomial and samples according to [1]_. +@sampler_finder([NegBinomialRV]) +def nbinom_logistic_finder(fgraph, node, srng): + r"""Find and construct a Gibbs sampler for a negative-binomial logistic-link regression. - In other words, this model is the same as `nbinom_horseshoe_gibbs` except - for the addition assumption: + The implementation follows the sampler described in `nbinom_normal_posterior`. It is designed to + sample efficiently from the following negative binomial regression model: .. math:: - \begin{gather*} + \begin{align*} + Y &\sim \operatorname{NB}\left(h, p\right) \\ + p &= \frac{\exp(\psi)}{1 + \exp(\psi)} \\ + \psi &= X^\top \beta \\ + \beta_j &\sim \operatorname{N}(0, \lambda_j^2) \\ h \sim \operatorname{Gamma}\left(a, b\right) - \end{gather*} + \end{align*} + If :math:`h` doesn't take the above form, a sampler is produced with steps + for all the other terms; otherwise, sampling for :math:`h` is performed + in accordance with [1]_. References ---------- - .. [1] Zhou, Mingyuan, Lingbo Li, David Dunson, and Lawrence Carin. 2012. “Lognormal and Gamma Mixed Negative Binomial Regression.” Proceedings of the International Conference on Machine Learning. International Conference on Machine Learning 2012: 1343–50. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4180062/. + .. [1] Zhou, Mingyuan, Lingbo Li, David Dunson, and Lawrence Carin. 2012. + Lognormal and Gamma Mixed Negative Binomial Regression. + Proceedings of the International Conference on Machine Learning. + 2012: 1343–50. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4180062/. """ - def nbinom_horseshoe_step( - beta: TensorVariable, - lmbda: TensorVariable, - tau: TensorVariable, - h: TensorVariable, - y: TensorVariable, - X: TensorVariable, - a: TensorVariable, - b: TensorVariable, - ): - """Complete one full update of the Gibbs sampler and return the new state - of the posterior conditional parameters. - - Parameters - ---------- - beta - Coefficients (other than intercept) of the regression model. - lmbda - Inverse of the local shrinkage parameter of the horseshoe prior. - tau - Inverse of the global shrinkage parameters of the horseshoe prior. - h - The "number of successes" parameter of the negative-binomial distribution - used to model the data. - y - The observed count data. - X - The covariate matrix. - a - The shape parameter for the :math:`h` gamma prior. - b - The rate parameter for the :math:`h` gamma prior. - - """ - xb = X @ beta - w = srng.gen(polyagamma, y + h, xb) - z = 0.5 * (y - h) / w - - lmbda_inv = 1.0 / lmbda - tau_inv = 1.0 / tau - beta_new = update_beta(srng, w, lmbda_inv * tau_inv, X, z) - - lmbda_inv_new, tau_inv_new = horseshoe_step( - srng, beta_new, 1.0, lmbda_inv, tau_inv - ) - eta = X @ beta_new - p = at.sigmoid(-eta) - h_new, h_updates = h_step(srng, h, p, a, b, y) + y = node.outputs[1] - return (beta_new, 1.0 / lmbda_inv_new, 1.0 / tau_inv_new, h_new), h_updates + try: + X, h, beta_rv = nbinom_sigmoid_dot_match(node.outputs[1]) + except ValueError: # pragma: no cover + return None - X, beta_rv, lmbda_rv, tau_rv, h_rv, a, b = nbinom_horseshoe_with_dispersion_match( - Y_rv + beta_posterior = nbinom_normal_posterior( + srng, beta_rv, beta_rv.owner.inputs[4], X, h, y ) - outputs, updates = aesara.scan( - nbinom_horseshoe_step, - outputs_info=[beta_rv, lmbda_rv, tau_rv, h_rv], - non_sequences=[y, X, a, b], - n_steps=num_samples, - strict=True, - ) + res: List[ + Tuple[TensorVariable, TensorVariable, Optional[Mapping[Variable, Variable]]] + ] = [(beta_rv, beta_posterior, None)] + + # TODO: Should this be in a separate rewriter? + try: + a, b = gamma_match(h) + except ValueError: # pragma: no cover + return res - return outputs, updates + p = at.sigmoid(-X @ beta_posterior) + + h_posterior, updates = nbinom_dispersion_posterior(srng, h, p, a, b, y) + + res.append((h, h_posterior, updates)) + + return res + + +gibbs_db.register("nbinom_logistic_regression", nbinom_logistic_finder, "basic") bernoulli_sigmoid_dot_pattern = etuple( @@ -667,11 +595,12 @@ def nbinom_horseshoe_step( ) -def bernoulli_sigmoid_dot_match( +def bern_sigmoid_dot_match( graph: TensorVariable, ) -> Tuple[TensorVariable, TensorVariable]: - graph_opt = optimize_graph(graph) - graph_et = etuplize(graph_opt) + + graph_et = etuplize(graph) + s = unify(graph_et, bernoulli_sigmoid_dot_pattern) if s is False: raise ValueError("Not a Bernoulli regression.") @@ -688,31 +617,14 @@ def bernoulli_sigmoid_dot_match( return X, beta_rv -def bernoulli_horseshoe_model(srng: RandomStream) -> TensorVariable: - """Bernoulli regression model with a horseshoe shrinkage prior.""" - X = at.matrix("X") - - beta_rv = horseshoe_model(srng) - eta = X @ beta_rv - p = at.sigmoid(-eta) - Y_rv = srng.bernoulli(p) - - return Y_rv - - -def bernoulli_horseshoe_match( - Y_rv: TensorVariable, -) -> Tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable]: - X, beta_rv = bernoulli_sigmoid_dot_match(Y_rv) - lmbda_rv, tau_rv = horseshoe_match(beta_rv) - - return X, beta_rv, lmbda_rv, tau_rv - - -def bernoulli_horseshoe_gibbs( - srng: RandomStream, Y_rv: TensorVariable, y: TensorVariable, num_samples: int -) -> Tuple[Union[TensorVariable, List[TensorVariable]], Dict]: - r"""Build a Gibbs sampler for Bernoulli (logistic) regression with a horseshoe prior. +def bern_normal_posterior( + srng: RandomStream, + beta: TensorVariable, + beta_std: TensorVariable, + X: TensorVariable, + y: TensorVariable, +) -> Tuple[TensorVariable, TensorVariable, TensorVariable]: + r"""Produce a Gibbs sample step for a bernoulli logistic-link regression with a normal prior. The implementation follows the sampler described in [1]_. It is designed to sample efficiently from the following binary logistic regression model: @@ -720,88 +632,66 @@ def bernoulli_horseshoe_gibbs( .. math:: \begin{align*} - Y_i &\sim \operatorname{Bern}\left(p_i\right) \\ - p_i &= \frac{1}{1 + \exp\left(-(\beta_0 + x_i^\top \beta)\right)} \\ - \beta_j &\sim \operatorname{N}(0, \lambda_j^2 \tau^2) \\ - \lambda_j &\sim \operatorname{HalfCauchy}(0, 1) \\ - \tau &\sim \operatorname{HalfCauchy}(0, 1) + Y &\sim \operatorname{Bern}\left( p \right) \\ + p &= \frac{1}{1 + \exp\left( -X^\top \beta\right)} \\ + \beta_j &\sim \operatorname{N}\right( 0, \lambda_j^2 \right) \end{align*} + Parameters ---------- - srng - The random number generating object to be used during sampling. - Y_rv - Model graph. - y - The observed binary data. + beta + The current/previous value of the regression parameter :math:`beta`. + beta_std + The std. dev. of the regression parameter :math:`beta`. X - The covariate matrix. - n_samples - A tensor describing the number of posterior samples to generate. + The regression matrix. + y + A sample from the observation distribution :math:`y \sim Y`. Returns ------- - (outputs, updates): tuple - A symbolic description of the sampling result to be used to - compile a sampling function. - + A sample from :math:`\beta \mid y`. References ---------- - .. [1] Makalic, Enes & Schmidt, Daniel. (2015). A Simple Sampler for the - Horseshoe Estimator. 10.1109/LSP.2015.2503725. - .. [2] Makalic, Enes & Schmidt, Daniel. (2016). High-Dimensional Bayesian + .. [1] Makalic, Enes & Schmidt, Daniel. (2016). High-Dimensional Bayesian Regularised Regression with the BayesReg Package. """ - def bernoulli_horseshoe_step( - beta: TensorVariable, - lmbda: TensorVariable, - tau: TensorVariable, - y: TensorVariable, - X: TensorVariable, - ) -> Tuple[TensorVariable, TensorVariable, TensorVariable]: - """Complete one full update of the Gibbs sampler and return the new - state of the posterior conditional parameters. - - Parameters - ---------- - beta - Coefficients (other than intercept) of the regression model. - lmbda - Square of the local shrinkage parameter of the horseshoe prior. - tau - Square of the global shrinkage parameters of the horseshoe prior. - y - The observed binary data - X - The covariate matrix. - - """ - xb = X @ beta - w = srng.gen(polyagamma, 1, xb) - z = 0.5 * y / w - - lmbda_inv = 1.0 / lmbda - tau_inv = 1.0 / tau - beta_new = update_beta(srng, w, lmbda_inv * tau_inv, X, z) - - lmbda_inv_new, tau_inv_new = horseshoe_step( - srng, beta_new, 1.0, lmbda_inv, tau_inv - ) + w = srng.gen(polyagamma, 1, X @ beta) + z = (y - 0.5) / w - return beta_new, 1 / lmbda_inv_new, 1.0 / tau_inv_new + tau_beta = at.reciprocal(beta_std) - X, beta_rv, lmbda_rv, tau_rv = bernoulli_horseshoe_match(Y_rv) + beta_posterior = normal_regression_posterior(srng, w, tau_beta, X, z) - outputs, updates = aesara.scan( - bernoulli_horseshoe_step, - outputs_info=[beta_rv, lmbda_rv, tau_rv], - non_sequences=[y, X], - n_steps=num_samples, - strict=True, - ) + if beta.name: + beta_posterior.name = f"{beta.name}_posterior" + + return beta_posterior + + +@sampler_finder([BernoulliRV]) +def bern_logistic_finder(fgraph, node, srng): + r"""Find and construct a Gibbs sampler for a negative binomial logistic-link regression.""" + + y = node.outputs[1] - return outputs, updates + try: + X, beta_rv = bern_sigmoid_dot_match(node.outputs[1]) + except ValueError: # pragma: no cover + return None + + beta_posterior = bern_normal_posterior(srng, beta_rv, beta_rv.owner.inputs[4], X, y) + + return [(beta_rv, beta_posterior, None)] + + +gibbs_db.register("bern_logistic_finder", bern_logistic_finder, "basic") + + +sampler_finder_db.register( + "gibbs_db", in2out(gibbs_db.query("+basic"), name="gibbs"), "basic" +) diff --git a/aemcmc/opt.py b/aemcmc/opt.py new file mode 100644 index 0000000..e5a303f --- /dev/null +++ b/aemcmc/opt.py @@ -0,0 +1,388 @@ +from collections.abc import Mapping +from functools import wraps +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union + +from aeppl.opt import PreserveRVMappings +from aesara.compile.builders import OpFromGraph +from aesara.compile.mode import optdb +from aesara.graph.basic import Apply, Variable, clone_replace, io_toposort +from aesara.graph.features import AlreadyThere, Feature +from aesara.graph.fg import FunctionGraph +from aesara.graph.op import Op +from aesara.graph.opt import in2out, local_optimizer +from aesara.graph.optdb import SequenceDB +from aesara.tensor.basic_opt import ShapeFeature +from aesara.tensor.elemwise import DimShuffle, Elemwise +from aesara.tensor.random.op import RandomVariable +from aesara.tensor.random.utils import RandomStream +from aesara.tensor.var import TensorVariable +from cons.core import _car +from unification.core import _unify + +SamplerFunctionReturnType = Optional[ + Iterable[Tuple[Variable, Variable, Union[Dict[Variable, Variable]]]] +] +SamplerFunctionType = Callable[ + [FunctionGraph, Apply, RandomStream], SamplerFunctionReturnType +] +LocalOptimizerReturnType = Optional[Union[Dict[Variable, Variable], Sequence[Variable]]] + +sampler_ir_db = SequenceDB() +sampler_ir_db.name = "sampler_ir_db" +sampler_ir_db.register( + "sampler_canonicalize", + optdb.query("+canonicalize"), + "basic", +) + +sampler_rewrites_db = SequenceDB() +sampler_rewrites_db.name = "sampler_rewrites_db" + +sampler_finder_db = SequenceDB() +sampler_finder_db.name = "sampler_finder_db" + +sampler_rewrites_db.register( + "sampler_finders", + sampler_finder_db, + "basic", + position=0, +) + + +def construct_ir_fgraph( + obs_rvs_to_values: Dict[Variable, Variable] +) -> Tuple[ + FunctionGraph, + Dict[Variable, Variable], + Dict[Variable, Variable], + Dict[Variable, Variable], +]: + r"""Construct a `FunctionGraph` in measurable IR form for the keys in `obs_rvs_to_values`. + + Returns + ------- + A `FunctionGraph` of the measurable IR, a copy of `obs_rvs_to_values` containing + the new, cloned versions of the original variables in `obs_rvs_to_values`, and + a ``dict`` mapping all the original variables to their cloned values in + `FunctionGraph`. + """ + memo = {v: v for v in obs_rvs_to_values.values()} + + rv_outputs = tuple( + node.outputs[1] + for node in io_toposort([], list(obs_rvs_to_values.keys())) + if isinstance(node.op, RandomVariable) + ) + + observed_vars = tuple(obs_rvs_to_values.keys()) + + assert all(obs_rv in obs_rvs_to_values for obs_rv in observed_vars) + + fgraph = FunctionGraph( + outputs=rv_outputs, + clone=True, + memo=memo, + copy_orphans=False, + copy_inputs=False, + features=[ShapeFeature(), PreserveRVMappings(obs_rvs_to_values)], + ) + + # Update `obs_rvs_to_values` so that it uses the new cloned variables + obs_rvs_to_values = {memo[k]: v for k, v in obs_rvs_to_values.items()} + + sampler_ir_db.query("+basic").optimize(fgraph) + + new_to_old_rvs = { + new_rv: old_rv for old_rv, new_rv in zip(rv_outputs, fgraph.outputs) + } + + return fgraph, obs_rvs_to_values, memo, new_to_old_rvs + + +class SamplerTracker(Feature): + """A `Feature` that tracks potential sampler steps in a graph.""" + + def __init__(self, srng: RandomStream): + self.srng: RandomStream = srng + # Maps variables to a list of tuples that each provide a description of + # the posterior step, the posterior step's graph/output variable, and + # any updates generated for the posterior step + self.rvs_to_samplers: Dict[ + TensorVariable, + List[Tuple[str, TensorVariable, Optional[Dict[Variable, Variable]]]], + ] = {} + self.rvs_seen: Set[TensorVariable] = set() + + def on_attach(self, fgraph: FunctionGraph): + if hasattr(fgraph, "sampler_mappings"): # pragma: no cover + raise AlreadyThere( + f"{fgraph} already has the `SamplerTracker` feature attached." + ) + + fgraph.sampler_mappings = self + + +def sampler_finder(tracks: Optional[Sequence[Union[Op, type]]]): + """Construct a `LocalOptimizer` that identifies sample steps. + + This is a decorator that is used as follows: + + @sampler_finder([NormalRV]) + def local_horseshoe_posterior(fgraph, node, srng): + # Determine if this normal is the root of a Horseshoe + # prior graph. + ... + # If it is, construct the posterior steps for its parameters and + # return them as a list of tuples like `(rv, posterior_rv, updates)`. + ... + return [(lambda_rv, lambda_posterior, None), (tau_rv, tau_posterior, None)] + + """ + + def decorator(f: SamplerFunctionType): + @local_optimizer(tracks) + @wraps(f) + def sampler_finder( + fgraph: FunctionGraph, node: Apply + ) -> LocalOptimizerReturnType: + sampler_mappings = getattr(fgraph, "sampler_mappings", None) + + # TODO: This assumes that `node` is a `RandomVariable`-generated `Apply` node + rv_var = node.outputs[1] + key = (f.__name__, rv_var) + + if sampler_mappings is None or key in sampler_mappings.rvs_seen: + return None # pragma: no cover + + srng = sampler_mappings.srng + + rvs_and_posteriors: SamplerFunctionReturnType = f(fgraph, node, srng) + + if not rvs_and_posteriors: + return None # pragma: no cover + + for rv, posterior_rv, updates in rvs_and_posteriors: + sampler_mappings.rvs_to_samplers.setdefault(rv, []).append( + (f.__name__, posterior_rv, updates) + ) + sampler_mappings.rvs_seen.add(key) + + return rv_var.owner.outputs + + return sampler_finder + + return decorator + + +class SubsumingElemwise(OpFromGraph, Elemwise): + r"""A class representing an `Elemwise` with `DimShuffle`\ed arguments.""" + + def __init__(self, inputs, outputs, *args, **kwargs): + # TODO: Mock the `Elemwise` interface just enough for our purposes + self.elemwise_op = outputs[0].owner.op + self.scalar_op = self.elemwise_op.scalar_op + self.nfunc_spec = self.elemwise_op.nfunc_spec + self.inplace_pattern = self.elemwise_op.inplace_pattern + # self.destroy_map = self.elemwise_op.destroy_map + self.ufunc = None + self.nfunc = None + OpFromGraph.__init__(self, inputs, outputs, *args, **kwargs) + + def make_node(self, *inputs): + node = super().make_node(*inputs) + # Remove shared variable inputs. We aren't going to compute anything + # with this `Op`, so they're not needed + real_inputs = node.inputs[: len(node.inputs) - len(self.shared_inputs)] + return Apply(self, real_inputs, [o.clone() for o in node.outputs]) + + def perform(self, *args, **kwargs): + raise NotImplementedError( # pragma: no cover + "This `OpFromGraph` should have been in-line expanded." + ) + + def clone(self): + res = OpFromGraph.clone(self) + res.elemwise_op = self.elemwise_op + res.scalar_op = self.scalar_op + res.nfunc_spec = self.nfunc_spec + res.inplace_pattern = self.inplace_pattern + # res.destroy_map = self.destroy_map + res.ufunc = self.ufunc + res.nfunc = self.nfunc + return res + + def __str__(self): + return repr(self) + + def __repr__(self): + return f"{type(self).__name__}{{{self.scalar_op}}}" + + def __eq__(self, other): + return OpFromGraph.__eq__(self, other) + + def __hash__(self): + return OpFromGraph.__hash__(self) + + +def _unify_SubsumingElemwise(u: Elemwise, v: SubsumingElemwise, s: Mapping): + yield _unify(u, v.elemwise_op, s) + + +_unify.add( + (Elemwise, SubsumingElemwise, Mapping), + lambda u, v, s: _unify_SubsumingElemwise(u, v, s), +) +_unify.add( + (SubsumingElemwise, Elemwise, Mapping), + lambda u, v, s: _unify_SubsumingElemwise(v, u, s), +) +_unify.add( + (SubsumingElemwise, SubsumingElemwise, Mapping), + lambda u, v, s: _unify(v.elemwise_op, u.elemwise_op, s), +) + + +def car_SubsumingElemwise(x): + return type(x.elemwise_op) + + +_car.add((SubsumingElemwise,), car_SubsumingElemwise) + + +@local_optimizer([Elemwise]) +def local_elemwise_dimshuffle_subsume(fgraph, node): + r"""This rewrite converts `DimShuffle`s in the `Elemwise` inputs into a single `Op`. + + The replacement rule is + + .. math: + + \frac{ + \operatorname{Elemwise}_{o}\left( + \operatorname{DimShuffle}_{z_i}(x_i), \dots + \right) + }{ + \operatorname{OpFromGraph}_{\operatorname{Elemwise}_{o}\left( + \operatorname{DimShuffle}_{z_i}(y_i), \dots + \right)}\left( + x_i, \dots + \right) + } + //, \quad + // x_i \text{ is a } \operatorname{RandomVariable} + + where :math:`o` is a scalar `Op`, :math:`z_i` are the `DimShuffle` settings + for the inputs at index :math:`i`. + + """ + + if isinstance(node.op, SubsumingElemwise): + return None + + new_inputs = [] + subsumed_inputs = [] + + out_ndim = node.outputs[0].type.ndim + + found_subsumable_ds = False + for i in node.inputs: + if i.owner and isinstance(i.owner.op, DimShuffle): + # TODO FIXME: Only do this when the `DimShuffle`s are adding + # broadcastable dimensions. If they're doing more + # (e.g. transposing), separate the broadcasting from everything + # else. + ds_order = i.owner.op.new_order + dim_shuffle_input = i.owner.inputs[0] + + ndim_diff = out_ndim - dim_shuffle_input.type.ndim + + # The `DimShuffle`ing added by `Elemwise` + el_order = ds_order[:ndim_diff] + # The remaining `DimShuffle`ing that was added by something else + new_ds_order = ds_order[ndim_diff:] + + # Only consider broadcast dimensions added on the left as + # having come from `Elemwise.make_node` + if len(el_order) == 0 or not all(d == "x" for d in el_order): + # In this case, the necessary broadcast elements were most + # likely not added by `Elemwise.make_node` (e.g. broadcasts are + # interspersed with transposes, or there are none at all), so + # we don't want to mess with them. + # TODO: We could still subsume some of these `DimShuffle`s, + # though + subsumed_inputs.append(i) + new_inputs.append(i) + continue + + # if dim_shuffle_input.owner and isinstance( + # dim_shuffle_input.owner.op, RandomVariable + # ): + found_subsumable_ds = True + + if new_ds_order and not new_ds_order == tuple(range(len(new_ds_order))): + # The remaining `DimShuffle`ing is substantial, so we need to + # apply it separately + new_dim_shuffle_input = dim_shuffle_input.dimshuffle(new_ds_order) + new_subsumed_input = new_dim_shuffle_input.dimshuffle( + el_order + tuple(range(new_dim_shuffle_input.type.ndim)) + ) + + subsumed_inputs.append(new_subsumed_input) + new_inputs.append(new_dim_shuffle_input) + else: + subsumed_inputs.append(i) + new_inputs.append(dim_shuffle_input) + + else: + subsumed_inputs.append(i) + new_inputs.append(i) + + if not found_subsumable_ds: + return None # pragma: no cover + + assert len(subsumed_inputs) == len(node.inputs) + new_outputs = node.op.make_node(*subsumed_inputs).outputs + new_op = SubsumingElemwise(new_inputs, new_outputs, inline=True) + + new_out = new_op(*new_inputs) + + assert len(new_out.owner.inputs) == len(node.inputs) + + return new_out.owner.outputs + + +sampler_ir_db.register( + "elemwise_dimshuffle_subsume", + in2out(local_elemwise_dimshuffle_subsume), + "basic", + position=-10, +) + + +@local_optimizer([Elemwise]) +def inline_SubsumingElemwise(fgraph, node): + + op = node.op + + if not isinstance(op, SubsumingElemwise): + return False + + if not op.is_inline: + return False # pragma: no cover + + res = clone_replace( + op.inner_outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)} + ) + + return res + + +expand_subsumptions = in2out(inline_SubsumingElemwise) + +# This step undoes `elemwise_dimshuffle_subsume` +sampler_rewrites_db.register( + "expand_subsumptions", + expand_subsumptions, + "basic", + position="last", +) diff --git a/tests/test_basic.py b/tests/test_basic.py new file mode 100644 index 0000000..c2d464e --- /dev/null +++ b/tests/test_basic.py @@ -0,0 +1,132 @@ +import aesara +import aesara.tensor as at +import numpy as np +import pytest +from aesara.graph.basic import graph_inputs, io_toposort +from aesara.ifelse import IfElse +from aesara.tensor.random import RandomStream +from aesara.tensor.random.basic import BetaRV, ExponentialRV, GammaRV + +from aemcmc.basic import construct_sampler +from aemcmc.opt import SubsumingElemwise + + +def test_closed_form_posterior(): + srng = RandomStream(0) + + alpha_tt = at.scalar("alpha") + beta_tt = at.scalar("beta") + p_rv = srng.beta(alpha_tt, beta_tt, name="p") + + n_tt = at.iscalar("n") + Y_rv = srng.binomial(n_tt, p_rv, name="Y") + + y_vv = Y_rv.clone() + y_vv.name = "y" + + sample_steps, updates, initial_values = construct_sampler({Y_rv: y_vv}, srng) + + p_posterior_step = sample_steps[p_rv] + assert isinstance(p_posterior_step.owner.op, BetaRV) + + +def test_no_samplers(): + srng = RandomStream(0) + + size = at.lscalar("size") + tau_rv = srng.halfcauchy(0, 1, name="tau") + Y_rv = srng.halfcauchy(0, tau_rv, size=size, name="Y") + + y_vv = Y_rv.clone() + y_vv.name = "y" + + with pytest.raises(NotImplementedError): + construct_sampler({Y_rv: y_vv}, srng) + + +def test_create_gibbs(): + srng = RandomStream(0) + + X = at.matrix("X") + + # Horseshoe `beta_rv` + tau_rv = srng.halfcauchy(0, 1, name="tau") + lmbda_rv = srng.halfcauchy(0, 1, size=X.shape[1], name="lambda") + beta_rv = srng.normal(0, lmbda_rv * tau_rv, size=X.shape[1], name="beta") + + a = at.scalar("a") + b = at.scalar("b") + h_rv = srng.gamma(a, b, name="h") + + # Negative-binomial regression + eta = X @ beta_rv + p = at.sigmoid(-eta) + Y_rv = srng.nbinom(h_rv, p, name="Y") + + y_vv = Y_rv.clone() + y_vv.name = "y" + + sample_vars = [tau_rv, lmbda_rv, beta_rv, h_rv] + + sample_steps, updates, initial_values = construct_sampler({Y_rv: y_vv}, srng) + + assert len(sample_steps) == 4 + assert updates + + tau_post_step = sample_steps[tau_rv] + assert isinstance(tau_post_step.owner.op, GammaRV) + + lmbda_post_step = sample_steps[lmbda_rv] + assert isinstance(lmbda_post_step.owner.op, ExponentialRV) + + beta_post_step = sample_steps[beta_rv] + assert isinstance(beta_post_step.owner.op, IfElse) + + assert y_vv in graph_inputs([beta_post_step]) + + inputs = [X, a, b, y_vv] + [initial_values[rv] for rv in sample_vars] + outputs = [sample_steps[rv] for rv in sample_vars] + + subsuming_elemwises = [ + n for n in io_toposort([], outputs) if isinstance(n.op, SubsumingElemwise) + ] + assert not any(subsuming_elemwises) + + sample_step = aesara.function( + inputs, + outputs, + updates=updates, + on_unused_input="ignore", + ) + + rng = np.random.default_rng(2309) + + X_val = rng.normal(0, 1, size=(10, 10)) + X_val = X_val.dot(X_val.T) + X_val = X_val[:, :2] + a_val, b_val = 1.0, 10.0 + beta_true = beta_val = np.array([1.0, 0.5]) + tau_val, lmbda_val, h_val = 1.0, np.zeros(2), 10.0 + + y_fn = aesara.function([X, a, b, beta_rv], Y_rv) + y_val = y_fn(X_val, a_val, b_val, beta_val) + + tau_pst_val, lmbda_pst_val, beta_pst_val, h_pst_val = ( + tau_val, + lmbda_val, + beta_val, + h_val, + ) + for i in range(10): + tau_pst_val, lmbda_pst_val, beta_pst_val, h_pst_val = sample_step( + X_val, + a_val, + b_val, + y_val, + tau_pst_val, + lmbda_pst_val, + beta_pst_val, + h_pst_val, + ) + + assert np.allclose(beta_pst_val, beta_true, rtol=1e-1) diff --git a/tests/test_conjugates.py b/tests/test_conjugates.py index fbae116..5bfda32 100644 --- a/tests/test_conjugates.py +++ b/tests/test_conjugates.py @@ -26,7 +26,7 @@ def test_beta_binomial_conjugate_contract(): y_vv.tag.name = "y" q_lv = var() - (posterior_expr,) = run(1, q_lv, beta_binomial_conjugateo((Y_rv, y_vv), q_lv)) + (posterior_expr,) = run(1, q_lv, beta_binomial_conjugateo(y_vv, Y_rv, q_lv)) posterior = eval_if_etuple(posterior_expr) assert isinstance(posterior.owner.op, type(at.random.beta)) diff --git a/tests/test_gibbs.py b/tests/test_gibbs.py index 8cd636b..a796572 100644 --- a/tests/test_gibbs.py +++ b/tests/test_gibbs.py @@ -4,24 +4,22 @@ import pytest import scipy.special from aesara.graph.basic import equal_computations +from aesara.graph.opt_utils import optimize_graph from aesara.tensor.random.utils import RandomStream from scipy.linalg import toeplitz from aemcmc.gibbs import ( - bernoulli_horseshoe_gibbs, - bernoulli_horseshoe_match, - bernoulli_horseshoe_model, + bern_normal_posterior, + bern_sigmoid_dot_match, gamma_match, - h_step, horseshoe_match, - horseshoe_model, - nbinom_horseshoe_gibbs, - nbinom_horseshoe_gibbs_with_dispersion, - nbinom_horseshoe_match, - nbinom_horseshoe_model, - nbinom_horseshoe_with_dispersion_match, + horseshoe_posterior, + nbinom_dispersion_posterior, + nbinom_normal_posterior, + normal_regression_posterior, sample_CRT, ) +from aemcmc.opt import SamplerTracker, construct_ir_fgraph, sampler_rewrites_db @pytest.fixture @@ -37,31 +35,33 @@ def test_horseshoe_match(srng): lmbda_rv = srng.halfcauchy(0, 1, size=size, name="lambda") beta_rv = srng.normal(0, lmbda_rv * tau_rv, size=size, name="beta") + fgraph, _, memo, _ = construct_ir_fgraph({beta_rv: beta_rv}) + beta_rv = fgraph.outputs[-1] + lambda_res, tau_res = horseshoe_match(beta_rv) - assert lambda_res is lmbda_rv - assert tau_res is tau_rv + assert lambda_res is memo[lmbda_rv] + assert tau_res is memo[tau_rv] # Scalar tau tau_rv = srng.halfcauchy(0, 1, name="tau") lmbda_rv = srng.halfcauchy(0, 1, size=size, name="lambda") beta_rv = srng.normal(0, lmbda_rv * tau_rv, size=size, name="beta") + fgraph, _, memo, _ = construct_ir_fgraph({beta_rv: beta_rv}) + beta_rv = fgraph.outputs[-1] + lambda_res, tau_res = horseshoe_match(beta_rv) - assert lambda_res is lmbda_rv + assert lambda_res is memo[lmbda_rv] # `tau_res` should've had its `DimShuffle` lifted, so it's not identical to `tau_rv` assert isinstance(tau_res.owner.op, type(tau_rv.owner.op)) - assert tau_res.type.ndim == 1 + assert tau_res.type.ndim == 0 - -def test_horseshoe_match_wrong_graph(srng): beta_rv = srng.normal(0, 1) with pytest.raises(ValueError): horseshoe_match(beta_rv) - -def test_horseshoe_match_wrong_local_scale_dist(srng): size = at.scalar("size", dtype="int32") tau_rv = srng.halfcauchy(0, 1, size=1) lmbda_rv = srng.normal(0, 1, size=size) @@ -69,8 +69,6 @@ def test_horseshoe_match_wrong_local_scale_dist(srng): with pytest.raises(ValueError): horseshoe_match(beta_rv) - -def test_horseshoe_match_wrong_global_scale_dist(srng): size = at.scalar("size", dtype="int32") tau_rv = srng.normal(0, 1, size=1) lmbda_rv = srng.halfcauchy(0, 1, size=size) @@ -78,8 +76,6 @@ def test_horseshoe_match_wrong_global_scale_dist(srng): with pytest.raises(ValueError): horseshoe_match(beta_rv) - -def test_horseshoe_match_wrong_dimensions(srng): size = at.scalar("size", dtype="int32") tau_rv = srng.halfcauchy(0, 1, size=size) lmbda_rv = srng.halfcauchy(0, 1, size=size) @@ -89,69 +85,56 @@ def test_horseshoe_match_wrong_dimensions(srng): horseshoe_match(beta_rv) -def test_match_nbinom_horseshoe(srng): - nbinom_horseshoe_match(nbinom_horseshoe_model(srng)) - - -def test_match_binom_horseshoe_wrong_graph(srng): - beta = at.vector("beta") - X = at.matrix("X") - Y = X @ beta - - with pytest.raises(ValueError): - nbinom_horseshoe_match(Y) - - -def test_match_nbinom_horseshoe_wrong_sign(srng): - X = at.matrix("X") - h = at.scalar("h") - - beta_rv = horseshoe_model(srng) - eta = X @ beta_rv - p = at.sigmoid(2 * eta) - Y_rv = srng.nbinom(h, p) - - with pytest.raises(ValueError): - nbinom_horseshoe_match(Y_rv) - +@pytest.mark.parametrize( + "N, p, nonzero_atol", + [ + (50, 10, np.array([1.0, 0.5, 0.5, 3e-1, 3e-1])), + (50, 55, np.array([1.5, 0.5, 0.5, 0.7, 3e-1])), + ], +) +def test_normal_horseshoe_sampler(srng, N, p, nonzero_atol): + """Check the results of a normal regression model with a Horseshoe prior. -def test_horseshoe_nbinom(srng): - """ This test example is modified from section 3.2 of Makalic & Schmidt (2016) + """ - h = 2 - p = 10 - N = 50 + rng = np.random.default_rng(9420) - # generate synthetic data true_beta = np.array([5, 3, 3, 1, 1] + [0] * (p - 5)) S = toeplitz(0.5 ** np.arange(p)) - X = srng.multivariate_normal(np.zeros(p), cov=S, size=N) - y = srng.nbinom(h, at.sigmoid(-(X.dot(true_beta)))) + X = rng.multivariate_normal(np.zeros(p), cov=S, size=N) + y = rng.normal(X @ true_beta, np.ones(N)) - # build the model - tau_rv = srng.halfcauchy(0, 1, size=1) + tau_rv = srng.halfcauchy(0, 1) lambda_rv = srng.halfcauchy(0, 1, size=p) - beta_rv = srng.normal(0, tau_rv * lambda_rv, size=p) - eta_tt = X @ beta_rv - p_tt = at.sigmoid(-eta_tt) - Y_rv = srng.nbinom(h, p_tt) + tau_inv_vv = tau_rv.clone() + lambda_inv_vv = lambda_rv.clone() - # sample from the posterior distributions - num_samples = at.scalar("num_samples", dtype="int32") - outputs, updates = nbinom_horseshoe_gibbs(srng, Y_rv, y, num_samples) - sample_fn = aesara.function((num_samples,), outputs, updates=updates) + beta_post = normal_regression_posterior( + srng, np.ones(N), tau_inv_vv * lambda_inv_vv, at.as_tensor(X), y + ) - beta, lmbda, tau = sample_fn(2000) + lambda_post, tau_post = horseshoe_posterior( + srng, beta_post, 1.0, lambda_inv_vv, tau_inv_vv + ) - assert beta.shape == (2000, p) - assert lmbda.shape == (2000, p) - assert tau.shape == (2000, 1) + outputs = (beta_post, lambda_post, tau_post) + sample_fn = aesara.function((tau_inv_vv, lambda_inv_vv), outputs) - # test distribution domains - assert np.all(tau > 0) - assert np.all(lmbda > 0) + beta_post_vals = [] + lambda_inv_post_val, tau_inv_post_val = np.ones(p), 1.0 + for i in range(3000): + beta_post_val, lambda_inv_post_val, tau_inv_post_val = sample_fn( + tau_inv_post_val, lambda_inv_post_val + ) + beta_post_vals += [beta_post_val] + assert np.all(tau_inv_post_val >= 0) + assert np.all(lambda_inv_post_val >= 0) + + beta_post_median = np.median(beta_post_vals[100::2], axis=0) + assert np.allclose(beta_post_median[:5], true_beta[:5], atol=nonzero_atol) + assert np.all(np.abs(beta_post_median[5:]) < 1) @pytest.mark.parametrize( @@ -183,12 +166,39 @@ def test_sample_CRT_mean(srng, h_val, y_val): assert np.allclose(crt_mean_val, crt_exp_val, rtol=1e-1) -def test_h_step(srng): +def test_nbinom_normal_posterior(srng): + M = 10 + N = 50 true_h = 10 + true_beta = np.array([2, 0.02, 0.2, 0.1, 1] + [0.0] * (M - 5)) + S = toeplitz(0.5 ** np.arange(M)) + X_at = srng.multivariate_normal(np.zeros(M), cov=S, size=N) + p_at = at.sigmoid(-(X_at.dot(true_beta))) + X, p, y = aesara.function([], [X_at, p_at, srng.nbinom(true_h, p_at)])() + + beta_vv = at.vector("beta") + beta_post = nbinom_normal_posterior( + srng, beta_vv, 200 * np.ones(M), at.as_tensor(X), true_h, y + ) + + beta_post_fn = aesara.function([beta_vv], beta_post) + + beta_post_vals = [] + beta_post_val = np.zeros(M) + for i in range(1000): + beta_post_val = beta_post_fn(beta_post_val) + beta_post_vals += [beta_post_val] + + beta_post_mean = np.mean(beta_post_vals, axis=0) + assert np.allclose(beta_post_mean, true_beta, atol=3e-1) + + +def test_nbinom_dispersion_posterior(srng): M = 10 N = 50 + true_h = 10 true_beta = np.array([2, 0.02, 0.2, 0.1, 1] + [0.1] * (M - 5)) S = toeplitz(0.5 ** np.arange(M)) X = srng.multivariate_normal(np.zeros(M), cov=S, size=N) @@ -201,7 +211,8 @@ def test_h_step(srng): b = at.as_tensor(b_val) h_samples, h_updates = aesara.scan( - lambda: h_step(srng, at.as_tensor(true_h), p, a, b, y), n_steps=1000 + lambda: nbinom_dispersion_posterior(srng, at.as_tensor(true_h), p, a, b, y), + n_steps=1000, ) h_mean_fn = aesara.function([], h_samples.mean(), updates=h_updates) @@ -216,128 +227,72 @@ def test_h_step(srng): assert np.allclose(h_mean_val, true_h, rtol=2e-1) -def test_horseshoe_nbinom_w_dispersion(srng): - """ - This test example is modified from section 3.2 of Makalic & Schmidt (2016) - """ - true_h = 10 - M = 10 - N = 50 - - # generate synthetic data - true_beta = np.array([2, 0.02, 0.2, 0.1, 1] + [0.1] * (M - 5)) - S = toeplitz(0.5 ** np.arange(M)) - X_at = srng.multivariate_normal(np.zeros(M), cov=S, size=N) - X, y = aesara.function( - [], [X_at, srng.nbinom(true_h, at.sigmoid(-(X_at.dot(true_beta))))] - )() - X = at.as_tensor(X) - y = at.as_tensor(y) - - # build the model - tau_rv = srng.halfcauchy(0, 1, name="tau") - lambda_rv = srng.halfcauchy(0, 1, size=M, name="lambda") - beta_rv = srng.normal(0, tau_rv * lambda_rv, size=M, name="beta") - - eta_tt = X @ beta_rv - p_tt = at.sigmoid(-eta_tt) - p_tt.name = "p" - - h_rv = srng.gamma(100, 1, name="h") - - Y_rv = srng.nbinom(h_rv, p_tt, name="Y") - - # sample from the posterior distributions - num_samples = at.lscalar("num_samples") - outputs, updates = nbinom_horseshoe_gibbs_with_dispersion( - srng, Y_rv, y, num_samples - ) - - sample_fn = aesara.function((num_samples,), outputs, updates=updates) - - sample_num = 2000 - beta, lmbda, tau, h = sample_fn(sample_num) - - assert beta.shape == (sample_num, M) - assert lmbda.shape == (sample_num, M) - assert tau.shape == (sample_num, 1) - assert h.shape == (sample_num,) - - assert np.all(tau > 0) - assert np.all(lmbda > 0) - assert np.all(h > 0) - - assert np.allclose(h.mean(), true_h, rtol=1e-1) +def test_bern_sigmoid_dot_match(srng): + X = at.matrix("X") + beta_rv = srng.normal(0, 1, size=X.shape[1], name="beta") + eta = X @ beta_rv + p = at.sigmoid(-eta) + Y_rv = srng.bernoulli(p) -def test_match_bernoulli_horseshoe(srng): - bernoulli_horseshoe_match(bernoulli_horseshoe_model(srng)) + Y_rv = optimize_graph(Y_rv) + assert bern_sigmoid_dot_match(Y_rv) -def test_match_bernoulli_horseshoe_wrong_graph(srng): beta = at.vector("beta") X = at.matrix("X") Y = X @ beta with pytest.raises(ValueError): - bernoulli_horseshoe_match(Y) + bern_sigmoid_dot_match(Y) - -def test_match_bernoulli_horseshoe_wrong_sign(srng): X = at.matrix("X") - - beta_rv = horseshoe_model(srng) + beta_rv = srng.normal(0, 1, name="beta") eta = X @ beta_rv p = at.sigmoid(2 * eta) Y_rv = srng.bernoulli(p) with pytest.raises(ValueError): - bernoulli_horseshoe_match(Y_rv) + bern_sigmoid_dot_match(Y_rv) -def test_bernoulli_horseshoe(srng): - p = 10 +def test_bern_normal_posterior(srng): + M = 10 N = 50 - # generate synthetic data - true_beta = np.array([5, 3, 3, 1, 1] + [0] * (p - 5)) - S = toeplitz(0.5 ** np.arange(p)) - X = srng.multivariate_normal(np.zeros(p), cov=S, size=N) - y = srng.bernoulli(at.sigmoid(-X.dot(true_beta))) - - # build the model - tau_rv = srng.halfcauchy(0, 1, size=1) - lambda_rv = srng.halfcauchy(0, 1, size=p) - beta_rv = srng.normal(0, tau_rv * lambda_rv, size=p) - - eta_tt = X @ beta_rv - p_tt = at.sigmoid(-eta_tt) - Y_rv = srng.bernoulli(p_tt) + true_beta = np.array([2, 0.02, 0.2, 0.1, 1] + [0.1] * (M - 5)) + S = toeplitz(0.5 ** np.arange(M)) + X_at = srng.multivariate_normal(np.zeros(M), cov=S, size=N) + p_at = at.sigmoid(X_at.dot(true_beta)) + X, p, y = aesara.function([], [X_at, p_at, srng.bernoulli(p_at)])() - # sample from the posterior distributions - num_samples = at.scalar("num_samples", dtype="int32") - outputs, updates = bernoulli_horseshoe_gibbs(srng, Y_rv, y, num_samples) - sample_fn = aesara.function((num_samples,), outputs, updates=updates) + beta_vv = at.vector("beta") + beta_post = bern_normal_posterior(srng, beta_vv, np.ones(M), at.as_tensor(X), y) - beta, lmbda, tau = sample_fn(2000) + beta_post_fn = aesara.function([beta_vv], beta_post) - assert beta.shape == (2000, p) - assert lmbda.shape == (2000, p) - assert tau.shape == (2000, 1) + beta_post_vals = [] + beta_post_val = np.zeros(M) + for i in range(3000): + beta_post_val = beta_post_fn(beta_post_val) + beta_post_vals += [beta_post_val] - # test distribution domains - assert np.all(tau > 0) - assert np.all(lmbda > 0) + beta_post_mean = np.mean(beta_post_vals, axis=0) + assert np.allclose(beta_post_mean, true_beta, atol=0.7) def test_gamma_match(srng): beta_rv = srng.normal(0, 1) + with pytest.raises(ValueError): gamma_match(beta_rv) a = at.scalar("a") b = at.scalar("b") beta_rv = srng.gamma(a, b) + + beta_rv = optimize_graph(beta_rv) + a_m, b_m = gamma_match(beta_rv) assert a_m is a @@ -346,20 +301,84 @@ def test_gamma_match(srng): assert equal_computations([b_m], [b_exp]) -def test_nbinom_horseshoe_with_dispersion_match(srng): +def test_nbinom_logistic_horseshoe_finders(): + """Make sure `nbinom_logistic_finder` and `normal_horseshoe_finder` work.""" + srng = RandomStream(0) + + X = at.matrix("X") + + # Horseshoe `beta_rv` + tau_rv = srng.halfcauchy(0, 1, name="tau") + lmbda_rv = srng.halfcauchy(0, 1, size=X.shape[1], name="lambda") + beta_rv = srng.normal(0, lmbda_rv * tau_rv, size=X.shape[1], name="beta") + a = at.scalar("a") b = at.scalar("b") + h_rv = srng.gamma(a, b, name="h") + + # Negative-binomial regression + eta = X @ beta_rv + p = at.sigmoid(-eta) + Y_rv = srng.nbinom(h_rv, p, name="Y") + + y_vv = Y_rv.clone() + y_vv.name = "y" + + fgraph, obs_rvs_to_values, memo, new_to_old_rvs = construct_ir_fgraph({Y_rv: y_vv}) + + fgraph.attach_feature(SamplerTracker(srng)) + + _ = sampler_rewrites_db.query("+basic").optimize(fgraph) + + discovered_samplers = fgraph.sampler_mappings.rvs_to_samplers + discovered_samplers = { + new_to_old_rvs[rv]: discovered_samplers.get(rv) + for rv in fgraph.outputs + if rv not in obs_rvs_to_values + } + + assert len(discovered_samplers) == 4 + + assert discovered_samplers[tau_rv][0][0] == "normal_horseshoe_finder" + assert discovered_samplers[lmbda_rv][0][0] == "normal_horseshoe_finder" + assert discovered_samplers[beta_rv][0][0] == "nbinom_logistic_finder" + assert discovered_samplers[h_rv][0][0] == "nbinom_logistic_finder" + + +def test_bern_logistic_horseshoe_finders(): + """Make sure `bern_logistic_finder` and `normal_horseshoe_finder` work.""" + srng = RandomStream(0) + X = at.matrix("X") - beta_rv = horseshoe_model(srng) + # Horseshoe `beta_rv` + tau_rv = srng.halfcauchy(0, 1, name="tau") + lmbda_rv = srng.halfcauchy(0, 1, size=X.shape[1], name="lambda") + beta_rv = srng.normal(0, lmbda_rv * tau_rv, size=X.shape[1], name="beta") + + # Negative-binomial regression eta = X @ beta_rv p = at.sigmoid(-eta) - h = srng.gamma(a, b) - Y_rv = srng.nbinom(h, p) + Y_rv = srng.bernoulli(p, name="Y") - X_m, beta_m, lmbda_m, tau_m, h_m, a_m, b_m = nbinom_horseshoe_with_dispersion_match( - Y_rv - ) + y_vv = Y_rv.clone() + y_vv.name = "y" - assert a_m is a - assert X_m is X + fgraph, obs_rvs_to_values, memo, new_to_old_rvs = construct_ir_fgraph({Y_rv: y_vv}) + + fgraph.attach_feature(SamplerTracker(srng)) + + _ = sampler_rewrites_db.query("+basic").optimize(fgraph) + + discovered_samplers = fgraph.sampler_mappings.rvs_to_samplers + discovered_samplers = { + new_to_old_rvs[rv]: discovered_samplers.get(rv) + for rv in fgraph.outputs + if rv not in obs_rvs_to_values + } + + assert len(discovered_samplers) == 3 + + assert discovered_samplers[tau_rv][0][0] == "normal_horseshoe_finder" + assert discovered_samplers[lmbda_rv][0][0] == "normal_horseshoe_finder" + assert discovered_samplers[beta_rv][0][0] == "bern_logistic_finder" diff --git a/tests/test_opt.py b/tests/test_opt.py new file mode 100644 index 0000000..2e08ec4 --- /dev/null +++ b/tests/test_opt.py @@ -0,0 +1,109 @@ +import aesara.tensor as at +import numpy as np +from aesara.graph.basic import equal_computations +from aesara.tensor.elemwise import DimShuffle, Elemwise +from cons import car, cdr +from etuples import etuple, etuplize +from unification import unify + +from aemcmc.opt import SubsumingElemwise, local_elemwise_dimshuffle_subsume + + +def test_SubsumingElemwise_basics(): + a = at.vector("a") + b = at.scalar("b") + + x = a * b + + assert isinstance(x.owner.op, Elemwise) + b_ds = x.owner.inputs[1].owner.op + assert isinstance(b_ds, DimShuffle) + + ee_mul_op = SubsumingElemwise([a, b], [x]) + + assert ee_mul_op != ee_mul_op.clone() + assert str(ee_mul_op) == "SubsumingElemwise{mul}" + + s = unify(at.mul, ee_mul_op) + assert s is not False + + assert car(ee_mul_op) == car(x.owner.op) + assert cdr(ee_mul_op) == cdr(x.owner.op) + + s = unify(etuplize(at.mul), etuplize(ee_mul_op)) + assert s is not False + + ee_et = etuplize(ee_mul_op(a, b)) + x_et = etuple(etuplize(at.mul), a, b) + + s = unify(ee_et, x_et) + assert s is not False + + # TODO: Consider making this possible + # s = unify(ee_mul(a, b), x) + # assert s is not False + + +def test_local_elemwise_dimshuffle_subsume_basic(): + srng = at.random.RandomStream(2398) + + a = at.vector("a") + b = srng.normal(0, 1, name="b") + + x = a * b + + node = x.owner + assert isinstance(node.op, Elemwise) + b_ds = node.inputs[1].owner.op + assert isinstance(b_ds, DimShuffle) + + (res,) = local_elemwise_dimshuffle_subsume.transform(None, node) + assert isinstance(res.owner.op, SubsumingElemwise) + assert equal_computations( + [res.owner.op.inner_outputs[0]], [x], res.owner.op.inner_inputs[:2], [a, b] + ) + assert res.owner.inputs == [a, b] + + +def test_local_elemwise_dimshuffle_subsume_transpose(): + """Make sure that `local_elemwise_dimshuffle_subsume` is applied selectively.""" + srng = at.random.RandomStream(2398) + + a = at.vector("a") + # This transpose shouldn't be subsumed, but the one applied to `a` by + # `Elemwise.make_node` should + b = srng.normal(at.arange(4).reshape((2, 2)), 1, name="b").T + + x = a * b + + node = x.owner + assert isinstance(node.op, Elemwise) + b_ds = node.inputs[1].owner.op + assert isinstance(b_ds, DimShuffle) + + (res,) = local_elemwise_dimshuffle_subsume.transform(None, node) + assert isinstance(res.owner.op, SubsumingElemwise) + assert equal_computations( + [res.owner.op.inner_outputs[0]], [x], res.owner.op.inner_inputs[:2], [a, b] + ) + assert res.owner.inputs == [a, b] + + a = at.tensor(np.float64, shape=(None, None, None), name="a") + # Again, the transpose part shouldn't be subsumed, but the added broadcast + # dimension should + b = srng.normal(at.arange(4).reshape((2, 2)), 1, name="b") + b_ds = b.dimshuffle(("x", 1, 0)) + + x = a * b_ds + + node = x.owner + assert isinstance(node.op, Elemwise) + b_ds = node.inputs[1].owner.op + assert isinstance(b_ds, DimShuffle) + + (res,) = local_elemwise_dimshuffle_subsume.transform(None, node) + assert isinstance(res.owner.op, SubsumingElemwise) + assert res.owner.inputs[0] == a + # The input corresponding to `b`/`b_ds` should be equivalent to `b.T` + assert isinstance(res.owner.inputs[1].owner.op, DimShuffle) + assert equal_computations([b.T], [res.owner.inputs[1]])