Skip to content

Commit

Permalink
Revise use cache (#41)
Browse files Browse the repository at this point in the history
* fix typing

* template-update
  • Loading branch information
malmans2 committed Sep 25, 2023
1 parent c94905b commit 20ee622
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 60 deletions.
2 changes: 1 addition & 1 deletion .cruft.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"template": "https://github.com/ecmwf-projects/cookiecutter-conda-package",
"commit": "46c5959d066cde32a3f39e51474fe8a058de4860",
"commit": "308afb845f02bf8e5e5af2bab993f9fc472aa04d",
"checkout": null,
"context": {
"cookiecutter": {
Expand Down
14 changes: 7 additions & 7 deletions .github/workflows/on-push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.x
Expand All @@ -32,7 +32,7 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Install conda-merge
run: |
$CONDA/bin/python -m pip install conda-merge
Expand All @@ -56,7 +56,7 @@ jobs:
python-version: ['3.10', '3.11']

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Download combined environments
uses: actions/download-artifact@v3
with:
Expand All @@ -82,7 +82,7 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Download combined environments
uses: actions/download-artifact@v3
with:
Expand All @@ -108,7 +108,7 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Download combined environments
uses: actions/download-artifact@v3
with:
Expand Down Expand Up @@ -142,7 +142,7 @@ jobs:
extra: -integration

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Download combined environments
uses: actions/download-artifact@v3
with:
Expand Down Expand Up @@ -174,7 +174,7 @@ jobs:
(needs.integration-tests.result == 'success' || needs.integration-tests.result == 'skipped')
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Install packages
run: |
$CONDA/bin/python -m pip install build twine
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- id: debug-statements
- id: mixed-line-ending
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.9.1
hooks:
- id: black
- repo: https://github.com/keewis/blackdoc
Expand All @@ -20,12 +20,12 @@ repos:
- id: blackdoc
additional_dependencies: [black==22.3.0]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.280
rev: v0.0.290
hooks:
- id: ruff
args: [--fix, --show-fixes]
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
rev: 0.7.17
hooks:
- id: mdformat
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
Expand All @@ -36,6 +36,6 @@ repos:
- id: pretty-format-toml
args: [--autofix]
- repo: https://github.com/gitleaks/gitleaks
rev: v8.17.0
rev: v8.18.0
hooks:
- id: gitleaks
42 changes: 20 additions & 22 deletions c3s_eqc_automatic_quality_control/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def ensure_request_gets_cached(request: dict[str, Any]) -> dict[str, Any]:


def _cached_retrieve(collection_id: str, request: dict[str, Any]) -> emohawk.Data:
with cacholote.config.set(use_cache=True, return_cache_entry=False):
with cacholote.config.set(return_cache_entry=False):
return cads_toolbox.catalogue.retrieve(collection_id, request).data


Expand Down Expand Up @@ -401,7 +401,6 @@ def get_data(source: list[str]) -> Any:
return emohwak_dir


@cacholote.cacheable
def _download_and_transform_requests(
collection_id: str,
request_list: list[dict[str, Any]],
Expand Down Expand Up @@ -525,8 +524,11 @@ def download_and_transform(
else cached_open_mfdataset_kwargs or {}
)

use_cache = transform_func is not None
func = functools.partial(
_download_and_transform_requests,
cacholote.cacheable(_download_and_transform_requests)
if use_cache
else _download_and_transform_requests,
collection_id=collection_id,
transform_func=transform_func,
transform_func_kwargs=transform_func_kwargs,
Expand All @@ -544,28 +546,24 @@ def download_and_transform(
for request in request_list
)

use_cache = transform_func is not None
with cacholote.config.set(use_cache=use_cache):
if use_cache and transform_chunks:
# Cache each chunk transformed
sources = []
for request in tqdm.tqdm(request_list):
if invalidate_cache:
cacholote.delete(
func.func, *func.args, request_list=[request], **func.keywords
)
with cacholote.config.set(return_cache_entry=True):
sources.append(
func(request_list=[request]).result["args"][0]["href"]
)
ds = xr.open_mfdataset(sources, **cached_open_mfdataset_kwargs)
else:
# Cache final dataset transformed
if use_cache and transform_chunks:
# Cache each chunk transformed
sources = []
for request in tqdm.tqdm(request_list):
if invalidate_cache:
cacholote.delete(
func.func, *func.args, request_list=request_list, **func.keywords
func.func, *func.args, request_list=[request], **func.keywords
)
ds = func(request_list=request_list)
with cacholote.config.set(return_cache_entry=True):
sources.append(func(request_list=[request]).result["args"][0]["href"])
ds = xr.open_mfdataset(sources, **cached_open_mfdataset_kwargs)
else:
# Cache final dataset transformed
if invalidate_cache:
cacholote.delete(
func.func, *func.args, request_list=request_list, **func.keywords
)
ds = func(request_list=request_list)

ds.attrs.pop("coordinates", None) # Previously added to guarantee roundtrip
return ds
23 changes: 9 additions & 14 deletions c3s_eqc_automatic_quality_control/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
import plotly.graph_objs as go
import xarray as xr
from cartopy.mpl.geocollection import GeoQuadMesh
from matplotlib.axes import Axes
from matplotlib.image import AxesImage
from matplotlib.typing import ColorType
from xarray.plot.facetgrid import FacetGrid

from . import diagnostics, utils
Expand All @@ -43,7 +42,6 @@
}

FLAGS_T = int | Iterable[int]
COLOR_T = str | Iterable[str]


def line_plot(
Expand Down Expand Up @@ -268,7 +266,7 @@ def projected_map(
return plot_obj


def _infer_legend_dict(da: xr.DataArray) -> dict[str, tuple[COLOR_T, FLAGS_T]]:
def _infer_legend_dict(da: xr.DataArray) -> dict[str, tuple[ColorType, FLAGS_T]]:
flags = list(map(int, da.attrs["flag_values"]))
colors = da.attrs["flag_colors"].split()
meanings = da.attrs["flag_meanings"].split()
Expand All @@ -279,17 +277,17 @@ def _infer_legend_dict(da: xr.DataArray) -> dict[str, tuple[COLOR_T, FLAGS_T]]:
if len(flags) - len(colors) == 1:
colors.insert(flags.index(0), "#000000")

legend_dict: dict[str, tuple[COLOR_T, FLAGS_T]] = {}
legend_dict: dict[str, tuple[ColorType, FLAGS_T]] = {}
for m, c, f in zip(meanings, colors, flags, strict=True):
legend_dict[m.replace("_", " ").title()] = (c, f)
return legend_dict


def lccs_map(
da_lccs: xr.DataArray,
legend_dict: dict[str, tuple[COLOR_T, FLAGS_T]] | None = None,
legend_dict: dict[str, tuple[ColorType, FLAGS_T]] | None = None,
**kwargs: Any,
) -> AxesImage | FacetGrid[Any]:
) -> Any:
"""
Plot LCCS map.
Expand Down Expand Up @@ -346,12 +344,12 @@ def lccs_map(
def lccs_bar(
da: xr.DataArray,
da_lccs: xr.DataArray,
labels_dict: dict[str, tuple[COLOR_T, FLAGS_T]] | None = None,
labels_dict: dict[str, tuple[ColorType, FLAGS_T]] | None = None,
reduction: str = "mean",
groupby_bins_dims: dict[str, Any] = {},
exclude_no_data: bool = True,
**kwargs: Any,
) -> Axes:
) -> Any:
"""
Plot LCCS map.
Expand Down Expand Up @@ -424,7 +422,7 @@ def lccs_bar(
" ".join([reduction.title(), "of", xr.plot.utils.label_from_attrs(da)]),
)

ax = df_or_ser.plot.bar(color=colors, **kwargs)
ax = df_or_ser.plot.bar(color=list(colors), **kwargs)
if groupby_bins_dims:
ax.legend(bbox_to_anchor=(1, 1), loc="upper left")
return ax
Expand Down Expand Up @@ -456,8 +454,5 @@ def seasonal_boxplot(

da = da.stack(stacked_dim=da.dims)
df = da.to_dataframe()
axes = df.groupby(by=da[time_dim].dt.season.values).boxplot(**kwargs)
for ax in axes:
ax.xaxis.set_ticklabels([])

return axes
return df.groupby(by=da[time_dim].dt.season.values).boxplot(**kwargs)
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies:
- numpy
- pip
- plotly
- pooch
- pydantic == 1.*
- python-dotenv
- pyyaml
Expand Down
16 changes: 4 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ module = [
"emohawk.*",
"fsspec",
"joblib",
"matplotlib.*",
"plotly.*",
"shapely",
"sklearn.*",
Expand All @@ -48,17 +47,7 @@ module = [
[tool.ruff]
ignore = [
# pydocstyle: Missing Docstrings
"D1",
# pydocstyle: numpy convention
"D107",
"D203",
"D212",
"D213",
"D402",
"D413",
"D415",
"D416",
"D417"
"D1"
]
# Black line length is 88, but black does not format comments.
line-length = 110
Expand All @@ -77,6 +66,9 @@ select = [
]
target-version = "py310"

[tool.ruff.pydocstyle]
convention = "numpy"

[tool.setuptools]
packages = ["c3s_eqc_automatic_quality_control"]

Expand Down

0 comments on commit 20ee622

Please sign in to comment.