Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in the Sample function when running with num_samples > 1 #496

Closed
liunelson opened this issue Feb 22, 2024 · 2 comments
Closed

Error in the Sample function when running with num_samples > 1 #496

liunelson opened this issue Feb 22, 2024 · 2 comments

Comments

@liunelson
Copy link
Contributor

I'm trying to run a baseline scenario to compare with the results of an Optimize run.

MODELS_PATH = "https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/"
model3 = os.path.join(MODELS_PATH, "SIR_stockflow.json")

start_time = 0.0
end_time = 50.0
logging_step_size = 1.0
num_samples = 1

results_baseline = pyciemss.sample(
    model3, 
    end_time, 
    logging_step_size, 
    num_samples, 
    start_time = start_time, 
    # static_parameter_interventions = {torch.tensor(0.0): {intervened_params: torch.tensor(0.35)}},
    solver_method = "euler"
)

This function call works but it raises this error whenever num_samples is not 1.
However, I can get it to run without error using num_samples = 100 for example if I supply a static_parameter_interventions.

ERROR:root:
                ###############################

                There was an exception in pyciemss

                Error occured in function: sample

                Function docs : 
    Load a model from a file, compile it into a probabilistic program, and sample from it.

    Args:
        model_path_or_json: Union[str, Dict]
            - A path to a AMR model file or JSON containing a model in AMR form.
        end_time: float
            - The end time of the sampled simulation.
        logging_step_size: float
            - The step size to use for logging the trajectory.
        num_samples: int
            - The number of samples to draw from the model.
        solver_method: str
            - The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details.
            - If performance is incredibly slow, we suggest using `euler` to debug.
              If using `euler` results in faster simulation, the issue is likely that the model is stiff.
        solver_options: Dict[str, Any]
            - Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
        start_time: float
            - The start time of the model. This is used to align the `start_state` from the
              AMR model with the simulation timepoints.
            - By default we set the `start_time` to be 0.
        inferred_parameters: Optional[pyro.nn.PyroModule]
            - A Pyro module that contains the inferred parameters of the model.
              This is typically the result of `calibrate`.
            - If not provided, we will use the default values from the AMR model.
        static_state_interventions: Dict[float, Dict[str, Intervention]]
            - A dictionary of static interventions to apply to the model.
            - Each key is the time at which the intervention is applied.
            - Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        static_parameter_interventions: Dict[float, Dict[str, Intervention]]
            - A dictionary of static interventions to apply to the model.
            - Each key is the time at which the intervention is applied.
            - Each value is a dictionary of the form {parameter_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        dynamic_state_interventions: Dict[
                                        Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
                                        Dict[str, Intervention]
                                        ]
            - A dictionary of dynamic interventions to apply to the model.
            - Each key is a function that takes in the current state of the model and returns a tensor.
              When this function crosses 0, the dynamic intervention is applied.
            - Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        dynamic_parameter_interventions: Dict[
                                            Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
                                            Dict[str, Intervention]
                                            ]
            - A dictionary of dynamic interventions to apply to the model.
            - Each key is a function that takes in the current state of the model and returns a tensor.
              When this function crosses 0, the dynamic intervention is applied.
            - Each value is a dictionary of the form {parameter_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.

    Returns:
        result: Dict[str, torch.Tensor]
            - Dictionary of outputs from the model.
                - Each key is the name of a parameter or state variable in the model.
                - Each value is a tensor of shape (num_samples, num_timepoints) for state variables
                    and (num_samples,) for parameters.
    

                ################################
            
Traceback (most recent call last):
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/nliu/projects/askem/pyciemss/pyciemss/interfaces.py", line 282, in wrapped_model
    full_trajectory = model(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in __call__
    result = super().__call__(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nliu/projects/askem/pyciemss/pyciemss/compiled_dynamics.py", line 77, in forward
    simulate(self.deriv, self.initial_state(), start_time, end_time)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate
    state, start_time, next_interruption = simulate_to_interruption(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption
    msg["value"] = torchdiffeq_simulate_to_interruption(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 248, in torchdiffeq_simulate_to_interruption
    value = simulate_point(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point
    trajectory: State[T] = simulate_trajectory(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory
    msg["value"] = torchdiffeq_simulate_trajectory(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 165, in torchdiffeq_simulate_trajectory
    return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inner
    solns = _batched_odeint(  # torchdiffeq.odeint(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 123, in _batched_odeint
    yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 77, in odeint
    solution = solver.integrate(t)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py", line 106, in integrate
    y1 = y0 + dy
RuntimeError: The size of tensor a (3) must match the size of tensor b (6) at non-singleton dimension 0

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/nliu/projects/askem/pyciemss/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped
    result = function(*args, **kwargs)
  File "/home/nliu/projects/askem/pyciemss/pyciemss/interfaces.py", line 298, in sample
    samples = pyro.infer.Predictive(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward
    return _predictive(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/infer/predictive.py", line 137, in _predictive
    trace = poutine.trace(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 180, in __call__
    raise exc from e
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/nliu/projects/askem/pyciemss/pyciemss/interfaces.py", line 282, in wrapped_model
    full_trajectory = model(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in __call__
    result = super().__call__(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nliu/projects/askem/pyciemss/pyciemss/compiled_dynamics.py", line 77, in forward
    simulate(self.deriv, self.initial_state(), start_time, end_time)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate
    state, start_time, next_interruption = simulate_to_interruption(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption
    msg["value"] = torchdiffeq_simulate_to_interruption(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 248, in torchdiffeq_simulate_to_interruption
    value = simulate_point(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point
    trajectory: State[T] = simulate_trajectory(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory
    msg["value"] = torchdiffeq_simulate_trajectory(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 165, in torchdiffeq_simulate_trajectory
    return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inner
    solns = _batched_odeint(  # torchdiffeq.odeint(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 123, in _batched_odeint
    yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 77, in odeint
    solution = solver.integrate(t)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py", line 106, in integrate
    y1 = y0 + dy
RuntimeError: The size of tensor a (3) must match the size of tensor b (6) at non-singleton dimension 0
                               Trace Shapes:    
                                Param Sites:    
numeric_deriv_func$$$_nodes.0._args.0._value    
numeric_deriv_func$$$_nodes.1._args.0._value    
                               Sample Sites:    
                     persistent_p_cbeta dist 2 |
                                       value 2 |
                        persistent_p_tr dist 2 |
                                       value 2 |
@SamWitty
Copy link
Contributor

@liunelson , are you using the version of pyciemss on main or the last tagged release? I believe we addressed this issue with #491. If that doesn't do it, there should be a more robust upstream fix coming from ChiRho which was just merged in this morning. BasisResearch/chirho#525

@liunelson
Copy link
Contributor Author

@SamWitty Just updating here that your fix worked to resolve this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants