Skip to content

Commit

Permalink
Add helper function tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brynpickering committed Oct 26, 2023
1 parent 8ec5296 commit 1e6dfad
Showing 1 changed file with 111 additions and 0 deletions.
111 changes: 111 additions & 0 deletions tests/test_backend_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def where_any(where, parsing_kwargs):
return where["any"](**parsing_kwargs)


@pytest.fixture(scope="class")
def where_defined(where, parsing_kwargs):
return where["defined"](**parsing_kwargs)


@pytest.fixture(scope="class")
def expression_sum(expression, parsing_kwargs):
return expression["sum"](**parsing_kwargs)
Expand Down Expand Up @@ -62,6 +67,28 @@ class TestAsArray:
def parsing_kwargs(self, dummy_model_data):
return {"model_data": dummy_model_data}

@pytest.fixture(scope="function")
def is_defined_any(self, dummy_model_data):
def _is_defined(drop_dims, dims):
return (
dummy_model_data.definition_matrix.any(drop_dims)
.sel(**dims)
.any(dims.keys())
)

return _is_defined

@pytest.fixture(scope="function")
def is_defined_all(self, dummy_model_data):
def _is_defined(drop_dims, dims):
return (
dummy_model_data.definition_matrix.any(drop_dims)
.sel(**dims)
.all(dims.keys())
)

return _is_defined

@pytest.mark.parametrize(
["string_type", "func_name"], [("where", "inheritance"), ("expression", "sum")]
)
Expand Down Expand Up @@ -110,6 +137,49 @@ def test_any_exists(self, where_any, dummy_model_data, var, over, expected):
summed = where_any(var, over=over)
assert summed.equals(dummy_model_data[expected])

def test_defined_any_one_dim_one_val(self, is_defined_any, where_defined):
dims = {"techs": "foobar"}
dims_check = {"techs": ["foobar"]}
defined = where_defined(within="nodes", how="any", **dims)
assert defined.equals(is_defined_any(["carriers", "carrier_tiers"], dims_check))
assert defined.dtype.kind == "b"

def test_defined_any_two_dim_one_val(self, is_defined_any, where_defined):
dims = {"techs": "foobar", "carriers": "foo"}
dims_check = {"techs": ["foobar"], "carriers": ["foo"]}
defined = where_defined(within="nodes", how="any", **dims)
assert defined.equals(is_defined_any(["carrier_tiers"], dims_check))

def test_defined_any_one_dim_multi_val(self, is_defined_any, where_defined):
dims = {"techs": ["foobar", "foobaz"]}
defined = where_defined(within="nodes", how="any", **dims)
assert defined.equals(is_defined_any(["carriers", "carrier_tiers"], dims))
assert defined.dtype.kind == "b"

def test_defined_any_one_dim_multi_val_techs_within(
self, is_defined_any, where_defined
):
dims = {"carriers": ["foo", "bar"]}
defined = where_defined(within="techs", how="any", **dims)
assert defined.equals(is_defined_any(["nodes", "carrier_tiers"], dims))

def test_defined_any_two_dim_multi_val(self, is_defined_any, where_defined):
dims = {"techs": ["foobar", "foobaz"], "carriers": ["foo", "bar"]}
defined = where_defined(within="nodes", how="any", **dims)
assert defined.equals(is_defined_any(["carrier_tiers"], dims))
assert defined.dtype.kind == "b"

def test_defined_all_one_dim_one_val(self, is_defined_all, where_defined):
dims = {"techs": ["foobar"]}
defined = where_defined(within="nodes", how="all", **dims)
assert defined.equals(is_defined_all(["carriers", "carrier_tiers"], dims))
assert defined.dtype.kind == "b"

def test_defined_all_two_dim_one_val(self, is_defined_all, where_defined):
dims = {"techs": ["foobar"], "carriers": ["foo"]}
defined = where_defined(within="nodes", how="all", **dims)
assert defined.equals(is_defined_all(["carrier_tiers"], dims))

@pytest.mark.parametrize("over", ["techs", ["techs"]])
def test_sum_one_dim(self, expression_sum, dummy_model_data, over):
summed_array = expression_sum(dummy_model_data.only_techs, over=over)
Expand Down Expand Up @@ -247,6 +317,47 @@ def test_any_not_exists(self, where_any):
summed_string = where_any("foo", over="techs")
assert summed_string == r"\bigvee\limits_{\text{tech} \in \text{techs}} (foo)"

def test_defined_any(self, where_defined):
defined_string = where_defined(within="nodes", how="any", techs="foobar")
assert (
defined_string
== r"\bigvee\limits_{\substack{\text{tech} \in \text{[foobar]}}}\text{tech defined in node}"
)

def test_defined_any_multi_val(self, where_defined):
defined_string = where_defined(
within="nodes", how="any", techs=["foobar", "foobaz"]
)
assert (
defined_string
== r"\bigvee\limits_{\substack{\text{tech} \in \text{[foobar,foobaz]}}}\text{tech defined in node}"
)

def test_defined_any_multi_dim(self, where_defined):
defined_string = where_defined(
within="nodes", how="any", techs="foobar", carriers="foo"
)
assert (
defined_string
== r"\bigwedge(\bigvee\limits_{\substack{\text{tech} \in \text{[foobar]}}}\text{tech defined in node}, \bigvee\limits_{\substack{\text{carrier} \in \text{[foo]}}}\text{carrier defined in node})"
)

def test_defined_all(self, where_defined):
defined_string = where_defined(within="nodes", how="all", techs="foobar")
assert (
defined_string
== r"\bigwedge\limits_{\substack{\text{tech} \in \text{[foobar]}}}\text{tech defined in node}"
)

def test_defined_all_multi_dim(self, where_defined):
defined_string = where_defined(
within="nodes", how="all", techs="foobar", carriers="foo"
)
assert (
defined_string
== r"\bigwedge(\bigwedge\limits_{\substack{\text{tech} \in \text{[foobar]}}}\text{tech defined in node}, \bigwedge\limits_{\substack{\text{carrier} \in \text{[foo]}}}\text{carrier defined in node})"
)

@pytest.mark.parametrize(
["over", "expected_substring"],
[
Expand Down

0 comments on commit 1e6dfad

Please sign in to comment.