Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training GenCast #123

Merged
merged 32 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9eb6f40
Add graph as buffer
gbruno16 Jul 22, 2024
6d366c6
Add training script
gbruno16 Jul 22, 2024
99e9d6f
Remove unused variable
gbruno16 Jul 22, 2024
8ab40ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2024
112ee21
Update README.md
gbruno16 Aug 4, 2024
8452ef2
Add images folder
gbruno16 Aug 4, 2024
5f4b24f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2024
de92b09
Add files via upload
gbruno16 Aug 4, 2024
7e7ee3a
Add files via upload
gbruno16 Aug 4, 2024
61eedbf
Update README.md
gbruno16 Aug 4, 2024
cac3986
Add norm first
gbruno16 Aug 14, 2024
c5aa995
Update noise and train
gbruno16 Aug 14, 2024
fd75b19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 14, 2024
c5d0206
Add requirement
gbruno16 Aug 14, 2024
b63c62d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 14, 2024
b023c8a
Add requirement
gbruno16 Aug 14, 2024
1f59f4e
Add files via upload
gbruno16 Aug 14, 2024
beee826
Update README.md
gbruno16 Aug 14, 2024
7bcd250
Update graph_weather/models/gencast/denoiser.py
gbruno16 Aug 15, 2024
e163cb7
Add HF support
gbruno16 Aug 19, 2024
2b6e323
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
e83a920
Add full test
gbruno16 Aug 23, 2024
b74fe57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2024
ada0d76
Update workflows.yaml
gbruno16 Aug 23, 2024
9775cf0
Update requirements.txt
gbruno16 Aug 23, 2024
842829f
Update requirements.txt
gbruno16 Aug 23, 2024
81d66c0
Edit test full model
gbruno16 Aug 23, 2024
d30a0ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2024
fd9aee1
Update tests
gbruno16 Aug 24, 2024
fad4e38
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2024
1471392
Fix version check
gbruno16 Aug 24, 2024
b604526
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies:
- zarr
- h3-py
- numpy
- pyshtools
- torch_harmonics
- pip:
- datasets
- einops
Expand Down
2 changes: 1 addition & 1 deletion environment_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies:
- zarr
- h3-py
- numpy
- pyshtools
- torch_harmonics
- pip:
- datasets
- einops
Expand Down
126 changes: 126 additions & 0 deletions graph_weather/models/gencast/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,127 @@
# GenCast

