From 253513bbc17914ad68921070ec90da44c8720808 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> Date: Thu, 29 Aug 2024 15:39:44 +0200 Subject: [PATCH] Make zip strict in `apply_function_over_dataset` --- pymc/backends/arviz.py | 2 +- tests/stats/test_log_density.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index a623cc0366..d1c27b787b 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -656,7 +656,7 @@ def apply_function_over_dataset( for idx in indices: out = fn(posterior_pts[idx]) fn.f.trust_input = True # If we arrive here the dtypes are valid - for var_name, val in zip(output_var_names, out): + for var_name, val in zip(output_var_names, out, strict=True): out_dict.insert(var_name, val, idx) progress.advance(task) diff --git a/tests/stats/test_log_density.py b/tests/stats/test_log_density.py index c7b120af25..5128913e88 100644 --- a/tests/stats/test_log_density.py +++ b/tests/stats/test_log_density.py @@ -184,9 +184,15 @@ def test_compilation_kwargs(self): Normal("y", x, observed=[0, 1, 2]) idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)})) - with patch("pymc.model.core.compile_pymc") as patched_compile_pymc: - compute_log_prior(idata, compile_kwargs={"mode": "JAX"}) - compute_log_likelihood(idata, compile_kwargs={"mode": "NUMBA"}) + with ( + # apply_function_over_dataset fails with patched `compile_pymc` + patch("pymc.stats.log_density.apply_function_over_dataset"), + patch("pymc.model.core.compile_pymc") as patched_compile_pymc, + ): + compute_log_prior(idata, compile_kwargs={"mode": "JAX"}, extend_inferencedata=False) + compute_log_likelihood( + idata, compile_kwargs={"mode": "NUMBA"}, extend_inferencedata=False + ) assert len(patched_compile_pymc.call_args_list) == 2 assert patched_compile_pymc.call_args_list[0].kwargs["mode"] == "JAX" assert patched_compile_pymc.call_args_list[1].kwargs["mode"] == "NUMBA"