From 35121e18bb9b67935fac98d0e20e920cda878181 Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 15 Mar 2024 16:54:50 -0400 Subject: [PATCH] encode nan, inf, -inf attr as strings --- .vscode/tasks.json | 11 ++++++- .vscode/tasks/quick_test.sh | 7 ++++ lindi/LindiH5Store/LindiH5Store.py | 51 +++++++++++++++++++++++++++--- pytest.ini | 5 ++- tests/test_core.py | 29 +++++++++++++++++ tests/test_with_real_data.py | 3 ++ 6 files changed, 100 insertions(+), 6 deletions(-) create mode 100755 .vscode/tasks/quick_test.sh diff --git a/.vscode/tasks.json b/.vscode/tasks.json index 8f03462..f7420d9 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -5,13 +5,22 @@ "version": "2.0.0", "tasks": [ { - "label": "Test", + "label": "Run tests", "type": "shell", "command": "bash -ic .vscode/tasks/test.sh", "presentation": { "clear": true }, "detail": "Run tests" + }, + { + "label": "Run quck tests", + "type": "shell", + "command": "bash -ic .vscode/tasks/quick_test.sh", + "presentation": { + "clear": true + }, + "detail": "Run quick tests" } ] } \ No newline at end of file diff --git a/.vscode/tasks/quick_test.sh b/.vscode/tasks/quick_test.sh new file mode 100755 index 0000000..6d8a84a --- /dev/null +++ b/.vscode/tasks/quick_test.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -ex + +# black --check . +flake8 . +# pyright +pytest --cov=lindi --cov-report=xml --cov-report=term -m "not slow" tests/ diff --git a/lindi/LindiH5Store/LindiH5Store.py b/lindi/LindiH5Store/LindiH5Store.py index 7e022de..ba58b08 100644 --- a/lindi/LindiH5Store/LindiH5Store.py +++ b/lindi/LindiH5Store/LindiH5Store.py @@ -75,9 +75,10 @@ def close(self): @staticmethod def from_file( - hdf5_file_name_or_url: str, *, + hdf5_file_name_or_url: str, + *, opts: LindiH5StoreOpts = LindiH5StoreOpts(), - url: Union[str, None] = None + url: Union[str, None] = None, ): """ Create a LindiH5Store from a file or url. @@ -348,7 +349,9 @@ def _get_external_array_link(self, parent_key: str, h5_item: h5py.Dataset): "name": parent_key, } else: - print(f'WARNING when creating external array link for {parent_key}: url is not set, so external array link will not work') + print( + f"WARNING when creating external array link for {parent_key}: url is not set, so external array link will not work" + ) return self._external_array_links[parent_key] def listdir(self, path: str = "") -> List[str]: @@ -484,4 +487,44 @@ def _get_chunk_names_for_dataset(chunk_coords_shape: List[int]) -> List[str]: def _reformat_json(x: Union[bytes, None]) -> Union[bytes, None]: if x is None: return None - return json.dumps(json.loads(x.decode("utf-8"))).encode("utf-8") + a = json.loads(x.decode("utf-8")) + return json.dumps(a, cls=FloatJSONEncoder).encode("utf-8") + + +# From https://github.com/rly/h5tojson/blob/b162ff7f61160a48f1dc0026acb09adafdb422fa/h5tojson/h5tojson.py#L121-L156 +class FloatJSONEncoder(json.JSONEncoder): + """JSON encoder that converts NaN, Inf, and -Inf to strings.""" + + def encode(self, obj, *args, **kwargs): # type: ignore + """Convert NaN, Inf, and -Inf to strings.""" + obj = FloatJSONEncoder._convert_nan(obj) + return super().encode(obj, *args, **kwargs) + + def iterencode(self, obj, *args, **kwargs): # type: ignore + """Convert NaN, Inf, and -Inf to strings.""" + obj = FloatJSONEncoder._convert_nan(obj) + return super().iterencode(obj, *args, **kwargs) + + @staticmethod + def _convert_nan(obj): + """Convert NaN, Inf, and -Inf from a JSON object to strings.""" + if isinstance(obj, dict): + return {k: FloatJSONEncoder._convert_nan(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [FloatJSONEncoder._convert_nan(v) for v in obj] + elif isinstance(obj, float): + return FloatJSONEncoder._nan_to_string(obj) + return obj + + @staticmethod + def _nan_to_string(obj: float): + """Convert NaN, Inf, and -Inf from a float to a string.""" + if np.isnan(obj): + return "NaN" + elif np.isinf(obj): + if obj > 0: + return "Infinity" + else: + return "-Infinity" + else: + return float(obj) diff --git a/pytest.ini b/pytest.ini index 82211d3..e6b231c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,7 @@ [pytest] addopts = --verbose log_cli = true -log_cli_level = INFO \ No newline at end of file +log_cli_level = INFO +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + network: marks tests as network (deselect with '-m "not network"') \ No newline at end of file diff --git a/tests/test_core.py b/tests/test_core.py index 773cd7d..ad83cd1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -154,6 +154,30 @@ def test_attributes(): raise ValueError("Attribute mismatch") +def test_nan_inf_attr(): + print("Testing NaN, Inf, and -Inf attributes") + with tempfile.TemporaryDirectory() as tmpdir: + filename = f"{tmpdir}/test.h5" + with h5py.File(filename, "w") as f: + f.create_dataset("X", data=[1, 2, 3]) + f["X"].attrs["nan"] = np.nan + f["X"].attrs["inf"] = np.inf + f["X"].attrs["ninf"] = -np.inf + h5f = h5py.File(filename, "r") + with LindiH5Store.from_file(filename, url=filename) as store: + rfs = store.to_reference_file_system() + client = LindiClient.from_reference_file_system(rfs) + + X1 = h5f["X"] + assert isinstance(X1, h5py.Dataset) + X2 = client["X"] + assert isinstance(X2, LindiDataset) + + assert X2.attrs["nan"] == 'NaN' + assert X2.attrs["inf"] == 'Infinity' + assert X2.attrs["ninf"] == '-Infinity' + + def _check_equal(a, b): # allow comparison of bytes and strings if isinstance(a, str): @@ -184,6 +208,11 @@ def _check_equal(a, b): assert isinstance(b, np.ndarray) return _check_arrays_equal(a, b) + # test for NaNs (we need to use np.isnan because NaN != NaN in python) + if isinstance(a, float) and isinstance(b, float): + if np.isnan(a) and np.isnan(b): + return True + return a == b diff --git a/tests/test_with_real_data.py b/tests/test_with_real_data.py index 9429be0..a3d88a5 100644 --- a/tests/test_with_real_data.py +++ b/tests/test_with_real_data.py @@ -5,6 +5,7 @@ import remfile from lindi import LindiH5Store, LindiClient import lindi +import pytest examples = [] @@ -272,6 +273,8 @@ def _hdf5_visit_items(item, callback): return +@pytest.mark.network +@pytest.mark.slow def test_with_real_data(): example_num = 0 example = examples[example_num]