Skip to content

Commit

Permalink
Add parameters to observables (#580)
Browse files Browse the repository at this point in the history
* add parameters to observables

* lint

* added test

* progress towards broadcasting

* lint
  • Loading branch information
SamWitty authored Jun 6, 2024
1 parent 7967dfa commit 422232c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
17 changes: 16 additions & 1 deletion pyciemss/mira_integration/compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,22 @@ def _eval_observables_mira(
if len(src.observables) == 0:
return dict()

numeric_observables = param_module.numeric_observables_func(**X)
parameters = {
get_name(param_info): getattr(param_module, get_name(param_info))
for param_info in src.parameters.values()
if not param_info.placeholder
}

# TODO: support event_dim > 0 upstream in ChiRho
# Default to time being the rightmost dimension
parameters_expanded = {
k: torch.unsqueeze(v, -1) if len(v.size()) > 0 else v
for k, v in parameters.items()
}

numeric_observables = param_module.numeric_observables_func(
**X, **parameters_expanded
)

observables: State[torch.Tensor] = dict()
for i, obs in enumerate(src.observables.values()):
Expand Down
7 changes: 7 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def __init__(
),
"u",
),
ModelFixture(
os.path.join(MODELS_PATH, "SIR_param_in_observables.json"),
"beta",
os.path.join(DATA_PATH, "SIR_data_case_hosp.csv"),
{"case": "incident_cases", "hosp": "I"},
True,
),
]

REGNET_MODELS = [
Expand Down

0 comments on commit 422232c

Please sign in to comment.