From 2823e7cba8c34be9a04e4a85295eb318679e3688 Mon Sep 17 00:00:00 2001 From: "Zucker, Jeremy D" Date: Sun, 3 Nov 2024 18:33:17 -0800 Subject: [PATCH] Added def instead of lambda function. Still getting recursive call --- docs/source/hierarchical_sir_model.ipynb | 4 +- pyciemss/compiled_dynamics.py | 2 +- .../mira_integration/compiled_dynamics.py | 57 ++++++++++++++++--- tests/test_compiled_dynamics.py | 8 ++- 4 files changed, 57 insertions(+), 14 deletions(-) diff --git a/docs/source/hierarchical_sir_model.ipynb b/docs/source/hierarchical_sir_model.ipynb index 3236a1dd..9cb0b245 100644 --- a/docs/source/hierarchical_sir_model.ipynb +++ b/docs/source/hierarchical_sir_model.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -18,7 +18,7 @@ "['gamma_mean', 'gamma', 'beta_mean', 'beta']" ] }, - "execution_count": 17, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } diff --git a/pyciemss/compiled_dynamics.py b/pyciemss/compiled_dynamics.py index cb3b630e..8d218274 100644 --- a/pyciemss/compiled_dynamics.py +++ b/pyciemss/compiled_dynamics.py @@ -21,7 +21,7 @@ class CompiledDynamics(pyro.nn.PyroModule): def __init__(self, src, **kwargs): super().__init__() self.src = src - + params = _compile_param_values(self.src) try: params = _compile_param_values(self.src) except Exception as e: diff --git a/pyciemss/mira_integration/compiled_dynamics.py b/pyciemss/mira_integration/compiled_dynamics.py index 2d13146f..75ef7072 100644 --- a/pyciemss/mira_integration/compiled_dynamics.py +++ b/pyciemss/mira_integration/compiled_dynamics.py @@ -3,6 +3,7 @@ import numbers from typing import Callable, Dict, Optional, Tuple, TypeVar, Union +import networkx as nx import mira import mira.metamodel import mira.modeling @@ -27,13 +28,44 @@ eval_observables, get_name, ) -from pyciemss.mira_integration.distributions import ( - mira_distribution_to_pyro -) + +from pyciemss.mira_integration.distributions import mira_distribution_to_pyro S = TypeVar("S") T = TypeVar("T") +@_sort_dependencies.register(mira.modeling.Model) +def sort_mira_dependencies(src: mira.modeling.Model) -> list: + """ + Sort the model parameters of a MIRA TemplateModel by their distribution parameter dependencies. + + Parameters + ---------- + src : mira.modeling.Model + The MIRA Model to sort. + + Returns + ------- + list + A list of parameter names in the order in which they must be evaluated. + """ + dependencies = nx.DiGraph() + for param_name, param_info in src.parameters.items(): + param_name = get_name(param_info) + #if param_info.placeholder: + # continue + param_dist = getattr(param_info, "distribution", None) + + if param_dist is None: + dependencies.add_node(param_name) + else: + for k, v in param_dist.parameters.items(): + # Check to see if the distribution parameters are sympy expressions + # and add their free symbols to the dependency graph + if isinstance(v, mira.metamodel.utils.SympyExprStr): + for free_symbol in v.free_symbols: + dependencies.add_edge(str(free_symbol), str(param_name)) + return list(nx.topological_sort(dependencies)) @_compile_deriv.register(mira.modeling.Model) def _compile_deriv_mira(src: mira.modeling.Model) -> Callable[..., Tuple[torch.Tensor]]: @@ -94,20 +126,24 @@ def _compile_param_values_mira( param_info = src.parameters[param_name] if param_info.placeholder: continue - + param_dist = getattr(param_info, "distribution", None) if param_dist is None: param_value = float(param_info.value) else: idx = sorted_dependencies.index(param_name) - param_value = lambda self: mira_distribution_to_pyro(param_dist, { - k: getattr(self, f"persistent_{k}") for k in sorted_dependencies[:idx] - } - ) + upstream_dependencies = sorted_dependencies[:idx] + + def param_value(model: pyro.nn.PyroModule) -> torch.Tensor: + return mira_distribution_to_pyro( + param_dist, { + k: getattr(model, f"persistent_{k}") + for k in upstream_dependencies + }) if isinstance(param_value, torch.nn.Parameter): param_values[param_name] = pyro.nn.PyroParam(param_value) - elif isinstance(param_value, pyro.distributions.distribution.Distribution): + elif isinstance(param_value, (pyro.distributions.distribution.Distribution, Callable)): param_values[param_name] = pyro.nn.PyroSample(param_value) elif isinstance(param_value, (numbers.Number, numpy.ndarray, torch.Tensor)): param_values[param_name] = torch.as_tensor(param_value, dtype=torch.float32) @@ -208,6 +244,9 @@ def _get_name_mira_transition(trans: mira.modeling.Transition) -> str: def _get_name_mira_modelparameter(param: mira.modeling.ModelParameter) -> str: return str(param.key) +@get_name.register +def _get_name_mira_metamodel_parameter(param: mira.metamodel.Parameter) -> str: + return str(param.name) @get_name.register def _get_name_mira_model_observable(obs: mira.modeling.ModelObservable) -> str: diff --git a/tests/test_compiled_dynamics.py b/tests/test_compiled_dynamics.py index a5bf5eaa..f07c552c 100644 --- a/tests/test_compiled_dynamics.py +++ b/tests/test_compiled_dynamics.py @@ -10,8 +10,12 @@ from mira.sources.amr import model_from_url from pyro.infer.inspect import get_dependencies -from pyciemss.compiled_dynamics import CompiledDynamics -from pyciemss.mira_integration.distributions import sort_mira_dependencies +from pyciemss.compiled_dynamics import ( + CompiledDynamics, +) + +from pyciemss.mira_integration.compiled_dynamics import sort_mira_dependencies + from .fixtures import ( ACYCLIC_MODELS,