diff --git a/.github/workflows/workflows.yaml b/.github/workflows/workflows.yaml index 4ad85be6..7fd4816f 100644 --- a/.github/workflows/workflows.yaml +++ b/.github/workflows/workflows.yaml @@ -13,19 +13,12 @@ jobs: matrix: os: [ubuntu-latest] python-version: ["3.10", "3.11", "3.12"] - torch-version: [2.0.0, 2.1.0, 2.2.0] + torch-version: [2.3.0, 2.4.0] include: - - torch-version: 2.1.0 - torchvision-version: 0.16.0 - - torch-version: 2.0.0 - torchvision-version: 0.15.1 - - torch-version: 2.2.0 - torchvision-version: 0.17.0 - exclude: - - python-version: "3.12" - torch-version: 2.0.0 - - python-version: "3.12" - torch-version: 2.1.0 + - torch-version: 2.3.0 + torchvision-version: 0.18.0 + - torch-version: 2.4.0 + torchvision-version: 0.19.0 steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -39,6 +32,7 @@ jobs: - name: Install internal dependencies run: | pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${{ matrix.torch-version}}+cpu.html + if [ ${{ matrix.torch-version}} == 2.3.0 ]; then pip install dgl -f https://data.dgl.ai/wheels/torch-2.3/repo.html; fi - name: Install main package run: | pip install -e . @@ -71,6 +65,8 @@ jobs: echo "PYTESTCOV=$PYTESTCOV" >> $GITHUB_ENV - name: Run pytest run: | + # import dgl to initialize backend + if [ ${{ matrix.torch-version}} == 2.3.0 ]; then python3 -c "import dgl"; fi export PYTEST_COMMAND="pytest $PYTESTCOV $PYTESTXDIST -s" echo "Will be running this command: $PYTEST_COMMAND" eval $PYTEST_COMMAND diff --git a/environment_cpu.yml b/environment_cpu.yml index 5784bbb4..db246dba 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -24,7 +24,7 @@ dependencies: - zarr - h3-py - numpy - - pyshtools + - torch_harmonics - pip: - datasets - einops diff --git a/environment_cuda.yml b/environment_cuda.yml index 9f76251a..0b86f8fb 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -25,7 +25,7 @@ dependencies: - zarr - h3-py - numpy - - pyshtools + - torch_harmonics - pip: - datasets - einops diff --git a/graph_weather/models/gencast/README.md b/graph_weather/models/gencast/README.md index 5d963790..cb8f1a4c 100644 --- a/graph_weather/models/gencast/README.md +++ b/graph_weather/models/gencast/README.md @@ -1 +1,216 @@ # GenCast + +## Overview +This repository offers an unofficial implementation of [GenCast](https://arxiv.org/abs/2312.15796), a cutting-edge model designed to enhance weather forecasting accuracy. GenCast integrates diffusion models, sparse transformers, and graph neural networks (GNNs) to improve upon the GraphCast model with a score-based generative approach. This innovative combination aims to revolutionize weather prediction, significantly contributing to climate research and mitigation efforts. Additionally, GenCast supports ensemble predictions to assess the probability of extreme weather events. + +Below is an illustration of the full model architecture: + +

+ +

+ +After 20 epochs, the diffusion process for generating low resolution (128x64) 12h residuals predictions, looks like this: + +

+ +

+ +A longer time range can be achieved by using the model's own predictions as inputs in an autoregressive manner: + +

+ +

