Skip to content

Commit

Permalink
Added def instead of lambda function. Still getting recursive call
Browse files Browse the repository at this point in the history
  • Loading branch information
djinnome committed Nov 4, 2024
1 parent e6fa616 commit 2823e7c
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 14 deletions.
4 changes: 2 additions & 2 deletions docs/source/hierarchical_sir_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 22,
"metadata": {},
"outputs": [
{
Expand All @@ -18,7 +18,7 @@
"['gamma_mean', 'gamma', 'beta_mean', 'beta']"
]
},
"execution_count": 17,
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
Expand Down
2 changes: 1 addition & 1 deletion pyciemss/compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CompiledDynamics(pyro.nn.PyroModule):
def __init__(self, src, **kwargs):
super().__init__()
self.src = src

params = _compile_param_values(self.src)
try:
params = _compile_param_values(self.src)
except Exception as e:
Expand Down
57 changes: 48 additions & 9 deletions pyciemss/mira_integration/compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numbers
from typing import Callable, Dict, Optional, Tuple, TypeVar, Union

import networkx as nx
import mira
import mira.metamodel
import mira.modeling
Expand All @@ -27,13 +28,44 @@
eval_observables,
get_name,
)
from pyciemss.mira_integration.distributions import (
mira_distribution_to_pyro
)

from pyciemss.mira_integration.distributions import mira_distribution_to_pyro

S = TypeVar("S")
T = TypeVar("T")

@_sort_dependencies.register(mira.modeling.Model)
def sort_mira_dependencies(src: mira.modeling.Model) -> list:
"""
Sort the model parameters of a MIRA TemplateModel by their distribution parameter dependencies.
Parameters
----------
src : mira.modeling.Model
The MIRA Model to sort.
Returns
-------
list
A list of parameter names in the order in which they must be evaluated.
"""
dependencies = nx.DiGraph()
for param_name, param_info in src.parameters.items():
param_name = get_name(param_info)
#if param_info.placeholder:
# continue
param_dist = getattr(param_info, "distribution", None)

if param_dist is None:
dependencies.add_node(param_name)
else:
for k, v in param_dist.parameters.items():
# Check to see if the distribution parameters are sympy expressions
# and add their free symbols to the dependency graph
if isinstance(v, mira.metamodel.utils.SympyExprStr):
for free_symbol in v.free_symbols:
dependencies.add_edge(str(free_symbol), str(param_name))
return list(nx.topological_sort(dependencies))

@_compile_deriv.register(mira.modeling.Model)
def _compile_deriv_mira(src: mira.modeling.Model) -> Callable[..., Tuple[torch.Tensor]]:
Expand Down Expand Up @@ -94,20 +126,24 @@ def _compile_param_values_mira(
param_info = src.parameters[param_name]
if param_info.placeholder:
continue

param_dist = getattr(param_info, "distribution", None)
if param_dist is None:
param_value = float(param_info.value)
else:
idx = sorted_dependencies.index(param_name)
param_value = lambda self: mira_distribution_to_pyro(param_dist, {
k: getattr(self, f"persistent_{k}") for k in sorted_dependencies[:idx]
}
)
upstream_dependencies = sorted_dependencies[:idx]

def param_value(model: pyro.nn.PyroModule) -> torch.Tensor:
return mira_distribution_to_pyro(
param_dist, {
k: getattr(model, f"persistent_{k}")
for k in upstream_dependencies
})

if isinstance(param_value, torch.nn.Parameter):
param_values[param_name] = pyro.nn.PyroParam(param_value)
elif isinstance(param_value, pyro.distributions.distribution.Distribution):
elif isinstance(param_value, (pyro.distributions.distribution.Distribution, Callable)):
param_values[param_name] = pyro.nn.PyroSample(param_value)
elif isinstance(param_value, (numbers.Number, numpy.ndarray, torch.Tensor)):
param_values[param_name] = torch.as_tensor(param_value, dtype=torch.float32)
Expand Down Expand Up @@ -208,6 +244,9 @@ def _get_name_mira_transition(trans: mira.modeling.Transition) -> str:
def _get_name_mira_modelparameter(param: mira.modeling.ModelParameter) -> str:
return str(param.key)

@get_name.register
def _get_name_mira_metamodel_parameter(param: mira.metamodel.Parameter) -> str:
return str(param.name)

@get_name.register
def _get_name_mira_model_observable(obs: mira.modeling.ModelObservable) -> str:
Expand Down
8 changes: 6 additions & 2 deletions tests/test_compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
from mira.sources.amr import model_from_url
from pyro.infer.inspect import get_dependencies

from pyciemss.compiled_dynamics import CompiledDynamics
from pyciemss.mira_integration.distributions import sort_mira_dependencies
from pyciemss.compiled_dynamics import (
CompiledDynamics,
)

from pyciemss.mira_integration.compiled_dynamics import sort_mira_dependencies


from .fixtures import (
ACYCLIC_MODELS,
Expand Down

0 comments on commit 2823e7c

Please sign in to comment.