## Overview
This repository offers an unofficial implementation of [GenCast](https://arxiv.org/abs/2312.15796), a cutting-edge model designed to enhance weather forecasting accuracy. GenCast integrates diffusion models, sparse transformers, and graph neural networks (GNNs) to improve upon the GraphCast model with a score-based generative approach. This innovative combination aims to revolutionize weather prediction, significantly contributing to climate research and mitigation efforts. Additionally, GenCast supports ensemble predictions to assess the probability of extreme weather events.

Below is an illustration of the full model architecture:

<p align="center" width="100%">
<img width="50%" src="images/fullmodel.png">
</p>

After 20 epochs, the diffusion process for generating low resolution (128x64) 12h residuals predictions, looks like this:

<p align="center" width="100%">
<img width="75%" src="images/animated.gif">
</p>

A longer time range can be achieved by using the model's own predictions as inputs in an autoregressive manner:

<p align="center" width="100%">
<img width="75%" src="images/autoregressive.gif">
</p>

## The Denoiser
### Description
The core component of GenCast i is the `Denoiser` module: it takes as inputs the previous two timesteps, the corrupted target residual, and the
noise level, and outputs the denoised predictions. The `Denoiser` operates as follows:
- initializes the graph using the `GraphBuilder` class,
- combines `encoder`, `processor`, and `decoder`,
- preconditions inputs and outputs on the noise levels using the parametrization from [Karras et al. (2022)](https://arxiv.org/abs/2206.00364).

The code is modular, allowing for easy swapping of the graph, encoder, processor, and decoder with custom architectures. The main arguments are:
- `grid_lon` (np.ndarray): array of longitudes.
- `grid_lat` (np.ndarray): array of latitudes.
- `input_features_dim` (int): dimension of the input features for a single timestep.
- `output_features_dim` (int): dimension of the target features.
- `hidden_dims` (list[int], optional): list of dimensions for the hidden layers in the MLPs used in GenCast. This also determines the latent dimension. Defaults to [512, 512].
- `num_blocks` (int, optional): number of transformer blocks in Processor. Defaults to 16.
- `num_heads` (int, optional): number of heads for each transformer. Defaults to 4.
- `splits` (int, optional): number of time to split the icosphere during graph building. Defaults to 5.
- `num_hops` (int, optional): the transformes will attention to the (2^num_hops)-neighbours of each node. Defaults to 16.
- `device` (torch.device, optional): device on which we want to build graph. Defaults to torch.device("cpu").
- `sparse` (bool): if true the processor will apply Sparse Attention using DGL backend. Defaults to False.
- `use_edges_features` (bool): if true use mesh edges features inside the Processor. Defaults to True.
- `scale_factor` (float): the message in the Encoder is multiplied by the scale factor. Defaults to 1.0.

> [!NOTE]
> If the graph has many edges, setting `sparse = True` may perform better in terms of memory and speed. Note that `sparse = False` uses PyG as the backend, while `sparse = True` uses DGL. The two implementations are not exactly equivalent: the former is described in the paper _"Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification"_ and can also handle edge features, while the latter is a classical transformer that performs multi-head attention utilizing the mask's sparsity and does not include edge features in the computations.

> [!WARNING]
> The sparse implementation currently does not support `Float16/BFloat16` precision.

> [!NOTE]
> To fine-tune a pretrained model with a higher resolution dataset, the `scale_factor` should be set accordingly. For example: if the starting resolution is 1 deg and the final resolution is 0.25 deg, then the scale factor is 1/16.

### Example of usage
```python
import torch
import numpy as np
from graph_weather.models.gencast import Denoiser

grid_lat = np.arange(-90, 90, 1)
grid_lon = np.arange(0, 360, 1)
input_features_dim = 10
output_features_dim = 5
batch_size = 16

denoiser = Denoiser(
grid_lon=grid_lon,
grid_lat=grid_lat,
input_features_dim=input_features_dim,
output_features_dim=output_features_dim,
hidden_dims=[32, 32],
num_blocks=8,
num_heads=4,
splits=4,
num_hops=8,
device=torch.device("cpu"),
sparse=True,
use_edges_features=False,
).eval()

corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim))
prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2 * input_features_dim))
noise_levels = torch.rand((batch_size, 1))

preds = denoiser(
corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels
)
```
## The Sampler
The sampler employs the second-order `DPMSolver++2S` solver, augmented with stochastic churn and noise inflation techniques from [Karras et al. (2022)](https://arxiv.org/abs/2206.00364), to introduce additional stochasticity into the sampling process. When conditioning on previous timesteps, it follows the Conditional Denoising Estimator approach as outlined by Batzolis et al. (2021).

### Example of usage
```python
import torch
import numpy as np
from graph_weather.models.gencast import Sampler, Denoiser

grid_lat = np.arange(-90, 90, 1)
grid_lon = np.arange(0, 360, 1)
input_features_dim = 10
output_features_dim = 5

denoiser = Denoiser(
grid_lon=grid_lon,
grid_lat=grid_lat,
input_features_dim=input_features_dim,
output_features_dim=output_features_dim,
hidden_dims=[16, 32],
num_blocks=3,
num_heads=4,
splits=0,
num_hops=1,
device=torch.device("cpu"),
).eval()

prev_inputs = torch.randn((1, len(grid_lon), len(grid_lat), 2 * input_features_dim))

sampler = Sampler()
preds = sampler.sample(denoiser, prev_inputs)
```
> [!NOTE]
> The Sampler class supports modifying all the sampling parameters ($S_{churn}$, $S_{noise}$, $\sigma_{max}$, $\sigma_{min}$ ...). The defaults values are detailed in Gencast's paper.

## Training
The script `train.py` provides a basic setup for training the model. It follows guidelines from GraphCast and GenCast, combining a linear warmup phase with a cosine scheduler. The script supports multi-device DDP training, gradient accumulation, and WandB logging.
102 changes: 87 additions & 15 deletions graph_weather/models/gencast/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import einops
import numpy as np
import torch
from torch_geometric.data import Batch

from graph_weather.models.gencast.graph.graph_builder import GraphBuilder
from graph_weather.models.gencast.layers.decoder import Decoder
from graph_weather.models.gencast.layers.encoder import Encoder
from graph_weather.models.gencast.layers.processor import Processor
from graph_weather.models.gencast.utils.batching import batch, hetero_batch
from graph_weather.models.gencast.utils.noise import Preconditioner


Expand All @@ -36,6 +36,7 @@ def __init__(
device: torch.device = torch.device("cpu"),
sparse: bool = False,
use_edges_features: bool = True,
scale_factor: float = 1.0,
):
"""Initialize the Denoiser.

Expand All @@ -58,6 +59,9 @@ def __init__(
Defaults to False.
use_edges_features (bool): if true use mesh edges features inside the Processor.
Defaults to True.
scale_factor (float): in the Encoder the message passing output is multiplied by the
scale factor. This is important when you want to fine-tune a pretrained model to a
higher resolution. Defaults to 1.
"""
super().__init__()
self.num_lon = len(grid_lon)
Expand All @@ -76,6 +80,8 @@ def __init__(
add_edge_features_to_khop=use_edges_features,
)

self._register_graph()

# Initialize Encoder
self.encoder = Encoder(
grid_dim=output_features_dim + 2 * input_features_dim + self.graphs.grid_nodes_dim,
Expand All @@ -84,6 +90,7 @@ def __init__(
hidden_dims=hidden_dims,
activation_layer=torch.nn.SiLU,
use_layer_norm=True,
scale_factor=scale_factor,
)

# Initialize Processor
Expand Down Expand Up @@ -138,16 +145,19 @@ def _check_shapes(self, corrupted_targets, prev_inputs, noise_levels):
def _run_encoder(self, grid_features):
# build big graph with batch_size disconnected copies of the graph, with features [(b n) f].
batch_size = grid_features.shape[0]
g2m_batched = Batch.from_data_list([self.graphs.g2m_graph] * batch_size)

batched_senders, batched_receivers, batched_edge_index, batched_edge_attr = hetero_batch(
self.g2m_grid_nodes,
self.g2m_mesh_nodes,
self.g2m_edge_index,
self.g2m_edge_attr,
batch_size,
)
# load features.
grid_features = einops.rearrange(grid_features, "b n f -> (b n) f")
input_grid_nodes = torch.cat([grid_features, g2m_batched["grid_nodes"].x], dim=-1).type(
torch.float32
)
input_mesh_nodes = g2m_batched["mesh_nodes"].x
input_edge_attr = g2m_batched["grid_nodes", "to", "mesh_nodes"].edge_attr
edge_index = g2m_batched["grid_nodes", "to", "mesh_nodes"].edge_index
input_grid_nodes = torch.cat([grid_features, batched_senders], dim=-1)
input_mesh_nodes = batched_receivers
input_edge_attr = batched_edge_attr
edge_index = batched_edge_index

# run the encoder.
latent_grid_nodes, latent_mesh_nodes = self.encoder(
Expand All @@ -168,13 +178,19 @@ def _run_encoder(self, grid_features):
def _run_decoder(self, latent_mesh_nodes, latent_grid_nodes):
# build big graph with batch_size disconnected copies of the graph, with features [(b n) f].
batch_size = latent_mesh_nodes.shape[0]
m2g_batched = Batch.from_data_list([self.graphs.m2g_graph] * batch_size)
_, _, batched_edge_index, batched_edge_attr = hetero_batch(
self.m2g_mesh_nodes,
self.m2g_grid_nodes,
self.m2g_edge_index,
self.m2g_edge_attr,
batch_size,
)

# load features.
input_mesh_nodes = einops.rearrange(latent_mesh_nodes, "b n f -> (b n) f")
input_grid_nodes = einops.rearrange(latent_grid_nodes, "b n f -> (b n) f")
input_edge_attr = m2g_batched["mesh_nodes", "to", "grid_nodes"].edge_attr
edge_index = m2g_batched["mesh_nodes", "to", "grid_nodes"].edge_index
input_edge_attr = batched_edge_attr
edge_index = batched_edge_index

# run the decoder.
output_grid_nodes = self.decoder(
Expand All @@ -194,12 +210,17 @@ def _run_processor(self, latent_mesh_nodes, noise_levels):
# build big graph with batch_size disconnected copies of the graph, with features [(b n) f].
batch_size = latent_mesh_nodes.shape[0]
num_nodes = latent_mesh_nodes.shape[1]
mesh_batched = Batch.from_data_list([self.graphs.khop_mesh_graph] * batch_size)
_, batched_edge_index, batched_edge_attr = batch(
self.khop_mesh_nodes,
self.khop_mesh_edge_index,
self.khop_mesh_edge_attr if self.use_edges_features else None,
batch_size,
)

# load features.
latent_mesh_nodes = einops.rearrange(latent_mesh_nodes, "b n f -> (b n) f")
input_edge_attr = mesh_batched.edge_attr if self.use_edges_features else None
edge_index = mesh_batched.edge_index
input_edge_attr = batched_edge_attr
edge_index = batched_edge_index

# repeat noise levels for each node.
noise_levels = einops.repeat(noise_levels, "b f -> (b n) f", n=num_nodes)
Expand Down Expand Up @@ -272,3 +293,54 @@ def forward(
# restore lon/lat dimensions.
out = einops.rearrange(out, "b (lon lat) f -> b lon lat f", lon=self.num_lon)
return out

def _register_graph(self):
# we need to egister all the tensors associated with the graph as buffers. In this way they
gbruno16 marked this conversation as resolved.
Show resolved Hide resolved
# will move to the same device of the model. These tensors won't be part of the state since
# persistent is set to False.

self.register_buffer(
"g2m_grid_nodes", self.graphs.g2m_graph["grid_nodes"].x, persistent=False
)
self.register_buffer(
"g2m_mesh_nodes", self.graphs.g2m_graph["mesh_nodes"].x, persistent=False
)
self.register_buffer(
"g2m_edge_attr",
self.graphs.g2m_graph["grid_nodes", "to", "mesh_nodes"].edge_attr,
persistent=False,
)
self.register_buffer(
"g2m_edge_index",
self.graphs.g2m_graph["grid_nodes", "to", "mesh_nodes"].edge_index,
persistent=False,
)

self.register_buffer("mesh_nodes", self.graphs.mesh_graph.x, persistent=False)
self.register_buffer("mesh_edge_attr", self.graphs.mesh_graph.edge_attr, persistent=False)
self.register_buffer("mesh_edge_index", self.graphs.mesh_graph.edge_index, persistent=False)

self.register_buffer("khop_mesh_nodes", self.graphs.khop_mesh_graph.x, persistent=False)
self.register_buffer(
"khop_mesh_edge_attr", self.graphs.khop_mesh_graph.edge_attr, persistent=False
)
self.register_buffer(
"khop_mesh_edge_index", self.graphs.khop_mesh_graph.edge_index, persistent=False
)

self.register_buffer(
"m2g_grid_nodes", self.graphs.m2g_graph["grid_nodes"].x, persistent=False
)
self.register_buffer(
"m2g_mesh_nodes", self.graphs.m2g_graph["mesh_nodes"].x, persistent=False
)
self.register_buffer(
"m2g_edge_attr",
self.graphs.m2g_graph["mesh_nodes", "to", "grid_nodes"].edge_attr,
persistent=False,
)
self.register_buffer(
"m2g_edge_index",
self.graphs.m2g_graph["mesh_nodes", "to", "grid_nodes"].edge_index,
persistent=False,
)
Binary file added graph_weather/models/gencast/images/animated.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
5 changes: 5 additions & 0 deletions graph_weather/models/gencast/layers/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
hidden_dims: list[int],
activation_layer: torch.nn.Module = torch.nn.ReLU,
use_layer_norm: bool = True,
scale_factor: float = 1.0,
):
"""Initialize the Encoder.

Expand All @@ -34,6 +35,9 @@ def __init__(
Defaults to torch.nn.ReLU.
use_layer_norm (bool, optional): if true add a LayerNorm at the end of each MLP.
Defaults to True.
scale_factor (float): the message of the interaction network between the grid and the
the mesh is multiplied by the scale factor. Useful when fine-tuning a pretrained
model to a higher resolution. Defaults to 1.
"""
super().__init__()

Expand Down Expand Up @@ -79,6 +83,7 @@ def __init__(
hidden_dims=hidden_dims,
use_layer_norm=use_layer_norm,
activation_layer=activation_layer,
scale_factor=scale_factor,
)

# Final grid nodes update
Expand Down
Loading