Skip to content

Commit

Permalink
Refactor rewrites and add a general sampler constructor function
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 22, 2022
1 parent 589520d commit 213e0fd
Show file tree
Hide file tree
Showing 10 changed files with 1,227 additions and 524 deletions.
4 changes: 4 additions & 0 deletions aemcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from . import _version

__version__ = _version.get_versions()["version"]

# Register rewrite databases
import aemcmc.conjugates
import aemcmc.gibbs
123 changes: 123 additions & 0 deletions aemcmc/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Dict, Tuple

from aesara.graph.basic import Variable
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.var import TensorVariable

from aemcmc.opt import (
SamplerTracker,
construct_ir_fgraph,
expand_subsumptions,
sampler_rewrites_db,
)


def construct_sampler(
obs_rvs_to_values: Dict[TensorVariable, TensorVariable], srng: RandomStream
) -> Tuple[
Dict[TensorVariable, TensorVariable],
Dict[Variable, Variable],
Dict[TensorVariable, TensorVariable],
]:
r"""Eagerly construct a sampler for a given set of observed variables and their observations.
Parameters
==========
obs_rvs_to_values
A ``dict`` of variables that maps stochastic elements
(e.g. `RandomVariable`\s) to symbolic `Variable`\s representing their
observed values.
Returns
=======
A ``dict`` that maps each random variable to its sampler step and
any updates generated by the sampler steps.
"""

fgraph, obs_rvs_to_values, memo, new_to_old_rvs = construct_ir_fgraph(
obs_rvs_to_values
)

fgraph.attach_feature(SamplerTracker(srng))

_ = sampler_rewrites_db.query("+basic").optimize(fgraph)

random_vars = tuple(rv for rv in fgraph.outputs if rv not in obs_rvs_to_values)

discovered_samplers = fgraph.sampler_mappings.rvs_to_samplers

rvs_to_init_vals = {rv: rv.clone() for rv in random_vars}
posterior_sample_steps = rvs_to_init_vals.copy()
# Replace occurrences of observed variables with their observed values
posterior_sample_steps.update(obs_rvs_to_values)

# TODO FIXME: Get/extract `Scan`-generated updates
posterior_updates: Dict[Variable, Variable] = {}

rvs_without_samplers = set()

for rv in fgraph.outputs:

if rv in obs_rvs_to_values:
continue

rv_steps = discovered_samplers.get(rv)

if not rv_steps:
rvs_without_samplers.add(rv)
continue

# TODO FIXME: Just choosing one for now, but we should consider them all.
step_desc, step, updates = rv_steps.pop()

# Expand subsumed `DimShuffle`d inputs to `Elemwise`s
if updates:
update_keys, update_values = zip(*updates.items())
else:
update_keys, update_values = tuple(), tuple()

sfgraph = FunctionGraph(
outputs=(step,) + tuple(update_keys) + tuple(update_values),
clone=False,
copy_inputs=False,
copy_orphans=False,
)

# Update the other sampled random variables in this step's graph
sfgraph.replace_all(list(posterior_sample_steps.items()), import_missing=True)

expand_subsumptions.optimize(sfgraph)

step = sfgraph.outputs[0]

# Update the other sampled random variables in this step's graph
# (step,) = clone_replace([step], replace=posterior_sample_steps)

posterior_sample_steps[rv] = step

if updates:
keys_offset = len(update_keys) + 1
update_keys = sfgraph.outputs[1:keys_offset]
update_values = sfgraph.outputs[keys_offset:]
updates = dict(zip(update_keys, update_values))
posterior_updates.update(updates)

if rvs_without_samplers:
# TODO: Assign NUTS to these
raise NotImplementedError(
f"Could not find a posterior samplers for {rvs_without_samplers}"
)

# TODO: Track/handle "auxiliary/augmentation" variables introduced by sample
# steps?

return (
{
new_to_old_rvs[rv]: step
for rv, step in posterior_sample_steps.items()
if rv not in obs_rvs_to_values
},
posterior_updates,
{new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items()},
)
68 changes: 52 additions & 16 deletions aemcmc/conjugates.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import aesara.tensor as at
from aeppl.opt import NoCallbackEquilibriumDB
from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.opt import in2out, local_optimizer
from aesara.graph.optdb import LocalGroupDB
from aesara.graph.unify import eval_if_etuple
from aesara.tensor.random.basic import BinomialRV
from etuples import etuple, etuplize
from kanren import eq, lall
from kanren import eq, lall, run
from unification import var

