From 465d8ac1675f65f28051423bd7fd84dc5cd2a715 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Thu, 19 Sep 2024 09:55:46 +0200 Subject: [PATCH] Add HMC sampling state --- pymc/step_methods/hmc/base_hmc.py | 34 +++++-- pymc/step_methods/hmc/hmc.py | 10 +- pymc/step_methods/hmc/nuts.py | 10 +- pymc/step_methods/hmc/quadpotential.py | 123 ++++++++++++++++++++++--- pymc/step_methods/step_sizes.py | 21 ++++- 5 files changed, 178 insertions(+), 20 deletions(-) diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index b320ed81944..87daff649cf 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -27,14 +27,19 @@ from pymc.model import Point, modelcontext from pymc.pytensorf import floatX from pymc.stats.convergence import SamplerWarning, WarningType -from pymc.step_methods import step_sizes from pymc.step_methods.arraystep import GradientSharedStep from pymc.step_methods.compound import StepMethodState from pymc.step_methods.hmc import integration from pymc.step_methods.hmc.integration import IntegrationError, State -from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential +from pymc.step_methods.hmc.quadpotential import ( + PotentialState, + QuadPotentialDiagAdapt, + quad_potential, +) +from pymc.step_methods.state import dataclass_state +from pymc.step_methods.step_sizes import DualAverageAdaptation, StepSizeState from pymc.tuning import guess_scaling -from pymc.util import get_value_vars_from_user_vars +from pymc.util import RandomGenerator, get_random_generator, get_value_vars_from_user_vars logger = logging.getLogger(__name__) @@ -53,12 +58,27 @@ class HMCStepData(NamedTuple): stats: dict[str, Any] +@dataclass_state +class BaseHMCState(StepMethodState): + adapt_step_size: bool + Emax: float + iter_count: int + step_size: np.ndarray + step_adapt: StepSizeState + target_accept: float + tune: bool + potential: PotentialState + _num_divs_sample: int + + class BaseHMC(GradientSharedStep): """Superclass to implement Hamiltonian/hybrid monte carlo.""" integrator: integration.CpuLeapfrogIntegrator default_blocked = True + _state_class = BaseHMCState + def __init__( self, vars=None, @@ -134,9 +154,7 @@ def __init__( size = sum(v.size for v in nuts_vars) self.step_size = step_scale / (size**0.25) - self.step_adapt = step_sizes.DualAverageAdaptation( - self.step_size, target_accept, gamma, k, t0 - ) + self.step_adapt = DualAverageAdaptation(self.step_size, target_accept, gamma, k, t0) self.target_accept = target_accept self.tune = True @@ -268,3 +286,7 @@ def reset_tuning(self, start=None): def reset(self, start=None): self.tune = True self.potential.reset() + + def set_rng(self, rng: RandomGenerator): + self.rng = get_random_generator(rng, copy=False) + self.potential.set_rng(self.rng.spawn(1)[0]) diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 106faee501d..a5ebbd7a8c1 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -14,14 +14,16 @@ from __future__ import annotations +from dataclasses import field from typing import Any import numpy as np from pymc.stats.convergence import SamplerWarning from pymc.step_methods.compound import Competence -from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData +from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData from pymc.step_methods.hmc.integration import IntegrationError, State +from pymc.step_methods.state import dataclass_state from pymc.vartypes import discrete_types __all__ = ["HamiltonianMC"] @@ -31,6 +33,12 @@ def unif(step_size, elow=0.85, ehigh=1.15, rng: np.random.Generator | None = Non return (rng or np.random).uniform(elow, ehigh) * step_size +@dataclass_state +class HamiltonianMCState(BaseHMCState): + path_length: float = field(metadata={"frozen": True}) + max_steps: int = field(metadata={"frozen": True}) + + class HamiltonianMC(BaseHMC): R"""A sampler for continuous variables based on Hamiltonian mechanics. diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 3c4b4e68003..9bcde951041 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections import namedtuple +from dataclasses import field import numpy as np @@ -23,13 +24,20 @@ from pymc.stats.convergence import SamplerWarning from pymc.step_methods.compound import Competence from pymc.step_methods.hmc import integration -from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData +from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData from pymc.step_methods.hmc.integration import IntegrationError, State +from pymc.step_methods.state import dataclass_state from pymc.vartypes import continuous_types __all__ = ["NUTS"] +@dataclass_state +class NUTSState(BaseHMCState): + max_treedepth: int = field(metadata={"frozen": True}) + early_max_treedepth: int = field(metadata={"frozen": True}) + + class NUTS(BaseHMC): r"""A sampler for continuous variables based on Hamiltonian mechanics. diff --git a/pymc/step_methods/hmc/quadpotential.py b/pymc/step_methods/hmc/quadpotential.py index abddaaf35f1..05da188f9b3 100644 --- a/pymc/step_methods/hmc/quadpotential.py +++ b/pymc/step_methods/hmc/quadpotential.py @@ -16,7 +16,8 @@ import warnings -from typing import overload +from dataclasses import field +from typing import Any, overload import numpy as np import pytensor @@ -25,6 +26,8 @@ from scipy.sparse import issparse from pymc.pytensorf import floatX +from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state +from pymc.util import RandomGenerator, get_random_generator __all__ = [ "quad_potential", @@ -100,11 +103,18 @@ def __str__(self): return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}." -class QuadPotential: +@dataclass_state +class PotentialState(DataClassState): + rng: np.random.Generator + + +class QuadPotential(WithSamplingState): dtype: np.dtype + _state_class = PotentialState + def __init__(self, rng=None): - self.rng = np.random.default_rng(rng) + self.rng = get_random_generator(rng) @overload def velocity(self, x: np.ndarray, out: None) -> np.ndarray: ... @@ -157,15 +167,42 @@ def reset(self): def stats(self): return {"largest_eigval": np.nan, "smallest_eigval": np.nan} + def set_rng(self, rng: RandomGenerator): + self.rng = get_random_generator(rng, copy=False) + def isquadpotential(value): """Check whether an object might be a QuadPotential object.""" return isinstance(value, QuadPotential) +@dataclass_state +class QuadPotentialDiagAdaptState(PotentialState): + _var: np.ndarray + _stds: np.ndarray + _inv_stds: np.ndarray + _foreground_var: WeightedVarianceState + _background_var: WeightedVarianceState + _n_samples: int + adaptation_window: int + _mass_trace: list[np.ndarray] | None + + dtype: Any = field(metadata={"frozen": True}) + _n: int = field(metadata={"frozen": True}) + _discard_window: int = field(metadata={"frozen": True}) + _early_update: int = field(metadata={"frozen": True}) + _initial_mean: np.ndarray = field(metadata={"frozen": True}) + _initial_diag: np.ndarray = field(metadata={"frozen": True}) + _initial_weight: np.ndarray = field(metadata={"frozen": True}) + adaptation_window_multiplier: float = field(metadata={"frozen": True}) + _store_mass_matrix_trace: bool = field(metadata={"frozen": True}) + + class QuadPotentialDiagAdapt(QuadPotential): """Adapt a diagonal mass matrix from the sample variances.""" + _state_class = QuadPotentialDiagAdaptState + def __init__( self, n, @@ -346,9 +383,20 @@ def raise_ok(self, map_info): raise ValueError("\n".join(errmsg)) -class _WeightedVariance: +@dataclass_state +class WeightedVarianceState(DataClassState): + n_samples: int + mean: np.ndarray + raw_var: np.ndarray + + _dtype: Any = field(metadata={"frozen": True}) + + +class _WeightedVariance(WithSamplingState): """Online algorithm for computing mean of variance.""" + _state_class = WeightedVarianceState + def __init__( self, nelem, initial_mean=None, initial_variance=None, initial_weight=0, dtype="d" ): @@ -390,7 +438,16 @@ def current_mean(self): return self.mean.copy(dtype=self._dtype) -class _ExpWeightedVariance: +@dataclass_state +class ExpWeightedVarianceState(DataClassState): + _alpha: float + _mean: np.ndarray + _var: np.ndarray + + +class _ExpWeightedVariance(WithSamplingState): + _state_class = ExpWeightedVarianceState + def __init__(self, n_vars, *, init_mean, init_var, alpha): self._variance = init_var self._mean = init_mean @@ -415,7 +472,18 @@ def current_mean(self, out=None): return out +@dataclass_state +class QuadPotentialDiagAdaptExpState(QuadPotentialDiagAdaptState): + _alpha: float + _stop_adaptation: float + _variance_estimator: ExpWeightedVarianceState + + _variance_estimator_grad: ExpWeightedVarianceState | None = None + + class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt): + _state_class = QuadPotentialDiagAdaptExpState + def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, rng=None, **kwargs): """Set up a diagonal mass matrix. @@ -526,7 +594,7 @@ def __init__(self, v, dtype=None, rng=None): self.s = s self.inv_s = 1.0 / s self.v = v - self.rng = np.random.default_rng(rng) + self.rng = get_random_generator(rng) def velocity(self, x, out=None): """Compute the current velocity at a position in parameter space.""" @@ -572,7 +640,7 @@ def __init__(self, A, dtype=None, rng=None): dtype = pytensor.config.floatX self.dtype = dtype self.L = floatX(scipy.linalg.cholesky(A, lower=True)) - self.rng = np.random.default_rng(rng) + self.rng = get_random_generator(rng) def velocity(self, x, out=None): """Compute the current velocity at a position in parameter space.""" @@ -621,7 +689,7 @@ def __init__(self, cov, dtype=None, rng=None): self._cov = np.array(cov, dtype=self.dtype, copy=True) self._chol = scipy.linalg.cholesky(self._cov, lower=True) self._n = len(self._cov) - self.rng = np.random.default_rng(rng) + self.rng = get_random_generator(rng) def velocity(self, x, out=None): """Compute the current velocity at a position in parameter space.""" @@ -646,9 +714,31 @@ def velocity_energy(self, x, v_out): __call__ = random +@dataclass_state +class QuadPotentialFullAdaptState(PotentialState): + _previous_update: int + _cov: np.ndarray + _chol: np.ndarray + _chol_error: scipy.linalg.LinAlgError | ValueError | None = None + _foreground_cov: WeightedCovarianceState + _background_cov: WeightedCovarianceState + _n_samples: int + adaptation_window: int + + dtype: Any = field(metadata={"frozen": True}) + _n: int = field(metadata={"frozen": True}) + _update_window: int = field(metadata={"frozen": True}) + _initial_mean: np.ndarray = field(metadata={"frozen": True}) + _initial_cov: np.ndarray = field(metadata={"frozen": True}) + _initial_weight: np.ndarray = field(metadata={"frozen": True}) + adaptation_window_multiplier: float = field(metadata={"frozen": True}) + + class QuadPotentialFullAdapt(QuadPotentialFull): """Adapt a dense mass matrix using the sample covariances.""" + _state_class = QuadPotentialFullAdaptState + def __init__( self, n, @@ -689,7 +779,7 @@ def __init__( self.adaptation_window_multiplier = float(adaptation_window_multiplier) self._update_window = int(update_window) - self.rng = np.random.default_rng(rng) + self.rng = get_random_generator(rng) self.reset() @@ -742,7 +832,16 @@ def raise_ok(self, vmap): raise ValueError(str(self._chol_error)) -class _WeightedCovariance: +@dataclass_state +class WeightedCovarianceState(DataClassState): + n_samples: float + mean: np.ndarray + raw_cov: np.ndarray + + _dtype: Any = field(metadata={"frozen": True}) + + +class _WeightedCovariance(WithSamplingState): """Online algorithm for computing mean and covariance This implements the `Welford's algorithm @@ -752,6 +851,8 @@ class _WeightedCovariance: """ + _state_class = WeightedCovarianceState + def __init__( self, nelem, @@ -827,7 +928,7 @@ def __init__(self, A, rng=None): self.size = A.shape[0] self.factor = factor = cholmod.cholesky(A) self.d_sqrt = np.sqrt(factor.D()) - self.rng = np.random.default_rng(rng) + self.rng = get_random_generator(rng) def velocity(self, x): """Compute the current velocity at a position in parameter space.""" diff --git a/pymc/step_methods/step_sizes.py b/pymc/step_methods/step_sizes.py index 6c2b7340fdf..c0fdb934a36 100644 --- a/pymc/step_methods/step_sizes.py +++ b/pymc/step_methods/step_sizes.py @@ -12,14 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. + import numpy as np from scipy import stats from pymc.stats.convergence import SamplerWarning, WarningType +from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state + + +@dataclass_state +class StepSizeState(DataClassState): + _log_step: np.ndarray + _log_bar: np.ndarray + _hbar: float + _count: int + _mu: np.ndarray + _tuned_stats: list + _initial_step: np.ndarray + _target: float + _k: float + _t0: float + _gamma: float + +class DualAverageAdaptation(WithSamplingState): + _state_class = StepSizeState -class DualAverageAdaptation: def __init__(self, initial_step, target, gamma, k, t0): self._initial_step = initial_step self._target = target