From f2666dde1c8bfcc96718809e66d158f9d85e8f87 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 16 Jun 2024 12:32:45 +0200 Subject: [PATCH] Avoid code duplication --- mcbackend/__init__.py | 2 +- mcbackend/backends/null.py | 26 +++++++---------------- mcbackend/backends/numpy.py | 38 ++++++++++++++++++---------------- mcbackend/test_backend_null.py | 10 ++++++++- 4 files changed, 38 insertions(+), 38 deletions(-) diff --git a/mcbackend/__init__.py b/mcbackend/__init__.py index d7c2655..a0eb735 100644 --- a/mcbackend/__init__.py +++ b/mcbackend/__init__.py @@ -2,8 +2,8 @@ A framework agnostic implementation for storage of MCMC draws. """ -from .backends.numpy import NumPyBackend from .backends.null import NullBackend +from .backends.numpy import NumPyBackend from .core import Backend, Chain, Run from .meta import ChainMeta, Coordinate, DataVariable, ExtendedValue, RunMeta, Variable diff --git a/mcbackend/backends/null.py b/mcbackend/backends/null.py index f62fe29..03cb9b0 100644 --- a/mcbackend/backends/null.py +++ b/mcbackend/backends/null.py @@ -10,10 +10,10 @@ import numpy -from ..core import Backend, Chain, Run, is_rigid +from ..core import Backend, Chain, Run from ..meta import ChainMeta, RunMeta +from .numpy import grow_append, prepare_storage -from .numpy import grow_append class NullChain(Chain): """A null storage: discards values immediately and allocates no memory. @@ -52,26 +52,14 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non where the correct amount of memory cannot be pre-allocated. In these cases object arrays are used. """ - self._stat_is_rigid: Dict[str, bool] = {} - self._stats: Dict[str, numpy.ndarray] = {} self._draw_idx = 0 - # Create storage ndarrays for each model variable and sampler stat. - for target_dict, rigid_dict, variables in [ - (self._stats, self._stat_is_rigid, rmeta.sample_stats), - ]: - for var in variables: - rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str" - rigid_dict[var.name] = rigid - if rigid: - reserve = (preallocate, *var.shape) - target_dict[var.name] = numpy.empty(reserve, var.dtype) - else: - target_dict[var.name] = numpy.array([None] * preallocate, dtype=object) + # Create storage ndarrays only for sampler stats. + self._stats, self._stat_is_rigid = prepare_storage(rmeta.sample_stats, preallocate) super().__init__(cmeta, rmeta) - def append( + def append( # pylint: disable=duplicate-code self, draw: Mapping[str, numpy.ndarray], stats: Optional[Mapping[str, numpy.ndarray]] = None ): if stats: @@ -88,7 +76,9 @@ def get_draws(self, var_name: str, slc: slice = slice(None)) -> numpy.ndarray: def get_draws_at(self, idx: int, var_names: Sequence[str]) -> Dict[str, numpy.ndarray]: raise RuntimeError("NullChain does not save draws.") - def get_stats(self, stat_name: str, slc: slice = slice(None)) -> numpy.ndarray: + def get_stats( # pylint: disable=duplicate-code + self, stat_name: str, slc: slice = slice(None) + ) -> numpy.ndarray: data = self._stats[stat_name][: self._draw_idx][slc] if self.sample_stats[stat_name].dtype == "str": return numpy.array(data.tolist(), dtype=str) diff --git a/mcbackend/backends/numpy.py b/mcbackend/backends/numpy.py index 6245b4c..c4e17c8 100644 --- a/mcbackend/backends/numpy.py +++ b/mcbackend/backends/numpy.py @@ -3,12 +3,12 @@ """ import math -from typing import Dict, List, Mapping, Optional, Sequence, Tuple +from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple import numpy from ..core import Backend, Chain, Run, is_rigid -from ..meta import ChainMeta, RunMeta +from ..meta import ChainMeta, RunMeta, Variable def grow_append( @@ -34,6 +34,22 @@ def grow_append( return +def prepare_storage( + variables: Iterable[Variable], preallocate: int +) -> Tuple[Dict[str, numpy.ndarray], Dict[str, bool]]: + storage: Dict[str, numpy.ndarray] = {} + rigid_dict: Dict[str, bool] = {} + for var in variables: + rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str" + rigid_dict[var.name] = rigid + if rigid: + reserve = (preallocate, *var.shape) + storage[var.name] = numpy.empty(reserve, var.dtype) + else: + storage[var.name] = numpy.array([None] * preallocate, dtype=object) + return storage, rigid_dict + + class NumPyChain(Chain): """Stores value draws in NumPy arrays and can pre-allocate memory.""" @@ -54,25 +70,11 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non where the correct amount of memory cannot be pre-allocated. In these cases object arrays are used. """ - self._var_is_rigid: Dict[str, bool] = {} - self._samples: Dict[str, numpy.ndarray] = {} - self._stat_is_rigid: Dict[str, bool] = {} - self._stats: Dict[str, numpy.ndarray] = {} self._draw_idx = 0 # Create storage ndarrays for each model variable and sampler stat. - for target_dict, rigid_dict, variables in [ - (self._samples, self._var_is_rigid, rmeta.variables), - (self._stats, self._stat_is_rigid, rmeta.sample_stats), - ]: - for var in variables: - rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str" - rigid_dict[var.name] = rigid - if rigid: - reserve = (preallocate, *var.shape) - target_dict[var.name] = numpy.empty(reserve, var.dtype) - else: - target_dict[var.name] = numpy.array([None] * preallocate, dtype=object) + self._samples, self._var_is_rigid = prepare_storage(rmeta.variables, preallocate) + self._stats, self._stat_is_rigid = prepare_storage(rmeta.sample_stats, preallocate) super().__init__(cmeta, rmeta) diff --git a/mcbackend/test_backend_null.py b/mcbackend/test_backend_null.py index 043f4f2..033ab64 100644 --- a/mcbackend/test_backend_null.py +++ b/mcbackend/test_backend_null.py @@ -7,7 +7,13 @@ from mcbackend.backends.null import NullBackend, NullChain, NullRun from mcbackend.core import RunMeta, is_rigid from mcbackend.meta import Variable -from mcbackend.test_utils import CheckBehavior, CheckPerformance, make_runmeta, make_draw +from mcbackend.test_utils import ( + CheckBehavior, + CheckPerformance, + make_draw, + make_runmeta, +) + class CheckNullBehavior(CheckBehavior): """ @@ -152,6 +158,7 @@ def test__to_inferencedata(self): """ pass + class TestNullBackend(CheckNullBehavior, CheckPerformance): cls_backend = NullBackend cls_run = NullRun @@ -207,6 +214,7 @@ def test_growing(self, preallocate): # TODO: Check dimensions of stats array ? pass + if __name__ == "__main__": tc = TestNullBackend() df = tc.run_all_benchmarks()