+ +## The Denoiser +### Description +The core component of GenCast i is the `Denoiser` module: it takes as inputs the previous two timesteps, the corrupted target residual, and the +noise level, and outputs the denoised predictions. The `Denoiser` operates as follows: +- initializes the graph using the `GraphBuilder` class, +- combines `encoder`, `processor`, and `decoder`, +- preconditions inputs and outputs on the noise levels using the parametrization from [Karras et al. (2022)](https://arxiv.org/abs/2206.00364). + +The code is modular, allowing for easy swapping of the graph, encoder, processor, and decoder with custom architectures. The main arguments are: +- `grid_lon` (np.ndarray): array of longitudes. +- `grid_lat` (np.ndarray): array of latitudes. +- `input_features_dim` (int): dimension of the input features for a single timestep. +- `output_features_dim` (int): dimension of the target features. +- `hidden_dims` (list[int], optional): list of dimensions for the hidden layers in the MLPs used in GenCast. This also determines the latent dimension. Defaults to [512, 512]. +- `num_blocks` (int, optional): number of transformer blocks in Processor. Defaults to 16. +- `num_heads` (int, optional): number of heads for each transformer. Defaults to 4. +- `splits` (int, optional): number of time to split the icosphere during graph building. Defaults to 5. +- `num_hops` (int, optional): the transformes will attention to the (2^num_hops)-neighbours of each node. Defaults to 16. +- `device` (torch.device, optional): device on which we want to build graph. Defaults to torch.device("cpu"). +- `sparse` (bool): if true the processor will apply Sparse Attention using DGL backend. Defaults to False. +- `use_edges_features` (bool): if true use mesh edges features inside the Processor. Defaults to True. +- `scale_factor` (float): the message in the Encoder is multiplied by the scale factor. Defaults to 1.0. + +> [!NOTE] +> If the graph has many edges, setting `sparse = True` may perform better in terms of memory and speed. Note that `sparse = False` uses PyG as the backend, while `sparse = True` uses DGL. The two implementations are not exactly equivalent: the former is described in the paper _"Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification"_ and can also handle edge features, while the latter is a classical transformer that performs multi-head attention utilizing the mask's sparsity and does not include edge features in the computations. + +> [!WARNING] +> The sparse implementation currently does not support `Float16/BFloat16` precision. + +> [!NOTE] +> To fine-tune a pretrained model with a higher resolution dataset, the `scale_factor` should be set accordingly. For example: if the starting resolution is 1 deg and the final resolution is 0.25 deg, then the scale factor is 1/16. + +### Example of usage +```python +import torch +import numpy as np +from graph_weather.models.gencast import Denoiser + +grid_lat = np.arange(-90, 90, 1) +grid_lon = np.arange(0, 360, 1) +input_features_dim = 10 +output_features_dim = 5 +batch_size = 16 + +denoiser = Denoiser( + grid_lon=grid_lon, + grid_lat=grid_lat, + input_features_dim=input_features_dim, + output_features_dim=output_features_dim, + hidden_dims=[32, 32], + num_blocks=8, + num_heads=4, + splits=4, + num_hops=8, + device=torch.device("cpu"), + sparse=True, + use_edges_features=False, + ).eval() + +corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim)) +prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2 * input_features_dim)) +noise_levels = torch.rand((batch_size, 1)) + +preds = denoiser( + corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels + ) +``` +## The Sampler +The sampler employs the second-order `DPMSolver++2S` solver, augmented with stochastic churn and noise inflation techniques from [Karras et al. (2022)](https://arxiv.org/abs/2206.00364), to introduce additional stochasticity into the sampling process. When conditioning on previous timesteps, it follows the Conditional Denoising Estimator approach as outlined by Batzolis et al. (2021). + +### Example of usage +```python +import torch +import numpy as np +from graph_weather.models.gencast import Sampler, Denoiser + +grid_lat = np.arange(-90, 90, 1) +grid_lon = np.arange(0, 360, 1) +input_features_dim = 10 +output_features_dim = 5 + +denoiser = Denoiser( + grid_lon=grid_lon, + grid_lat=grid_lat, + input_features_dim=input_features_dim, + output_features_dim=output_features_dim, + hidden_dims=[16, 32], + num_blocks=3, + num_heads=4, + splits=0, + num_hops=1, + device=torch.device("cpu"), + ).eval() + +prev_inputs = torch.randn((1, len(grid_lon), len(grid_lat), 2 * input_features_dim)) + +sampler = Sampler() +preds = sampler.sample(denoiser, prev_inputs) +``` +> [!NOTE] +> The Sampler class supports modifying all the sampling parameters ($S_{churn}$, $S_{noise}$, $\sigma_{max}$, $\sigma_{min}$ ...). The defaults values are detailed in Gencast's paper. + +## Training +The script `train.py` provides a basic setup for training the model. It follows guidelines from GraphCast and GenCast, combining a linear warmup phase with a cosine scheduler. The script supports multi-device DDP training, gradient accumulation, and WandB logging. + +## 🤗 Hugging Face pretrained models +You can easily test our models using the pretrained versions available on Hugging Face. These models have been trained on the [WeatherBench2](https://weatherbench2.readthedocs.io/en/latest/) dataset. + + +> [!WARNING] +> Currently, the following resolutions are supported: +> - **128x64**: `openclimatefix/gencast-128x64` (early stage), +> - **240x121**: `openclimatefix/gencast-240x121` (coming soon). + +```python +import matplotlib.pyplot as plt +import numpy as np +import torch + +from graph_weather.data.gencast_dataloader import GenCastDataset +from graph_weather.models.gencast import Denoiser, Sampler + +if torch.cuda.is_available(): + device = torch.device('cuda:0') +else: + device = torch.device('cpu') + +# load dataset +OBS_PATH = "gs://weatherbench2/datasets/era5/1959-2022-6h-128x64_equiangular_conservative.zarr" +atmospheric_features = [ + "geopotential", + "specific_humidity", + "temperature", + "u_component_of_wind", + "v_component_of_wind", + "vertical_velocity", +] +single_features = [ + "2m_temperature", + "10m_u_component_of_wind", + "10m_v_component_of_wind", + "mean_sea_level_pressure", + # "sea_surface_temperature", + "total_precipitation_12hr", +] +static_features = [ + "geopotential_at_surface", + "land_sea_mask", +] + +dataset = GenCastDataset( + obs_path=OBS_PATH, + atmospheric_features=atmospheric_features, + single_features=single_features, + static_features=static_features, + max_year=2018, + time_step=2, +) + +# download weights from HF +print("> Downloading model's weights...") +denoiser=Denoiser.from_pretrained("openclimatefix/gencast-128x64", + grid_lon=dataset.grid_lon, + grid_lat=dataset.grid_lat).to(device) + +# load inputs and targets +print("> Loading inputs and target...") +data = dataset[0] +_, prev_inputs, _, target_residuals = data +prev_inputs = torch.tensor(prev_inputs).unsqueeze(0).to(device) +target_residuals = torch.tensor(target_residuals).unsqueeze(0).to(device) + +# predict +print("> Making predictions...") +sampler = Sampler() +preds = sampler.sample(denoiser, prev_inputs) + +print("Done!") + +# plot results +var_id = 78 # 2m_temperature +fig1, ax = plt.subplots(2) +ax[0].imshow(preds[0, :, :, var_id].T.cpu(), origin="lower", cmap="RdBu", vmin=-5, vmax=5) +ax[0].set_xticks([]) +ax[0].set_yticks([]) +ax[0].set_title("Diffusion sampling prediction") + +ax[1].imshow(target_residuals[0, :, :, var_id].T.cpu(), origin="lower", cmap="RdBu", vmin=-5, vmax=5) +ax[1].set_xticks([]) +ax[1].set_yticks([]) +ax[1].set_title("Ground truth") +plt.show() +``` diff --git a/graph_weather/models/gencast/denoiser.py b/graph_weather/models/gencast/denoiser.py index 9599bc25..60148657 100644 --- a/graph_weather/models/gencast/denoiser.py +++ b/graph_weather/models/gencast/denoiser.py @@ -10,16 +10,17 @@ import einops import numpy as np import torch -from torch_geometric.data import Batch +from huggingface_hub import PyTorchModelHubMixin from graph_weather.models.gencast.graph.graph_builder import GraphBuilder from graph_weather.models.gencast.layers.decoder import Decoder from graph_weather.models.gencast.layers.encoder import Encoder from graph_weather.models.gencast.layers.processor import Processor +from graph_weather.models.gencast.utils.batching import batch, hetero_batch from graph_weather.models.gencast.utils.noise import Preconditioner -class Denoiser(torch.nn.Module): +class Denoiser(torch.nn.Module, PyTorchModelHubMixin): """GenCast's Denoiser.""" def __init__( @@ -36,6 +37,7 @@ def __init__( device: torch.device = torch.device("cpu"), sparse: bool = False, use_edges_features: bool = True, + scale_factor: float = 1.0, ): """Initialize the Denoiser. @@ -58,6 +60,9 @@ def __init__( Defaults to False. use_edges_features (bool): if true use mesh edges features inside the Processor. Defaults to True. + scale_factor (float): in the Encoder the message passing output is multiplied by the + scale factor. This is important when you want to fine-tune a pretrained model to a + higher resolution. Defaults to 1. """ super().__init__() self.num_lon = len(grid_lon) @@ -76,6 +81,8 @@ def __init__( add_edge_features_to_khop=use_edges_features, ) + self._register_graph() + # Initialize Encoder self.encoder = Encoder( grid_dim=output_features_dim + 2 * input_features_dim + self.graphs.grid_nodes_dim, @@ -84,6 +91,7 @@ def __init__( hidden_dims=hidden_dims, activation_layer=torch.nn.SiLU, use_layer_norm=True, + scale_factor=scale_factor, ) # Initialize Processor @@ -138,16 +146,19 @@ def _check_shapes(self, corrupted_targets, prev_inputs, noise_levels): def _run_encoder(self, grid_features): # build big graph with batch_size disconnected copies of the graph, with features [(b n) f]. batch_size = grid_features.shape[0] - g2m_batched = Batch.from_data_list([self.graphs.g2m_graph] * batch_size) - + batched_senders, batched_receivers, batched_edge_index, batched_edge_attr = hetero_batch( + self.g2m_grid_nodes, + self.g2m_mesh_nodes, + self.g2m_edge_index, + self.g2m_edge_attr, + batch_size, + ) # load features. grid_features = einops.rearrange(grid_features, "b n f -> (b n) f") - input_grid_nodes = torch.cat([grid_features, g2m_batched["grid_nodes"].x], dim=-1).type( - torch.float32 - ) - input_mesh_nodes = g2m_batched["mesh_nodes"].x - input_edge_attr = g2m_batched["grid_nodes", "to", "mesh_nodes"].edge_attr - edge_index = g2m_batched["grid_nodes", "to", "mesh_nodes"].edge_index + input_grid_nodes = torch.cat([grid_features, batched_senders], dim=-1) + input_mesh_nodes = batched_receivers + input_edge_attr = batched_edge_attr + edge_index = batched_edge_index # run the encoder. latent_grid_nodes, latent_mesh_nodes = self.encoder( @@ -168,13 +179,19 @@ def _run_encoder(self, grid_features): def _run_decoder(self, latent_mesh_nodes, latent_grid_nodes): # build big graph with batch_size disconnected copies of the graph, with features [(b n) f]. batch_size = latent_mesh_nodes.shape[0] - m2g_batched = Batch.from_data_list([self.graphs.m2g_graph] * batch_size) + _, _, batched_edge_index, batched_edge_attr = hetero_batch( + self.m2g_mesh_nodes, + self.m2g_grid_nodes, + self.m2g_edge_index, + self.m2g_edge_attr, + batch_size, + ) # load features. input_mesh_nodes = einops.rearrange(latent_mesh_nodes, "b n f -> (b n) f") input_grid_nodes = einops.rearrange(latent_grid_nodes, "b n f -> (b n) f") - input_edge_attr = m2g_batched["mesh_nodes", "to", "grid_nodes"].edge_attr - edge_index = m2g_batched["mesh_nodes", "to", "grid_nodes"].edge_index + input_edge_attr = batched_edge_attr + edge_index = batched_edge_index # run the decoder. output_grid_nodes = self.decoder( @@ -194,12 +211,17 @@ def _run_processor(self, latent_mesh_nodes, noise_levels): # build big graph with batch_size disconnected copies of the graph, with features [(b n) f]. batch_size = latent_mesh_nodes.shape[0] num_nodes = latent_mesh_nodes.shape[1] - mesh_batched = Batch.from_data_list([self.graphs.khop_mesh_graph] * batch_size) + _, batched_edge_index, batched_edge_attr = batch( + self.khop_mesh_nodes, + self.khop_mesh_edge_index, + self.khop_mesh_edge_attr if self.use_edges_features else None, + batch_size, + ) # load features. latent_mesh_nodes = einops.rearrange(latent_mesh_nodes, "b n f -> (b n) f") - input_edge_attr = mesh_batched.edge_attr if self.use_edges_features else None - edge_index = mesh_batched.edge_index + input_edge_attr = batched_edge_attr + edge_index = batched_edge_index # repeat noise levels for each node. noise_levels = einops.repeat(noise_levels, "b f -> (b n) f", n=num_nodes) @@ -272,3 +294,54 @@ def forward( # restore lon/lat dimensions. out = einops.rearrange(out, "b (lon lat) f -> b lon lat f", lon=self.num_lon) return out + + def _register_graph(self): + # we need to register all the tensors associated with the graph as buffers. In this way they + # will move to the same device of the model. These tensors won't be part of the state since + # persistent is set to False. + + self.register_buffer( + "g2m_grid_nodes", self.graphs.g2m_graph["grid_nodes"].x, persistent=False + ) + self.register_buffer( + "g2m_mesh_nodes", self.graphs.g2m_graph["mesh_nodes"].x, persistent=False + ) + self.register_buffer( + "g2m_edge_attr", + self.graphs.g2m_graph["grid_nodes", "to", "mesh_nodes"].edge_attr, + persistent=False, + ) + self.register_buffer( + "g2m_edge_index", + self.graphs.g2m_graph["grid_nodes", "to", "mesh_nodes"].edge_index, + persistent=False, + ) + + self.register_buffer("mesh_nodes", self.graphs.mesh_graph.x, persistent=False) + self.register_buffer("mesh_edge_attr", self.graphs.mesh_graph.edge_attr, persistent=False) + self.register_buffer("mesh_edge_index", self.graphs.mesh_graph.edge_index, persistent=False) + + self.register_buffer("khop_mesh_nodes", self.graphs.khop_mesh_graph.x, persistent=False) + self.register_buffer( + "khop_mesh_edge_attr", self.graphs.khop_mesh_graph.edge_attr, persistent=False + ) + self.register_buffer( + "khop_mesh_edge_index", self.graphs.khop_mesh_graph.edge_index, persistent=False + ) + + self.register_buffer( + "m2g_grid_nodes", self.graphs.m2g_graph["grid_nodes"].x, persistent=False + ) + self.register_buffer( + "m2g_mesh_nodes", self.graphs.m2g_graph["mesh_nodes"].x, persistent=False + ) + self.register_buffer( + "m2g_edge_attr", + self.graphs.m2g_graph["mesh_nodes", "to", "grid_nodes"].edge_attr, + persistent=False, + ) + self.register_buffer( + "m2g_edge_index", + self.graphs.m2g_graph["mesh_nodes", "to", "grid_nodes"].edge_index, + persistent=False, + ) diff --git a/graph_weather/models/gencast/images/animated.gif b/graph_weather/models/gencast/images/animated.gif new file mode 100644 index 00000000..c8f31778 Binary files /dev/null and b/graph_weather/models/gencast/images/animated.gif differ diff --git a/graph_weather/models/gencast/images/autoregressive.gif b/graph_weather/models/gencast/images/autoregressive.gif new file mode 100644 index 00000000..32d39f8d Binary files /dev/null and b/graph_weather/models/gencast/images/autoregressive.gif differ diff --git a/graph_weather/models/gencast/images/fullmodel.png b/graph_weather/models/gencast/images/fullmodel.png new file mode 100644 index 00000000..2c2bdb59 Binary files /dev/null and b/graph_weather/models/gencast/images/fullmodel.png differ diff --git a/graph_weather/models/gencast/images/readme.md b/graph_weather/models/gencast/images/readme.md new file mode 100644 index 00000000..e69de29b diff --git a/graph_weather/models/gencast/layers/encoder.py b/graph_weather/models/gencast/layers/encoder.py index e062752e..15b0a645 100644 --- a/graph_weather/models/gencast/layers/encoder.py +++ b/graph_weather/models/gencast/layers/encoder.py @@ -22,6 +22,7 @@ def __init__( hidden_dims: list[int], activation_layer: torch.nn.Module = torch.nn.ReLU, use_layer_norm: bool = True, + scale_factor: float = 1.0, ): """Initialize the Encoder. @@ -34,6 +35,9 @@ def __init__( Defaults to torch.nn.ReLU. use_layer_norm (bool, optional): if true add a LayerNorm at the end of each MLP. Defaults to True. + scale_factor (float): the message of the interaction network between the grid and the + the mesh is multiplied by the scale factor. Useful when fine-tuning a pretrained + model to a higher resolution. Defaults to 1. """ super().__init__() @@ -79,6 +83,7 @@ def __init__( hidden_dims=hidden_dims, use_layer_norm=use_layer_norm, activation_layer=activation_layer, + scale_factor=scale_factor, ) # Final grid nodes update diff --git a/graph_weather/models/gencast/layers/experimental/sparse_transformer.py b/graph_weather/models/gencast/layers/experimental/sparse_transformer.py index 387a64da..f822eeeb 100644 --- a/graph_weather/models/gencast/layers/experimental/sparse_transformer.py +++ b/graph_weather/models/gencast/layers/experimental/sparse_transformer.py @@ -75,6 +75,7 @@ def __init__( output_dim: int, num_heads: int, activation_layer: torch.nn.Module = nn.ReLU, + norm_first: bool = True, ): """Initialize SparseTransformer module. @@ -87,6 +88,7 @@ def __init__( num_heads (int): number of heads for multi-head attention. activation_layer (torch.nn.Module): activation function applied before returning the output. + norm_first (bool): if True apply layer normalization before attention. Defaults to True. """ super().__init__() @@ -109,6 +111,8 @@ def __init__( conditioning_dim=conditioning_dim, features_dim=output_dim ) + self.norm_first = norm_first + def forward( self, x: torch.Tensor, @@ -129,10 +133,21 @@ def forward( **kwargs: ignored by the module. """ - x = x + self.sparse_attention.forward( - x=x, adj=dglsp.spmatrix(indices=edge_index, shape=(x.shape[0], x.shape[0])) - ) - x = self.cond_norm_1(x, cond_param) - x = x + self.mlp(x) - x = self.cond_norm_2(x, cond_param) + if self.norm_first: + x1 = self.cond_norm_1(x, cond_param) + x = x + self.sparse_attention( + x=x1, adj=dglsp.spmatrix(indices=edge_index, shape=(x.shape[0], x.shape[0])) + ) + else: + x = x + self.sparse_attention( + x=x, adj=dglsp.spmatrix(indices=edge_index, shape=(x.shape[0], x.shape[0])) + ) + x = self.cond_norm_1(x, cond_param) + + if self.norm_first: + x2 = self.cond_norm_2(x, cond_param) + x = x + self.mlp(x2) + else: + x = x + self.mlp(x) + x = self.cond_norm_2(x, cond_param) return x diff --git a/graph_weather/models/gencast/layers/modules.py b/graph_weather/models/gencast/layers/modules.py index d0e353f2..6a07f348 100644 --- a/graph_weather/models/gencast/layers/modules.py +++ b/graph_weather/models/gencast/layers/modules.py @@ -87,6 +87,7 @@ def __init__( hidden_dims: list[int], use_layer_norm: bool = False, activation_layer: nn.Module = nn.ReLU, + scale_factor: float = 1.0, ): """Initialize the Interaction Network. @@ -98,6 +99,7 @@ def __init__( use_layer_norm (bool): if true add layer normalization to MLP's last layer. Defaults to False. activation_layer (torch.nn.Module): activation function. Defaults to nn.ReLU. + scale_factor (float): the message is multiplied by this value. Defaults to 1.0. """ super().__init__(aggr="add", flow="source_to_target") self.mlp_edges = MLP( @@ -116,12 +118,13 @@ def __init__( bias=True, activate_final=False, ) + self.scale_factor = scale_factor def message(self, x_i, x_j, edge_attr): """Message-passing step.""" x = torch.cat((x_i, x_j, edge_attr), dim=-1) x = self.mlp_edges(x) - return x + return self.scale_factor * x def forward( self, diff --git a/graph_weather/models/gencast/layers/processor.py b/graph_weather/models/gencast/layers/processor.py index 86e77e2d..50fb7c69 100644 --- a/graph_weather/models/gencast/layers/processor.py +++ b/graph_weather/models/gencast/layers/processor.py @@ -120,7 +120,7 @@ def __init__( if not has_dgl: raise ValueError("Please install DGL to use sparsity.") - for _ in range(num_blocks - 1): + for _ in range(num_blocks): # concatenating multi-head attention self.cond_transformers.append( SparseTransformer( diff --git a/graph_weather/models/gencast/train.py b/graph_weather/models/gencast/train.py new file mode 100644 index 00000000..a6bdb127 --- /dev/null +++ b/graph_weather/models/gencast/train.py @@ -0,0 +1,300 @@ +""" +Training script for GenCast. +""" + +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0,3" + +import lightning as L # noqa: E402 +import matplotlib.pyplot as plt # noqa: E402 +import numpy as np # noqa: E402 +import torch # noqa: E402 +from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint # noqa: E402 +from lightning.pytorch.loggers import WandbLogger # noqa: E402 +from torch.utils.data import DataLoader # noqa: E402 + +from graph_weather.data.gencast_dataloader import GenCastDataset # noqa: E402 +from graph_weather.models.gencast import Denoiser, Sampler, WeightedMSELoss # noqa: E402 + +torch.set_float32_matmul_precision("high") + +############################################## SETTINGS ############################################ + +# training settings +NUM_EPOCHS = 20 +NUM_DEVICES = 2 +NUM_ACC_GRAD = 1 +INITIAL_LR = 1e-3 +BATCH_SIZE = 16 # true batch size: BATCH_SIZE*NUM_DEVICES*NUM_ACC_GRAD +WARMUP = 1000 + +# dataloader setting +NUM_WORKERS = 8 +PREFETCH_FACTOR = 3 +PERSISTENT_WORKERS = True + +# model configs +CHECKPOINT_PATH = "checkpoints/epoch=3-step=10776.ckpt" +CFG = { + "hidden_dims": [512, 512], + "num_blocks": 16, + "num_heads": 4, + "splits": 4, + "num_hops": 8, + "sparse": True, + "use_edges_features": False, + "scale_factor": 1.0, +} + +# dataset configs +atmospheric_features = [ + "geopotential", + "specific_humidity", + "temperature", + "u_component_of_wind", + "v_component_of_wind", + "vertical_velocity", +] +single_features = [ + "2m_temperature", + "10m_u_component_of_wind", + "10m_v_component_of_wind", + "mean_sea_level_pressure", + # "sea_surface_temperature", + "total_precipitation_12hr", +] +static_features = [ + "geopotential_at_surface", + "land_sea_mask", +] + +OBS_PATH = "dataset.zarr" +# OBS_PATH = "gs://weatherbench2/datasets/era5/1959-2022-6h-128x64_equiangular_conservative.zarr" +# OBS_PATH = 'gs://weatherbench2/datasets/era5/1959-2022-6h-1440x721.zarr' +# OBS_PATH = 'gs://weatherbench2/datasets/era5/1959-2022-6h-512x256_equiangular_conservative.zarr' + +################################################################################################# + + +class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): + """Cosine Scheduler with Warmup""" + + def __init__(self, optimizer, warmup, max_iters): + """Initialize the scheduler""" + self.warmup = warmup + self.max_num_iters = max_iters + super().__init__(optimizer) + + def get_lr(self): + """Return the learning rates""" + lr_factor = self.get_lr_factor(epoch=self.last_epoch) + return [base_lr * lr_factor for base_lr in self.base_lrs] + + def get_lr_factor(self, epoch): + """Return the scaling factor for the learning rate at a given iteration""" + lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters)) + if epoch <= self.warmup: + lr_factor *= epoch * 1.0 / self.warmup + return lr_factor + + +class LitModel(L.LightningModule): + """Lightning wrapper for Gencast""" + + def __init__( + self, + warmup, + learning_rate, + cosine_t_max, + pressure_levels, + grid_lon, + grid_lat, + input_features_dim, + output_features_dim, + hidden_dims, + num_blocks, + num_heads, + splits, + num_hops, + sparse, + use_edges_features, + scale_factor=1.0, + ): + """Initialize the module""" + super().__init__() + + self.model = Denoiser( + grid_lon=grid_lon, + grid_lat=grid_lat, + input_features_dim=input_features_dim, + output_features_dim=output_features_dim, + hidden_dims=hidden_dims, + num_blocks=num_blocks, + num_heads=num_heads, + splits=splits, + num_hops=num_hops, + device=self.device, + sparse=sparse, + use_edges_features=use_edges_features, + scale_factor=scale_factor, + ) + + self.criterion = WeightedMSELoss( + grid_lat=torch.tensor(grid_lat).to(self.device), + pressure_levels=torch.tensor(pressure_levels).to(self.device), + num_atmospheric_features=len(atmospheric_features), + single_features_weights=torch.tensor([1.0, 0.1, 0.1, 0.1, 0.1]).to(self.device), + ) + + self.learning_rate = learning_rate + self.cosine_t_max = cosine_t_max + self.warmup = warmup + + def forward(self, corrupted_targets, prev_inputs, noise_levels): + """Compute forward pass""" + return self.model(corrupted_targets, prev_inputs, noise_levels) + + def training_step(self, batch): + """Single training step""" + corrupted_targets, prev_inputs, noise_levels, target_residuals = batch + + preds = self.model( + corrupted_targets=corrupted_targets, + prev_inputs=prev_inputs, + noise_levels=noise_levels, + ) + loss = self.criterion(preds, noise_levels, target_residuals) + self.log("loss", loss) + return loss + + def configure_optimizers(self): + """Initialize the optimizer""" + opt = torch.optim.AdamW( + self.parameters(), lr=self.learning_rate, weight_decay=0.1, betas=(0.9, 0.95) + ) + # sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.cosine_t_max) + sch = CosineWarmupScheduler(opt, warmup=self.warmup, max_iters=self.cosine_t_max) + return { + "optimizer": opt, + "lr_scheduler": { + "scheduler": sch, + "monitor": "train_loss", + "interval": "step", # step means "batch" here, default: epoch + "frequency": 1, # default + }, + } + + def plot_sample(self, prev_inputs, target_residuals): + """Plot 2m_temperature and geopotential""" + prev_inputs = prev_inputs[:1, :, :, :] + target = target_residuals[:1, :, :, :] + sampler = Sampler() + preds = sampler.sample(self.model, prev_inputs) + + fig1, ax = plt.subplots(2) + ax[0].imshow(preds[0, :, :, 78].T.cpu(), origin="lower", cmap="RdBu", vmin=-5, vmax=5) + ax[0].set_xticks([]) + ax[0].set_yticks([]) + ax[0].set_title("Diffusion sampling prediction") + + ax[1].imshow(target[0, :, :, 78].T.cpu(), origin="lower", cmap="RdBu", vmin=-5, vmax=5) + ax[1].set_xticks([]) + ax[1].set_yticks([]) + ax[1].set_title("Ground truth") + + fig2, ax = plt.subplots(2) + ax[0].imshow(preds[0, :, :, 12].T.cpu(), origin="lower", cmap="RdBu", vmin=-5, vmax=5) + ax[0].set_xticks([]) + ax[0].set_yticks([]) + ax[0].set_title("Diffusion sampling prediction") + + ax[1].imshow(target[0, :, :, 12].T.cpu(), origin="lower", cmap="RdBu", vmin=-5, vmax=5) + ax[1].set_xticks([]) + ax[1].set_yticks([]) + ax[1].set_title("Ground truth") + + return fig1, fig2 + + +class SamplingCallback(Callback): + """Callback for sampling when a new epoch starts""" + + def __init__(self, data): + """Initialize the callback""" + _, prev_inputs, _, target_residuals = data + self.prev_inputs = torch.tensor(prev_inputs).unsqueeze(0) + self.target_residuals = torch.tensor(target_residuals).unsqueeze(0) + + def on_train_epoch_start(self, trainer, pl_module): + """Sample and log predictions""" + print("Epoch is starting") + fig1, fig2 = pl_module.plot_sample( + self.prev_inputs.to(pl_module.device), self.target_residuals.to(pl_module.device) + ) + trainer.logger.log_image( + key="samples", images=[fig1, fig2], caption=["2m_temperature", "geopotential"] + ) + print("Uploaded samples") + + +if __name__ == "__main__": + # define dataloader + dataset = GenCastDataset( + obs_path=OBS_PATH, + atmospheric_features=atmospheric_features, + single_features=single_features, + static_features=static_features, + max_year=2018, + time_step=2, + ) + + dataloader = DataLoader( + dataset, + shuffle=True, + batch_size=BATCH_SIZE, + num_workers=NUM_WORKERS, + prefetch_factor=PREFETCH_FACTOR, + persistent_workers=PERSISTENT_WORKERS, + multiprocessing_context="forkserver", + ) + + # define/resume model + num_steps = NUM_EPOCHS * len(dataloader) // (NUM_DEVICES * NUM_ACC_GRAD) + initial_lr = INITIAL_LR + + denoiser = LitModel.load_from_checkpoint( + checkpoint_path=CHECKPOINT_PATH, + warmup=WARMUP, + learning_rate=initial_lr, + cosine_t_max=num_steps, + pressure_levels=dataset.pressure_levels, + grid_lon=dataset.grid_lon, + grid_lat=dataset.grid_lat, + input_features_dim=dataset.input_features_dim, + output_features_dim=dataset.output_features_dim, + **CFG, + ) + # denoiser = torch.compile(denoiser) + + # define trainer + wandb_logger = WandbLogger(project="gencast") + checkpoint_callback = ModelCheckpoint(dirpath="checkpoints/") + lr_monitor = LearningRateMonitor(logging_interval="step") + + trainer = L.Trainer( + accumulate_grad_batches=NUM_ACC_GRAD, + accelerator="gpu", + devices=NUM_DEVICES, + strategy="DDP", + # precision=16, + max_epochs=NUM_EPOCHS, + logger=wandb_logger, + callbacks=[checkpoint_callback, SamplingCallback(data=dataset[0]), lr_monitor], + log_every_n_steps=1, + ) + + # start training + print("Starting training") + trainer.fit(model=denoiser, train_dataloaders=dataloader) diff --git a/graph_weather/models/gencast/utils/batching.py b/graph_weather/models/gencast/utils/batching.py new file mode 100644 index 00000000..58800c3c --- /dev/null +++ b/graph_weather/models/gencast/utils/batching.py @@ -0,0 +1,69 @@ +"""Utils for batching graphs.""" + +import torch + + +def batch(senders, edge_index, edge_attr=None, batch_size=1): + """Build big batched graph. + + Returns nodes and edges of a big graph with batch_size disconnected copies of the original + graph, with features shape [(b n) f]. + + Args: + senders (torch.Tensor): nodes' features. + edge_index (torch.Tensor): edge index tensor. + edge_attr (torch.Tensor, optional): edge attributes tensor, if None returns None. + Defaults to None. + batch_size (int): batch size. Defaults to 1. + + Returns: + batched_senders, batched_edge_index, batched_edge_attr + """ + ns = senders.shape[0] + batched_senders = senders + batched_edge_attr = edge_attr + batched_edge_index = edge_index + + for i in range(1, batch_size): + batched_senders = torch.cat([batched_senders, senders], dim=0) + batched_edge_index = torch.cat([batched_edge_index, edge_index + i * ns], dim=1) + + if edge_attr is not None: + batched_edge_attr = torch.cat([batched_edge_attr, edge_attr], dim=0) + + return batched_senders, batched_edge_index, batched_edge_attr + + +def hetero_batch(senders, receivers, edge_index, edge_attr=None, batch_size=1): + """Build big batched heterogenous graph. + + Returns nodes and edges of a big graph with batch_size disconnected copies of the original + graph, with features shape [(b n) f]. + + Args: + senders (torch.Tensor): senders' features. + receivers (torch.Tensor): receivers' features. + edge_index (torch.Tensor): edge index tensor. + edge_attr (torch.Tensor, optional): edge attributes tensor, if None returns None. + Defaults to None. + batch_size (int): batch size. Defaults to 1. + + Returns: + batched_senders, batched_edge_index, batched_edge_attr + """ + ns = senders.shape[0] + nr = receivers.shape[0] + nodes_shape = torch.tensor([[ns], [nr]]).to(edge_index) + batched_senders = senders + batched_receivers = receivers + batched_edge_attr = edge_attr + batched_edge_index = edge_index + + for i in range(1, batch_size): + batched_senders = torch.cat([batched_senders, senders], dim=0) + batched_receivers = torch.cat([batched_receivers, receivers], dim=0) + batched_edge_index = torch.cat([batched_edge_index, edge_index + i * nodes_shape], dim=1) + if edge_attr is not None: + batched_edge_attr = torch.cat([batched_edge_attr, edge_attr], dim=0) + + return batched_senders, batched_receivers, batched_edge_index, batched_edge_attr diff --git a/graph_weather/models/gencast/utils/noise.py b/graph_weather/models/gencast/utils/noise.py index 9262407f..d6e24378 100644 --- a/graph_weather/models/gencast/utils/noise.py +++ b/graph_weather/models/gencast/utils/noise.py @@ -1,8 +1,9 @@ """Noise generation utils.""" +import einops import numpy as np -import pyshtools as pysh import torch +import torch_harmonics as th def generate_isotropic_noise(num_lon: int, num_lat: int, num_samples=1, isotropic=True): @@ -35,15 +36,16 @@ def generate_isotropic_noise(num_lon: int, num_lat: int, num_samples=1, isotropi ) if isotropic: - l_max = num_lat // 2 - power = np.ones(l_max, dtype=float) / l_max**2 # normalized to get each point with std 1 - grid = np.zeros((num_lon, num_lat, num_samples)) - for i in range(num_samples): - clm = pysh.SHCoeffs.from_random(power, power_unit="per_lm") - grid[:, :, i] = ( - clm.expand(grid="DH2", extend=extend).to_array().transpose()[:num_lon, :num_lat] - ) - noise = grid.astype(np.float32) + lmax = num_lat - 1 if extend else num_lat + mmax = lmax + 1 + coeffs = torch.randn(num_samples, lmax, mmax, dtype=torch.complex64) / np.sqrt( + (num_lat**2) // 2 + ) + isht = th.InverseRealSHT( + nlat=num_lat, nlon=num_lon, lmax=lmax, mmax=mmax, grid="equiangular" + ) + noise = isht(coeffs) * np.sqrt(2 * np.pi) + noise = einops.rearrange(noise, "b lat lon -> lon lat b").numpy() else: noise = np.random.randn(num_lon, num_lat, num_samples) return noise @@ -88,7 +90,7 @@ def __init__(self, sigma_data: float = 1): def c_skip(self, sigma): """Scaling factor for skip connection.""" - return self.sigma_data / (sigma**2 + self.sigma_data**2) + return self.sigma_data**2 / (sigma**2 + self.sigma_data**2) def c_out(self, sigma): """Scaling factor for output.""" diff --git a/graph_weather/models/gencast/weighted_mse_loss.py b/graph_weather/models/gencast/weighted_mse_loss.py index baa2b21c..85da0e56 100644 --- a/graph_weather/models/gencast/weighted_mse_loss.py +++ b/graph_weather/models/gencast/weighted_mse_loss.py @@ -36,19 +36,19 @@ def __init__( """ super().__init__() - self.area_weights = None - self.features_weights = None + area_weights = None + features_weights = None if grid_lat is not None: - self.area_weights = torch.cos(grid_lat * np.pi / 180.0) - + area_weights = torch.abs(torch.cos(grid_lat * np.pi / 180.0)) + area_weights = area_weights / torch.mean(area_weights) if ( pressure_levels is not None and num_atmospheric_features is not None and single_features_weights is not None ): pressure_weights = pressure_levels / torch.sum(pressure_levels) - self.features_weights = torch.cat( + features_weights = torch.cat( (pressure_weights.repeat(num_atmospheric_features), single_features_weights), dim=-1 ) elif ( @@ -63,6 +63,9 @@ def __init__( self.sigma_data = 1 # assuming normalized data! + self.register_buffer("area_weights", area_weights, persistent=False) + self.register_buffer("features_weights", features_weights, persistent=False) + def _lambda_sigma(self, noise_level): noise_weights = (noise_level**2 + self.sigma_data**2) / (noise_level * self.sigma_data) ** 2 return noise_weights # [batch, 1] diff --git a/requirements.txt b/requirements.txt index c8e9ca8e..79d45e1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,10 @@ huggingface-hub datasets einops torch-geometric-temporal -pyshtools +torch_harmonics trimesh rtree +xarray +setuptools +pydantic +safetensors diff --git a/tests/test_gencast.py b/tests/test_gencast.py new file mode 100644 index 00000000..2ed23fe5 --- /dev/null +++ b/tests/test_gencast.py @@ -0,0 +1,186 @@ +import numpy as np +import torch +import pytest +from packaging.version import Version + +from torch_geometric.transforms import TwoHop + +from graph_weather.models.gencast.utils.noise import ( + generate_isotropic_noise, + sample_noise_level, +) +from graph_weather.models.gencast import GraphBuilder, WeightedMSELoss, Denoiser, Sampler +from graph_weather.models.gencast.layers.modules import FourierEmbedding + + +def test_gencast_noise(): + num_lon = 360 + num_lat = 180 + num_samples = 5 + target_residuals = np.zeros((num_lon, num_lat, num_samples)) + noise_level = sample_noise_level() + noise = generate_isotropic_noise( + num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1] + ) + corrupted_residuals = target_residuals + noise_level * noise + assert corrupted_residuals.shape == target_residuals.shape + assert not np.isnan(corrupted_residuals).any() + + num_lon = 360 + num_lat = 181 + num_samples = 5 + target_residuals = np.zeros((num_lon, num_lat, num_samples)) + noise_level = sample_noise_level() + noise = generate_isotropic_noise( + num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1] + ) + corrupted_residuals = target_residuals + noise_level * noise + assert corrupted_residuals.shape == target_residuals.shape + assert not np.isnan(corrupted_residuals).any() + + num_lon = 100 + num_lat = 100 + num_samples = 5 + target_residuals = np.zeros((num_lon, num_lat, num_samples)) + noise_level = sample_noise_level() + noise = generate_isotropic_noise( + num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1], isotropic=False + ) + corrupted_residuals = target_residuals + noise_level * noise + assert corrupted_residuals.shape == target_residuals.shape + assert not np.isnan(corrupted_residuals).any() + + +def test_gencast_graph(): + grid_lat = np.arange(-90, 90, 1) + grid_lon = np.arange(0, 360, 1) + graphs = GraphBuilder(grid_lon=grid_lon, grid_lat=grid_lat, splits=4, num_hops=8) + + # compare khop sparse implementation with pyg. + transform = TwoHop() + khop_mesh_graph_pyg = graphs.mesh_graph + for i in range(3): # 8-hop mesh + khop_mesh_graph_pyg = transform(khop_mesh_graph_pyg) + + assert graphs.mesh_graph.x.shape[0] == 2562 + assert graphs.g2m_graph["grid_nodes"].x.shape[0] == 360 * 180 + assert graphs.m2g_graph["mesh_nodes"].x.shape[0] == 2562 + assert not torch.isnan(graphs.mesh_graph.edge_attr).any() + assert graphs.khop_mesh_graph.x.shape[0] == 2562 + assert torch.allclose(graphs.khop_mesh_graph.x, khop_mesh_graph_pyg.x) + assert torch.allclose(graphs.khop_mesh_graph.edge_index, khop_mesh_graph_pyg.edge_index) + + +def test_gencast_loss(): + grid_lat = torch.arange(-90, 90, 1) + grid_lon = torch.arange(0, 360, 1) + pressure_levels = torch.tensor( + [50.0, 100.0, 150.0, 200.0, 250, 300, 400, 500, 600, 700, 850, 925, 1000.0] + ) + single_features_weights = torch.tensor([1, 0.1, 0.1, 0.1, 0.1]) + num_atmospheric_features = 6 + batch_size = 3 + features_dim = len(pressure_levels) * num_atmospheric_features + len(single_features_weights) + + loss = WeightedMSELoss( + grid_lat=grid_lat, + pressure_levels=pressure_levels, + num_atmospheric_features=num_atmospheric_features, + single_features_weights=single_features_weights, + ) + + preds = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) + noise_levels = torch.rand((batch_size, 1)) + targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) + assert loss.forward(preds, noise_levels, targets) is not None + + +def test_gencast_denoiser(): + grid_lat = np.arange(-90, 90, 1) + grid_lon = np.arange(0, 360, 1) + input_features_dim = 10 + output_features_dim = 5 + batch_size = 3 + + denoiser = Denoiser( + grid_lon=grid_lon, + grid_lat=grid_lat, + input_features_dim=input_features_dim, + output_features_dim=output_features_dim, + hidden_dims=[16, 32], + num_blocks=3, + num_heads=4, + splits=0, + num_hops=1, + device=torch.device("cpu"), + ).eval() + + corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim)) + prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2 * input_features_dim)) + noise_levels = torch.rand((batch_size, 1)) + + with torch.no_grad(): + preds = denoiser( + corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels + ) + + assert not torch.isnan(preds).any() + + +def test_gencast_fourier(): + batch_size = 10 + output_dim = 20 + fourier_embedder = FourierEmbedding(output_dim=output_dim, num_frequencies=32, base_period=16) + t = torch.rand((batch_size, 1)) + assert fourier_embedder(t).shape == (batch_size, output_dim) + + +def test_gencast_sampler(): + grid_lat = np.arange(-90, 90, 1) + grid_lon = np.arange(0, 360, 1) + input_features_dim = 10 + output_features_dim = 5 + + denoiser = Denoiser( + grid_lon=grid_lon, + grid_lat=grid_lat, + input_features_dim=input_features_dim, + output_features_dim=output_features_dim, + hidden_dims=[16, 32], + num_blocks=3, + num_heads=4, + splits=0, + num_hops=1, + device=torch.device("cpu"), + ).eval() + + prev_inputs = torch.randn((1, len(grid_lon), len(grid_lat), 2 * input_features_dim)) + + sampler = Sampler() + preds = sampler.sample(denoiser, prev_inputs) + assert not torch.isnan(preds).any() + assert preds.shape == (1, len(grid_lon), len(grid_lat), output_features_dim) + + +@pytest.mark.skipif( + Version(torch.__version__).release != Version("2.3.0").release, + reason="dgl tests for experimental features only runs with torch 2.3.0", +) +def test_gencast_full(): + # download weights from HF + denoiser = Denoiser.from_pretrained( + "openclimatefix/gencast-128x64", + grid_lon=np.arange(0, 360, 360 / 128), + grid_lat=np.arange(-90, 90, 180 / 64) + 1 / 2 * 180 / 64, + ) + + # load inputs and targets + prev_inputs = torch.randn([1, 128, 64, 178]) + target_residuals = torch.randn([1, 128, 64, 83]) + + # predict + sampler = Sampler() + preds = sampler.sample(denoiser, prev_inputs) + + assert not torch.isnan(preds).any() + assert preds.shape == target_residuals.shape diff --git a/tests/test_model.py b/tests/test_model.py index 90da16b7..94ed991b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,7 +1,6 @@ import h3 import numpy as np import torch -from torch_geometric.transforms import TwoHop from graph_weather import GraphWeatherAssimilator, GraphWeatherForecaster from graph_weather.models import ( @@ -17,13 +16,6 @@ ) from graph_weather.models.losses import NormalizedMSELoss -from graph_weather.models.gencast.utils.noise import ( - generate_isotropic_noise, - sample_noise_level, -) -from graph_weather.models.gencast import GraphBuilder, WeightedMSELoss, Denoiser, Sampler -from graph_weather.models.gencast.layers.modules import FourierEmbedding - def test_encoder(): lat_lons = [] @@ -313,152 +305,3 @@ def test_meta_model(): assert not torch.isnan(out).any() assert not torch.isnan(out).any() assert out.size() == features.size() - - -def test_gencast_noise(): - num_lon = 360 - num_lat = 180 - num_samples = 5 - target_residuals = np.zeros((num_lon, num_lat, num_samples)) - noise_level = sample_noise_level() - noise = generate_isotropic_noise( - num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1] - ) - corrupted_residuals = target_residuals + noise_level * noise - assert corrupted_residuals.shape == target_residuals.shape - assert not np.isnan(corrupted_residuals).any() - - num_lon = 360 - num_lat = 181 - num_samples = 5 - target_residuals = np.zeros((num_lon, num_lat, num_samples)) - noise_level = sample_noise_level() - noise = generate_isotropic_noise( - num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1] - ) - corrupted_residuals = target_residuals + noise_level * noise - assert corrupted_residuals.shape == target_residuals.shape - assert not np.isnan(corrupted_residuals).any() - - num_lon = 100 - num_lat = 100 - num_samples = 5 - target_residuals = np.zeros((num_lon, num_lat, num_samples)) - noise_level = sample_noise_level() - noise = generate_isotropic_noise( - num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1], isotropic=False - ) - corrupted_residuals = target_residuals + noise_level * noise - assert corrupted_residuals.shape == target_residuals.shape - assert not np.isnan(corrupted_residuals).any() - - -def test_gencast_graph(): - grid_lat = np.arange(-90, 90, 1) - grid_lon = np.arange(0, 360, 1) - graphs = GraphBuilder(grid_lon=grid_lon, grid_lat=grid_lat, splits=4, num_hops=8) - - # compare khop sparse implementation with pyg. - transform = TwoHop() - khop_mesh_graph_pyg = graphs.mesh_graph - for i in range(3): # 8-hop mesh - khop_mesh_graph_pyg = transform(khop_mesh_graph_pyg) - - assert graphs.mesh_graph.x.shape[0] == 2562 - assert graphs.g2m_graph["grid_nodes"].x.shape[0] == 360 * 180 - assert graphs.m2g_graph["mesh_nodes"].x.shape[0] == 2562 - assert not torch.isnan(graphs.mesh_graph.edge_attr).any() - assert graphs.khop_mesh_graph.x.shape[0] == 2562 - assert torch.allclose(graphs.khop_mesh_graph.x, khop_mesh_graph_pyg.x) - assert torch.allclose(graphs.khop_mesh_graph.edge_index, khop_mesh_graph_pyg.edge_index) - - -def test_gencast_loss(): - grid_lat = torch.arange(-90, 90, 1) - grid_lon = torch.arange(0, 360, 1) - pressure_levels = torch.tensor( - [50.0, 100.0, 150.0, 200.0, 250, 300, 400, 500, 600, 700, 850, 925, 1000.0] - ) - single_features_weights = torch.tensor([1, 0.1, 0.1, 0.1, 0.1]) - num_atmospheric_features = 6 - batch_size = 3 - features_dim = len(pressure_levels) * num_atmospheric_features + len(single_features_weights) - - loss = WeightedMSELoss( - grid_lat=grid_lat, - pressure_levels=pressure_levels, - num_atmospheric_features=num_atmospheric_features, - single_features_weights=single_features_weights, - ) - - preds = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) - noise_levels = torch.rand((batch_size, 1)) - targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) - assert loss.forward(preds, noise_levels, targets) is not None - - -def test_gencast_denoiser(): - grid_lat = np.arange(-90, 90, 1) - grid_lon = np.arange(0, 360, 1) - input_features_dim = 10 - output_features_dim = 5 - batch_size = 3 - - denoiser = Denoiser( - grid_lon=grid_lon, - grid_lat=grid_lat, - input_features_dim=input_features_dim, - output_features_dim=output_features_dim, - hidden_dims=[16, 32], - num_blocks=3, - num_heads=4, - splits=0, - num_hops=1, - device=torch.device("cpu"), - ).eval() - - corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim)) - prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2 * input_features_dim)) - noise_levels = torch.rand((batch_size, 1)) - - with torch.no_grad(): - preds = denoiser( - corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels - ) - - assert not torch.isnan(preds).any() - - -def test_gencast_fourier(): - batch_size = 10 - output_dim = 20 - fourier_embedder = FourierEmbedding(output_dim=output_dim, num_frequencies=32, base_period=16) - t = torch.rand((batch_size, 1)) - assert fourier_embedder(t).shape == (batch_size, output_dim) - - -def test_gencast_sampler(): - grid_lat = np.arange(-90, 90, 1) - grid_lon = np.arange(0, 360, 1) - input_features_dim = 10 - output_features_dim = 5 - - denoiser = Denoiser( - grid_lon=grid_lon, - grid_lat=grid_lat, - input_features_dim=input_features_dim, - output_features_dim=output_features_dim, - hidden_dims=[16, 32], - num_blocks=3, - num_heads=4, - splits=0, - num_hops=1, - device=torch.device("cpu"), - ).eval() - - prev_inputs = torch.randn((1, len(grid_lon), len(grid_lat), 2 * input_features_dim)) - - sampler = Sampler() - preds = sampler.sample(denoiser, prev_inputs) - assert not torch.isnan(preds).any() - assert preds.shape == (1, len(grid_lon), len(grid_lat), output_features_dim)