diff --git a/notebooks/wp4/cmip6_energy_indices.ipynb b/notebooks/wp4/cmip6_energy_indices.ipynb index 4ce0e81..2030cd9 100644 --- a/notebooks/wp4/cmip6_energy_indices.ipynb +++ b/notebooks/wp4/cmip6_energy_indices.ipynb @@ -23,12 +23,14 @@ "metadata": {}, "outputs": [], "source": [ + "import math\n", "import tempfile\n", "\n", + "import cartopy.crs as ccrs\n", "import icclim\n", "import matplotlib.pyplot as plt\n", "import xarray as xr\n", - "from c3s_eqc_automatic_quality_control import diagnostics, download\n", + "from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils\n", "from xarrayMannKendall import Mann_Kendall_test\n", "\n", "plt.style.use(\"seaborn-v0_8-notebook\")\n", @@ -83,7 +85,11 @@ " \"inm_cm5_0\",\n", " \"miroc6\",\n", " \"mpi_esm1_2_lr\",\n", - ")" + ")\n", + "\n", + "# Colormaps\n", + "cmaps = {\"HDD15.5\": \"Blues\", \"CDD22\": \"Reds\"}\n", + "cmaps_trend = cmaps_bias = {\"HDD15.5\": \"RdBu\", \"CDD22\": \"RdBu_r\"}" ] }, { @@ -398,6 +404,431 @@ "\n", "ds_interpolated = xr.concat(interpolated_datasets, \"model\")" ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "## Mask land and change attributes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "lsm = download.download_and_transform(*request_lsm)[\"lsm\"].squeeze(drop=True)\n", + "\n", + "# Cutout\n", + "regionalise_kwargs = {\n", + " \"lon_slice\": slice(area[1], area[3]),\n", + " \"lat_slice\": slice(area[0], area[2]),\n", + "}\n", + "lsm = utils.regionalise(lsm, **regionalise_kwargs)\n", + "ds_interpolated = utils.regionalise(ds_interpolated, **regionalise_kwargs)\n", + "model_datasets = {\n", + " model: utils.regionalise(ds, **regionalise_kwargs)\n", + " for model, ds in model_datasets.items()\n", + "}\n", + "\n", + "# Mask\n", + "ds_era5 = ds_era5.where(lsm)\n", + "ds_interpolated = ds_interpolated.where(lsm)\n", + "model_datasets = {\n", + " model: ds.where(diagnostics.regrid(lsm, ds, method=\"bilinear\"))\n", + " for model, ds in model_datasets.items()\n", + "}\n", + "\n", + "# Edit attributes\n", + "for ds in (ds_era5, ds_interpolated, *model_datasets.values()):\n", + " ds[\"trend\"] *= 10\n", + " ds[\"trend\"].attrs = {\"long_name\": \"trend\"}\n", + " for index in index_timeseries:\n", + " ds[index].attrs = {\"long_name\": \"\", \"units\": ds[index].attrs[\"units\"]}" + ] + }, + { + "cell_type": "markdown", + "id": "17", + "metadata": {}, + "source": [ + "## Plotting functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "def hatch_p_value(da, ax, **kwargs):\n", + " default_kwargs = {\n", + " \"plot_func\": \"contourf\",\n", + " \"show_stats\": False,\n", + " \"cmap\": \"none\",\n", + " \"add_colorbar\": False,\n", + " \"levels\": [0, 0.05, 1],\n", + " \"hatches\": [\"\", \"/\" * 3],\n", + " }\n", + " kwargs = default_kwargs | kwargs\n", + "\n", + " title = ax.get_title()\n", + " plot_obj = plot.projected_map(da, ax=ax, **kwargs)\n", + " ax.set_title(title)\n", + " return plot_obj\n", + "\n", + "\n", + "def hatch_p_value_ensemble(trend, p_value, ax):\n", + " n_models = trend.sizes[\"model\"]\n", + " robust_ratio = (p_value <= 0.05).sum(\"model\") / n_models\n", + " robust_ratio = robust_ratio.where(p_value.notnull().any(\"model\"))\n", + " signs = xr.concat([(trend > 0).sum(\"model\"), (trend < 0).sum(\"model\")], \"sign\")\n", + " sign_ratio = signs.max(\"sign\") / n_models\n", + " robust_threshold = 0.66\n", + " sign_ratio = sign_ratio.where(robust_ratio > robust_threshold)\n", + " for da, threshold, character in zip(\n", + " [robust_ratio, sign_ratio], [robust_threshold, 0.8], [\"/\", \"\\\\\"]\n", + " ):\n", + " hatch_p_value(da, ax=ax, levels=[0, threshold, 1], hatches=[character * 3, \"\"])\n", + "\n", + "\n", + "def set_extent(da, axs, area):\n", + " extent = [area[i] for i in (1, 3, 2, 0)]\n", + " for i, coord in enumerate(extent):\n", + " extent[i] += -1 if i % 2 else +1\n", + " for ax in axs:\n", + " ax.set_extent(extent)\n", + "\n", + "\n", + "def plot_models(\n", + " data,\n", + " da_for_kwargs=None,\n", + " p_values=None,\n", + " col_wrap=3,\n", + " subplot_kw={\"projection\": ccrs.PlateCarree()},\n", + " figsize=None,\n", + " layout=\"constrained\",\n", + " area=area,\n", + " **kwargs,\n", + "):\n", + " if isinstance(data, dict):\n", + " assert da_for_kwargs is not None\n", + " model_dataarrays = data\n", + " else:\n", + " da_for_kwargs = da_for_kwargs or data\n", + " model_dataarrays = dict(data.groupby(\"model\"))\n", + "\n", + " if p_values is not None:\n", + " model_p_dataarrays = (\n", + " p_values if isinstance(p_values, dict) else dict(p_values.groupby(\"model\"))\n", + " )\n", + " else:\n", + " model_p_dataarrays = None\n", + "\n", + " # Get kwargs\n", + " default_kwargs = {\"robust\": True, \"extend\": \"both\"}\n", + " kwargs = default_kwargs | kwargs\n", + " kwargs = xr.plot.utils._determine_cmap_params(da_for_kwargs.values, **kwargs)\n", + "\n", + " fig, axs = plt.subplots(\n", + " *(col_wrap, math.ceil(len(model_dataarrays) / col_wrap)),\n", + " subplot_kw=subplot_kw,\n", + " figsize=figsize,\n", + " layout=layout,\n", + " )\n", + " axs = axs.flatten()\n", + " for (model, da), ax in zip(model_dataarrays.items(), axs):\n", + " pcm = plot.projected_map(\n", + " da, ax=ax, show_stats=False, add_colorbar=False, **kwargs\n", + " )\n", + " ax.set_title(model)\n", + " if model_p_dataarrays is not None:\n", + " hatch_p_value(model_p_dataarrays[model], ax)\n", + " set_extent(da_for_kwargs, axs, area)\n", + " fig.colorbar(\n", + " pcm,\n", + " ax=axs.flatten(),\n", + " extend=kwargs[\"extend\"],\n", + " location=\"right\",\n", + " label=f\"{da_for_kwargs.attrs.get('long_name', '')} [{da_for_kwargs.attrs.get('units', '')}]\",\n", + " )\n", + " return fig\n", + "\n", + "\n", + "def plot_ensemble(\n", + " da_models,\n", + " da_era5=None,\n", + " p_value_era5=None,\n", + " p_value_models=None,\n", + " subplot_kw={\"projection\": ccrs.PlateCarree()},\n", + " figsize=None,\n", + " layout=\"constrained\",\n", + " cbar_kwargs=None,\n", + " area=area,\n", + " cmap_bias=None,\n", + " cmap_std=None,\n", + " **kwargs,\n", + "):\n", + " # Get kwargs\n", + " default_kwargs = {\"robust\": True, \"extend\": \"both\"}\n", + " kwargs = default_kwargs | kwargs\n", + " kwargs = xr.plot.utils._determine_cmap_params(\n", + " da_models.values if da_era5 is None else da_era5.values, **kwargs\n", + " )\n", + " if da_era5 is None and cbar_kwargs is None:\n", + " cbar_kwargs = {\"orientation\": \"horizontal\"}\n", + "\n", + " # Figure\n", + " fig, axs = plt.subplots(\n", + " *(1 if da_era5 is None else 2, 2),\n", + " subplot_kw=subplot_kw,\n", + " figsize=figsize,\n", + " layout=layout,\n", + " )\n", + " axs = axs.flatten()\n", + " axs_iter = iter(axs)\n", + "\n", + " # ERA5\n", + " if da_era5 is not None:\n", + " ax = next(axs_iter)\n", + " plot.projected_map(\n", + " da_era5, ax=ax, show_stats=False, cbar_kwargs=cbar_kwargs, **kwargs\n", + " )\n", + " if p_value_era5 is not None:\n", + " hatch_p_value(p_value_era5, ax=ax)\n", + " ax.set_title(\"ERA5\")\n", + "\n", + " # Median\n", + " ax = next(axs_iter)\n", + " median = da_models.median(\"model\", keep_attrs=True)\n", + " plot.projected_map(\n", + " median, ax=ax, show_stats=False, cbar_kwargs=cbar_kwargs, **kwargs\n", + " )\n", + " if p_value_models is not None:\n", + " hatch_p_value_ensemble(trend=da_models, p_value=p_value_models, ax=ax)\n", + " ax.set_title(\"Ensemble Median\")\n", + "\n", + " # Bias\n", + " if da_era5 is not None:\n", + " ax = next(axs_iter)\n", + " with xr.set_options(keep_attrs=True):\n", + " bias = median - da_era5\n", + " plot.projected_map(\n", + " bias,\n", + " ax=ax,\n", + " show_stats=False,\n", + " center=0,\n", + " cbar_kwargs=cbar_kwargs,\n", + " **(default_kwargs | {\"cmap\": cmap_bias}),\n", + " )\n", + " ax.set_title(\"Ensemble Median Bias\")\n", + "\n", + " # Std\n", + " ax = next(axs_iter)\n", + " std = da_models.std(\"model\", keep_attrs=True)\n", + " plot.projected_map(\n", + " std,\n", + " ax=ax,\n", + " show_stats=False,\n", + " cbar_kwargs=cbar_kwargs,\n", + " **(default_kwargs | {\"cmap\": cmap_std}),\n", + " )\n", + " ax.set_title(\"Ensemble Standard Deviation\")\n", + "\n", + " set_extent(da_models, axs, area)\n", + " return fig\n", + "\n", + "\n", + "common_title = f\"{year_start=} {year_stop=}\"" + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": {}, + "source": [ + "## Plot ensembles" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "for index in index_timeseries:\n", + " # Index\n", + " da = ds_interpolated[index]\n", + " fig = plot_ensemble(\n", + " da_models=da,\n", + " da_era5=ds_era5[index],\n", + " cmap=cmaps.get(index),\n", + " cmap_bias=cmaps_bias.get(index),\n", + " )\n", + " fig.suptitle(f\"{index}\\n{common_title}\")\n", + " plt.show()\n", + "\n", + " # Trend\n", + " da_trend = ds_interpolated[\"trend\"].sel(index=index)\n", + " da_trend.attrs[\"units\"] = f\"{da.attrs['units']} / decade\"\n", + " fig = plot_ensemble(\n", + " da_models=da_trend,\n", + " da_era5=ds_era5[\"trend\"].sel(index=index),\n", + " p_value_era5=ds_era5[\"p\"].sel(index=index),\n", + " p_value_models=ds_interpolated[\"p\"].sel(index=index),\n", + " center=0,\n", + " cmap=cmaps_trend.get(index),\n", + " cmap_bias=cmaps_bias.get(index),\n", + " )\n", + " fig.suptitle(f\"Trend of {index}\\n{common_title}\")\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, + "source": [ + "## Plot models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "for index in index_timeseries:\n", + " cmap_trend = \"RdBu\" if index.startswith(\"HDD\") else \"RdBu_r\"\n", + "\n", + " # Index\n", + " da_for_kwargs = ds_era5[index]\n", + " fig = plot_models(\n", + " data={model: ds[index] for model, ds in model_datasets.items()},\n", + " da_for_kwargs=da_for_kwargs,\n", + " cmap=cmaps.get(index),\n", + " )\n", + " fig.suptitle(f\"{index}\\n{common_title}\")\n", + " plt.show()\n", + "\n", + " # Trend\n", + " da_for_kwargs_trends = ds_era5[\"trend\"].sel(index=index)\n", + " da_for_kwargs_trends.attrs[\"units\"] = f\"{da_for_kwargs.attrs['units']} / decade\"\n", + " fig = plot_models(\n", + " data={\n", + " model: ds[\"trend\"].sel(index=index) for model, ds in model_datasets.items()\n", + " },\n", + " da_for_kwargs=da_for_kwargs_trends,\n", + " p_values={\n", + " model: ds[\"p\"].sel(index=index) for model, ds in model_datasets.items()\n", + " },\n", + " center=0,\n", + " cmap=cmaps_trend.get(index),\n", + " )\n", + " fig.suptitle(f\"Trend of {index}\\n{common_title}\")\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "23", + "metadata": {}, + "source": [ + "## Plot bias" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "with xr.set_options(keep_attrs=True):\n", + " bias = ds_interpolated - ds_era5\n", + "\n", + "for index in index_timeseries:\n", + " # Index bias\n", + " da = bias[index]\n", + " fig = plot_models(data=da, center=0, cmap=cmaps_bias.get(index))\n", + " fig.suptitle(f\"Bias of {index}\\n{common_title}\")\n", + " plt.show()\n", + "\n", + " # Trend bias\n", + " da_trend = bias[\"trend\"].sel(index=index)\n", + " da_trend.attrs[\"units\"] = f\"{da.attrs['units']} / decade\"\n", + " fig = plot_models(data=da_trend, center=0, cmap=cmaps_bias.get(index))\n", + " fig.suptitle(f\"Trend bias of {index}\\n{common_title}\")\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "25", + "metadata": {}, + "source": [ + "## Boxplot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [ + "weights = True\n", + "mean_datasets = [\n", + " diagnostics.spatial_weighted_mean(ds.expand_dims(model=[model]), weights=weights)\n", + " for model, ds in model_datasets.items()\n", + "]\n", + "mean_ds = xr.concat(mean_datasets, \"model\")\n", + "mean_bias_ds = diagnostics.spatial_weighted_mean(bias, weights=weights)\n", + "for is_bias, ds in zip((False, True), (mean_ds, mean_bias_ds)):\n", + " for index, da in ds[\"trend\"].groupby(\"index\"):\n", + " df_slope = da.to_dataframe()[[\"trend\"]]\n", + " ax = df_slope.boxplot()\n", + " ax.scatter(\n", + " x=[1] * len(df_slope),\n", + " y=df_slope,\n", + " color=\"grey\",\n", + " marker=\".\",\n", + " label=\"models\",\n", + " )\n", + "\n", + " # Ensemble mean\n", + " ax.scatter(\n", + " x=1,\n", + " y=da.mean(\"model\"),\n", + " marker=\"o\",\n", + " label=\"CMIP6 Ensemble Mean\",\n", + " )\n", + "\n", + " # ERA5\n", + " labels = [\"CMIP6 Ensemble\"]\n", + " if not is_bias:\n", + " da = ds_era5[\"trend\"].sel(index=index)\n", + " da = diagnostics.spatial_weighted_mean(da)\n", + " ax.scatter(\n", + " x=2,\n", + " y=da.values,\n", + " marker=\"o\",\n", + " label=\"ERA5\",\n", + " )\n", + " labels.append(\"ERA5\")\n", + "\n", + " ax.set_xticks(range(1, len(labels) + 1), labels)\n", + " ax.set_ylabel(f\"{ds[index].attrs['units']} / decade\")\n", + " plt.suptitle(f\"Trend{' bias ' if is_bias else ' '}of {index}\")\n", + " plt.legend()\n", + " plt.show()" + ] } ], "metadata": {