Skip to content

Commit

Permalink
Clean parsers (#506)
Browse files Browse the repository at this point in the history
*Move to an abstract base class for our parsers, so they have the ability to output an xarray dataarray or a string (for latex math), and anything downstream can expect one of those types.
* Move constraint expression setting to the expression parsing, so that we get one dataarray (of comparison expressions) instead of an lhs/rhs. Keeps the data structure in line with the other optimisation problem components.
* Rename `as_latex` to `as_math_string` / `return_type="math_string"`.
  • Loading branch information
brynpickering authored Oct 30, 2023
1 parent 45436c9 commit ea55acc
Show file tree
Hide file tree
Showing 15 changed files with 1,103 additions and 1,167 deletions.
28 changes: 17 additions & 11 deletions src/calliope/backend/backend_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(self, inputs: xr.Dataset, **config_overrides):
self.inputs.attrs["config"].build.union(
AttrDict(config_overrides), allow_override=True
)
self.valid_math_element_names: set = set()
self._solve_logger = logging.getLogger(__name__ + ".<solve>")

@abstractmethod
Expand Down Expand Up @@ -197,10 +196,6 @@ def _build(self) -> None:
"objectives",
]:
component = components.removesuffix("s")
if components in ["variables", "global_expressions"]:
self.valid_math_element_names.update(
self.inputs.math[components].keys()
)
for name in self.inputs.math[components]:
getattr(self, f"add_{component}")(name)
LOGGER.info(
Expand Down Expand Up @@ -275,8 +270,7 @@ def _add_component(
)

top_level_where = parsed_component.generate_top_level_where_array(
self.inputs,
self._dataset,
self,
align_to_foreach_sets=False,
break_early=break_early,
)
Expand All @@ -285,7 +279,7 @@ def _add_component(

self._create_obj_list(name, component_type)

equations = parsed_component.parse_equations(self.valid_math_element_names)
equations = parsed_component.parse_equations(self.valid_component_names)
if not equations:
component_da = component_setter(
parsed_component.drop_dims_not_in_foreach(top_level_where)
Expand All @@ -297,9 +291,7 @@ def _add_component(
.astype(np.dtype("O"))
)
for element in equations:
where = element.evaluate_where(
self.inputs, self._dataset, initial_where=top_level_where
)
where = element.evaluate_where(self, initial_where=top_level_where)
if break_early and not where.any():
continue

Expand Down Expand Up @@ -462,6 +454,7 @@ def _apply_func(
kwargs=kwargs,
vectorize=True,
keep_attrs=True,
dask="parallelized",
output_dtypes=[np.dtype("O")],
output_core_dims=output_core_dims,
)
Expand Down Expand Up @@ -519,6 +512,19 @@ def objectives(self):
"Slice of backend dataset to show only built objectives"
return self._dataset.filter_by_attrs(obj_type="objectives")

@property
def valid_component_names(self):
def _filter(val):
return val in ["variables", "parameters", "global_expressions"]

in_data = set(self._dataset.filter_by_attrs(obj_type=_filter).data_vars.keys())
in_math = set(
name
for component in ["variables", "global_expressions"]
for name in self.inputs.math[component].keys()
)
return in_data.union(in_math)


class BackendModel(BackendModelGenerator, Generic[T]):
def __init__(self, inputs: xr.Dataset, instance: T, **config_overrides) -> None:
Expand Down
Loading

0 comments on commit ea55acc

Please sign in to comment.