You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This function call works but it raises this error whenever num_samples is not 1.
However, I can get it to run without error using num_samples = 100 for example if I supply a static_parameter_interventions.
ERROR:root:
###############################
There was an exception in pyciemss
Error occured in function: sample
Function docs :
Load a model from a file, compile it into a probabilistic program, and sample from it.
Args:
model_path_or_json: Union[str, Dict]
- A path to a AMR model file or JSON containing a model in AMR form.
end_time: float
- The end time of the sampled simulation.
logging_step_size: float
- The step size to use for logging the trajectory.
num_samples: int
- The number of samples to draw from the model.
solver_method: str
- The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details.
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
- By default we set the `start_time` to be 0.
inferred_parameters: Optional[pyro.nn.PyroModule]
- A Pyro module that contains the inferred parameters of the model.
This is typically the result of `calibrate`.
- If not provided, we will use the default values from the AMR model.
static_state_interventions: Dict[float, Dict[str, Intervention]]
- A dictionary of static interventions to apply to the model.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
static_parameter_interventions: Dict[float, Dict[str, Intervention]]
- A dictionary of static interventions to apply to the model.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {parameter_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
dynamic_state_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention]
]
- A dictionary of dynamic interventions to apply to the model.
- Each key is a function that takes in the current state of the model and returns a tensor.
When this function crosses 0, the dynamic intervention is applied.
- Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
dynamic_parameter_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention]
]
- A dictionary of dynamic interventions to apply to the model.
- Each key is a function that takes in the current state of the model and returns a tensor.
When this function crosses 0, the dynamic intervention is applied.
- Each value is a dictionary of the form {parameter_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
Returns:
result: Dict[str, torch.Tensor]
- Dictionary of outputs from the model.
- Each key is the name of a parameter or state variable in the model.
- Each value is a tensor of shape (num_samples, num_timepoints) for state variables
and (num_samples,) for parameters.
################################
Traceback (most recent call last):
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
ret = self.fn(*args, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "/home/nliu/projects/askem/pyciemss/pyciemss/interfaces.py", line 282, in wrapped_model
full_trajectory = model(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in __call__
result = super().__call__(*args, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/nliu/projects/askem/pyciemss/pyciemss/compiled_dynamics.py", line 77, in forward
simulate(self.deriv, self.initial_state(), start_time, end_time)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
apply_stack(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
frame._process_message(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
return method(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate
state, start_time, next_interruption = simulate_to_interruption(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
apply_stack(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
frame._process_message(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
return method(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption
msg["value"] = torchdiffeq_simulate_to_interruption(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 248, in torchdiffeq_simulate_to_interruption
value = simulate_point(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
apply_stack(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
frame._process_message(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
return method(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point
trajectory: State[T] = simulate_trajectory(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
apply_stack(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
frame._process_message(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
return method(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory
msg["value"] = torchdiffeq_simulate_trajectory(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 165, in torchdiffeq_simulate_trajectory
return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inner
solns = _batched_odeint( # torchdiffeq.odeint(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 123, in _batched_odeint
yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 77, in odeint
solution = solver.integrate(t)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py", line 106, in integrate
y1 = y0 + dy
RuntimeError: The size of tensor a (3) must match the size of tensor b (6) at non-singleton dimension 0
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/nliu/projects/askem/pyciemss/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped
result = function(*args, **kwargs)
File "/home/nliu/projects/askem/pyciemss/pyciemss/interfaces.py", line 298, in sample
samples = pyro.infer.Predictive(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward
return _predictive(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/infer/predictive.py", line 137, in _predictive
trace = poutine.trace(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace
self(*args, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 180, in __call__
raise exc from e
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
ret = self.fn(*args, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "/home/nliu/projects/askem/pyciemss/pyciemss/interfaces.py", line 282, in wrapped_model
full_trajectory = model(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in __call__
result = super().__call__(*args, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/nliu/projects/askem/pyciemss/pyciemss/compiled_dynamics.py", line 77, in forward
simulate(self.deriv, self.initial_state(), start_time, end_time)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
apply_stack(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
frame._process_message(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
return method(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate
state, start_time, next_interruption = simulate_to_interruption(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
apply_stack(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
frame._process_message(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
return method(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption
msg["value"] = torchdiffeq_simulate_to_interruption(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 248, in torchdiffeq_simulate_to_interruption
value = simulate_point(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
apply_stack(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
frame._process_message(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
return method(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point
trajectory: State[T] = simulate_trajectory(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
apply_stack(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
frame._process_message(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
return method(msg)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory
msg["value"] = torchdiffeq_simulate_trajectory(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 165, in torchdiffeq_simulate_trajectory
return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inner
solns = _batched_odeint( # torchdiffeq.odeint(
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 123, in _batched_odeint
yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 77, in odeint
solution = solver.integrate(t)
File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py", line 106, in integrate
y1 = y0 + dy
RuntimeError: The size of tensor a (3) must match the size of tensor b (6) at non-singleton dimension 0
Trace Shapes:
Param Sites:
numeric_deriv_func$$$_nodes.0._args.0._value
numeric_deriv_func$$$_nodes.1._args.0._value
Sample Sites:
persistent_p_cbeta dist 2 |
value 2 |
persistent_p_tr dist 2 |
value 2 |
The text was updated successfully, but these errors were encountered:
@liunelson , are you using the version of pyciemss on main or the last tagged release? I believe we addressed this issue with #491. If that doesn't do it, there should be a more robust upstream fix coming from ChiRho which was just merged in this morning. BasisResearch/chirho#525
I'm trying to run a baseline scenario to compare with the results of an Optimize run.
This function call works but it raises this error whenever
num_samples
is not1
.However, I can get it to run without error using
num_samples = 100
for example if I supply astatic_parameter_interventions
.The text was updated successfully, but these errors were encountered: