Skip to content

Commit

Permalink
Integrate ZarrTrace into pymc.sample
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Oct 24, 2024
1 parent a3b5d57 commit a890ddb
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 7 deletions.
15 changes: 14 additions & 1 deletion pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
from pymc.backends.base import BaseTrace, IBaseTrace
from pymc.backends.ndarray import NDArray
from pymc.backends.zarr import ZarrTrace
from pymc.model import Model
from pymc.step_methods.compound import BlockedStep, CompoundStep

Expand Down Expand Up @@ -118,15 +119,27 @@ def _init_trace(

def init_traces(
*,
backend: TraceOrBackend | None,
backend: TraceOrBackend | ZarrTrace | None,
chains: int,
expected_length: int,
step: BlockedStep | CompoundStep,
initial_point: Mapping[str, np.ndarray],
model: Model,
trace_vars: list[TensorVariable] | None = None,
tune: int = 0,
) -> tuple[RunType | None, Sequence[IBaseTrace]]:
"""Initialize a trace recorder for each chain."""
if isinstance(backend, ZarrTrace):
backend.init_trace(

Check warning on line 133 in pymc/backends/__init__.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/__init__.py#L133

Added line #L133 was not covered by tests
chains=chains,
draws=expected_length - tune,
tune=tune,
step=step,
model=model,
vars=trace_vars,
test_point=initial_point,
)
return None, backend.straces

Check warning on line 142 in pymc/backends/__init__.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/__init__.py#L142

Added line #L142 was not covered by tests
if HAS_MCB and isinstance(backend, Backend):
return init_chain_adapters(
backend=backend,
Expand Down
65 changes: 59 additions & 6 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
find_observations,
)
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
from pymc.backends.zarr import ZarrTrace
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
Expand Down Expand Up @@ -475,7 +476,7 @@ def sample(
blas_cores: int | None | Literal["auto"] = "auto",
model: Model | None = None,
**kwargs,
) -> InferenceData | MultiTrace:
) -> InferenceData | MultiTrace | ZarrTrace:
r"""Draw samples from the posterior using the given step methods.
Multiple step methods are supported via compound step methods.
Expand Down Expand Up @@ -808,6 +809,7 @@ def joined_blas_limiter():
trace_vars=trace_vars,
initial_point=ip,
model=model,
tune=tune,
)

sample_args = {
Expand Down Expand Up @@ -890,7 +892,7 @@ def joined_blas_limiter():
# into a function to make it easier to test and refactor.
return _sample_return(
run=run,
traces=traces,
traces=trace if isinstance(trace, ZarrTrace) else traces,
tune=tune,
t_sampling=t_sampling,
discard_tuned_samples=discard_tuned_samples,
Expand All @@ -905,7 +907,7 @@ def joined_blas_limiter():
def _sample_return(
*,
run: RunType | None,
traces: Sequence[IBaseTrace],
traces: Sequence[IBaseTrace] | ZarrTrace,
tune: int,
t_sampling: float,
discard_tuned_samples: bool,
Expand All @@ -914,18 +916,70 @@ def _sample_return(
keep_warning_stat: bool,
idata_kwargs: dict[str, Any],
model: Model,
) -> InferenceData | MultiTrace:
) -> InferenceData | MultiTrace | ZarrTrace:
"""Pick/slice chains, run diagnostics and convert to the desired return type.
Final step of `pm.sampler`.
"""
if isinstance(traces, ZarrTrace):
# Split warmup from posterior samples
traces.split_warmup_groups()

Check warning on line 926 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L926

Added line #L926 was not covered by tests

# Set sampling time
traces._sampling_state.sampling_time[:] = t_sampling

Check warning on line 929 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L929

Added line #L929 was not covered by tests

# Compute number of actual draws per chain
total_draws_per_chain = traces._sampling_state.draw_idx[:]
n_chains = len(traces.straces)
desired_tune = traces.tuning_steps
desired_draw = len(traces.posterior.draw)
tuning_steps_per_chain = np.clip(total_draws_per_chain, 0, desired_tune)
draws_per_chain = total_draws_per_chain - tuning_steps_per_chain

Check warning on line 937 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L932-L937

Added lines #L932 - L937 were not covered by tests

total_n_tune = tuning_steps_per_chain.sum()
total_draws = draws_per_chain.sum()

Check warning on line 940 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L939-L940

Added lines #L939 - L940 were not covered by tests

_log.info(

Check warning on line 942 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L942

Added line #L942 was not covered by tests
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations '
f"(Actually sampled {total_n_tune:_d} tune and {total_draws:_d} draws total) "
f"took {t_sampling:.0f} seconds."
)

if compute_convergence_checks or return_inferencedata:
idata = traces.to_inferencedata(save_warmup=not discard_tuned_samples)
log_likelihood = idata_kwargs.pop("log_likelihood", False)
if log_likelihood:
from pymc.stats.log_density import compute_log_likelihood

Check warning on line 952 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L948-L952

Added lines #L948 - L952 were not covered by tests

idata = compute_log_likelihood(

Check warning on line 954 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L954

Added line #L954 was not covered by tests
idata,
var_names=None if log_likelihood is True else log_likelihood,
extend_inferencedata=True,
model=model,
sample_dims=["chain", "draw"],
progressbar=False,
)

if compute_convergence_checks:
warns = run_convergence_checks(idata, model)
for warn in warns:
traces._sampling_state.global_warnings.append(warn)
log_warnings(warns)

Check warning on line 967 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L963-L967

Added lines #L963 - L967 were not covered by tests

if return_inferencedata:

Check warning on line 969 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L969

Added line #L969 was not covered by tests
# By default we drop the "warning" stat which contains `SamplerWarning`
# objects that can not be stored with `.to_netcdf()`.
if not keep_warning_stat:
return drop_warning_stat(idata)
return idata
return traces

Check warning on line 975 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L972-L975

Added lines #L972 - L975 were not covered by tests

# Pick and slice chains to keep the maximum number of samples
if discard_tuned_samples:
traces, length = _choose_chains(traces, tune)
else:
traces, length = _choose_chains(traces, 0)
mtrace = MultiTrace(traces)[:length]

# count the number of tune/draw iterations that happened
# ideally via the "tune" statistic, but not all samplers record it!
if "tune" in mtrace.stat_names:
Expand Down Expand Up @@ -954,7 +1008,6 @@ def _sample_return(
f"took {t_sampling:.0f} seconds."
)

idata = None
if compute_convergence_checks or return_inferencedata:
ikwargs: dict[str, Any] = {"model": model, "save_warmup": not discard_tuned_samples}
ikwargs.update(idata_kwargs)
Expand Down

0 comments on commit a890ddb

Please sign in to comment.