Skip to content

Commit

Permalink
Make zip strict in apply_function_over_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 1, 2024
1 parent 90f20a2 commit 253513b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def apply_function_over_dataset(
for idx in indices:
out = fn(posterior_pts[idx])
fn.f.trust_input = True # If we arrive here the dtypes are valid
for var_name, val in zip(output_var_names, out):
for var_name, val in zip(output_var_names, out, strict=True):
out_dict.insert(var_name, val, idx)

progress.advance(task)
Expand Down
12 changes: 9 additions & 3 deletions tests/stats/test_log_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,15 @@ def test_compilation_kwargs(self):
Normal("y", x, observed=[0, 1, 2])

idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)}))
with patch("pymc.model.core.compile_pymc") as patched_compile_pymc:
compute_log_prior(idata, compile_kwargs={"mode": "JAX"})
compute_log_likelihood(idata, compile_kwargs={"mode": "NUMBA"})
with (
# apply_function_over_dataset fails with patched `compile_pymc`
patch("pymc.stats.log_density.apply_function_over_dataset"),
patch("pymc.model.core.compile_pymc") as patched_compile_pymc,
):
compute_log_prior(idata, compile_kwargs={"mode": "JAX"}, extend_inferencedata=False)
compute_log_likelihood(
idata, compile_kwargs={"mode": "NUMBA"}, extend_inferencedata=False
)
assert len(patched_compile_pymc.call_args_list) == 2
assert patched_compile_pymc.call_args_list[0].kwargs["mode"] == "JAX"
assert patched_compile_pymc.call_args_list[1].kwargs["mode"] == "NUMBA"

0 comments on commit 253513b

Please sign in to comment.