Skip to content

Commit

Permalink
Merge pull request #368 from edeno/main
Browse files Browse the repository at this point in the history
Fix concat logic in uks, eks
  • Loading branch information
slinderman authored Jun 26, 2024
2 parents 9b3fb2f + 329bd58 commit fceca0e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions dynamax/nonlinear_gaussian_ssm/sarkka_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def _step(carry, t):

carry = (m_post[-1], P_post[-1])
_, (m_sm, P_sm) = lax.scan(_step, carry, jnp.arange(num_timesteps - 1), reverse=True)
m_sm = jnp.concatenate((jnp.array([m_post[-1]]), m_sm))
P_sm = jnp.concatenate((jnp.array([P_post[-1]]), P_sm))
m_sm = jnp.concatenate((m_sm, jnp.array([m_post[-1]])))
P_sm = jnp.concatenate((P_sm, jnp.array([P_post[-1]])))

return m_sm, P_sm

Expand Down Expand Up @@ -197,7 +197,7 @@ def compute_sigmas(m, P, n, lamb):

carry = (m_post[-1], P_post[-1])
_, (m_sm, P_sm) = lax.scan(_step, carry, jnp.arange(num_timesteps - 1), reverse=True)
m_sm = jnp.concatenate((jnp.array([m_post[-1]]), m_sm))
P_sm = jnp.concatenate((jnp.array([P_post[-1]]), P_sm))
m_sm = jnp.concatenate((m_sm, jnp.array([m_post[-1]])))
P_sm = jnp.concatenate((P_sm, jnp.array([P_post[-1]])))

return m_sm, P_sm
return m_sm, P_sm

0 comments on commit fceca0e

Please sign in to comment.