diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index d519ff3d..c420fe96 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -15,7 +15,7 @@ from pymc_experimental import distributions, gp, statespace, utils from pymc_experimental.inference.fit import fit -from pymc_experimental.model.marginal_model import MarginalModel +from pymc_experimental.model.marginal.marginal_model import MarginalModel from pymc_experimental.model.model_api import as_model from pymc_experimental.version import __version__ diff --git a/pymc_experimental/model/marginal/__init__.py b/pymc_experimental/model/marginal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pymc_experimental/model/marginal/distributions.py b/pymc_experimental/model/marginal/distributions.py new file mode 100644 index 00000000..a32e685f --- /dev/null +++ b/pymc_experimental/model/marginal/distributions.py @@ -0,0 +1,196 @@ +from typing import Sequence + +import numpy as np +import pytensor.tensor as pt +from pymc.distributions import ( + Bernoulli, + Categorical, + DiscreteUniform, + SymbolicRandomVariable +) +from pymc.logprob.basic import conditional_logp, logp +from pymc.logprob.abstract import _logprob +from pymc.pytensorf import constant_fold +from pytensor.graph.replace import clone_replace, graph_replace +from pytensor.scan import scan, map as scan_map +from pytensor.compile.mode import Mode +from pytensor.graph import vectorize_graph +from pytensor.tensor import TensorVariable, TensorType + +from pymc_experimental.distributions import DiscreteMarkovChain + + +class MarginalRV(SymbolicRandomVariable): + """Base class for Marginalized RVs""" + + +class FiniteDiscreteMarginalRV(MarginalRV): + """Base class for Finite Discrete Marginalized RVs""" + + +class DiscreteMarginalMarkovChainRV(MarginalRV): + """Base class for Discrete Marginal Markov Chain RVs""" + + +def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: + op = rv.owner.op + dist_params = rv.owner.op.dist_params(rv.owner) + if isinstance(op, Bernoulli): + return (0, 1) + elif isinstance(op, Categorical): + [p_param] = dist_params + return tuple(range(pt.get_vector_length(p_param))) + elif isinstance(op, DiscreteUniform): + lower, upper = constant_fold(dist_params) + return tuple(np.arange(lower, upper + 1)) + elif isinstance(op, DiscreteMarkovChain): + P, *_ = dist_params + return tuple(range(pt.get_vector_length(P[-1]))) + + raise NotImplementedError(f"Cannot compute domain for op {op}") + + +def _add_reduce_batch_dependent_logps( + marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable] +): + """Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`.""" + + mbcast = marginalized_type.broadcastable + reduced_logps = [] + for dependent_logp in dependent_logps: + dbcast = dependent_logp.type.broadcastable + dim_diff = len(dbcast) - len(mbcast) + mbcast_aligned = (True,) * dim_diff + mbcast + vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v] + reduced_logps.append(dependent_logp.sum(vbcast_axis)) + return pt.add(*reduced_logps) + + +@_logprob.register(FiniteDiscreteMarginalRV) +def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): + # Clone the inner RV graph of the Marginalized RV + marginalized_rvs_node = op.make_node(*inputs) + marginalized_rv, *inner_rvs = clone_replace( + op.inner_outputs, + replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, + ) + + # Obtain the joint_logp graph of the inner RV graph + inner_rv_values = dict(zip(inner_rvs, values)) + marginalized_vv = marginalized_rv.clone() + rv_values = inner_rv_values | {marginalized_rv: marginalized_vv} + logps_dict = conditional_logp(rv_values=rv_values, **kwargs) + + # Reduce logp dimensions corresponding to broadcasted variables + marginalized_logp = logps_dict.pop(marginalized_vv) + joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( + marginalized_rv.type, logps_dict.values() + ) + + # Compute the joint_logp for all possible n values of the marginalized RV. We assume + # each original dimension is independent so that it suffices to evaluate the graph + # n times, once with each possible value of the marginalized RV replicated across + # batched dimensions of the marginalized RV + + # PyMC does not allow RVs in the logp graph, even if we are just using the shape + marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) + marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) + marginalized_rv_domain_tensor = pt.moveaxis( + pt.full( + (*marginalized_rv_shape, len(marginalized_rv_domain)), + marginalized_rv_domain, + dtype=marginalized_rv.dtype, + ), + -1, + 0, + ) + + try: + joint_logps = vectorize_graph( + joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor} + ) + except Exception: + # Fallback to Scan + def logp_fn(marginalized_rv_const, *non_sequences): + return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const}) + + joint_logps, _ = scan_map( + fn=logp_fn, + sequences=marginalized_rv_domain_tensor, + non_sequences=[*values, *inputs], + mode=Mode().including("local_remove_check_parameter"), + ) + + joint_logps = pt.logsumexp(joint_logps, axis=0) + + # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise + return joint_logps, *(pt.constant(0),) * (len(values) - 1) + + +@_logprob.register(DiscreteMarginalMarkovChainRV) +def marginal_hmm_logp(op, values, *inputs, **kwargs): + marginalized_rvs_node = op.make_node(*inputs) + inner_rvs = clone_replace( + op.inner_outputs, + replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, + ) + + chain_rv, *dependent_rvs = inner_rvs + P, n_steps_, init_dist_, rng = chain_rv.owner.inputs + domain = pt.arange(P.shape[-1], dtype="int32") + + # Construct logp in two steps + # Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission) + + # First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating + # around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise, + # PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step. + chain_value = chain_rv.clone() + dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value}) + logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values))) + + # Reduce and add the batch dims beyond the chain dimension + reduced_logp_emissions = _add_reduce_batch_dependent_logps( + chain_rv.type, logp_emissions_dict.values() + ) + + # Add a batch dimension for the domain of the chain + chain_shape = constant_fold(tuple(chain_rv.shape)) + batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0) + batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value}) + + # Step 2: Compute the transition probabilities + # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1}) + # We do it entirely in logs, though. + + # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) + # under the initial distribution. This is robust to everything the user can throw at it. + init_dist_value = init_dist_.type() + logp_init_dist = logp(init_dist_, init_dist_value) + # There is a degerate batch dim for lags=1 (the only supported case), + # that we have to work around, by expanding the batch value and then squeezing it out of the logp + batch_logp_init_dist = vectorize_graph( + logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]} + ).squeeze(1) + log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0] + + def step_alpha(logp_emission, log_alpha, log_P): + step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0) + return logp_emission + step_log_prob + + P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2) + log_P = pt.shape_padright(pt.log(P), P_bcast_dims) + log_alpha_seq, _ = scan( + step_alpha, + non_sequences=[log_P], + outputs_info=[log_alpha_init], + # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value + sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0), + ) + # Final logp is just the sum of the last scan state + joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0) + + # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first + # return is the joint probability of everything together, but PyMC still expects one logp for each one. + dummy_logps = (pt.constant(0),) * (len(values) - 1) + return joint_logp, *dummy_logps diff --git a/pymc_experimental/model/marginal/graph_analysis.py b/pymc_experimental/model/marginal/graph_analysis.py new file mode 100644 index 00000000..bbdfc9f5 --- /dev/null +++ b/pymc_experimental/model/marginal/graph_analysis.py @@ -0,0 +1,255 @@ +from itertools import zip_longest, chain +from typing import Sequence + +from pymc import SymbolicRandomVariable +from pytensor.compile import SharedVariable +from pytensor.graph import ancestors, Constant, graph_inputs, Variable +from pytensor.graph.basic import io_toposort +from pytensor.tensor import TensorVariable, TensorType +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import DimShuffle, Elemwise, CAReduce +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.rewriting.subtensor import is_full_slice +from pytensor.tensor.shape import Shape +from pytensor.tensor.subtensor import Subtensor, get_idx_list, AdvancedSubtensor +from pytensor.tensor.type_other import NoneTypeT + + +def static_shape_ancestors(vars): + """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph).""" + return [ + var + for var in ancestors(vars) + if ( + var.owner + and isinstance(var.owner.op, Shape) + # All static dims lengths of Shape input are known + and None not in var.owner.inputs[0].type.shape + ) + ] + + +def find_conditional_input_rvs(output_rvs, all_rvs): + """Find conditionally indepedent input RVs.""" + blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs] + blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs)) + return [ + var + for var in ancestors(output_rvs, blockers=blockers) + if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable)) + ] + + +def is_conditional_dependent( + dependent_rv: TensorVariable, dependable_rv: TensorVariable, all_rvs +) -> bool: + """Check if dependent_rv is conditionall dependent on dependable_rv, + given all conditionally independent all_rvs""" + + return dependable_rv in find_conditional_input_rvs((dependent_rv,), all_rvs) + + +def find_conditional_dependent_rvs(dependable_rv, all_rvs): + """Find rvs than depend on dependable""" + return [ + rv + for rv in all_rvs + if (rv is not dependable_rv and is_conditional_dependent(rv, dependable_rv, all_rvs)) + ] + + + +def collect_shared_vars(outputs, blockers): + return [ + inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable) + ] + + +def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]: + """Find the output axis and dimensionality of the advanced indexing group (i.e., array indexing). + + There is a special case: when there are non-consecutive advanced indexing groups, the advanced indexing + group is always moved to the front. + + See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + """ + adv_group_axis = None + for axis, idx in enumerate(idxs): + if isinstance(idx.type, TensorType): + adv_group_axis = axis + elif adv_group_axis is not None: + # Special non-consecutive case + adv_group_axis = 0 + break + + adv_group_ndim = max(idx.type.ndim for idx in idxs if isinstance(idx.type, TensorType)) + return adv_group_axis, adv_group_ndim + + +def _broadcast_dims(inputs_dims: Sequence[tuple[tuple[int, ...], ...]]) -> tuple[tuple[int, ...], ...]: + output_ndim = max((len(input_dim) for input_dim in inputs_dims), default=0) + # Add missing dims + inputs_dims = [ + ((),) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims + ] + # Combine aligned dims + output_dims = tuple(tuple(sorted(set(chain.from_iterable(inputs_dim)))) for inputs_dim in zip(*inputs_dims)) + return output_dims + + +def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[tuple[int, ...], ...]]: + """Identify how the dims of rv_to_marginalize are consumed by the dims of the output_rvs. + + Raises + ------ + NotImplementedError + If variable related to marginalized batch_dims is used in an operation that is not yet supported + + """ + + var_dims: dict[Variable, tuple[tuple[int, ...], ...]] = { + input_var: tuple((i,) for i in range(input_var.type.ndim)) + } + + for node in io_toposort([input_var, *other_inputs], output_vars): + inputs_dims = [var_dims.get(inp, ()) for inp in node.inputs] + + if not any(inputs_dims): + # None of the inputs are related to the batch_axes of the marginalized_rv + # We could set `()` for everything, but for now that doesn't seem needed + continue + + elif isinstance(node.op, DimShuffle): + [input_dims] = inputs_dims + output_dims = tuple( + input_dims[i] if isinstance(i, int) else () for i in node.op.new_order + ) + var_dims[node.outputs[0]] = output_dims + + elif isinstance(node.op, Elemwise | Blockwise | RandomVariable | SymbolicRandomVariable): + # NOTE: User-provided CustomDist may not respect core dimensions on the left. + + if isinstance(node.op, Elemwise): + op_batch_ndim = node.outputs[0].type.ndim + else: + op_batch_ndim = node.op.batch_ndim(node) + + # Collapse all core_dims + core_dims = tuple(sorted(chain.from_iterable([i for input_dim in inputs_dims for i in input_dim[op_batch_ndim:]]))) + batch_dims = _broadcast_dims( + tuple( + input_dims[:op_batch_ndim] + for input_dims in inputs_dims + ) + ) + # Add batch dims to each output_dims + batch_dims = tuple(batch_dim + core_dims for batch_dim in batch_dims) + for out in node.outputs: + if isinstance(out.type, TensorType): + core_ndim = out.type.ndim - op_batch_ndim + output_dims = batch_dims + (core_dims,) * core_ndim + var_dims[out] = output_dims + + elif isinstance(node.op, CAReduce): + [input_dims] = inputs_dims + + axes = node.op.axis + if isinstance(axes, int): + axes = (axes,) + elif axes is None: + axes = tuple(range(node.inputs[0].type.ndim)) + + # Output dims contain the collapsed dims + output_dims = [dims + axes for i, dims in enumerate(input_dims) if i not in axes] + var_dims[node.outputs[0]] = tuple(output_dims) + + elif isinstance(node.op, Subtensor): + value_dims, *keys_dims = inputs_dims + # Dims in basic indexing must belong to the value variable, since indexing keys are always scalar + assert not any(keys_dims) + keys = get_idx_list(node.inputs, node.op.idx_list) + + output_dims = [] + for value_dims, idx in zip_longest(value_dims, keys, fillvalue=slice(None)): + if not isinstance(idx, slice): + # Integer indexing: Dimension dropped + continue + if idx == slice(None): + # Dim is kept + output_dims.append(value_dims) + elif value_dims: + raise NotImplementedError("Partial slicing of known dimensions not supported") + else: + # We keep an unknown / dummy dimension, nothing to worry about + output_dims.append(()) + + var_dims[node.outputs[0]] = tuple(output_dims) + + elif isinstance(node.op, AdvancedSubtensor): + # AdvancedSubtensor dimensions can show up as both the indexed variable and indexing variables + value, *keys = node.inputs + value_dims, *keys_dims = inputs_dims + + # Just to stay sane, we forbid any boolean indexing... + if any(isinstance(idx.type, TensorType) and idx.type.dtype == "bool" for idx in keys): + raise NotImplementedError( + f"Array indexing with boolean variables in node {node} not supported." + ) + + if value_dims and keys_dims: + # Both indexed variable and indexing variables have known dimensions + # I am to lazy to think through these, so we raise for now. + raise NotImplementedError( + f"Simultaneous use of known dimensions in indexed and indexing variables in node {node} not supported." + ) + + adv_group_axis, adv_group_ndim = _advanced_indexing_axis_and_ndim(keys) + + if value_dims: + # Indexed variable has known dimensions + + if any(isinstance(idx.type, NoneTypeT) for idx in keys): + # Corresponds to an expand_dims, for now not supported + raise NotImplementedError( + f"Advanced indexing in node {node} which introduces new axis is not supported" + ) + + non_adv_dims = [] + for value_dim, idx in zip_longest(value_dims, keys, fillvalue=slice(None)): + if is_full_slice(idx): + non_adv_dims.append(value_dim) + elif value_dim: + # We are trying to partially slice or index a known dimension + raise NotImplementedError( + f"Partial slicing or advanced integer indexing of known dimensions not supported" + ) + elif isinstance(idx, slice): + # Unknown dimensions kept by partial slice. + non_adv_dims.append(()) + + # Insert unknown dimensions corresponding to advanced indexing + output_dims = tuple( + non_adv_dims[adv_group_axis:] + + ((),) * adv_group_ndim + + non_adv_dims[adv_group_axis:] + ) + + else: + # Indexing keys have known dimensions. + # Only array indices can have dimensions, the rest are just slices or newaxis + + # Advanced indexing variables broadcast together, so we apply same rules as in Elemwise + adv_dims = _broadcast_dims(keys_dims) + + start_non_adv_dims = ((),) * adv_group_axis + end_non_adv_dims = ((),) * ( + node.outputs[0].type.ndim - adv_group_axis - adv_group_ndim + ) + output_dims = start_non_adv_dims + adv_dims + end_non_adv_dims + + var_dims[node.outputs[0]] = output_dims + + else: + raise NotImplementedError(f"Marginalization through operation {node} not supported") + + return [var_dims[output_rv] for output_rv in output_vars] diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal/marginal_model.py similarity index 50% rename from pymc_experimental/model/marginal_model.py rename to pymc_experimental/model/marginal/marginal_model.py index 530c862b..e3f77d8c 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal/marginal_model.py @@ -1,7 +1,6 @@ import warnings from collections.abc import Sequence -from itertools import zip_longest from typing import Union import numpy as np @@ -9,32 +8,25 @@ import pytensor.tensor as pt from arviz import InferenceData, dict_to_dataset -from pymc import SymbolicRandomVariable from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform from pymc.distributions.transforms import Chain -from pymc.logprob.abstract import _logprob -from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.transforms import IntervalTransform from pymc.model import Model from pymc.pytensorf import compile_pymc, constant_fold from pymc.util import RandomState, _get_seeds_per_chain, treedict -from pytensor import Mode, scan -from pytensor.compile import SharedVariable -from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace -from pytensor.graph.basic import graph_inputs -from pytensor.graph.replace import graph_replace, vectorize_graph -from pytensor.scan import map as scan_map -from pytensor.tensor import TensorType, TensorVariable -from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.shape import Shape +from pytensor.graph import FunctionGraph, clone_replace +from pytensor.graph.replace import vectorize_graph +from pytensor.tensor import TensorVariable from pytensor.tensor.special import log_softmax -from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list -from pytensor.tensor.type_other import NoneTypeT __all__ = ["MarginalModel", "marginalize"] from pymc_experimental.distributions import DiscreteMarkovChain +from pymc_experimental.model.marginal.distributions import FiniteDiscreteMarginalRV, DiscreteMarginalMarkovChainRV, \ + get_domain_of_finite_discrete_rv, _add_reduce_batch_dependent_logps +from pymc_experimental.model.marginal.graph_analysis import find_conditional_input_rvs, is_conditional_dependent, \ + find_conditional_dependent_rvs, subgraph_dim_connection, collect_shared_vars ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str] @@ -543,346 +535,6 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel: return marginal_model -class MarginalRV(SymbolicRandomVariable): - """Base class for Marginalized RVs""" - - -class FiniteDiscreteMarginalRV(MarginalRV): - """Base class for Finite Discrete Marginalized RVs""" - - -class DiscreteMarginalMarkovChainRV(MarginalRV): - """Base class for Discrete Marginal Markov Chain RVs""" - - -def static_shape_ancestors(vars): - """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph).""" - return [ - var - for var in ancestors(vars) - if ( - var.owner - and isinstance(var.owner.op, Shape) - # All static dims lengths of Shape input are known - and None not in var.owner.inputs[0].type.shape - ) - ] - - -def find_conditional_input_rvs(output_rvs, all_rvs): - """Find conditionally indepedent input RVs.""" - blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs] - blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs)) - return [ - var - for var in ancestors(output_rvs, blockers=blockers) - if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable)) - ] - - -def is_conditional_dependent( - dependent_rv: TensorVariable, dependable_rv: TensorVariable, all_rvs -) -> bool: - """Check if dependent_rv is conditionall dependent on dependable_rv, - given all conditionally independent all_rvs""" - - return dependable_rv in find_conditional_input_rvs((dependent_rv,), all_rvs) - - -def find_conditional_dependent_rvs(dependable_rv, all_rvs): - """Find rvs than depend on dependable""" - return [ - rv - for rv in all_rvs - if (rv is not dependable_rv and is_conditional_dependent(rv, dependable_rv, all_rvs)) - ] - - -def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]: - """Find the output axis and dimensionality of the advanced indexing group (i.e., array indexing). - - There is a special case: when there are non-consecutive advanced indexing groups, the advanced indexing - group is always moved to the front. - - See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing - """ - adv_group_axis = None - for axis, idx in enumerate(idxs): - if isinstance(idx.type, TensorType): - adv_group_axis = axis - elif adv_group_axis is not None: - # Special non-consecutive case - adv_group_axis = 0 - break - - adv_group_ndim = max(idx.type.ndim for idx in idxs if isinstance(idx.type, TensorType)) - return adv_group_axis, adv_group_ndim - - -# TODO: Use typevar for batch_dims -def _broadcast_dims(inputs_dims: Sequence[tuple[int | None, ...]]) -> tuple[int | None, ...]: - output_ndim = max((len(input_dim) for input_dim in inputs_dims), default=0) - output_dims = { - ((None,) * (output_ndim - len(input_dim))) + input_dim for input_dim in inputs_dims - } - if len(output_dims) > 1: - raise ValueError - return next(iter(output_dims)) - - -def _insert_implicit_expanded_dims(inputs_dims, ndims_params, op_batch_ndim): - """Insert implicit expanded dims for RandomVariable and SymbolicRandomVariable.""" - new_inputs_dims = [] - for input_dims, core_ndim in zip(inputs_dims, ndims_params): - if input_dims is None: - new_inputs_dims.append(None) - else: - missing_ndim = (op_batch_ndim + core_ndim) - len(input_dims) - new_input_dims = ((None,) * missing_ndim) + input_dims - new_inputs_dims.append(new_input_dims) - return new_inputs_dims - - -def subgraph_dims(rv_to_marginalize, other_inputs_rvs, output_rvs) -> list[tuple[int, ...]]: - """Identify how the batch dims of rv_to_marginalize map to the batch dims of the output_rvs. - - Raises - ------ - ValueError - If information from batch_dims is mixed at any point in the subgraph - NotImplementedError - If variable related to marginalized batch_dims is used in an operation that is not yet supported - - """ - - batch_bcast_axes = rv_to_marginalize.type.broadcastable - if rv_to_marginalize.owner.op.ndim_supp > 0: - batch_bcast_axes = batch_bcast_axes[: -rv_to_marginalize.owner.op.ndim_supp] - - if any(batch_bcast_axes): - # Note: We could support this by distinguishing between original and broadcasted axis - # But for now this would complicate logic quite a lot for little gain. - # We could also just allow without a specific type, but we would need to check if they are ever broadcasted - # by any of the Ops that do implicit broadcasting (Alloc, Elemwise, Blockwise, AdvancedSubtensor, RandomVariable). - raise NotImplementedError( - "Marginalization of variables with broadcastable batch axes not supported." - ) - - # Batch axes for core RVs are always on the left - # We will have info on SymbolicRVs at some point in PyMC - var_dims: dict[Variable, tuple[int | None, ...]] = { - rv_to_marginalize: tuple(range(len(batch_bcast_axes))) - } - - for node in io_toposort([rv_to_marginalize, *other_inputs_rvs], output_rvs): - inputs_dims = [var_dims.get(inp, None) for inp in node.inputs] - - if not any(inputs_dims): - # None of the inputs are related to the batch_axes of the marginalized_rv - # We could pass `None` for everything, but for now that doesn't seem needed - continue - - elif isinstance(node.op, DimShuffle): - [key_dims] = inputs_dims - if any(key_dims[dropped_dim] is not None for dropped_dim in node.op.drop): - # Note: This is currently not possible as we forbid marginalized variable with broadcasted dims - raise ValueError(f"{node} drops batch axes of the marginalized variable") - - output_dims = tuple( - key_dims[i] if isinstance(i, int) else None for i in node.op.new_order - ) - var_dims[node.outputs[0]] = output_dims - - elif isinstance(node.op, Elemwise | Blockwise | RandomVariable | SymbolicRandomVariable): - # TODO: Add SymbolicRandomVariables to the mix? - if isinstance(node.op, Elemwise): - op_batch_ndim = node.outputs[0].type.ndim - elif isinstance(node.op, Blockwise): - op_batch_ndim = node.op.batch_ndim(node) - elif isinstance(node.op, RandomVariable): - op_batch_ndim = node.default_output().type.ndim - # The first 3 inputs (rng, size, dtype) don't behave like a regular gufunc - inputs_dims = [ - None, - None, - None, - *_insert_implicit_expanded_dims( - inputs_dims[3:], node.op.ndims_params, op_batch_ndim - ), - ] - elif isinstance(node.op, SymbolicRandomVariable): - ndims_params = getattr(node.op, "ndims_params", None) - if ndims_params is None: - raise NotImplementedError( - "Dependent SymbolicRandomVariables without gufunc_signature are not supported" - ) - op_batch_ndim = node.op.batch_ndim(node) - inputs_dims = _insert_implicit_expanded_dims( - inputs_dims, ndims_params, op_batch_ndim - ) - - if op_batch_ndim > 0: - if any( - core_dim is not None - for input_dims in inputs_dims - if input_dims is not None - for core_dim in input_dims[op_batch_ndim:] - ): - raise ValueError( - f"Node {node} uses batch dimensions of the marginalized variable as a core dimension" - ) - - # Check batch dims are not broadcasted - try: - batch_dims = _broadcast_dims( - tuple( - input_dims[:op_batch_ndim] - for input_dims in inputs_dims - if input_dims is not None - ) - ) - except ValueError: - raise NotImplementedError( - f"Node {node} mixes batch dimensions of the marginalized variable" - ) - for out in node.outputs: - if isinstance(out.type, TensorType): - core_dims = out.type.ndim - op_batch_ndim - var_dims[out] = batch_dims + (None,) * core_dims - - elif isinstance(node.op, CAReduce): - # Only non batch_axes dims can be reduced - [key_dims] = inputs_dims - - axes = node.op.axis - if isinstance(axes, int): - axes = (axes,) - - if axes is None or any(key_dims[axis] is not None for axis in axes): - raise ValueError( - f"Reduction node {node} mixes batch dimensions of the marginalized variable" - ) - - output_dims = list(key_dims) - for axis in sorted(axes, reverse=True): - output_dims.pop(axis) - var_dims[node.outputs[0]] = tuple(output_dims) - - elif isinstance(node.op, Subtensor): - value_dims, *keys_dims = inputs_dims - # Batch dims in basic indexing must belong to the value variable, since indexing keys are always scalar - assert all(key_dims is None for key_dims in keys_dims) - keys = get_idx_list(node.inputs, node.op.idx_list) - - output_dims = [] - for value_dims, idx in zip_longest(value_dims, keys, fillvalue=slice(None)): - if is_full_slice(idx): - output_dims.append(value_dims) - else: - if value_dims is not None: - # We are trying to slice or index a batch dim - # This is not necessarily problematic, unless this indexed dim is later mixed with other dims - # For now, we simply don't try to support it - raise NotImplementedError( - f"Indexing of batch dimensions of the marginalized variable in node {node} not supported" - ) - if isinstance(idx, slice): - # Slice keeps the dim, whereas integer drops it - output_dims.append(None) - - var_dims[node.outputs[0]] = tuple(output_dims) - - elif isinstance(node.op, AdvancedSubtensor): - # AdvancedSubtensor batch axis can show up as both the indexed variable and indexing variable - value, *keys = node.inputs - value_dims, *keys_dims = inputs_dims - - # Just to stay sane, we forbid any boolean indexing... - if any(isinstance(idx.type, TensorType) and idx.type.dtype == "bool" for idx in keys): - raise NotImplementedError( - f"Array indexing with boolean variables in node {node} not supported." - ) - - if value_dims and any(keys_dims): - # Both indexed variable and indexing variables have batch dimensions - # I am to lazy to think through these, so we raise for now. - raise NotImplementedError( - f"Simultaneous use of indexed and indexing variables in node {node}, " - f"related to batch dimensions of the marginalized variable not supported" - ) - - adv_group_axis, adv_group_ndim = _advanced_indexing_axis_and_ndim(keys) - - if value_dims: - # Indexed variable has batch dims - - if any(isinstance(idx.type, NoneTypeT) for idx in keys): - # TODO: Reason about these as AdvancedIndexing followed by ExpandDims - # (maybe that's how they should be represented in PyTensor?) - # Only complication is when the NewAxis force the AdvancedGroup to go to the front - # which changes the position of the output batch axes - raise NotImplementedError( - f"Advanced indexing in node {node} which introduces new axis is not supported" - ) - - # TODO: refactor this code which is completely shared by the Subtensor Op - non_adv_dims = [] - for value_dim, idx in zip_longest(value_dims, keys, fillvalue=slice(None)): - if is_full_slice(idx): - non_adv_dims.append(value_dim) - else: - if value_dim is not None: - # We are trying to slice or index a batch dim - raise NotImplementedError( - f"Indexing of batch dimensions of the marginalized variable in node {node} not supported" - ) - if isinstance(idx, slice): - non_adv_dims.append(None) - - # Insert dims corresponding to advanced indexing among the remaining ones - output_dims = tuple( - non_adv_dims[adv_group_axis:] - + [None] * adv_group_ndim - + non_adv_dims[adv_group_axis:] - ) - - else: - # Indexing keys have batch dims. - # Only array indices can have batch_dims, the rest are just slices or new axis - - # Advanced indexing variables broadcast together, so we apply same rules as in Elemwise - # However indexing is implicit, so we have to add None dims - try: - adv_dims = _broadcast_dims( - tuple(key_dim for key_dim in keys_dims if key_dim is not None) - ) - except ValueError: - raise ValueError( - f"Index node {node} mixes batch dimensions of the marginalized variable" - ) - - start_non_adv_dims = (None,) * adv_group_axis - end_non_adv_dims = (None,) * ( - node.outputs[0].type.ndim - adv_group_axis - adv_group_ndim - ) - output_dims = start_non_adv_dims + adv_dims + end_non_adv_dims - - var_dims[node.outputs[0]] = output_dims - - else: - # TODO: Assert, SpecifyShape: easy, batch dims can only be on the input, - # TODO: Alloc: Easy, core batch dims stay in the same position (because we raised before for bcastable dims) - raise NotImplementedError(f"Marginalization through operation {node} not supported") - - return [var_dims[output_rv] for output_rv in output_rvs] - - -def collect_shared_vars(outputs, blockers): - return [ - inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable) - ] - - def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs): # TODO: This should eventually be integrated in a more general routine that can # identify other types of supported marginalization, of which finite discrete @@ -892,14 +544,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs if not dependent_rvs: raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") - ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs} - if len(ndim_supp) != 1: - raise NotImplementedError( - "Marginalization of withe dependent Multivariate RVs not implemented" - ) - [ndim_supp] = ndim_supp - if ndim_supp > 0: - raise NotImplementedError("Marginalization with dependent Multivariate RVs not implemented") + ndim_supp = max({rv.owner.op.ndim_supp for rv in dependent_rvs}) marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) other_direct_rv_ancestors = [ @@ -908,39 +553,37 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs if rv is not rv_to_marginalize ] - # If the marginalized RV has batched dimensions, check that graph between - # marginalized RV and dependent RVs is composed strictly of Elemwise Operations. - # This implies (?) that the dimensions are completely independent and a logp graph - # can ultimately be generated that is proportional to the support domain and not - # to the variables dimensions - # We don't need to worry about this if the RV is scalar. + # If the marginalized RV has multiple dimensions, check that graph between + # marginalized RV and dependent RVs does not mix information from batch dimensions + # (otherwise logp would require enuremating over all combinations of batch dimension values) if any(not bcast for bcast in rv_to_marginalize.type.broadcastable): # When there are batch dimensions, we call `batch_dims_subgraph` to make sure these are not mixed - try: - dependent_rvs_batch_dims = subgraph_dims( - rv_to_marginalize, other_direct_rv_ancestors, dependent_rvs - ) - except ValueError as err: - # This happens when information is mixed. From the user perspective this is a NotImplementedError - raise NotImplementedError from err + dependent_rvs_dims = subgraph_dim_connection( + rv_to_marginalize, other_direct_rv_ancestors, dependent_rvs + ) + + # Cr + + if any(len(dim) > 1 for dim in dependent_rvs_dims): + raise NotImplementedError("Multiple dimensions are mixed") - # We further check that any extra batch dimensions of dependnt RVs beyond those implied by the MarginalizedRV - # show up on the left, so that collapsing logic in logp can be more straightforward easily. + # We further check that any extra batch dimensions of dependent RVs beyond those implied by the MarginalizedRV + # show up on the left, so that collapsing logic in logp can be more straightforward. # This also ensures the MarginalizedRV still behaves as an RV itself - marginal_batch_ndim = rv_to_marginalize.type.ndim - rv_to_marginalize.owner.op.ndim_supp - marginal_batch_dims = tuple(range(marginal_batch_ndim)) - for dependent_rv, dependent_rv_batch_dims in zip(dependent_rvs, dependent_rvs_batch_dims): + marginal_batch_ndim = rv_to_marginalize.owner.op.batch_ndim(rv_to_marginalize.owner) + marginal_batch_dims = tuple((i,) for i in range(marginal_batch_ndim)) + for dependent_rv, dependent_rv_batch_dims in zip(dependent_rvs, dependent_rvs_dims): extra_batch_ndim = ( dependent_rv.type.ndim - marginal_batch_ndim - dependent_rv.owner.op.ndim_supp ) - valid_dependent_batch_dims = ((None,) * extra_batch_ndim) + marginal_batch_dims + valid_dependent_batch_dims = (((),) * extra_batch_ndim) + marginal_batch_dims if dependent_rv_batch_dims != valid_dependent_batch_dims: raise NotImplementedError( "Any extra batch dimensions introduced by dependent RVs must be " "on the left of dimensions introduced by the marginalized RV" ) - for dependent_rv, dependent_rv_batch_dims in zip(dependent_rvs, dependent_rvs_batch_dims): + for dependent_rv, dependent_rv_batch_dims in zip(dependent_rvs, dependent_rvs_dims): shared_batch_dims = [ batch_dim for batch_dim in dependent_rv_batch_dims if batch_dim is not None ] @@ -971,166 +614,3 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) return rvs_to_marginalize, marginalized_rvs - -def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: - op = rv.owner.op - dist_params = rv.owner.op.dist_params(rv.owner) - if isinstance(op, Bernoulli): - return (0, 1) - elif isinstance(op, Categorical): - [p_param] = dist_params - return tuple(range(pt.get_vector_length(p_param))) - elif isinstance(op, DiscreteUniform): - lower, upper = constant_fold(dist_params) - return tuple(np.arange(lower, upper + 1)) - elif isinstance(op, DiscreteMarkovChain): - P, *_ = dist_params - return tuple(range(pt.get_vector_length(P[-1]))) - - raise NotImplementedError(f"Cannot compute domain for op {op}") - - -def _add_reduce_batch_dependent_logps( - marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable] -): - """Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`.""" - - mbcast = marginalized_type.broadcastable - reduced_logps = [] - for dependent_logp in dependent_logps: - dbcast = dependent_logp.type.broadcastable - dim_diff = len(dbcast) - len(mbcast) - mbcast_aligned = (True,) * dim_diff + mbcast - vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v] - reduced_logps.append(dependent_logp.sum(vbcast_axis)) - return pt.add(*reduced_logps) - - -@_logprob.register(FiniteDiscreteMarginalRV) -def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): - # Clone the inner RV graph of the Marginalized RV - marginalized_rvs_node = op.make_node(*inputs) - marginalized_rv, *inner_rvs = clone_replace( - op.inner_outputs, - replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, - ) - - # Obtain the joint_logp graph of the inner RV graph - inner_rv_values = dict(zip(inner_rvs, values)) - marginalized_vv = marginalized_rv.clone() - rv_values = inner_rv_values | {marginalized_rv: marginalized_vv} - logps_dict = conditional_logp(rv_values=rv_values, **kwargs) - - # Reduce logp dimensions corresponding to broadcasted variables - marginalized_logp = logps_dict.pop(marginalized_vv) - joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( - marginalized_rv.type, logps_dict.values() - ) - - # Compute the joint_logp for all possible n values of the marginalized RV. We assume - # each original dimension is independent so that it suffices to evaluate the graph - # n times, once with each possible value of the marginalized RV replicated across - # batched dimensions of the marginalized RV - - # PyMC does not allow RVs in the logp graph, even if we are just using the shape - marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) - marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) - marginalized_rv_domain_tensor = pt.moveaxis( - pt.full( - (*marginalized_rv_shape, len(marginalized_rv_domain)), - marginalized_rv_domain, - dtype=marginalized_rv.dtype, - ), - -1, - 0, - ) - - try: - joint_logps = vectorize_graph( - joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor} - ) - except Exception: - # Fallback to Scan - def logp_fn(marginalized_rv_const, *non_sequences): - return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const}) - - joint_logps, _ = scan_map( - fn=logp_fn, - sequences=marginalized_rv_domain_tensor, - non_sequences=[*values, *inputs], - mode=Mode().including("local_remove_check_parameter"), - ) - - joint_logps = pt.logsumexp(joint_logps, axis=0) - - # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise - return joint_logps, *(pt.constant(0),) * (len(values) - 1) - - -@_logprob.register(DiscreteMarginalMarkovChainRV) -def marginal_hmm_logp(op, values, *inputs, **kwargs): - marginalized_rvs_node = op.make_node(*inputs) - inner_rvs = clone_replace( - op.inner_outputs, - replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, - ) - - chain_rv, *dependent_rvs = inner_rvs - P, n_steps_, init_dist_, rng = chain_rv.owner.inputs - domain = pt.arange(P.shape[-1], dtype="int32") - - # Construct logp in two steps - # Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission) - - # First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating - # around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise, - # PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step. - chain_value = chain_rv.clone() - dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value}) - logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values))) - - # Reduce and add the batch dims beyond the chain dimension - reduced_logp_emissions = _add_reduce_batch_dependent_logps( - chain_rv.type, logp_emissions_dict.values() - ) - - # Add a batch dimension for the domain of the chain - chain_shape = constant_fold(tuple(chain_rv.shape)) - batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0) - batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value}) - - # Step 2: Compute the transition probabilities - # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1}) - # We do it entirely in logs, though. - - # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) - # under the initial distribution. This is robust to everything the user can throw at it. - init_dist_value = init_dist_.type() - logp_init_dist = logp(init_dist_, init_dist_value) - # There is a degerate batch dim for lags=1 (the only supported case), - # that we have to work around, by expanding the batch value and then squeezing it out of the logp - batch_logp_init_dist = vectorize_graph( - logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]} - ).squeeze(1) - log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0] - - def step_alpha(logp_emission, log_alpha, log_P): - step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0) - return logp_emission + step_log_prob - - P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2) - log_P = pt.shape_padright(pt.log(P), P_bcast_dims) - log_alpha_seq, _ = scan( - step_alpha, - non_sequences=[log_P], - outputs_info=[log_alpha_init], - # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value - sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0), - ) - # Final logp is just the sum of the last scan state - joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0) - - # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first - # return is the joint probability of everything together, but PyMC still expects one logp for each one. - dummy_logps = (pt.constant(0),) * (len(values) - 1) - return joint_logp, *dummy_logps diff --git a/tests/model/marginal/__init__.py b/tests/model/marginal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/model/marginal/test_distributions.py b/tests/model/marginal/test_distributions.py new file mode 100644 index 00000000..822cc34a --- /dev/null +++ b/tests/model/marginal/test_distributions.py @@ -0,0 +1,119 @@ +import numpy as np +import pymc as pm +import pytest +from pymc.logprob.abstract import _logprob +from pytensor import tensor as pt +from scipy.stats import norm + +from pymc_experimental import MarginalModel +from pymc_experimental.distributions import DiscreteMarkovChain + +from pymc_experimental.model.marginal.distributions import FiniteDiscreteMarginalRV + + +def test_marginalized_bernoulli_logp(): + """Test logp of IR TestFiniteMarginalDiscreteRV directly""" + mu = pt.vector("mu") + + idx = pm.Bernoulli.dist(0.7, name="idx") + y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y") + marginal_rv_node = FiniteDiscreteMarginalRV( + [mu], + [idx, y], + ndim_supp=0, + n_updates=0, + # Ignore the fact we didn't specify shared RNG input/outputs for idx,y + strict=False, + )(mu)[0].owner + + y_vv = y.clone() + (logp,) = _logprob( + marginal_rv_node.op, + (y_vv,), + *marginal_rv_node.inputs, + ) + + ref_logp = pm.logp(pm.NormalMixture.dist(w=[0.3, 0.7], mu=mu, sigma=1.0), y_vv) + np.testing.assert_almost_equal( + logp.eval({mu: [-1, 1], y_vv: 2}), + ref_logp.eval({mu: [-1, 1], y_vv: 2}), + ) + + +@pytest.mark.parametrize("batch_chain", (False, True), ids=lambda x: f"batch_chain={x}") +@pytest.mark.parametrize("batch_emission", (False, True), ids=lambda x: f"batch_emission={x}") +def test_marginalized_hmm_normal_emission(batch_chain, batch_emission): + if batch_chain and not batch_emission: + pytest.skip("Redundant implicit combination") + + with MarginalModel() as m: + P = [[0, 1], [1, 0]] + init_dist = pm.Categorical.dist(p=[1, 0]) + chain = DiscreteMarkovChain( + "chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None + ) + emission = pm.Normal( + "emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None + ) + + m.marginalize([chain]) + logp_fn = m.compile_logp() + + test_value = np.array([-1, 1, -1, 1]) + expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval() + if batch_emission: + test_value = np.broadcast_to(test_value, (3, 4)) + expected_logp *= 3 + np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp) + + +@pytest.mark.parametrize( + "categorical_emission", + [False, True], +) +def test_marginalized_hmm_categorical_emission(categorical_emission): + """Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0""" + with MarginalModel() as m: + P = np.array([[0.5, 0.5], [0.3, 0.7]]) + init_dist = pm.Categorical.dist(p=[0.375, 0.625]) + chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2) + if categorical_emission: + emission = pm.Categorical( + "emission", p=pt.where(pt.eq(chain, 0)[..., None], [0.8, 0.2], [0.4, 0.6]) + ) + else: + emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6)) + m.marginalize([chain]) + + test_value = np.array([0, 0, 1]) + expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video + logp_fn = m.compile_logp() + np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp) + + +@pytest.mark.parametrize("batch_emission1", (False, True)) +@pytest.mark.parametrize("batch_emission2", (False, True)) +def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2): + emission1_shape = (2, 4) if batch_emission1 else (4,) + emission2_shape = (2, 4) if batch_emission2 else (4,) + with MarginalModel() as m: + P = [[0, 1], [1, 0]] + init_dist = pm.Categorical.dist(p=[1, 0]) + chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3) + emission_1 = pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1, shape=emission1_shape) + emission_2 = pm.Normal( + "emission_2", mu=(1 - chain) * 2 - 1, sigma=1e-1, shape=emission2_shape + ) + + with pytest.warns(UserWarning, match="multiple dependent variables"): + m.marginalize([chain]) + + logp_fn = m.compile_logp() + + test_value = np.array([-1, 1, -1, 1]) + multiplier = 2 + batch_emission1 + batch_emission2 + expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier + test_value_emission1 = np.broadcast_to(test_value, emission1_shape) + test_value_emission2 = np.broadcast_to(-test_value, emission2_shape) + test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2} + np.testing.assert_allclose(logp_fn(test_point), expected_logp) diff --git a/tests/model/marginal/test_graph_analysis.py b/tests/model/marginal/test_graph_analysis.py new file mode 100644 index 00000000..272bddc6 --- /dev/null +++ b/tests/model/marginal/test_graph_analysis.py @@ -0,0 +1,97 @@ +import pytensor.tensor as pt +import pytest +from pymc.distributions import CustomDist + +from pymc_experimental.model.marginal.graph_analysis import subgraph_dim_connection + + +class TestSubgraphDimConnection: + + def test_dimshuffle(self): + inp = pt.zeros(shape=(5, 1, 4, 3)) + out1 = pt.matrix_transpose(inp) + out2 = pt.expand_dims(inp, 1) + out3 = pt.squeeze(inp) + [dims1, dims2, dims3] = subgraph_dim_connection(inp, [], [out1, out2, out3]) + assert dims1 == ((0,), (1,), (3,), (2,)) + assert dims2 == ((0,), (), (1,), (2,), (3,)) + assert dims3 == ((0,), (2,), (3,)) + + def test_careduce(self): + inp = pt.zeros(shape=(4, 3, 2)) + out = pt.sum(inp, axis=(1,)) + [dims] = subgraph_dim_connection(inp, [], [out]) + assert dims == ((0, 1), (2, 1)) + + def test_subtensor(self): + inp = pt.zeros(shape=(4, 3, 2)) + + out = inp[0, :, 1] + [dims] = subgraph_dim_connection(inp, [], [out]) + assert dims == ((1,),) + + invalid_out = inp[0, :1] + with pytest.raises(NotImplementedError, match="Partial slicing of known dimensions not supported"): + subgraph_dim_connection(inp, [], [invalid_out]) + + # If we are slicing a dummy / unknown dimension that's fine + valid_out = pt.expand_dims(inp[:, 0], 1)[0, :1,] + [dims] = subgraph_dim_connection(inp, [], [valid_out]) + assert dims == ((), (2,)) + + def test_elemwise(self): + inp = pt.zeros(shape=(5, 5)) + + out = inp + inp + [dims] = subgraph_dim_connection(inp, [], [out]) + assert dims == ((0,), (1,)) + + out = inp + inp[0] + [dims] = subgraph_dim_connection(inp, [], [out]) + assert dims == ((0,), (1,)) + + # By removing the last dimension, we align the first and the last in the addition + out = inp + inp[:, 0] + [dims] = subgraph_dim_connection(inp, [], [out]) + assert dims == ((0,), (0, 1,)) + + out = inp + inp.T + [dims] = subgraph_dim_connection(inp, [], [out]) + assert dims == ((0, 1), (0, 1,)) + + def test_blockwise(self): + inp = pt.zeros(shape=(5, 4, 3, 2)) + out = inp @ pt.ones((2, 3)) + [dims] = subgraph_dim_connection(inp, [], [out]) + # Every dimension contains information from the core dimensions + assert dims == ((0, 2, 3), (1, 2, 3), (2, 3), (2, 3)) + + def test_random_variable(self): + inp = pt.zeros(shape=(5, 4, 3)) + out1 = pt.random.normal(loc=inp) + out2 = pt.random.categorical(p=inp) + out3 = pt.random.multivariate_normal(mean=inp, cov=pt.eye(3)) + [dims1, dims2, dims3] = subgraph_dim_connection(inp, [], [out1, out2, out3]) + assert dims1 == ((0,), (1,), (2,)) + assert dims2 == ((0, 2), (1, 2)) + assert dims3 == ((0, 2), (1, 2), (2,)) + + def test_symbolic_random_variable(self): + inp = pt.zeros(shape=(4, 3, 2)) + out = CustomDist.dist( + inp, + dist=lambda mu, size: pt.random.normal(loc=mu, size=size), + ) + [dims] = subgraph_dim_connection(inp, [], [out]) + assert dims == ((0,), (1,), (2,)) + + # Test multivariate + out = CustomDist.dist( + inp, + dist=lambda mu, size: pt.random.normal(loc=mu, size=size).sum(-1), + ) + [dims] = subgraph_dim_connection(inp, [], [out]) + assert dims == ((0, 2), (1, 2)) + + def test_advanced_indexing(self): + raise NotImplementedError() diff --git a/tests/model/test_marginal_model.py b/tests/model/marginal/test_marginal_model.py similarity index 86% rename from tests/model/test_marginal_model.py rename to tests/model/marginal/test_marginal_model.py index e2742b20..773b53a2 100644 --- a/tests/model/test_marginal_model.py +++ b/tests/model/marginal/test_marginal_model.py @@ -11,52 +11,20 @@ from arviz import InferenceData, dict_to_dataset from pymc.distributions import transforms from pymc.distributions.transforms import ordered -from pymc.logprob.abstract import _logprob from pymc.model.fgraph import fgraph_from_model from pymc.pytensorf import inputvars from pymc.util import UNSET from scipy.special import log_softmax, logsumexp from scipy.stats import halfnorm, norm -from pymc_experimental.distributions import DiscreteMarkovChain -from pymc_experimental.model.marginal_model import ( - FiniteDiscreteMarginalRV, +from pymc_experimental.model.marginal.marginal_model import ( MarginalModel, - is_conditional_dependent, marginalize, ) +from pymc_experimental.model.marginal.graph_analysis import is_conditional_dependent from tests.utils import equal_computations_up_to_root -def test_marginalized_bernoulli_logp(): - """Test logp of IR TestFiniteMarginalDiscreteRV directly""" - mu = pt.vector("mu") - - idx = pm.Bernoulli.dist(0.7, name="idx") - y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y") - marginal_rv_node = FiniteDiscreteMarginalRV( - [mu], - [idx, y], - ndim_supp=0, - n_updates=0, - # Ignore the fact we didn't specify shared RNG input/outputs for idx,y - strict=False, - )(mu)[0].owner - - y_vv = y.clone() - (logp,) = _logprob( - marginal_rv_node.op, - (y_vv,), - *marginal_rv_node.inputs, - ) - - ref_logp = pm.logp(pm.NormalMixture.dist(w=[0.3, 0.7], mu=mu, sigma=1.0), y_vv) - np.testing.assert_almost_equal( - logp.eval({mu: [-1, 1], y_vv: 2}), - ref_logp.eval({mu: [-1, 1], y_vv: 2}), - ) - - def test_marginalized_basic(): data = [2] * 5 @@ -526,85 +494,6 @@ def dist(idx, size): np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt)) -@pytest.mark.parametrize("batch_chain", (False, True), ids=lambda x: f"batch_chain={x}") -@pytest.mark.parametrize("batch_emission", (False, True), ids=lambda x: f"batch_emission={x}") -def test_marginalized_hmm_normal_emission(batch_chain, batch_emission): - if batch_chain and not batch_emission: - pytest.skip("Redundant implicit combination") - - with MarginalModel() as m: - P = [[0, 1], [1, 0]] - init_dist = pm.Categorical.dist(p=[1, 0]) - chain = DiscreteMarkovChain( - "chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None - ) - emission = pm.Normal( - "emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None - ) - - m.marginalize([chain]) - logp_fn = m.compile_logp() - - test_value = np.array([-1, 1, -1, 1]) - expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval() - if batch_emission: - test_value = np.broadcast_to(test_value, (3, 4)) - expected_logp *= 3 - np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp) - - -@pytest.mark.parametrize( - "categorical_emission", - [False, True], -) -def test_marginalized_hmm_categorical_emission(categorical_emission): - """Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0""" - with MarginalModel() as m: - P = np.array([[0.5, 0.5], [0.3, 0.7]]) - init_dist = pm.Categorical.dist(p=[0.375, 0.625]) - chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2) - if categorical_emission: - emission = pm.Categorical( - "emission", p=pt.where(pt.eq(chain, 0)[..., None], [0.8, 0.2], [0.4, 0.6]) - ) - else: - emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6)) - m.marginalize([chain]) - - test_value = np.array([0, 0, 1]) - expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video - logp_fn = m.compile_logp() - np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp) - - -@pytest.mark.parametrize("batch_emission1", (False, True)) -@pytest.mark.parametrize("batch_emission2", (False, True)) -def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2): - emission1_shape = (2, 4) if batch_emission1 else (4,) - emission2_shape = (2, 4) if batch_emission2 else (4,) - with MarginalModel() as m: - P = [[0, 1], [1, 0]] - init_dist = pm.Categorical.dist(p=[1, 0]) - chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3) - emission_1 = pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1, shape=emission1_shape) - emission_2 = pm.Normal( - "emission_2", mu=(1 - chain) * 2 - 1, sigma=1e-1, shape=emission2_shape - ) - - with pytest.warns(UserWarning, match="multiple dependent variables"): - m.marginalize([chain]) - - logp_fn = m.compile_logp() - - test_value = np.array([-1, 1, -1, 1]) - multiplier = 2 + batch_emission1 + batch_emission2 - expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier - test_value_emission1 = np.broadcast_to(test_value, emission1_shape) - test_value_emission2 = np.broadcast_to(-test_value, emission2_shape) - test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2} - np.testing.assert_allclose(logp_fn(test_point), expected_logp) - - def test_mutable_indexing_jax_backend(): pytest.importorskip("jax") from pymc.sampling.jax import get_jaxified_logp @@ -932,8 +821,3 @@ def true_sub_idx_logp(y): ) np.testing.assert_almost_equal(logsumexp(post.lp_idx, axis=-1), 0) np.testing.assert_almost_equal(logsumexp(post.lp_sub_idx, axis=-1), 0) - - -class TestSubgraphDims: - def test_baisc(self): - raise NotImplementedError("Write tests")