Skip to content

Commit

Permalink
Add set serialisation & check for lists in attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
brynpickering committed Oct 23, 2023
1 parent 4069505 commit e2c3028
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 49 deletions.
102 changes: 76 additions & 26 deletions src/calliope/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,9 @@ def read_netcdf(path):
calliope_version, __version__
)
)
for attr in _pop_serialised_list(model_data.attrs, "serialised_dicts"):
model_data.attrs[attr] = AttrDict.from_yaml_string(model_data.attrs[attr])
for attr in _pop_serialised_list(model_data.attrs, "serialised_bools"):
model_data.attrs[attr] = bool(model_data.attrs[attr])
for attr in _pop_serialised_list(model_data.attrs, "serialised_nones"):
model_data.attrs[attr] = None
_deserialise(model_data.attrs)
for var in model_data.data_vars.values():
_deserialise(var.attrs)

# Convert empty strings back to np.NaN
# TODO: revert when this issue is solved: https://github.com/pydata/xarray/issues/1647
Expand All @@ -58,17 +55,82 @@ def read_netcdf(path):
return model_data


def _pop_serialised_list(attribute_dict, serialised_items):
serialised_ = attribute_dict.pop(serialised_items, [])
def _pop_serialised_list(
attrs: dict, serialised_items: Union[str, list, np.ndarray]
) -> Union[list, np.ndarray]:
serialised_ = attrs.pop(serialised_items, [])
if not isinstance(serialised_, (list, np.ndarray)):
return [serialised_]
else:
return serialised_


def _serialise(attrs: dict) -> None:
"""Convert troublesome datatypes to nicer ones in xarray attribute dictionaries.
This will tackle dictionaries (to string), booleans (to int), None (to string), and sets (to list).
Args:
attrs (dict):
Attribute dictionary from an xarray Dataset/DataArray.
Changes will be made in-place, so be sure to supply a copy of your dictionary if you want access to its original state.
"""
# Convert dicts attrs to yaml strings
dict_attrs = [k for k, v in attrs.items() if isinstance(v, dict)]
attrs["serialised_dicts"] = dict_attrs
for attr in dict_attrs:
attrs[attr] = AttrDict(attrs[attr]).to_yaml()

# Convert boolean attrs to ints
bool_attrs = [k for k, v in attrs.items() if isinstance(v, bool)]
attrs["serialised_bools"] = bool_attrs
for attr in bool_attrs:
attrs[attr] = int(attrs[attr])

# Convert None attrs to 'None'
none_attrs = [k for k, v in attrs.items() if v is None]
attrs["serialised_nones"] = none_attrs
for attr in none_attrs:
attrs[attr] = "None"

# Convert set attrs to lists
set_attrs = [k for k, v in attrs.items() if isinstance(v, set)]
for attr in set_attrs:
attrs[attr] = list(attrs[attr])

list_attrs = [k for k, v in attrs.items() if isinstance(v, list)]
for attr in list_attrs:
if any(not isinstance(i, str) for i in attrs[attr]):
raise TypeError(
f"Cannot serialise a sequence of values stored in a model attribute unless all values are strings, found: {attrs[attr]}"
)
else:
attrs["serialised_sets"] = set_attrs


def _deserialise(attrs: dict) -> None:
"""Convert troublesome datatypes in xarray attribute dictionaries from their stored data type to the data types expected by Calliope.
This will tackle dictionaries (from string), booleans (from int), None (form string), and sets (from list).
Args:
attrs (dict):
Attribute dictionary from an xarray Dataset/DataArray.
Changes will be made in-place, so be sure to supply a copy of your dictionary if you want access to its original state.
"""
for attr in _pop_serialised_list(attrs, "serialised_dicts"):
attrs[attr] = AttrDict.from_yaml_string(attrs[attr])
for attr in _pop_serialised_list(attrs, "serialised_bools"):
attrs[attr] = bool(attrs[attr])
for attr in _pop_serialised_list(attrs, "serialised_nones"):
attrs[attr] = None
for attr in _pop_serialised_list(attrs, "serialised_sets"):
attrs[attr] = set(attrs[attr])


def save_netcdf(model_data, path, model=None):
original_model_data_attrs = model_data.attrs
model_data_attrs = model_data.attrs.copy()
model_data_attrs = original_model_data_attrs.copy()

if model is not None and hasattr(model, "_model_run"):
# Attach _model_run and _debug_data to _model_data
Expand All @@ -79,23 +141,9 @@ def save_netcdf(model_data, path, model=None):
if hasattr(model, "_debug_data"):
model_data_attrs["_debug_data"] = model._debug_data.to_yaml()

# Convert dicts attrs to yaml strings
dict_attrs = [k for k, v in model_data_attrs.items() if isinstance(v, dict)]
model_data_attrs["serialised_dicts"] = dict_attrs
for k in dict_attrs:
model_data_attrs[k] = AttrDict(model_data_attrs[k]).to_yaml()

# Convert boolean attrs to ints
bool_attrs = [k for k, v in model_data_attrs.items() if isinstance(v, bool)]
model_data_attrs["serialised_bools"] = bool_attrs
for k in bool_attrs:
model_data_attrs[k] = int(model_data_attrs[k])

