Skip to content

Commit

Permalink
GenCast Sampling (#122)
Browse files Browse the repository at this point in the history
* Add sampler

* Add tests

* fix

* Add new noise shapes

* Add tests for new noise shapes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
gbruno16 and pre-commit-ci[bot] authored Jul 22, 2024
1 parent 0a1a5f3 commit 68f8a11
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 636 deletions.
38 changes: 34 additions & 4 deletions graph_weather/data/gencast_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
- corrupt the residual with noise generated at the given noise level.
"""

import warnings

import einops
import numpy as np
import xarray as xr
Expand Down Expand Up @@ -65,9 +67,19 @@ def __init__(
self.atmospheric_features = atmospheric_features
self.single_features = single_features
self.static_features = static_features
# Lat and long will be added by the model itself in the graph

self.means, self.stds, self.diff_means, self.diff_stds = self._init_means_and_stds()
# Lat and long will be added by the model itself in the graph

# check if fast isotropic noise generation is possible
if (self.num_lon == 2 * self.num_lat) or (self.num_lon == 2 * (self.num_lat - 1)):
self.use_isotropic_noise = True
else:
self.use_isotropic_noise = False
warnings.warn(
"Isotropic noise requires grid's shape to be 2N x N or 2N x (N+1): "
f"got {self.num_lon} x {self.num_lat}: falling back to flat normal random noise"
)

def _init_means_and_stds(self):
means = []
Expand Down Expand Up @@ -194,7 +206,10 @@ def __getitem__(self, item):
# Corrupt targets with noise
noise_levels = np.array([sample_noise_level()]).astype(np.float32)
noise = generate_isotropic_noise(
num_lat=self.num_lat, num_samples=target_residuals.shape[-1]
num_lon=self.num_lon,
num_lat=self.num_lat,
num_samples=target_residuals.shape[-1],
isotropic=self.use_isotropic_noise,
)
corrupted_targets = target_residuals + noise_levels * noise

Expand Down Expand Up @@ -240,6 +255,8 @@ def __init__(
self.data = xr.open_zarr(obs_path, chunks={})
self.max_year = max_year

self.grid_lon = self.data["longitude"].values
self.grid_lat = self.data["latitude"].values
self.num_lon = len(self.data["longitude"].values)
self.num_lat = len(self.data["latitude"].values)
self.num_vars = len(self.data.keys())
Expand All @@ -253,11 +270,21 @@ def __init__(
self.atmospheric_features = atmospheric_features
self.single_features = single_features
self.static_features = static_features
# Lat and long will be added by the model itself in the graph

self.clock_features = ["local_time_of_the_day", "elapsed_year_progress"]

self.means, self.stds, self.diff_means, self.diff_stds = self._init_means_and_stds()
# Lat and long will be added by the model itself in the graph

# check if fast isotropic noise generation is possible
if (self.num_lon == 2 * self.num_lat) or (self.num_lon == 2 * (self.num_lat - 1)):
self.use_isotropic_noise = True
else:
self.use_isotropic_noise = False
warnings.warn(
"Isotropic noise requires grid's shape to be 2N x N or 2N x (N+1): "
f"got {self.num_lon} x {self.num_lat}: falling back to flat normal random noise"
)

def _init_means_and_stds(self):
means = []
Expand Down Expand Up @@ -386,7 +413,10 @@ def __getitem__(self, item):
for b in range(self.batch_size):
noise_level = sample_noise_level()
noise = generate_isotropic_noise(
num_lat=self.num_lat, num_samples=target_residuals.shape[-1]
num_lon=self.num_lon,
num_lat=self.num_lat,
num_samples=target_residuals.shape[-1],
isotropic=self.use_isotropic_noise,
)
corrupted_targets[b] = target_residuals[b] + noise_level * noise
noise_levels[b] = noise_level
Expand Down
1 change: 1 addition & 0 deletions graph_weather/models/gencast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

from .denoiser import Denoiser
from .graph.graph_builder import GraphBuilder
from .sampler import Sampler
from .weighted_mse_loss import WeightedMSELoss
130 changes: 130 additions & 0 deletions graph_weather/models/gencast/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Diffusion sampler"""

import math

import torch

from graph_weather.models.gencast import Denoiser
from graph_weather.models.gencast.utils.noise import generate_isotropic_noise


class Sampler:
"""Sampler for the denoiser.
The sampler consists in the second-order DPMSolver++2S solver (Lu et al., 2022), augmented with
the stochastic churn (again making use of the isotropic noise) and noise inflation techniques
used in Karras et al. (2022) to inject further stochasticity into the sampling process. In
conditioning on previous timesteps it follows the Conditional Denoising Estimator approach
outlined and motivated by Batzolis et al. (2021).
"""

def __init__(
self,
S_noise: float = 1.05,
S_tmin: float = 0.75,
S_tmax: float = 80.0,
S_churn: float = 2.5,
r: float = 0.5,
sigma_max: float = 80.0,
sigma_min: float = 0.03,
rho: float = 7,
num_steps: int = 20,
):
"""Initialize the sampler.
Args:
S_noise (float): noise inflation parameter. Defaults to 1.05.
S_tmin (float): minimum noise for sampling. Defaults to 0.75.
S_tmax (float): maximum noise for sampling. Defaults to 80.
S_churn (float): stochastic churn rate. Defaults to 2.5.
r (float): _description_. Defaults to 0.5.
sigma_max (float): maximum value of sigma for sigma's distribution. Defaults to 80.
sigma_min (float): minimum value of sigma for sigma's distribution. Defaults to 0.03.
rho (float): exponent of the sigma's distribution. Defaults to 7.
num_steps (int): number of timesteps during sampling. Defaults to 20.
"""
self.S_noise = S_noise
self.S_tmin = S_tmin
self.S_tmax = S_tmax
self.S_churn = S_churn
self.r = r
self.num_steps = num_steps

self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.rho = rho

def _sigmas_fn(self, u):
return (
self.sigma_max ** (1 / self.rho)
+ u * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho))
) ** self.rho

