Skip to content

Commit

Permalink
Avoid code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege committed Jun 16, 2024
1 parent 1540224 commit f2666dd
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 38 deletions.
2 changes: 1 addition & 1 deletion mcbackend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
A framework agnostic implementation for storage of MCMC draws.
"""

from .backends.numpy import NumPyBackend
from .backends.null import NullBackend
from .backends.numpy import NumPyBackend
from .core import Backend, Chain, Run
from .meta import ChainMeta, Coordinate, DataVariable, ExtendedValue, RunMeta, Variable

Expand Down
26 changes: 8 additions & 18 deletions mcbackend/backends/null.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

import numpy

from ..core import Backend, Chain, Run, is_rigid
from ..core import Backend, Chain, Run
from ..meta import ChainMeta, RunMeta
from .numpy import grow_append, prepare_storage

from .numpy import grow_append

class NullChain(Chain):
"""A null storage: discards values immediately and allocates no memory.
Expand Down Expand Up @@ -52,26 +52,14 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
where the correct amount of memory cannot be pre-allocated.
In these cases object arrays are used.
"""
self._stat_is_rigid: Dict[str, bool] = {}
self._stats: Dict[str, numpy.ndarray] = {}
self._draw_idx = 0

# Create storage ndarrays for each model variable and sampler stat.
for target_dict, rigid_dict, variables in [
(self._stats, self._stat_is_rigid, rmeta.sample_stats),
]:
for var in variables:
rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str"
rigid_dict[var.name] = rigid
if rigid:
reserve = (preallocate, *var.shape)
target_dict[var.name] = numpy.empty(reserve, var.dtype)
else:
target_dict[var.name] = numpy.array([None] * preallocate, dtype=object)
# Create storage ndarrays only for sampler stats.
self._stats, self._stat_is_rigid = prepare_storage(rmeta.sample_stats, preallocate)

super().__init__(cmeta, rmeta)

def append(
def append( # pylint: disable=duplicate-code
self, draw: Mapping[str, numpy.ndarray], stats: Optional[Mapping[str, numpy.ndarray]] = None
):
if stats:
Expand All @@ -88,7 +76,9 @@ def get_draws(self, var_name: str, slc: slice = slice(None)) -> numpy.ndarray:
def get_draws_at(self, idx: int, var_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
raise RuntimeError("NullChain does not save draws.")

def get_stats(self, stat_name: str, slc: slice = slice(None)) -> numpy.ndarray:
def get_stats( # pylint: disable=duplicate-code
self, stat_name: str, slc: slice = slice(None)
) -> numpy.ndarray:
data = self._stats[stat_name][: self._draw_idx][slc]
if self.sample_stats[stat_name].dtype == "str":
return numpy.array(data.tolist(), dtype=str)
Expand Down
38 changes: 20 additions & 18 deletions mcbackend/backends/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
"""

import math
from typing import Dict, List, Mapping, Optional, Sequence, Tuple
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple

import numpy

from ..core import Backend, Chain, Run, is_rigid
from ..meta import ChainMeta, RunMeta
from ..meta import ChainMeta, RunMeta, Variable


def grow_append(
Expand All @@ -34,6 +34,22 @@ def grow_append(
return


def prepare_storage(
variables: Iterable[Variable], preallocate: int
) -> Tuple[Dict[str, numpy.ndarray], Dict[str, bool]]:
storage: Dict[str, numpy.ndarray] = {}
rigid_dict: Dict[str, bool] = {}
for var in variables:
rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str"
rigid_dict[var.name] = rigid
if rigid:
reserve = (preallocate, *var.shape)
storage[var.name] = numpy.empty(reserve, var.dtype)
else:
storage[var.name] = numpy.array([None] * preallocate, dtype=object)
return storage, rigid_dict


class NumPyChain(Chain):
"""Stores value draws in NumPy arrays and can pre-allocate memory."""

Expand All @@ -54,25 +70,11 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
where the correct amount of memory cannot be pre-allocated.
In these cases object arrays are used.
"""
self._var_is_rigid: Dict[str, bool] = {}
self._samples: Dict[str, numpy.ndarray] = {}
self._stat_is_rigid: Dict[str, bool] = {}
self._stats: Dict[str, numpy.ndarray] = {}
self._draw_idx = 0

# Create storage ndarrays for each model variable and sampler stat.
for target_dict, rigid_dict, variables in [
(self._samples, self._var_is_rigid, rmeta.variables),
(self._stats, self._stat_is_rigid, rmeta.sample_stats),
]:
for var in variables:
rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str"
rigid_dict[var.name] = rigid
if rigid:
reserve = (preallocate, *var.shape)
target_dict[var.name] = numpy.empty(reserve, var.dtype)
else:
target_dict[var.name] = numpy.array([None] * preallocate, dtype=object)
self._samples, self._var_is_rigid = prepare_storage(rmeta.variables, preallocate)
self._stats, self._stat_is_rigid = prepare_storage(rmeta.sample_stats, preallocate)

super().__init__(cmeta, rmeta)

Expand Down
10 changes: 9 additions & 1 deletion mcbackend/test_backend_null.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from mcbackend.backends.null import NullBackend, NullChain, NullRun
from mcbackend.core import RunMeta, is_rigid
from mcbackend.meta import Variable
from mcbackend.test_utils import CheckBehavior, CheckPerformance, make_runmeta, make_draw
from mcbackend.test_utils import (
CheckBehavior,
CheckPerformance,
make_draw,
make_runmeta,
)


class CheckNullBehavior(CheckBehavior):
"""
Expand Down Expand Up @@ -152,6 +158,7 @@ def test__to_inferencedata(self):
"""
pass


class TestNullBackend(CheckNullBehavior, CheckPerformance):
cls_backend = NullBackend
cls_run = NullRun
Expand Down Expand Up @@ -207,6 +214,7 @@ def test_growing(self, preallocate):
# TODO: Check dimensions of stats array ?
pass


if __name__ == "__main__":
tc = TestNullBackend()
df = tc.run_all_benchmarks()
Expand Down

0 comments on commit f2666dd

Please sign in to comment.