diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py deleted file mode 100644 index 21b5961a9c..0000000000 --- a/pymc/sampling_jax.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2024 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file exists only for backward-compatibility with imports like -# `import pymc.sampling_jax` or `from pymc import sampling_jax`. - -import warnings - -warnings.warn("This module is deprecated, use pymc.sampling.jax", DeprecationWarning) -from pymc.sampling.jax import * # noqa: E402, F403 diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index e9db8e7000..d6a8d1021b 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -47,13 +47,6 @@ ) -def test_old_import_route(): - import pymc.sampling.jax as new_sj - import pymc.sampling_jax as old_sj - - assert set(new_sj.__all__) <= set(dir(old_sj)) - - def test_jax_PosDefMatrix(): x = pt.tensor(name="x", shape=(2, 2), dtype="float32") matrix_pos_def = PosDefMatrix()