Skip to content

Commit

Permalink
support compressed multi-array snapshots
Browse files Browse the repository at this point in the history
  • Loading branch information
wpbonelli committed May 13, 2024
1 parent 027ddb1 commit 660364b
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
Binary file not shown.
13 changes: 13 additions & 0 deletions autotest/test_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,19 @@ def test_binary_array_snapshot(array_snapshot):
assert np.allclose(np.load(snapshot_path), snapshot_array)


def test_multi_array_snapshot(multi_array_snapshot):
arrays = {"ascending": snapshot_array, "descending": np.flip(snapshot_array)}
assert multi_array_snapshot == arrays
snapshot_path = (
snapshots_path
/ module_path.stem
/ f"{inspect.currentframe().f_code.co_name}.npz"
)
assert snapshot_path.is_file()
assert np.allclose(np.load(snapshot_path)["ascending"], snapshot_array)
assert np.allclose(np.load(snapshot_path)["descending"], np.flip(snapshot_array))


def test_text_array_snapshot(text_array_snapshot):
assert text_array_snapshot == snapshot_array
snapshot_path = (
Expand Down
44 changes: 39 additions & 5 deletions modflow_devtools/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
from shutil import copytree, rmtree
from typing import Dict, Generator, List, Optional

import numpy as np
import pytest
from modflow_devtools.imports import import_optional_dependency
from modflow_devtools.misc import get_namefile_paths, get_packages

np = import_optional_dependency("numpy")
pytest = import_optional_dependency("pytest")
syrupy = import_optional_dependency("syrupy")

# ruff: noqa: E402
from syrupy.extensions.single_file import (
SingleFileSnapshotExtension,
WriteMode,
Expand All @@ -19,14 +25,15 @@
SerializedData,
)

from modflow_devtools.misc import get_namefile_paths, get_packages

# snapshot extensions


def _serialize_bytes(data):
buffer = BytesIO()
np.save(buffer, data)
if isinstance(data, dict):
np.savez_compressed(buffer, **data)
else:
np.save(buffer, data)
return buffer.getvalue()


Expand All @@ -52,6 +59,28 @@ def serialize(
return _serialize_bytes(data)


class CompressedArrayExtension(SingleFileSnapshotExtension):
"""
Compressed snapshot of one or more NumPy arrays. Can be read back into
NumPy with .load(), preserving dtype and shape. Note that .load() will
return a dict mapping array names to arrays. Use this extension rather
than BinaryArrayExtension for tests requiring multiple array snapshots.
"""

_write_mode = WriteMode.BINARY
_file_extension = "npz"

def serialize(
self,
data,
*,
exclude=None,
include=None,
matcher=None,
):
return _serialize_bytes(data)


class TextArrayExtension(SingleFileSnapshotExtension):
"""
Text snapshot of a NumPy array. Flattens the array before writing.
Expand Down Expand Up @@ -181,6 +210,11 @@ def array_snapshot(snapshot):
return snapshot.use_extension(BinaryArrayExtension)


@pytest.fixture
def multi_array_snapshot(snapshot):
return snapshot.use_extension(CompressedArrayExtension)


@pytest.fixture
def text_array_snapshot(snapshot):
return snapshot.use_extension(TextArrayExtension)
Expand Down

0 comments on commit 660364b

Please sign in to comment.