diff --git a/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py b/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py index 46e4d871..ff65405d 100644 --- a/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py +++ b/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py @@ -64,8 +64,8 @@ def _step(carry, t): carry = (m_post[-1], P_post[-1]) _, (m_sm, P_sm) = lax.scan(_step, carry, jnp.arange(num_timesteps - 1), reverse=True) - m_sm = jnp.concatenate((jnp.array([m_post[-1]]), m_sm)) - P_sm = jnp.concatenate((jnp.array([P_post[-1]]), P_sm)) + m_sm = jnp.concatenate((m_sm, jnp.array([m_post[-1]]))) + P_sm = jnp.concatenate((P_sm, jnp.array([P_post[-1]]))) return m_sm, P_sm @@ -197,7 +197,7 @@ def compute_sigmas(m, P, n, lamb): carry = (m_post[-1], P_post[-1]) _, (m_sm, P_sm) = lax.scan(_step, carry, jnp.arange(num_timesteps - 1), reverse=True) - m_sm = jnp.concatenate((jnp.array([m_post[-1]]), m_sm)) - P_sm = jnp.concatenate((jnp.array([P_post[-1]]), P_sm)) + m_sm = jnp.concatenate((m_sm, jnp.array([m_post[-1]]))) + P_sm = jnp.concatenate((P_sm, jnp.array([P_post[-1]]))) - return m_sm, P_sm \ No newline at end of file + return m_sm, P_sm