From 0615c9587294b3df746b80d5c20eab891d580c5e Mon Sep 17 00:00:00 2001 From: gileshd Date: Mon, 15 Jul 2024 22:49:55 +0100 Subject: [PATCH] Fix import path for `tree_map` `jax.tree_map` has been deprecated, see: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-26-april-3-2024 --- docs/notebooks/slds/rbpf_maneuver.ipynb | 3 ++- dynamax/hidden_markov_model/models/arhmm.py | 3 ++- dynamax/utils/distributions_test.py | 2 +- dynamax/utils/optimize.py | 3 ++- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/notebooks/slds/rbpf_maneuver.ipynb b/docs/notebooks/slds/rbpf_maneuver.ipynb index e7f45858..0e30070d 100644 --- a/docs/notebooks/slds/rbpf_maneuver.ipynb +++ b/docs/notebooks/slds/rbpf_maneuver.ipynb @@ -32,7 +32,8 @@ "from functools import partial\n", "import sys \n", "sys.path.append('/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax')\n", - "from jax import vmap, tree_map, jit\n", + "from jax import vmap, jit\n", + "from jax.tree_util import tree_map\n", "from dynamax.slds.inference import ParamsSLDS, LGParamsSLDS, DiscreteParamsSLDS, rbpfilter, rbpfilter_optimal\n", "from dynamax.slds.models import SLDS\n", "# import MVN from tfd\n", diff --git a/dynamax/hidden_markov_model/models/arhmm.py b/dynamax/hidden_markov_model/models/arhmm.py index 9eb5b2da..2eff832b 100644 --- a/dynamax/hidden_markov_model/models/arhmm.py +++ b/dynamax/hidden_markov_model/models/arhmm.py @@ -1,6 +1,7 @@ import jax.numpy as jnp import jax.random as jr -from jax import lax, tree_map +from jax import lax +from jax.tree_util import tree_map from jaxtyping import Float, Array from dynamax.hidden_markov_model.models.abstractions import HMM, HMMParameterSet, HMMPropertySet from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState diff --git a/dynamax/utils/distributions_test.py b/dynamax/utils/distributions_test.py index c4fe0fad..88355e19 100644 --- a/dynamax/utils/distributions_test.py +++ b/dynamax/utils/distributions_test.py @@ -1,7 +1,7 @@ import pytest import jax.numpy as jnp import jax.random as jr -from jax import tree_map +from jax.tree_util import tree_map from jax.scipy.stats import norm from scipy.stats import invgamma from tensorflow_probability.substrates import jax as tfp diff --git a/dynamax/utils/optimize.py b/dynamax/utils/optimize.py index a0c42bf2..d997111e 100644 --- a/dynamax/utils/optimize.py +++ b/dynamax/utils/optimize.py @@ -1,7 +1,8 @@ import jax.numpy as jnp import jax.random as jr import optax -from jax import lax, value_and_grad, tree_map +from jax import lax, value_and_grad +from jax.tree_util import tree_map from dynamax.utils.utils import pytree_len