from aemcmc.opt import sampler_finder_db

def beta_binomial_conjugateo(observed_rv_expr, posterior_expr):

def beta_binomial_conjugateo(observed_val, observed_rv_expr, posterior_expr):
r"""Produce a goal that represents the application of Bayes theorem
for a beta prior with a binomial observation model.
Expand All @@ -22,15 +26,15 @@ def beta_binomial_conjugateo(observed_rv_expr, posterior_expr):
Parameters
----------
observed_val
The observed value.
observed_rv_expr
A tuple that contains expressions that represent the observed variable
and it observed value respectively.
An expression that represents the observed variable.
posterior_exp
An expression that represents the posterior distribution of the latent
variable.
"""

# Beta-binomial observation model
alpha_lv, beta_lv = var(), var()
p_rng_lv = var()
Expand All @@ -42,12 +46,10 @@ def beta_binomial_conjugateo(observed_rv_expr, posterior_expr):
n_lv = var()
Y_et = etuple(etuplize(at.random.binomial), var(), var(), var(), n_lv, p_et)

y_lv = var() # observation

# Posterior distribution for p
new_alpha_et = etuple(etuplize(at.add), alpha_lv, y_lv)
new_alpha_et = etuple(etuplize(at.add), alpha_lv, observed_val)
new_beta_et = etuple(
etuplize(at.sub), etuple(etuplize(at.add), beta_lv, n_lv), y_lv
etuplize(at.sub), etuple(etuplize(at.add), beta_lv, n_lv), observed_val
)
p_posterior_et = etuple(
etuplize(at.random.beta),
Expand All @@ -59,13 +61,47 @@ def beta_binomial_conjugateo(observed_rv_expr, posterior_expr):
)

return lall(
eq(observed_rv_expr[0], Y_et),
eq(observed_rv_expr[1], y_lv),
eq(observed_rv_expr, Y_et),
eq(posterior_expr, p_posterior_et),
)


conjugates_db = NoCallbackEquilibriumDB()
conjugates_db.register(
"beta_binomial", KanrenRelationSub(beta_binomial_conjugateo), -5, "basic"
@local_optimizer([BinomialRV])
def local_beta_binomial_posterior(fgraph, node):

sampler_mappings = getattr(fgraph, "sampler_mappings", None)

rv_var = node.outputs[1]
key = ("local_beta_binomial_posterior", rv_var)

if sampler_mappings is None or key in sampler_mappings.rvs_seen:
return None # pragma: no cover

q = var()

rv_et = etuplize(rv_var)

res = run(None, q, beta_binomial_conjugateo(rv_var, rv_et, q))
res = next(res, None)

if res is None:
return None # pragma: no cover

beta_rv = rv_et[-1].evaled_obj
beta_posterior = eval_if_etuple(res)

sampler_mappings.rvs_to_samplers.setdefault(beta_rv, []).append(
("local_beta_binomial_posterior", beta_posterior, None)
)
sampler_mappings.rvs_seen.add(key)

return rv_var.owner.outputs


conjugates_db = LocalGroupDB(apply_all_opts=True)
conjugates_db.name = "conjugates_db"
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")

sampler_finder_db.register(
"conjugates", in2out(conjugates_db.query("+basic"), name="gibbs"), "basic"
)
2 changes: 2 additions & 0 deletions aemcmc/dists.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def multivariate_normal_rue2005(rng, b, Q):
w = at.slinalg.solve_triangular(L, b, lower=True)
u = at.slinalg.solve_triangular(L.T, w, lower=False)
z = rng.standard_normal(size=L.shape[0])
z.owner.outputs[0].name = "z_rng"
v = at.slinalg.solve_triangular(L.T, z, lower=False)
return u + v

Expand Down Expand Up @@ -135,6 +136,7 @@ def multivariate_normal_cong2017(
A_inv = 1 / A
a_rows = A.shape[0]
z = rng.standard_normal(size=a_rows + omega.shape[0])
z.owner.outputs[0].name = "z_rng"
y1 = at.sqrt(A_inv) * z[:a_rows]
y2 = (1 / at.sqrt(omega)) * z[a_rows:]
Ainv_phi = A_inv[:, None] * phi.T
Expand Down
Loading

0 comments on commit 213e0fd

Please sign in to comment.