Skip to content

Commit

Permalink
better caching of grid out
Browse files Browse the repository at this point in the history
  • Loading branch information
malmans2 committed Mar 20, 2024
1 parent cb1f289 commit eb17ed2
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 27 deletions.
34 changes: 21 additions & 13 deletions notebooks/wp4/extreme_temperature_indices.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -359,15 +359,33 @@
" return ds\n",
"\n",
"\n",
"def get_grid_out(request_grid_out, method):\n",
" ds_regrid = download.download_and_transform(*request_grid_out)\n",
" coords = [\"latitude\", \"longitude\"]\n",
" if method == \"conservative\":\n",
" ds_regrid = add_bounds(ds_regrid)\n",
" for coord in list(coords):\n",
" coords.extend(ds_regrid.cf.bounds[coord])\n",
" grid_out = ds_regrid[coords]\n",
" coords_to_drop = set(grid_out.coords) - set(coords) - set(grid_out.dims)\n",
" grid_out = ds_regrid[coords].reset_coords(coords_to_drop, drop=True)\n",
" grid_out.attrs = {}\n",
" return grid_out\n",
"\n",
"\n",
"def compute_indices_and_trends(\n",
" ds,\n",
" index_names,\n",
" timeseries,\n",
" year_start,\n",
" year_stop,\n",
" resample,\n",
" request_grid_out=None,\n",
" **regrid_kwargs,\n",
"):\n",
" assert (request_grid_out and regrid_kwargs) or not (\n",
" request_grid_out or regrid_kwargs\n",
" )\n",
" ds = ds.drop_vars([var for var, da in ds.data_vars.items() if len(da.dims) != 3])\n",
" ds = ds[list(ds.data_vars)]\n",
"\n",
Expand All @@ -389,9 +407,10 @@
" ds_trends = compute_trends(ds_indices)\n",
" ds = ds_indices.mean(\"time\", keep_attrs=True)\n",
" ds = ds.merge(ds_trends)\n",
" if regrid_kwargs:\n",
" if request_grid_out:\n",
" ds = diagnostics.regrid(\n",
" ds.merge({da.name: da for da in bounds}),\n",
" grid_out=get_grid_out(request_grid_out, regrid_kwargs[\"method\"]),\n",
" **regrid_kwargs,\n",
" )\n",
" return ds"
Expand Down Expand Up @@ -450,17 +469,6 @@
},
"outputs": [],
"source": [
"ds_regrid = ds_era5\n",
"coords = [\"latitude\", \"longitude\"]\n",
"if interpolation_method == \"conservative\":\n",
" ds_regrid = add_bounds(ds_regrid)\n",
" for coord in list(coords):\n",
" coords.extend(ds_regrid.cf.bounds[coord])\n",
"grid_out = ds_regrid[coords]\n",
"coords_to_drop = set(grid_out.coords) - set(coords) - set(grid_out.dims)\n",
"grid_out = ds_regrid[coords].reset_coords(coords_to_drop, drop=True)\n",
"grid_out.attrs = {}\n",
"\n",
"interpolated_datasets = []\n",
"model_datasets = {}\n",
"for model in models:\n",
Expand All @@ -484,7 +492,7 @@
" transform_func_kwargs=transform_func_kwargs\n",
" | {\n",
" \"resample\": False,\n",
" \"grid_out\": grid_out,\n",
" \"request_grid_out\": request_lsm,\n",
" \"method\": interpolation_method,\n",
" \"skipna\": True,\n",
" },\n",
Expand Down
42 changes: 28 additions & 14 deletions notebooks/wp4/extreme_temperature_indices_future.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
"assert timeseries in (\"annual\", \"DJF\", \"MAM\", \"JJA\", \"SON\")\n",
"\n",
"# Choose CORDEX or CMIP6\n",
"collection_id = \"CORDEX\"\n",
"collection_id = \"CMIP6\"\n",
"assert collection_id in (\"CORDEX\", \"CMIP6\")\n",
"\n",
"# Define region for analysis\n",
Expand Down Expand Up @@ -260,7 +260,12 @@
" download.split_request(request_cmip6, chunks=chunks),\n",
" )\n",
"else:\n",
" raise ValueError(f\"{collection_id=}\")"
" raise ValueError(f\"{collection_id=}\")\n",
"\n",
"request_grid_out = (\n",
" request_sim[0],\n",
" request_sim[1][0] | {model_key: model_regrid},\n",
")"
]
},
{
Expand Down Expand Up @@ -343,15 +348,33 @@
" return ds\n",
"\n",
"\n",
"def get_grid_out(request_grid_out, method):\n",
" ds_regrid = download.download_and_transform(*request_grid_out)\n",
" coords = [\"latitude\", \"longitude\"]\n",
" if method == \"conservative\":\n",
" ds_regrid = add_bounds(ds_regrid)\n",
" for coord in list(coords):\n",
" coords.extend(ds_regrid.cf.bounds[coord])\n",
" grid_out = ds_regrid[coords]\n",
" coords_to_drop = set(grid_out.coords) - set(coords) - set(grid_out.dims)\n",
" grid_out = ds_regrid[coords].reset_coords(coords_to_drop, drop=True)\n",
" grid_out.attrs = {}\n",
" return grid_out\n",
"\n",
"\n",
"def compute_indices_and_trends(\n",
" ds,\n",
" index_names,\n",
" timeseries,\n",
" year_start,\n",
" year_stop,\n",
" resample,\n",
" request_grid_out=None,\n",
" **regrid_kwargs,\n",
"):\n",
" assert (request_grid_out and regrid_kwargs) or not (\n",
" request_grid_out or regrid_kwargs\n",
" )\n",
" ds = ds.drop_vars([var for var, da in ds.data_vars.items() if len(da.dims) != 3])\n",
" ds = ds[list(ds.data_vars)]\n",
"\n",
Expand All @@ -373,9 +396,10 @@
" ds_trends = compute_trends(ds_indices)\n",
" ds = ds_indices.mean(\"time\", keep_attrs=True)\n",
" ds = ds.merge(ds_trends)\n",
" if regrid_kwargs:\n",
" if request_grid_out:\n",
" ds = diagnostics.regrid(\n",
" ds.merge({da.name: da for da in bounds}),\n",
" grid_out=get_grid_out(request_grid_out, regrid_kwargs[\"method\"]),\n",
" **regrid_kwargs,\n",
" )\n",
" return ds"
Expand Down Expand Up @@ -439,16 +463,6 @@
},
"outputs": [],
"source": [
"coords = [\"latitude\", \"longitude\"]\n",
"if interpolation_method == \"conservative\":\n",
" ds_regrid = add_bounds(ds_regrid)\n",
" for coord in list(coords):\n",
" coords.extend(ds_regrid.cf.bounds[coord])\n",
"grid_out = ds_regrid[coords]\n",
"coords_to_drop = set(grid_out.coords) - set(coords) - set(grid_out.dims)\n",
"grid_out = ds_regrid[coords].reset_coords(coords_to_drop, drop=True)\n",
"grid_out.attrs = {}\n",
"\n",
"interpolated_datasets = []\n",
"model_datasets = {}\n",
"for model in models:\n",
Expand All @@ -468,7 +482,7 @@
" **kwargs,\n",
" transform_func_kwargs=transform_func_kwargs\n",
" | {\n",
" \"grid_out\": grid_out,\n",
" \"request_grid_out\": request_grid_out,\n",
" \"method\": interpolation_method,\n",
" \"skipna\": True,\n",
" },\n",
Expand Down

0 comments on commit eb17ed2

Please sign in to comment.