diff --git a/notebooks/wp4/extreme_temperature_indices.ipynb b/notebooks/wp4/extreme_temperature_indices.ipynb index acf3244..5092b05 100644 --- a/notebooks/wp4/extreme_temperature_indices.ipynb +++ b/notebooks/wp4/extreme_temperature_indices.ipynb @@ -626,8 +626,8 @@ "\n", "\n", "def plot_ensemble(\n", - " da_era5,\n", " da_models,\n", + " da_era5=None,\n", " p_value_era5=None,\n", " p_value_models=None,\n", " subplot_kw={\"projection\": ccrs.PlateCarree()},\n", @@ -638,40 +638,49 @@ " # Get kwargs\n", " default_kwargs = {\"robust\": True, \"extend\": \"both\"}\n", " kwargs = default_kwargs | kwargs\n", - " kwargs = xr.plot.utils._determine_cmap_params(da_era5.values, **kwargs)\n", + " kwargs = xr.plot.utils._determine_cmap_params(\n", + " da_models.values if da_era5 is None else da_era5.values, **kwargs\n", + " )\n", "\n", " # Figure\n", " fig, axs = plt.subplots(\n", - " *(2, 2),\n", + " *(2, 1 if da_era5 is None else 2),\n", " subplot_kw=subplot_kw,\n", " figsize=figsize,\n", " layout=layout,\n", " )\n", " axs = axs.flatten()\n", + " axs_iter = iter(axs)\n", "\n", " # ERA5\n", - " plot.projected_map(da_era5, ax=axs[0], show_stats=False, **kwargs)\n", - " if p_value_era5 is not None:\n", - " hatch_p_value(p_value_era5, ax=axs[0])\n", - " axs[0].set_title(\"ERA5\")\n", + " if da_era5 is not None:\n", + " ax = next(axs_iter)\n", + " plot.projected_map(da_era5, ax=ax, show_stats=False, **kwargs)\n", + " if p_value_era5 is not None:\n", + " hatch_p_value(p_value_era5, ax=ax)\n", + " ax.set_title(\"ERA5\")\n", "\n", " # Median\n", + " ax = next(axs_iter)\n", " median = da_models.median(\"model\", keep_attrs=True)\n", - " plot.projected_map(median, ax=axs[1], show_stats=False, **kwargs)\n", + " plot.projected_map(median, ax=ax, show_stats=False, **kwargs)\n", " if p_value_models is not None:\n", - " hatch_p_value_ensemble(trend=da_models, p_value=p_value_models, ax=axs[1])\n", - " axs[1].set_title(\"Ensemble Median\")\n", + " hatch_p_value_ensemble(trend=da_models, p_value=p_value_models, ax=ax)\n", + " ax.set_title(\"Ensemble Median\")\n", "\n", " # Bias\n", - " with xr.set_options(keep_attrs=True):\n", - " bias = median - da_era5\n", - " plot.projected_map(bias, ax=axs[2], show_stats=False, center=0, **default_kwargs)\n", - " axs[2].set_title(\"Ensemble Median Bias\")\n", + " if da_era5 is not None:\n", + " ax = next(axs_iter)\n", + " with xr.set_options(keep_attrs=True):\n", + " bias = median - da_era5\n", + " plot.projected_map(bias, ax=ax, show_stats=False, center=0, **default_kwargs)\n", + " ax.set_title(\"Ensemble Median Bias\")\n", "\n", " # Std\n", + " ax = next(axs_iter)\n", " std = da_models.std(\"model\", keep_attrs=True)\n", - " plot.projected_map(std, ax=axs[3], show_stats=False, **default_kwargs)\n", - " axs[3].set_title(\"Ensemble Standard Deviation\")\n", + " plot.projected_map(std, ax=ax, show_stats=False, **default_kwargs)\n", + " ax.set_title(\"Ensemble Standard Deviation\")\n", "\n", " set_extent(da_era5, axs)\n", " return fig\n", @@ -695,14 +704,14 @@ "source": [ "for index in index_names:\n", " # Index\n", - " fig = plot_ensemble(da_era5=ds_era5[index], da_models=ds_interpolated[index])\n", + " fig = plot_ensemble(da_models=ds_interpolated[index], da_era5=ds_era5[index])\n", " fig.suptitle(f\"{index}\\n{common_title}\")\n", " plt.show()\n", "\n", " # Trend\n", " fig = plot_ensemble(\n", - " da_era5=ds_era5[\"trend\"].sel(index=index),\n", " da_models=ds_interpolated[\"trend\"].sel(index=index),\n", + " da_era5=ds_era5[\"trend\"].sel(index=index),\n", " p_value_era5=ds_era5[\"p\"].sel(index=index),\n", " p_value_models=ds_interpolated[\"p\"].sel(index=index),\n", " center=0,\n",