diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index c594e8ac..530c862b 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -1,6 +1,7 @@ import warnings from collections.abc import Sequence +from itertools import zip_longest from typing import Union import numpy as np @@ -28,6 +29,8 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.shape import Shape from pytensor.tensor.special import log_softmax +from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list +from pytensor.tensor.type_other import NoneTypeT __all__ = ["MarginalModel", "marginalize"] @@ -595,48 +598,283 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs): ] -def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs): - # TODO: No need to consider apply nodes outside the subgraph... - fg = FunctionGraph(outputs=output_rvs, clone=False) - - non_elemwise_blockers = [ - o - for node in fg.apply_nodes - if not ( - isinstance(node.op, Elemwise) - # Allow expand_dims on the left - or ( - isinstance(node.op, DimShuffle) - and not node.op.drop - and node.op.shuffle == sorted(node.op.shuffle) - ) - ) - for o in node.outputs - ] - blocker_candidates = [rv_to_marginalize, *other_input_rvs, *non_elemwise_blockers] - blockers = [var for var in blocker_candidates if var not in output_rvs] +def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]: + """Find the output axis and dimensionality of the advanced indexing group (i.e., array indexing). - truncated_inputs = [ - var - for var in ancestors(output_rvs, blockers=blockers) - if ( - var in blockers - or (var.owner is None and not isinstance(var, Constant | SharedVariable)) + There is a special case: when there are non-consecutive advanced indexing groups, the advanced indexing + group is always moved to the front. + + See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + """ + adv_group_axis = None + for axis, idx in enumerate(idxs): + if isinstance(idx.type, TensorType): + adv_group_axis = axis + elif adv_group_axis is not None: + # Special non-consecutive case + adv_group_axis = 0 + break + + adv_group_ndim = max(idx.type.ndim for idx in idxs if isinstance(idx.type, TensorType)) + return adv_group_axis, adv_group_ndim + + +# TODO: Use typevar for batch_dims +def _broadcast_dims(inputs_dims: Sequence[tuple[int | None, ...]]) -> tuple[int | None, ...]: + output_ndim = max((len(input_dim) for input_dim in inputs_dims), default=0) + output_dims = { + ((None,) * (output_ndim - len(input_dim))) + input_dim for input_dim in inputs_dims + } + if len(output_dims) > 1: + raise ValueError + return next(iter(output_dims)) + + +def _insert_implicit_expanded_dims(inputs_dims, ndims_params, op_batch_ndim): + """Insert implicit expanded dims for RandomVariable and SymbolicRandomVariable.""" + new_inputs_dims = [] + for input_dims, core_ndim in zip(inputs_dims, ndims_params): + if input_dims is None: + new_inputs_dims.append(None) + else: + missing_ndim = (op_batch_ndim + core_ndim) - len(input_dims) + new_input_dims = ((None,) * missing_ndim) + input_dims + new_inputs_dims.append(new_input_dims) + return new_inputs_dims + + +def subgraph_dims(rv_to_marginalize, other_inputs_rvs, output_rvs) -> list[tuple[int, ...]]: + """Identify how the batch dims of rv_to_marginalize map to the batch dims of the output_rvs. + + Raises + ------ + ValueError + If information from batch_dims is mixed at any point in the subgraph + NotImplementedError + If variable related to marginalized batch_dims is used in an operation that is not yet supported + + """ + + batch_bcast_axes = rv_to_marginalize.type.broadcastable + if rv_to_marginalize.owner.op.ndim_supp > 0: + batch_bcast_axes = batch_bcast_axes[: -rv_to_marginalize.owner.op.ndim_supp] + + if any(batch_bcast_axes): + # Note: We could support this by distinguishing between original and broadcasted axis + # But for now this would complicate logic quite a lot for little gain. + # We could also just allow without a specific type, but we would need to check if they are ever broadcasted + # by any of the Ops that do implicit broadcasting (Alloc, Elemwise, Blockwise, AdvancedSubtensor, RandomVariable). + raise NotImplementedError( + "Marginalization of variables with broadcastable batch axes not supported." ) - ] - # Check that we reach the marginalized rv following a pure elemwise graph - if rv_to_marginalize not in truncated_inputs: - return False + # Batch axes for core RVs are always on the left + # We will have info on SymbolicRVs at some point in PyMC + var_dims: dict[Variable, tuple[int | None, ...]] = { + rv_to_marginalize: tuple(range(len(batch_bcast_axes))) + } - # Check that none of the truncated inputs depends on the marginalized_rv - other_truncated_inputs = [inp for inp in truncated_inputs if inp is not rv_to_marginalize] - # TODO: We don't need to go all the way to the root variables - if rv_to_marginalize in ancestors( - other_truncated_inputs, blockers=[rv_to_marginalize, *other_input_rvs] - ): - return False - return True + for node in io_toposort([rv_to_marginalize, *other_inputs_rvs], output_rvs): + inputs_dims = [var_dims.get(inp, None) for inp in node.inputs] + + if not any(inputs_dims): + # None of the inputs are related to the batch_axes of the marginalized_rv + # We could pass `None` for everything, but for now that doesn't seem needed + continue + + elif isinstance(node.op, DimShuffle): + [key_dims] = inputs_dims + if any(key_dims[dropped_dim] is not None for dropped_dim in node.op.drop): + # Note: This is currently not possible as we forbid marginalized variable with broadcasted dims + raise ValueError(f"{node} drops batch axes of the marginalized variable") + + output_dims = tuple( + key_dims[i] if isinstance(i, int) else None for i in node.op.new_order + ) + var_dims[node.outputs[0]] = output_dims + + elif isinstance(node.op, Elemwise | Blockwise | RandomVariable | SymbolicRandomVariable): + # TODO: Add SymbolicRandomVariables to the mix? + if isinstance(node.op, Elemwise): + op_batch_ndim = node.outputs[0].type.ndim + elif isinstance(node.op, Blockwise): + op_batch_ndim = node.op.batch_ndim(node) + elif isinstance(node.op, RandomVariable): + op_batch_ndim = node.default_output().type.ndim + # The first 3 inputs (rng, size, dtype) don't behave like a regular gufunc + inputs_dims = [ + None, + None, + None, + *_insert_implicit_expanded_dims( + inputs_dims[3:], node.op.ndims_params, op_batch_ndim + ), + ] + elif isinstance(node.op, SymbolicRandomVariable): + ndims_params = getattr(node.op, "ndims_params", None) + if ndims_params is None: + raise NotImplementedError( + "Dependent SymbolicRandomVariables without gufunc_signature are not supported" + ) + op_batch_ndim = node.op.batch_ndim(node) + inputs_dims = _insert_implicit_expanded_dims( + inputs_dims, ndims_params, op_batch_ndim + ) + + if op_batch_ndim > 0: + if any( + core_dim is not None + for input_dims in inputs_dims + if input_dims is not None + for core_dim in input_dims[op_batch_ndim:] + ): + raise ValueError( + f"Node {node} uses batch dimensions of the marginalized variable as a core dimension" + ) + + # Check batch dims are not broadcasted + try: + batch_dims = _broadcast_dims( + tuple( + input_dims[:op_batch_ndim] + for input_dims in inputs_dims + if input_dims is not None + ) + ) + except ValueError: + raise NotImplementedError( + f"Node {node} mixes batch dimensions of the marginalized variable" + ) + for out in node.outputs: + if isinstance(out.type, TensorType): + core_dims = out.type.ndim - op_batch_ndim + var_dims[out] = batch_dims + (None,) * core_dims + + elif isinstance(node.op, CAReduce): + # Only non batch_axes dims can be reduced + [key_dims] = inputs_dims + + axes = node.op.axis + if isinstance(axes, int): + axes = (axes,) + + if axes is None or any(key_dims[axis] is not None for axis in axes): + raise ValueError( + f"Reduction node {node} mixes batch dimensions of the marginalized variable" + ) + + output_dims = list(key_dims) + for axis in sorted(axes, reverse=True): + output_dims.pop(axis) + var_dims[node.outputs[0]] = tuple(output_dims) + + elif isinstance(node.op, Subtensor): + value_dims, *keys_dims = inputs_dims + # Batch dims in basic indexing must belong to the value variable, since indexing keys are always scalar + assert all(key_dims is None for key_dims in keys_dims) + keys = get_idx_list(node.inputs, node.op.idx_list) + + output_dims = [] + for value_dims, idx in zip_longest(value_dims, keys, fillvalue=slice(None)): + if is_full_slice(idx): + output_dims.append(value_dims) + else: + if value_dims is not None: + # We are trying to slice or index a batch dim + # This is not necessarily problematic, unless this indexed dim is later mixed with other dims + # For now, we simply don't try to support it + raise NotImplementedError( + f"Indexing of batch dimensions of the marginalized variable in node {node} not supported" + ) + if isinstance(idx, slice): + # Slice keeps the dim, whereas integer drops it + output_dims.append(None) + + var_dims[node.outputs[0]] = tuple(output_dims) + + elif isinstance(node.op, AdvancedSubtensor): + # AdvancedSubtensor batch axis can show up as both the indexed variable and indexing variable + value, *keys = node.inputs + value_dims, *keys_dims = inputs_dims + + # Just to stay sane, we forbid any boolean indexing... + if any(isinstance(idx.type, TensorType) and idx.type.dtype == "bool" for idx in keys): + raise NotImplementedError( + f"Array indexing with boolean variables in node {node} not supported." + ) + + if value_dims and any(keys_dims): + # Both indexed variable and indexing variables have batch dimensions + # I am to lazy to think through these, so we raise for now. + raise NotImplementedError( + f"Simultaneous use of indexed and indexing variables in node {node}, " + f"related to batch dimensions of the marginalized variable not supported" + ) + + adv_group_axis, adv_group_ndim = _advanced_indexing_axis_and_ndim(keys) + + if value_dims: + # Indexed variable has batch dims + + if any(isinstance(idx.type, NoneTypeT) for idx in keys): + # TODO: Reason about these as AdvancedIndexing followed by ExpandDims + # (maybe that's how they should be represented in PyTensor?) + # Only complication is when the NewAxis force the AdvancedGroup to go to the front + # which changes the position of the output batch axes + raise NotImplementedError( + f"Advanced indexing in node {node} which introduces new axis is not supported" + ) + + # TODO: refactor this code which is completely shared by the Subtensor Op + non_adv_dims = [] + for value_dim, idx in zip_longest(value_dims, keys, fillvalue=slice(None)): + if is_full_slice(idx): + non_adv_dims.append(value_dim) + else: + if value_dim is not None: + # We are trying to slice or index a batch dim + raise NotImplementedError( + f"Indexing of batch dimensions of the marginalized variable in node {node} not supported" + ) + if isinstance(idx, slice): + non_adv_dims.append(None) + + # Insert dims corresponding to advanced indexing among the remaining ones + output_dims = tuple( + non_adv_dims[adv_group_axis:] + + [None] * adv_group_ndim + + non_adv_dims[adv_group_axis:] + ) + + else: + # Indexing keys have batch dims. + # Only array indices can have batch_dims, the rest are just slices or new axis + + # Advanced indexing variables broadcast together, so we apply same rules as in Elemwise + # However indexing is implicit, so we have to add None dims + try: + adv_dims = _broadcast_dims( + tuple(key_dim for key_dim in keys_dims if key_dim is not None) + ) + except ValueError: + raise ValueError( + f"Index node {node} mixes batch dimensions of the marginalized variable" + ) + + start_non_adv_dims = (None,) * adv_group_axis + end_non_adv_dims = (None,) * ( + node.outputs[0].type.ndim - adv_group_axis - adv_group_ndim + ) + output_dims = start_non_adv_dims + adv_dims + end_non_adv_dims + + var_dims[node.outputs[0]] = output_dims + + else: + # TODO: Assert, SpecifyShape: easy, batch dims can only be on the input, + # TODO: Alloc: Easy, core batch dims stay in the same position (because we raised before for bcastable dims) + raise NotImplementedError(f"Marginalization through operation {node} not supported") + + return [var_dims[output_rv] for output_rv in output_rvs] def collect_shared_vars(outputs, blockers): @@ -657,14 +895,14 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs} if len(ndim_supp) != 1: raise NotImplementedError( - "Marginalization with dependent variables of different support dimensionality not implemented" + "Marginalization of withe dependent Multivariate RVs not implemented" ) [ndim_supp] = ndim_supp if ndim_supp > 0: raise NotImplementedError("Marginalization with dependent Multivariate RVs not implemented") marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) - dependent_rvs_input_rvs = [ + other_direct_rv_ancestors = [ rv for rv in find_conditional_input_rvs(dependent_rvs, all_rvs) if rv is not rv_to_marginalize @@ -676,14 +914,42 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs # can ultimately be generated that is proportional to the support domain and not # to the variables dimensions # We don't need to worry about this if the RV is scalar. - if np.prod(constant_fold(tuple(rv_to_marginalize.shape), raise_not_constant=False)) != 1: - if not is_elemwise_subgraph(rv_to_marginalize, dependent_rvs_input_rvs, dependent_rvs): - raise NotImplementedError( - "The subgraph between a marginalized RV and its dependents includes non Elemwise operations. " - "This is currently not supported", + if any(not bcast for bcast in rv_to_marginalize.type.broadcastable): + # When there are batch dimensions, we call `batch_dims_subgraph` to make sure these are not mixed + try: + dependent_rvs_batch_dims = subgraph_dims( + rv_to_marginalize, other_direct_rv_ancestors, dependent_rvs ) + except ValueError as err: + # This happens when information is mixed. From the user perspective this is a NotImplementedError + raise NotImplementedError from err + + # We further check that any extra batch dimensions of dependnt RVs beyond those implied by the MarginalizedRV + # show up on the left, so that collapsing logic in logp can be more straightforward easily. + # This also ensures the MarginalizedRV still behaves as an RV itself + marginal_batch_ndim = rv_to_marginalize.type.ndim - rv_to_marginalize.owner.op.ndim_supp + marginal_batch_dims = tuple(range(marginal_batch_ndim)) + for dependent_rv, dependent_rv_batch_dims in zip(dependent_rvs, dependent_rvs_batch_dims): + extra_batch_ndim = ( + dependent_rv.type.ndim - marginal_batch_ndim - dependent_rv.owner.op.ndim_supp + ) + valid_dependent_batch_dims = ((None,) * extra_batch_ndim) + marginal_batch_dims + if dependent_rv_batch_dims != valid_dependent_batch_dims: + raise NotImplementedError( + "Any extra batch dimensions introduced by dependent RVs must be " + "on the left of dimensions introduced by the marginalized RV" + ) + + for dependent_rv, dependent_rv_batch_dims in zip(dependent_rvs, dependent_rvs_batch_dims): + shared_batch_dims = [ + batch_dim for batch_dim in dependent_rv_batch_dims if batch_dim is not None + ] + if shared_batch_dims != sorted(shared_batch_dims): + raise NotImplementedError( + "Shared batch dimensions between marginalized RV and dependent RVs must be aligned positionally" + ) - input_rvs = [*marginalized_rv_input_rvs, *dependent_rvs_input_rvs] + input_rvs = [*marginalized_rv_input_rvs, *other_direct_rv_ancestors] rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs] outputs = rvs_to_marginalize diff --git a/tests/model/test_marginal_model.py b/tests/model/test_marginal_model.py index 93df0ec4..e2742b20 100644 --- a/tests/model/test_marginal_model.py +++ b/tests/model/test_marginal_model.py @@ -10,6 +10,7 @@ from arviz import InferenceData, dict_to_dataset from pymc.distributions import transforms +from pymc.distributions.transforms import ordered from pymc.logprob.abstract import _logprob from pymc.model.fgraph import fgraph_from_model from pymc.pytensorf import inputvars @@ -201,51 +202,133 @@ def test_nested_marginalized_rvs(): ) +def test_marginalized_index_as_value_and_key(): + """Test we can marginalize graphs were marginalized_rv is indexed.""" + + def build_model(batch: bool) -> MarginalModel: + with MarginalModel() as m: + if batch: + latent_state = pm.Bernoulli("latent_state", p=0.3, size=(4,)) + else: + latent_state = pm.math.stack( + [pm.Bernoulli(f"latent_state_{i}", p=0.3) for i in range(4)] + ) + # latent state is used as the indexed variable + latent_intensities = pt.where(latent_state[:, None], [0.0, 1.0, 2.0], [0.0, 10.0, 20.0]) + picked_intensity = pm.Categorical("picked_intensity", p=[0.2, 0.2, 0.6]) + # picked intensity is used as the indexing variable + pm.Normal( + "intensity", + mu=latent_intensities[:, picked_intensity], + observed=[0.5, 1.5, 5.0, 15.0], + ) + return m + + # We compare with the equivalent but less efficient batched model + m = build_model(batch=True) + ref_m = build_model(batch=False) + + m.marginalize(["latent_state"]) + ref_m.marginalize([f"latent_state_{i}" for i in range(4)]) + test_point = {"picked_intensity": 1} + np.testing.assert_allclose( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), + ) + + m.marginalize(["picked_intensity"]) + ref_m.marginalize(["picked_intensity"]) + test_point = {} + np.testing.assert_allclose( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), + ) + + +@pytest.mark.parametrize("advanced_indexing", (False, True)) +def test_marginalized_index_as_key(advanced_indexing): + """Test we can marginalize graphs where indexing is used as a mapping.""" + + w = [0.1, 0.3, 0.6] + mu = pt.as_tensor([-1, 0, 1]) + + if advanced_indexing: + y_val = pt.as_tensor([[-1, -1], [0, 1]]) + shape = (2, 2) + else: + y_val = -1 + shape = () + + with MarginalModel() as m: + x = pm.Categorical("x", p=w, shape=shape) + y = pm.Normal("y", mu[x], sigma=1, observed=y_val) + + m.marginalize(x) + + marginal_logp = m.compile_logp(sum=False)({})[0] + ref_logp = pm.logp(pm.NormalMixture.dist(w=w, mu=mu, sigma=1, shape=shape), y_val).eval() + + np.testing.assert_allclose(marginal_logp, ref_logp) + + def test_not_supported_marginalized(): - """Marginalized graphs with non-Elemwise Operations are not supported as they - would violate the batching logp assumption""" - mu = pt.constant([-1, 1]) + """Test lack of support for models where batch dims of marginalized variables are mixed.""" - # Allowed, as only elemwise operations connect idx to y with MarginalModel() as m: - p = pm.Beta("p", 1, 1) - idx = pm.Bernoulli("idx", p=p, size=2) - y = pm.Normal("y", mu=pm.math.switch(idx, 0, 1)) - m.marginalize([idx]) + idx = pm.Bernoulli("idx", p=0.7, shape=2) + y = pm.Normal("y", mu=idx @ idx.T) + with pytest.raises(NotImplementedError): + m.marginalize(idx) - # ALlowed, as index operation does not connext idx to y with MarginalModel() as m: - p = pm.Beta("p", 1, 1) - idx = pm.Bernoulli("idx", p=p, size=2) - y = pm.Normal("y", mu=pm.math.switch(idx, mu[0], mu[1])) - m.marginalize([idx]) + mean = pt.as_tensor([[0.1, 0.9], [0.6, 0.4]]) + idx = pm.Bernoulli("idx", p=0.7, shape=2) + y = pm.Normal("y", mu=mean[idx, :] + mean[:, idx]) + with pytest.raises(NotImplementedError): + m.marginalize(idx) - # Not allowed, as index operation connects idx to y with MarginalModel() as m: - p = pm.Beta("p", 1, 1) - idx = pm.Bernoulli("idx", p=p, size=2) - # Not allowed - y = pm.Normal("y", mu=mu[idx]) + mean = pt.as_tensor([[0.1, 0.9], [0.6, 0.4]]) + idx = pm.Bernoulli("idx", p=0.7, shape=2) + y = pm.Normal("y", mu=mean[idx, None] + mean[None, idx]) with pytest.raises(NotImplementedError): m.marginalize(idx) - # Not allowed, as index operation connects idx to y, even though there is a - # pure Elemwise connection between the two with MarginalModel() as m: - p = pm.Beta("p", 1, 1) - idx = pm.Bernoulli("idx", p=p, size=2) - y = pm.Normal("y", mu=mu[idx] + idx) + mean = pt.as_tensor([[0.1, 0.9], [0.6, 0.4]]) + idx = pm.Bernoulli("idx", p=0.7, shape=2) + mu = ( + # FIXME: PyTensor does not figure out this static broadcastings! + # FIXME: Specify broadcastable does not handle negative axis correctly + pt.specify_broadcastable(mean[:, None][idx], 1) + + pt.specify_broadcastable(mean[None, :][:, idx], 0) + ) + y = pm.Normal("y", mu=mu) with pytest.raises(NotImplementedError): m.marginalize(idx) - # Multivariate dependent RVs not supported with MarginalModel() as m: - x = pm.Bernoulli("x", p=0.7) - y = pm.Dirichlet("y", a=pm.math.switch(x, [1, 1, 1], [10, 10, 10])) - with pytest.raises( - NotImplementedError, - match="Marginalization with dependent Multivariate RVs not implemented", - ): + idx = pm.Bernoulli("idx", p=0.7, shape=2) + y = pm.Normal("y", mu=idx[0] + idx[1]) + with pytest.raises(NotImplementedError): + m.marginalize(idx) + + with MarginalModel() as m: + idx = pm.Bernoulli("idx", p=0.7, shape=2) + y = pm.Normal("y", mu=idx[[0, 1, 0, 0]]) + with pytest.raises(NotImplementedError): + m.marginalize(idx) + + with MarginalModel() as m: + idx = pm.Categorical("key", p=[0.1, 0.3, 0.6], shape=(2, 2)) + y = pm.Normal("y", pt.as_tensor([[0, 1], [2, 3]])[idx.astype(bool)]) + with pytest.raises(NotImplementedError): + m.marginalize(idx) + + with MarginalModel() as m: + x = pm.Bernoulli("x", p=0.7, shape=3) + y = pm.Dirichlet("y", a=x * 10 + 1) + with pytest.raises(NotImplementedError): m.marginalize(x) @@ -641,6 +724,61 @@ def test_change_point_model_sampling(self, disaster_model): rtol=1e-2, ) + def test_k_censored_clusters_model(self): + def build_model(batch: bool) -> MarginalModel: + data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]]).T + nobs = data.shape[-1] + n_clusters = 5 + coords = { + "cluster": range(n_clusters), + "ndim": ("x", "y"), + "obs": range(nobs), + } + with MarginalModel(coords=coords) as m: + if batch: + idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"]) + else: + idx = pm.math.stack( + [ + pm.Categorical(f"idx_{i}", p=np.ones(n_clusters) / n_clusters) + for i in range(nobs) + ] + ) + + mu_x = pm.Normal( + "mu_x", + dims=["cluster"], + transform=ordered, + initval=np.linspace(-1, 1, n_clusters), + ) + mu_y = pm.Normal("mu_y", dims=["cluster"]) + mu = pm.math.concatenate([mu_x[None], mu_y[None]], axis=0) # (ndim, cluster) + + sigma = pm.HalfNormal("sigma") + + y = pm.Censored( + "y", + dist=pm.Normal.dist(mu[:, idx], sigma), + lower=-3, + upper=3, + observed=data, + dims=["ndim", "obs"], + ) + + return m + + m = build_model(batch=True) + ref_m = build_model(batch=False) + + m.marginalize([m["idx"]]) + ref_m.marginalize([n for n in ref_m.named_vars if n.startswith("idx_")]) + + test_point = m.initial_point() + np.testing.assert_almost_equal( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), + ) + class TestRecoverMarginals: def test_basic(self): @@ -794,3 +932,8 @@ def true_sub_idx_logp(y): ) np.testing.assert_almost_equal(logsumexp(post.lp_idx, axis=-1), 0) np.testing.assert_almost_equal(logsumexp(post.lp_sub_idx, axis=-1), 0) + + +class TestSubgraphDims: + def test_baisc(self): + raise NotImplementedError("Write tests")