Skip to content

Commit

Permalink
More templates (#12)
Browse files Browse the repository at this point in the history
* add workaround for satellite

* add pad_month_day

* rename to stringify_dates

* add coordinates attribute

* template update

* use space
  • Loading branch information
malmans2 authored Jan 9, 2023
1 parent e1ff9cf commit 2c163e5
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .cruft.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"template": "https://github.com/ecmwf-projects/cookiecutter-conda-package",
"commit": "c7cd369f71b4cfc647fc7c1972cc3c45860256d5",
"commit": "c4101cc46ee7312553f246d8806aebc7160edcc4",
"checkout": null,
"context": {
"cookiecutter": {
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ repos:
args: [--autofix]
additional_dependencies: [toml-sort<0.22.0]
- repo: https://github.com/PyCQA/pydocstyle.git
rev: 6.1.1
rev: 6.2.2
hooks:
- id: pydocstyle
additional_dependencies: [toml]
additional_dependencies: [tomli]
exclude: tests|docs
55 changes: 37 additions & 18 deletions c3s_eqc_automatic_quality_control/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import calendar
import itertools
import logging
import pathlib
from collections.abc import Callable
from typing import Any

Expand All @@ -34,7 +35,9 @@

LOGGER = dashboard.get_logger()
# In the future, this kwargs should somehow be handle upstream by the toolbox.
TO_XARRAY_KWARGS = {


TO_XARRAY_KWARGS: dict[str, Any] = {
"harmonise": True,
"pandas_read_csv_kwargs": {"comment": "#"},
}
Expand Down Expand Up @@ -145,27 +148,26 @@ def update_request_date(
start: str | pd.Period,
stop: str | pd.Period | None = None,
switch_month_day: int | None = None,
) -> dict[str, Any] | list[dict[str, Any]]:
stringify_dates: bool = False,
) -> list[dict[str, Any]]:
"""
Return the requests defined by 'request' for the period defined by start and stop.
Parameters
----------
request: dict
Parameters of the request
start: str or pd.Period
String {start_year}-{start_month} pd.Period with freq='M'
stop: str or pd.Period
Optional string {stop_year}-{stop_month} pd.Period with freq='M'
If None the stop date is computed using the `switch_month_day`
switch_month_day: int
Used to compute the stop date in case stop is None. The stop date is computed as follows:
if current day > switch_month_day then stop_month = current_month - 1
else stop_month = current_month - 2
stringify_dates: bool
Whether to convert date to strings
Returns
-------
Expand All @@ -178,11 +180,19 @@ def update_request_date(
stop = pd.Period(stop, "M")

dates = compute_request_date(start, stop, switch_month_day=switch_month_day)
if isinstance(dates, dict):
return {**request, **dates}
requests = []

for d in dates:
requests.append({**request, **d})
padded_d = {}
if stringify_dates:
for key, value in d.items():
if key in ("year", "month", "day"):
padded_d[key] = (
f"{value:02d}"
if isinstance(value, int)
else [f"{v:02d}" for v in value]
)
requests.append({**request, **d, **padded_d})
return requests


Expand Down Expand Up @@ -274,16 +284,29 @@ def split_request(
return requests


def expand_dim_using_source(ds: xr.Dataset) -> xr.Dataset:
# TODO: workaround beacuse the toolbox is not able to open satellite datasets
if source := ds.encoding.get("source"):
ds = ds.expand_dims(source=[pathlib.Path(source).stem])
return ds


@cacholote.cacheable
def download_and_transform_chunk(
collection_id: str,
request: dict[str, Any],
transform_func: Callable[[xr.Dataset], xr.Dataset] | None = None,
) -> xr.Dataset:
remote = cads_toolbox.catalogue.retrieve(collection_id, request)
ds: xr.Dataset = remote.to_xarray(**TO_XARRAY_KWARGS)
if transform_func is not None:
ds = transform_func(ds)
return ds
kwargs = dict(TO_XARRAY_KWARGS)
if collection_id.startswith("satellite-"):
kwargs.setdefault("xarray_open_mfdataset_kwargs", {})
kwargs["xarray_open_mfdataset_kwargs"]["preprocess"] = expand_dim_using_source
ds: xr.Dataset = remote.to_xarray(**kwargs)
# TODO: make cacholote add coordinates? Needed to guarantee roundtrip
# See: https://docs.xarray.dev/en/stable/user-guide/io.html#coordinates
ds.attrs["coordinates"] = " ".join([str(coord) for coord in ds.coords])
return transform_func(ds) if transform_func else ds


def download_and_transform(
Expand Down Expand Up @@ -323,14 +346,10 @@ def download_and_transform(
for request in ensure_list(requests):
request_list.extend(split_request(request, chunks, split_all))

func_chunk = download_and_transform_chunk
if transform_func:
func_chunk = cacholote.cacheable(func_chunk)

datasets = []
for n, request_chunk in enumerate(request_list):
logger.info(f"Gathering file {n+1} out of {len(request_list)}...")
ds = func_chunk(
ds = download_and_transform_chunk(
collection_id,
request=request_chunk,
transform_func=transform_func,
Expand Down
7 changes: 4 additions & 3 deletions c3s_eqc_automatic_quality_control/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def shaded_std(
ds_std: xr.Dataset | None = None,
hue_dim: str | None = None,
title: str | None = None,
x_dim: str = "time",
) -> go.Figure:

if isinstance(vars, str):
Expand Down Expand Up @@ -103,7 +104,7 @@ def shaded_std(
data.append(
go.Scatter(
name=label,
x=da_mean["time"],
x=da_mean[x_dim],
y=da_mean,
mode="lines",
line=dict(color=dark),
Expand All @@ -114,7 +115,7 @@ def shaded_std(
data.append(
go.Scatter(
name="Upper Bound",
x=da_mean["time"],
x=da_mean[x_dim],
y=da_mean + da_std,
mode="lines",
line=dict(width=0.25, color=dark),
Expand All @@ -124,7 +125,7 @@ def shaded_std(
data.append(
go.Scatter(
name="Lower Bound",
x=da_mean["time"],
x=da_mean[x_dim],
y=da_mean - da_std,
line=dict(width=0.25, color=dark),
mode="lines",
Expand Down
19 changes: 19 additions & 0 deletions tests/test_10_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,22 @@ def test_update_request() -> None:

requests = download.update_request_date({}, "2020-02", "2022-11")
assert len(requests) == 3


@pytest.mark.parametrize(
"stringify_dates,expected_month,expected_day",
[
(False, [1], list(range(1, 32))),
(True, ["01"], [f"{i:02d}" for i in range(1, 32)]),
],
)
def test_stringify_dates(
stringify_dates: bool,
expected_month: list[str | int],
expected_day: list[str | int],
) -> None:
request, *_ = download.update_request_date(
{}, "2022-1", "2022-1", stringify_dates=stringify_dates
)
assert request["month"] == expected_month
assert request["day"] == expected_day

0 comments on commit 2c163e5

Please sign in to comment.