@torch.no_grad()
def sample(self, denoiser: Denoiser, prev_inputs: torch.Tensor):
"""Generate a sample from random noise for the given inputs.
Args:
denoiser (Denoiser): the denoiser model.
prev_inputs (torch.Tensor): previous two timesteps.
Returns:
torch.Tensor: normalized residuals predicted.
"""
device = prev_inputs.device

time_steps = torch.arange(0, self.num_steps).to(device) / (self.num_steps - 1)
sigmas = self._sigmas_fn(time_steps)

batch_ones = torch.ones(1, 1).to(device)

# initialize noise
x = sigmas[0] * torch.tensor(
generate_isotropic_noise(
num_lon=denoiser.num_lon,
num_lat=denoiser.num_lat,
num_samples=denoiser.output_features_dim,
)
).unsqueeze(0).to(device)

for i in range(len(sigmas) - 1):
# stochastic churn from Karras et al. (Alg. 2)
gamma = (
min(self.S_churn / self.num_steps, math.sqrt(2) - 1)
if self.S_tmin <= sigmas[i] <= self.S_tmax
else 0.0
)
# noise inflation from Karras et al. (Alg. 2)
noise = self.S_noise * torch.tensor(
generate_isotropic_noise(
num_lon=denoiser.num_lon,
num_lat=denoiser.num_lat,
num_samples=denoiser.output_features_dim,
)
)
noise = noise.unsqueeze(0).to(device)

sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 * noise
denoised = denoiser(x, prev_inputs, sigma_hat * batch_ones)

if i == len(sigmas) - 2:
# final Euler step
d = (x - denoised) / sigma_hat
x = x + d * (sigmas[i + 1] - sigma_hat)
else:
# DPMSolver++2S step (Alg. 1 in Lu et al.) with alpha_t=1.
# t_{i-1} is t_hat because of stochastic churn!
lambda_hat = -torch.log(sigma_hat)
lambda_next = -torch.log(sigmas[i + 1])
h = lambda_next - lambda_hat
lambda_mid = lambda_hat + self.r * h
sigma_mid = torch.exp(-lambda_mid)

u = sigma_mid / sigma_hat * x - (torch.exp(-self.r * h) - 1) * denoised
denoised_2 = denoiser(u, prev_inputs, sigma_mid * batch_ones)
D = (1 - 1 / (2 * self.r)) * denoised + 1 / (2 * self.r) * denoised_2
x = sigmas[i + 1] / sigma_hat * x - (torch.exp(-h) - 1) * D

return x
50 changes: 35 additions & 15 deletions graph_weather/models/gencast/utils/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,48 @@
import torch


