Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coerce correct tensor datatypes #567

Closed
mwdchang opened this issue Apr 9, 2024 · 3 comments · Fixed by #604
Closed

coerce correct tensor datatypes #567

mwdchang opened this issue Apr 9, 2024 · 3 comments · Fixed by #604
Labels
good first issue Good for newcomers

Comments

@mwdchang
Copy link

mwdchang commented Apr 9, 2024

It would be nice if the library can coerce to the datatypes it wants to use, so at the input level we don't need to explicitly distinguish between integers, longs, floats ... etc.

This is one instance of an error trace we saw, where there was a type mismatch (probably the population number) of a long-type but the lib wants a float-type:

ERROR:rq.worker:[Job f9e4e180-1853-49cb-8dd1-a680445bf893]: exception raised while executing (execute.run)
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/rq/worker.py", line 1428, in perform_job
    rv = job.perform()
  File "/usr/local/lib/python3.10/site-packages/rq/job.py", line 1278, in perform
    self._result = self._execute()
  File "/usr/local/lib/python3.10/site-packages/rq/job.py", line 1315, in _execute
    result = self.func(*self.args, **self.kwargs)
  File "/service/./execute.py", line 37, in run
    output = eval(operation_name)(**kwargs)
  File "/usr/local/lib/python3.10/site-packages/pyciemss/integration_utils/custom_decorators.py", line 29, in wrapped
    raise e
  File "/usr/local/lib/python3.10/site-packages/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped
    result = function(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pyciemss/interfaces.py", line 508, in sample
    samples = pyro.infer.Predictive(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward
    return _predictive(
  File "/usr/local/lib/python3.10/site-packages/pyro/infer/predictive.py", line 78, in _predictive
    max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
  File "/usr/local/lib/python3.10/site-packages/pyro/infer/predictive.py", line 21, in _guess_max_plate_nesting
    model_trace = poutine.trace(model).get_trace(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pyciemss/interfaces.py", line 481, in wrapped_model
    full_trajectory = model(
  File "/usr/local/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in __call__
    result = super().__call__(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pyciemss/compiled_dynamics.py", line 124, in forward
    simulate(self.deriv, self.initial_state(), start_time, end_time)
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate
    state, start_time, next_interruption = simulate_to_interruption(
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption
    msg["value"] = torchdiffeq_simulate_to_interruption(
  File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 252, in torchdiffeq_simulate_to_interruption
    value = simulate_point(
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point
    trajectory: State[T] = simulate_trajectory(
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory
    msg["value"] = torchdiffeq_simulate_trajectory(
  File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 169, in torchdiffeq_simulate_trajectory
    return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inner
    solns = _batched_odeint(  # torchdiffeq.odeint(
  File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 127, in _batched_odeint
    yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs)
  File "/usr/local/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 72, in odeint
    shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS)
  File "/usr/local/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py", line 213, in _check_inputs
    _assert_floating('y0', y0)
  File "/usr/local/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py", line 106, in _assert_floating
    raise TypeError('`{}` must be a floating point Tensor but is a {}'.format(name, t.type()))
TypeError: `y0` must be a floating point Tensor but is a torch.LongTensor
@SamWitty
Copy link
Contributor

@mwdchang , could you please share the AMR that was used to generate this error? I'd like to include that in the tests when we resolve this issue. Thanks!

@mwdchang
Copy link
Author

@mwdchang , could you please share the AMR that was used to generate this error? I'd like to include that in the tests when we resolve this issue. Thanks!

@SamWitty I am just passing on the error from Pascale so the exact circumstance of how the error was generated may need to be taken with a grain of salt. From what I can gather, this is a a simulation with params in the following structure:

INFO:root:{'id': 'f9e4e180-1853-49cb-8dd1-a680445bf893', 'execution_payload': {'engine': 'ciemss', 'user_id': 'not_provided', 'model_config_id': '2162a841-a94f-4c86-8c7b-e34d2c934047', 'timespan': {'start': 0.0, 'end': 100.0}, 'interventions': [], 'step_size': 1.0, 'extra': {'num_samples': 10010, 'inferred_parameters': None}}, 'name': 'ff8f7867-b2ae-409d-b913-d23a84dc478e', 'description': None, 'result_files': [], 'type': 'SIMULATION', 'status': 'QUEUED', 'status_message': None, 'start_time': None, 'completed_time': None, 'engine': 'CIEMSS', 'workflow_id': 'ff8f7867-b2ae-409d-b913-d23a84dc478e', 'user_id': None, 'project_id': None, 'created_on': None, 'updated_on': None, 'deleted_on': None}
DEBUG:rq.queue:Pushed job f9e4e180-1853-49cb-8dd1-a680445bf893 into default

And the corresponding model is the one attached below:

sam-apr-11.json

@djinnome
Copy link
Contributor

Hi folks, it appears that this error occurs when you provide a long instead of a float. We have some tooling for catching these issues, but maybe we need to do some more robustifying.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants