Skip to content

Commit

Permalink
Add HMC sampling state
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Oct 7, 2024
1 parent db32421 commit 465d8ac
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 20 deletions.
34 changes: 28 additions & 6 deletions pymc/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,19 @@
from pymc.model import Point, modelcontext
from pymc.pytensorf import floatX
from pymc.stats.convergence import SamplerWarning, WarningType
from pymc.step_methods import step_sizes
from pymc.step_methods.arraystep import GradientSharedStep
from pymc.step_methods.compound import StepMethodState
from pymc.step_methods.hmc import integration
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
from pymc.step_methods.hmc.quadpotential import (
PotentialState,
QuadPotentialDiagAdapt,
quad_potential,
)
from pymc.step_methods.state import dataclass_state
from pymc.step_methods.step_sizes import DualAverageAdaptation, StepSizeState
from pymc.tuning import guess_scaling
from pymc.util import get_value_vars_from_user_vars
from pymc.util import RandomGenerator, get_random_generator, get_value_vars_from_user_vars

logger = logging.getLogger(__name__)

Expand All @@ -53,12 +58,27 @@ class HMCStepData(NamedTuple):
stats: dict[str, Any]


@dataclass_state
class BaseHMCState(StepMethodState):
adapt_step_size: bool
Emax: float
iter_count: int
step_size: np.ndarray
step_adapt: StepSizeState
target_accept: float
tune: bool
potential: PotentialState
_num_divs_sample: int


class BaseHMC(GradientSharedStep):
"""Superclass to implement Hamiltonian/hybrid monte carlo."""

integrator: integration.CpuLeapfrogIntegrator
default_blocked = True

_state_class = BaseHMCState

