From 329bd5866c28918f539c0e7631b2cf6f9516f738 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Wed, 26 Jun 2024 13:54:14 -0400 Subject: [PATCH] Fix concat logic --- dynamax/nonlinear_gaussian_ssm/sarkka_lib.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py b/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py index 46e4d871..ff65405d 100644 --- a/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py +++ b/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py @@ -64,8 +64,8 @@ def _step(carry, t): carry = (m_post[-1], P_post[-1]) _, (m_sm, P_sm) = lax.scan(_step, carry, jnp.arange(num_timesteps - 1), reverse=True) - m_sm = jnp.concatenate((jnp.array([m_post[-1]]), m_sm)) - P_sm = jnp.concatenate((jnp.array([P_post[-1]]), P_sm)) + m_sm = jnp.concatenate((m_sm, jnp.array([m_post[-1]]))) + P_sm = jnp.concatenate((P_sm, jnp.array([P_post[-1]]))) return m_sm, P_sm @@ -197,7 +197,7 @@ def compute_sigmas(m, P, n, lamb): carry = (m_post[-1], P_post[-1]) _, (m_sm, P_sm) = lax.scan(_step, carry, jnp.arange(num_timesteps - 1), reverse=True) - m_sm = jnp.concatenate((jnp.array([m_post[-1]]), m_sm)) - P_sm = jnp.concatenate((jnp.array([P_post[-1]]), P_sm)) + m_sm = jnp.concatenate((m_sm, jnp.array([m_post[-1]]))) + P_sm = jnp.concatenate((P_sm, jnp.array([P_post[-1]]))) - return m_sm, P_sm \ No newline at end of file + return m_sm, P_sm