def generate_isotropic_noise(num_lat, num_samples=1):
"""Generate isotropic noise on the grid.
def generate_isotropic_noise(num_lon: int, num_lat: int, num_samples=1, isotropic=True):
"""Generate noise on the grid.
Sample the equivalent of white noise on a sphere and project it onto a grid using
Driscoll and Healy, 1994 algorithm. The power spectrum is normalized to have variance 1.
We need to assume lons = 2* lats.
When isotropic is True it samples the equivalent of white noise on a sphere and project it onto
a grid using Driscoll and Healy, 1994, algorithm. The power spectrum is normalized to have
variance 1. We need to assume lons = 2 * lats or lons = 2 * (lats -1). If isotropic is false, it
samples flat normal random noise.
Args:
num_lat (int): Number of latitudes in the final grid.
num_samples (int, optional): Number of indipendent samples. Defaults to 1.
num_lon (int): number of longitudes in the grid.
num_lat (int): number of latitudes in the grid.
num_samples (int): number of indipendent samples. Defaults to 1.
isotropic (bool): if true generates isotropic noise, else flat noise. Defaults to True.
Returns:
grid: Numpy array with shape shape(grid) x num_samples.
"""
power = np.ones(num_lat // 2, dtype=float) / (
num_lat // 2
) # normalized to get each point with std 1
grid = np.zeros((num_lat * 2, num_lat, num_samples))
for i in range(num_samples):
clm = pysh.SHCoeffs.from_random(power)
grid[:, :, i] = clm.expand(grid="DH2", extend=False).to_array().transpose()
return grid.astype(np.float32)
if isotropic:
if 2 * num_lat == num_lon:
extend = False
elif 2 * (num_lat - 1) == num_lon:
extend = True
else:
raise ValueError(
"Isotropic noise requires grid's shape to be 2N x N or 2N x (N+1): "
f"got {num_lon} x {num_lat}. If the shape is correct, please specify"
"isotropic=False in the constructor.",
)

if isotropic:
l_max = num_lat // 2
power = np.ones(l_max, dtype=float) / l_max**2 # normalized to get each point with std 1
grid = np.zeros((num_lon, num_lat, num_samples))
for i in range(num_samples):
clm = pysh.SHCoeffs.from_random(power, power_unit="per_lm")
grid[:, :, i] = (
clm.expand(grid="DH2", extend=extend).to_array().transpose()[:num_lon, :num_lat]
)
noise = grid.astype(np.float32)
else:
noise = np.random.randn(num_lon, num_lat, num_samples)
return noise


def sample_noise_level(sigma_min=0.02, sigma_max=88, rho=7):
Expand Down
62 changes: 58 additions & 4 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
generate_isotropic_noise,
sample_noise_level,
)
from graph_weather.models.gencast import GraphBuilder, WeightedMSELoss, Denoiser
from graph_weather.models.gencast import GraphBuilder, WeightedMSELoss, Denoiser, Sampler
from graph_weather.models.gencast.layers.modules import FourierEmbedding


Expand Down Expand Up @@ -316,11 +316,38 @@ def test_meta_model():


def test_gencast_noise():
num_lat = 32
num_lon = 360
num_lat = 180
num_samples = 5
target_residuals = np.zeros((2 * num_lat, num_lat, num_samples))
target_residuals = np.zeros((num_lon, num_lat, num_samples))
noise_level = sample_noise_level()
noise = generate_isotropic_noise(num_lat=num_lat, num_samples=target_residuals.shape[-1])
noise = generate_isotropic_noise(
num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1]
)
corrupted_residuals = target_residuals + noise_level * noise
assert corrupted_residuals.shape == target_residuals.shape
assert not np.isnan(corrupted_residuals).any()

num_lon = 360
num_lat = 181
num_samples = 5
target_residuals = np.zeros((num_lon, num_lat, num_samples))
noise_level = sample_noise_level()
noise = generate_isotropic_noise(
num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1]
)
corrupted_residuals = target_residuals + noise_level * noise
assert corrupted_residuals.shape == target_residuals.shape
assert not np.isnan(corrupted_residuals).any()

num_lon = 100
num_lat = 100
num_samples = 5
target_residuals = np.zeros((num_lon, num_lat, num_samples))
noise_level = sample_noise_level()
noise = generate_isotropic_noise(
num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1], isotropic=False
)
corrupted_residuals = target_residuals + noise_level * noise
assert corrupted_residuals.shape == target_residuals.shape
assert not np.isnan(corrupted_residuals).any()
Expand Down Expand Up @@ -408,3 +435,30 @@ def test_gencast_fourier():
fourier_embedder = FourierEmbedding(output_dim=output_dim, num_frequencies=32, base_period=16)
t = torch.rand((batch_size, 1))
assert fourier_embedder(t).shape == (batch_size, output_dim)


def test_gencast_sampler():
grid_lat = np.arange(-90, 90, 1)
grid_lon = np.arange(0, 360, 1)
input_features_dim = 10
output_features_dim = 5

denoiser = Denoiser(
grid_lon=grid_lon,
grid_lat=grid_lat,
input_features_dim=input_features_dim,
output_features_dim=output_features_dim,
hidden_dims=[16, 32],
num_blocks=3,
num_heads=4,
splits=0,
num_hops=1,
device=torch.device("cpu"),
).eval()

prev_inputs = torch.randn((1, len(grid_lon), len(grid_lat), 2 * input_features_dim))

sampler = Sampler()
preds = sampler.sample(denoiser, prev_inputs)
assert not torch.isnan(preds).any()
assert preds.shape == (1, len(grid_lon), len(grid_lat), output_features_dim)
Loading

0 comments on commit 68f8a11

Please sign in to comment.