def __init__(
self,
vars=None,
Expand Down Expand Up @@ -134,9 +154,7 @@ def __init__(
size = sum(v.size for v in nuts_vars)

self.step_size = step_scale / (size**0.25)
self.step_adapt = step_sizes.DualAverageAdaptation(
self.step_size, target_accept, gamma, k, t0
)
self.step_adapt = DualAverageAdaptation(self.step_size, target_accept, gamma, k, t0)
self.target_accept = target_accept
self.tune = True

Expand Down Expand Up @@ -268,3 +286,7 @@ def reset_tuning(self, start=None):
def reset(self, start=None):
self.tune = True
self.potential.reset()

def set_rng(self, rng: RandomGenerator):
self.rng = get_random_generator(rng, copy=False)
self.potential.set_rng(self.rng.spawn(1)[0])
10 changes: 9 additions & 1 deletion pymc/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

from __future__ import annotations

from dataclasses import field
from typing import Any

import numpy as np

from pymc.stats.convergence import SamplerWarning
from pymc.step_methods.compound import Competence
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.step_methods.state import dataclass_state
from pymc.vartypes import discrete_types

__all__ = ["HamiltonianMC"]
Expand All @@ -31,6 +33,12 @@ def unif(step_size, elow=0.85, ehigh=1.15, rng: np.random.Generator | None = Non
return (rng or np.random).uniform(elow, ehigh) * step_size


@dataclass_state
class HamiltonianMCState(BaseHMCState):
path_length: float = field(metadata={"frozen": True})
max_steps: int = field(metadata={"frozen": True})


class HamiltonianMC(BaseHMC):
R"""A sampler for continuous variables based on Hamiltonian mechanics.
Expand Down
10 changes: 9 additions & 1 deletion pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from collections import namedtuple
from dataclasses import field

import numpy as np

Expand All @@ -23,13 +24,20 @@
from pymc.stats.convergence import SamplerWarning
from pymc.step_methods.compound import Competence
from pymc.step_methods.hmc import integration
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.step_methods.state import dataclass_state
from pymc.vartypes import continuous_types

__all__ = ["NUTS"]


@dataclass_state
class NUTSState(BaseHMCState):
max_treedepth: int = field(metadata={"frozen": True})
early_max_treedepth: int = field(metadata={"frozen": True})


class NUTS(BaseHMC):
r"""A sampler for continuous variables based on Hamiltonian mechanics.
Expand Down
123 changes: 112 additions & 11 deletions pymc/step_methods/hmc/quadpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import warnings

from typing import overload
from dataclasses import field
from typing import Any, overload

import numpy as np
import pytensor
Expand All @@ -25,6 +26,8 @@
from scipy.sparse import issparse

from pymc.pytensorf import floatX
from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state
from pymc.util import RandomGenerator, get_random_generator

__all__ = [
"quad_potential",
Expand Down Expand Up @@ -100,11 +103,18 @@ def __str__(self):
return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}."


class QuadPotential:
@dataclass_state
class PotentialState(DataClassState):
rng: np.random.Generator


class QuadPotential(WithSamplingState):
dtype: np.dtype

_state_class = PotentialState

def __init__(self, rng=None):
self.rng = np.random.default_rng(rng)
self.rng = get_random_generator(rng)

@overload
def velocity(self, x: np.ndarray, out: None) -> np.ndarray: ...
Expand Down Expand Up @@ -157,15 +167,42 @@ def reset(self):
def stats(self):
return {"largest_eigval": np.nan, "smallest_eigval": np.nan}

def set_rng(self, rng: RandomGenerator):
self.rng = get_random_generator(rng, copy=False)


def isquadpotential(value):
"""Check whether an object might be a QuadPotential object."""
return isinstance(value, QuadPotential)


@dataclass_state
class QuadPotentialDiagAdaptState(PotentialState):
_var: np.ndarray
_stds: np.ndarray
_inv_stds: np.ndarray
_foreground_var: WeightedVarianceState
_background_var: WeightedVarianceState
_n_samples: int
adaptation_window: int
_mass_trace: list[np.ndarray] | None

dtype: Any = field(metadata={"frozen": True})
_n: int = field(metadata={"frozen": True})
_discard_window: int = field(metadata={"frozen": True})
_early_update: int = field(metadata={"frozen": True})
_initial_mean: np.ndarray = field(metadata={"frozen": True})
_initial_diag: np.ndarray = field(metadata={"frozen": True})
_initial_weight: np.ndarray = field(metadata={"frozen": True})
adaptation_window_multiplier: float = field(metadata={"frozen": True})
_store_mass_matrix_trace: bool = field(metadata={"frozen": True})


class QuadPotentialDiagAdapt(QuadPotential):
"""Adapt a diagonal mass matrix from the sample variances."""

_state_class = QuadPotentialDiagAdaptState

def __init__(
self,
n,
Expand Down Expand Up @@ -346,9 +383,20 @@ def raise_ok(self, map_info):
raise ValueError("\n".join(errmsg))


class _WeightedVariance:
@dataclass_state
class WeightedVarianceState(DataClassState):
n_samples: int
mean: np.ndarray
raw_var: np.ndarray

_dtype: Any = field(metadata={"frozen": True})


class _WeightedVariance(WithSamplingState):
"""Online algorithm for computing mean of variance."""

_state_class = WeightedVarianceState

def __init__(
self, nelem, initial_mean=None, initial_variance=None, initial_weight=0, dtype="d"
):
Expand Down Expand Up @@ -390,7 +438,16 @@ def current_mean(self):
return self.mean.copy(dtype=self._dtype)


class _ExpWeightedVariance:
@dataclass_state
class ExpWeightedVarianceState(DataClassState):
_alpha: float
_mean: np.ndarray
_var: np.ndarray


class _ExpWeightedVariance(WithSamplingState):
_state_class = ExpWeightedVarianceState

def __init__(self, n_vars, *, init_mean, init_var, alpha):
self._variance = init_var
self._mean = init_mean
Expand All @@ -415,7 +472,18 @@ def current_mean(self, out=None):
return out


@dataclass_state
class QuadPotentialDiagAdaptExpState(QuadPotentialDiagAdaptState):
_alpha: float
_stop_adaptation: float
_variance_estimator: ExpWeightedVarianceState

_variance_estimator_grad: ExpWeightedVarianceState | None = None


class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt):
_state_class = QuadPotentialDiagAdaptExpState

def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, rng=None, **kwargs):
"""Set up a diagonal mass matrix.
Expand Down Expand Up @@ -526,7 +594,7 @@ def __init__(self, v, dtype=None, rng=None):
self.s = s
self.inv_s = 1.0 / s
self.v = v
self.rng = np.random.default_rng(rng)
self.rng = get_random_generator(rng)

def velocity(self, x, out=None):
"""Compute the current velocity at a position in parameter space."""
Expand Down Expand Up @@ -572,7 +640,7 @@ def __init__(self, A, dtype=None, rng=None):
dtype = pytensor.config.floatX
self.dtype = dtype
self.L = floatX(scipy.linalg.cholesky(A, lower=True))
self.rng = np.random.default_rng(rng)
self.rng = get_random_generator(rng)

def velocity(self, x, out=None):
"""Compute the current velocity at a position in parameter space."""
Expand Down Expand Up @@ -621,7 +689,7 @@ def __init__(self, cov, dtype=None, rng=None):
self._cov = np.array(cov, dtype=self.dtype, copy=True)
self._chol = scipy.linalg.cholesky(self._cov, lower=True)
self._n = len(self._cov)
self.rng = np.random.default_rng(rng)
self.rng = get_random_generator(rng)

def velocity(self, x, out=None):
"""Compute the current velocity at a position in parameter space."""
Expand All @@ -646,9 +714,31 @@ def velocity_energy(self, x, v_out):
__call__ = random


@dataclass_state
class QuadPotentialFullAdaptState(PotentialState):
_previous_update: int
_cov: np.ndarray
_chol: np.ndarray
_chol_error: scipy.linalg.LinAlgError | ValueError | None = None
_foreground_cov: WeightedCovarianceState
_background_cov: WeightedCovarianceState
_n_samples: int
adaptation_window: int

dtype: Any = field(metadata={"frozen": True})
_n: int = field(metadata={"frozen": True})
_update_window: int = field(metadata={"frozen": True})
_initial_mean: np.ndarray = field(metadata={"frozen": True})
_initial_cov: np.ndarray = field(metadata={"frozen": True})
_initial_weight: np.ndarray = field(metadata={"frozen": True})
adaptation_window_multiplier: float = field(metadata={"frozen": True})


class QuadPotentialFullAdapt(QuadPotentialFull):
"""Adapt a dense mass matrix using the sample covariances."""

_state_class = QuadPotentialFullAdaptState

def __init__(
self,
n,
Expand Down Expand Up @@ -689,7 +779,7 @@ def __init__(
self.adaptation_window_multiplier = float(adaptation_window_multiplier)
self._update_window = int(update_window)

self.rng = np.random.default_rng(rng)
self.rng = get_random_generator(rng)

self.reset()

Expand Down Expand Up @@ -742,7 +832,16 @@ def raise_ok(self, vmap):
raise ValueError(str(self._chol_error))


class _WeightedCovariance:
@dataclass_state
class WeightedCovarianceState(DataClassState):
n_samples: float
mean: np.ndarray
raw_cov: np.ndarray

_dtype: Any = field(metadata={"frozen": True})


class _WeightedCovariance(WithSamplingState):
"""Online algorithm for computing mean and covariance
This implements the `Welford's algorithm
Expand All @@ -752,6 +851,8 @@ class _WeightedCovariance:
"""

_state_class = WeightedCovarianceState

def __init__(
self,
nelem,
Expand Down Expand Up @@ -827,7 +928,7 @@ def __init__(self, A, rng=None):
self.size = A.shape[0]
self.factor = factor = cholmod.cholesky(A)
self.d_sqrt = np.sqrt(factor.D())
self.rng = np.random.default_rng(rng)
self.rng = get_random_generator(rng)

def velocity(self, x):
"""Compute the current velocity at a position in parameter space."""
Expand Down
Loading

0 comments on commit 465d8ac

Please sign in to comment.