# Convert None attrs to 'None'
none_attrs = [k for k, v in model_data_attrs.items() if v is None]
model_data_attrs["serialised_nones"] = none_attrs
for k in none_attrs:
model_data_attrs[k] = "None"
_serialise(model_data_attrs)
for var in model_data.data_vars.values():
_serialise(var.attrs)

encoding = {
k: {"zlib": False, "_FillValue": None}
Expand All @@ -110,6 +158,8 @@ def save_netcdf(model_data, path, model=None):
model_data.close() # Force-close NetCDF file after writing
finally: # Revert model_data.attrs back
model_data.attrs = original_model_data_attrs
for var in model_data.data_vars.values():
_deserialise(var.attrs)


def save_csv(model_data, path, dropna=True):
Expand Down
73 changes: 50 additions & 23 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,33 @@
import pytest # noqa: F401
import xarray as xr
from calliope import exceptions
from calliope.core import io
from calliope.test.common.util import check_error_or_warning


class TestIO:
@pytest.fixture(scope="module")
def model(self):
def vars_to_add_attrs(self):
return ["resource", "energy_cap"]

@pytest.fixture(scope="module")
def model(self, vars_to_add_attrs):
model = calliope.examples.national_scale()
model._model_data = model._model_data.assign_attrs(
foo_true=True,
foo_false=False,
foo_none=None,
foo_dict={"foo": {"a": 1}},
foo_attrdict=calliope.AttrDict({"foo": {"a": 1}}),
)
attrs = {
"foo_true": True,
"foo_false": False,
"foo_none": None,
"foo_dict": {"foo": {"a": 1}},
"foo_attrdict": calliope.AttrDict({"foo": {"a": 1}}),
"foo_set": set(["foo", "bar"]),
}
model._model_data = model._model_data.assign_attrs(**attrs)
model.build()
model.solve()

for var in vars_to_add_attrs:
model._model_data[var] = model._model_data[var].assign_attrs(**attrs)

return model

@pytest.fixture(scope="module")
Expand Down Expand Up @@ -53,31 +65,33 @@ def test_save_netcdf(self, model_file):
("foo_none", type(None), None),
("foo_dict", dict, {"foo": {"a": 1}}),
("foo_attrdict", calliope.AttrDict, calliope.AttrDict({"foo": {"a": 1}})),
("foo_set", set, set(["foo", "bar"])),
],
)
@pytest.mark.parametrize("model_name", ["model", "model_from_file"])
def test_serialised_attrs(
self, request, attr, expected_type, expected_val, model_name
self, request, attr, expected_type, expected_val, model_name, vars_to_add_attrs
):
model = request.getfixturevalue(model_name)
# Ensure that boolean attrs have not changed

assert isinstance(model._model_data.attrs[attr], expected_type)
if expected_val is None:
assert model._model_data.attrs[attr] is None
else:
assert model._model_data.attrs[attr] == expected_val
var_attrs = [model._model_data[var].attrs for var in vars_to_add_attrs]
for attrs in [model._model_data.attrs, *var_attrs]:
assert isinstance(attrs[attr], expected_type)
if expected_val is None:
assert attrs[attr] is None
else:
assert attrs[attr] == expected_val

@pytest.mark.parametrize(
"serialised_list", ["serialised_bools", "serialised_nones", "serialised_dicts"]
"serialised_list",
["serialised_bools", "serialised_nones", "serialised_dicts", "serialised_sets"],
)
@pytest.mark.parametrize("model_name", ["model", "model_from_file"])
def test_serialised_list_popped(self, request, serialised_list, model_name):
model = request.getfixturevalue(model_name)
assert serialised_list not in model._model_data.attrs.keys()

@pytest.mark.parametrize(
["serialised_list", "list_elements"],
["serialisation_list_name", "list_elements"],
[
("serialised_bools", ["foo_true", "foo_false"]),
("serialised_nones", ["foo_none", "scenario"]),
Expand All @@ -92,14 +106,27 @@ def test_serialised_list_popped(self, request, serialised_list, model_name):
"math",
],
),
("serialised_sets", ["foo_set"]),
],
)
def test_serialised_list(
self, model_from_file_no_processing, serialised_list, list_elements
def test_serialisation_lists(
self, model_from_file_no_processing, serialisation_list_name, list_elements
):
assert not set(
model_from_file_no_processing.attrs[serialised_list]
).symmetric_difference(list_elements)
serialisation_list = io._pop_serialised_list(
model_from_file_no_processing.attrs, serialisation_list_name
)
assert not set(serialisation_list).symmetric_difference(list_elements)

@pytest.mark.parametrize(
"attrs", [{"foo": [1]}, {"foo": [None]}, {"foo": [1, "bar"]}, {"foo": set([1])}]
)
def test_non_strings_in_serialised_lists(self, attrs):
with pytest.raises(TypeError) as excinfo:
io._serialise(attrs)
assert check_error_or_warning(
excinfo,
f"Cannot serialise a sequence of values stored in a model attribute unless all values are strings, found: {attrs['foo']}",
)

def test_save_csv_dir_mustnt_exist(self, model):
with tempfile.TemporaryDirectory() as tempdir:
Expand Down

0 comments on commit e2c3028

Please sign in to comment.