From cb7055190831d203aaa77cdfdef4a7066bbc4029 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Wed, 29 May 2024 10:58:50 +0200 Subject: [PATCH 01/16] fengwu_ghr: initial --- graph_weather/models/__init__.py | 1 + graph_weather/models/fengwu_ghr/layers.py | 133 ++++++++++++++++++++++ tests/test_model.py | 14 ++- 3 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 graph_weather/models/fengwu_ghr/layers.py diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index a18cda87..289d3724 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -5,3 +5,4 @@ from .layers.decoder import Decoder from .layers.encoder import Encoder from .layers.processor import Processor +from .fengwu_ghr.layers import MetaModel diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py new file mode 100644 index 00000000..62425651 --- /dev/null +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -0,0 +1,133 @@ +import torch +from torch import nn + +from einops import rearrange +from einops.layers.torch import Rearrange + +# helpers + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" + omega = torch.arange(dim // 4) / (dim // 4 - 1) + omega = 1.0 / (temperature ** omega) + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return pe.type(dtype) + +# classes + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, dim), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) + + self.attend = nn.Softmax(dim=-1) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x): + x = self.norm(x) + + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange( + t, 'b n (h d) -> b h n d', h=self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim, heads=heads, dim_head=dim_head), + FeedForward(dim, mlp_dim) + ])) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return self.norm(x) + + +class MetaModel(nn.Module): + def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64): + super().__init__() + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + + patch_dim = channels * patch_height * patch_width + dim = patch_dim + self.to_patch_embedding = nn.Sequential( + Rearrange("b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", + p_h=patch_height, p_w=patch_width), + nn.LayerNorm(patch_dim), # TODO Do we need this? + nn.Linear(patch_dim, dim), # TODO Do we need this? + nn.LayerNorm(dim), # TODO Do we need this? + ) + + self.pos_embedding = posemb_sincos_2d( + h=image_height // patch_height, + w=image_width // patch_width, + dim=dim, + ) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) + + self.reshaper = nn.Sequential( + Rearrange("b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)", + h=image_height // patch_height, + w=image_width // patch_width, + p_h=patch_height, p_w=patch_width) + ) + + def forward(self, img): + device = img.device + + x = self.to_patch_embedding(img) + x += self.pos_embedding.to(device, dtype=x.dtype) + + x = self.transformer(x) + + print(x.shape) + x = self.reshaper(x) + + return x diff --git a/tests/test_model.py b/tests/test_model.py index 58904292..2ef43cc8 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,7 +3,7 @@ import torch from graph_weather import GraphWeatherAssimilator, GraphWeatherForecaster -from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor +from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor, MetaModel from graph_weather.models.losses import NormalizedMSELoss @@ -222,3 +222,15 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) + + +def test_meta_model(): + model = MetaModel(image_size=100,patch_size=10, + depth=1, heads=1, mlp_dim=7, + channels=3 ) + features = torch.randn((1,3, 100,100) ) + + out = model(features) + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert out.size() == (1,3, 100,100) From 9eaf70d5dad163afa5311c0819ce9f363abd82cb Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Wed, 29 May 2024 11:00:25 +0200 Subject: [PATCH 02/16] fengwu_ghr: fixes --- .gitignore | 1 + environment_cpu.yml | 2 +- graph_weather/models/fengwu_ghr/layers.py | 3 +-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 9b8aab0b..d248bf98 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ *.txt # pixi environments .pixi +.vscode/ diff --git a/environment_cpu.yml b/environment_cpu.yml index db087fc2..e854dc84 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -9,7 +9,7 @@ dependencies: - pandas - pip - pyg - - python=3.12 + - python - pytorch - cpuonly - pytorch-cluster diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 62425651..9aeaf508 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -126,8 +126,7 @@ def forward(self, img): x += self.pos_embedding.to(device, dtype=x.dtype) x = self.transformer(x) - - print(x.shape) + x = self.reshaper(x) return x From 4f3d4c1774fded109433219dd56b7de2ef2c27c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 May 2024 09:11:13 +0000 Subject: [PATCH 03/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/__init__.py | 2 +- graph_weather/models/fengwu_ghr/layers.py | 50 +++++++++++++---------- tests/test_model.py | 19 +++++---- 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 289d3724..72d222a8 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,8 +1,8 @@ """Models""" +from .fengwu_ghr.layers import MetaModel from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder from .layers.encoder import Encoder from .layers.processor import Processor -from .fengwu_ghr.layers import MetaModel diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 9aeaf508..cd81218d 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,8 +1,7 @@ import torch -from torch import nn - from einops import rearrange from einops.layers.torch import Rearrange +from torch import nn # helpers @@ -15,13 +14,14 @@ def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" omega = torch.arange(dim // 4) / (dim // 4 - 1) - omega = 1.0 / (temperature ** omega) + omega = 1.0 / (temperature**omega) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) return pe.type(dtype) + # classes @@ -44,7 +44,7 @@ def __init__(self, dim, heads=8, dim_head=64): super().__init__() inner_dim = dim_head * heads self.heads = heads - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim=-1) @@ -56,15 +56,14 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange( - t, 'b n (h d) -> b h n d', h=self.heads), qkv) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) out = torch.matmul(attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) @@ -74,10 +73,11 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim): self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) for _ in range(depth): - self.layers.append(nn.ModuleList([ - Attention(dim, heads=heads, dim_head=dim_head), - FeedForward(dim, mlp_dim) - ])) + self.layers.append( + nn.ModuleList( + [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] + ) + ) def forward(self, x): for attn, ff in self.layers: @@ -92,16 +92,19 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) - assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + assert ( + image_height % patch_height == 0 and image_width % patch_width == 0 + ), "Image dimensions must be divisible by the patch size." patch_dim = channels * patch_height * patch_width dim = patch_dim self.to_patch_embedding = nn.Sequential( - Rearrange("b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", - p_h=patch_height, p_w=patch_width), - nn.LayerNorm(patch_dim), # TODO Do we need this? - nn.Linear(patch_dim, dim), # TODO Do we need this? - nn.LayerNorm(dim), # TODO Do we need this? + Rearrange( + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width + ), + nn.LayerNorm(patch_dim), # TODO Do we need this? + nn.Linear(patch_dim, dim), # TODO Do we need this? + nn.LayerNorm(dim), # TODO Do we need this? ) self.pos_embedding = posemb_sincos_2d( @@ -113,10 +116,13 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) self.reshaper = nn.Sequential( - Rearrange("b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)", - h=image_height // patch_height, - w=image_width // patch_width, - p_h=patch_height, p_w=patch_width) + Rearrange( + "b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)", + h=image_height // patch_height, + w=image_width // patch_width, + p_h=patch_height, + p_w=patch_width, + ) ) def forward(self, img): @@ -126,7 +132,7 @@ def forward(self, img): x += self.pos_embedding.to(device, dtype=x.dtype) x = self.transformer(x) - + x = self.reshaper(x) return x diff --git a/tests/test_model.py b/tests/test_model.py index 2ef43cc8..050f3e28 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,7 +3,14 @@ import torch from graph_weather import GraphWeatherAssimilator, GraphWeatherForecaster -from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor, MetaModel +from graph_weather.models import ( + AssimilatorDecoder, + AssimilatorEncoder, + Decoder, + Encoder, + Processor, + MetaModel, +) from graph_weather.models.losses import NormalizedMSELoss @@ -225,12 +232,10 @@ def test_normalized_loss(): def test_meta_model(): - model = MetaModel(image_size=100,patch_size=10, - depth=1, heads=1, mlp_dim=7, - channels=3 ) - features = torch.randn((1,3, 100,100) ) - + model = MetaModel(image_size=100, patch_size=10, depth=1, heads=1, mlp_dim=7, channels=3) + features = torch.randn((1, 3, 100, 100)) + out = model(features) assert not torch.isnan(out).any() assert not torch.isnan(out).any() - assert out.size() == (1,3, 100,100) + assert out.size() == (1, 3, 100, 100) From 8c60fb71e4de98560d601b0a295b2b957fbe9e3b Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Thu, 6 Jun 2024 15:07:10 +0200 Subject: [PATCH 04/16] Interpolate initial --- graph_weather/models/fengwu_ghr/layers.py | 79 ++++++++++++++++++++--- tests/test_model.py | 31 ++++++--- 2 files changed, 92 insertions(+), 18 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index cd81218d..036d9048 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,3 +1,5 @@ +import numpy as np +from scipy.interpolate import griddata, interpn import torch from einops import rearrange from einops.layers.torch import Rearrange @@ -10,6 +12,39 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) +def grid_interpolate(lat_lons: list, z: torch.Tensor, + height, width, + method: str = "cubic"): + # TODO 1. CPU only + # 2. The mesh is a rectangle, not a sphere + + xi = np.arange(0.5, width, 1)/width*360 + yi = np.arange(0.5, height, 1)/height*180 + + xi, yi = np.meshgrid(xi, yi) + z = rearrange(z, "b n c -> n b c") + z = griddata( + lat_lons, z, (xi, yi), + fill_value=0, method=method) + z = rearrange(z, "h w b c -> b c h w") # hw ? + z = torch.tensor(z) + return z + +def grid_extrapolate(lat_lons, z, + height, width, + method: str = "cubic"): + xi = np.arange(0.5, width, 1)/width*360 + yi = np.arange(0.5, height, 1)/height*180 + z = rearrange(z, "b c h w -> h w b c") + z = z.detach().numpy() + z= interpn((xi,yi),z, lat_lons, + bounds_error=False, + method=method) + z = rearrange(z, "n b c -> b n c") + z = torch.tensor(z) + return z + + def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" @@ -56,7 +91,8 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange( + t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -75,7 +111,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim): for _ in range(depth): self.layers.append( nn.ModuleList( - [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] + [Attention(dim, heads=heads, dim_head=dim_head), + FeedForward(dim, mlp_dim)] ) ) @@ -87,20 +124,31 @@ def forward(self, x): class MetaModel(nn.Module): - def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64): + def __init__(self, lat_lons: list, *, + patch_size, depth, + heads, mlp_dim, + resolution=(721, 1440), + channels=3, dim_head=64, + interp_method='cubic'): super().__init__() - image_height, image_width = pair(image_size) + image_height, image_width = pair(resolution) patch_height, patch_width = pair(patch_size) assert ( image_height % patch_height == 0 and image_width % patch_width == 0 ), "Image dimensions must be divisible by the patch size." + # interpolate + self.interpolate = lambda z: grid_interpolate( + lat_lons, z, image_height, image_width, + method=interp_method) + patch_dim = channels * patch_height * patch_width dim = patch_dim self.to_patch_embedding = nn.Sequential( Rearrange( - "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", + p_h=patch_height, p_w=patch_width ), nn.LayerNorm(patch_dim), # TODO Do we need this? nn.Linear(patch_dim, dim), # TODO Do we need this? @@ -125,14 +173,27 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, ) ) - def forward(self, img): - device = img.device + # extrapolate + self.extrapolate = lambda z: grid_extrapolate( + lat_lons, z, image_height, image_width, + method=interp_method) + - x = self.to_patch_embedding(img) - x += self.pos_embedding.to(device, dtype=x.dtype) + def forward(self, x): + device = x.device + dtype = x.dtype + + x = self.interpolate(x.to("cpu")) + x = x.to(device, dtype=dtype) + + x = self.to_patch_embedding(x) + x += self.pos_embedding.to(device, dtype=dtype) x = self.transformer(x) x = self.reshaper(x) + x = self.extrapolate(x.to("cpu")) + x = x.to(device, dtype=dtype) + return x diff --git a/tests/test_model.py b/tests/test_model.py index 050f3e28..474a3f89 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -142,7 +142,8 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator( + output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -156,7 +157,8 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -197,7 +199,8 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -228,14 +231,24 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights - assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) + assert torch.isclose( + loss, criterion.weights.expand_as(out.mean(-1)).mean()) def test_meta_model(): - model = MetaModel(image_size=100, patch_size=10, depth=1, heads=1, mlp_dim=7, channels=3) - features = torch.randn((1, 3, 100, 100)) + lat_lons = [] + for lat in range(-90, 90, 5): + for lon in range(0, 360, 5): + lat_lons.append((lat, lon)) + + batch =2 + channels = 3 + model = MetaModel(lat_lons, + resolution=4, patch_size=2, + depth=1, heads=1, mlp_dim=7, channels=channels) + features = torch.randn((batch,len(lat_lons), channels)) out = model(features) - assert not torch.isnan(out).any() - assert not torch.isnan(out).any() - assert out.size() == (1, 3, 100, 100) + #assert not torch.isnan(out).any() + #assert not torch.isnan(out).any() + assert out.size() == (batch,len(lat_lons), channels) From 725421df408eafcd4bcea62a420b7a8e704ad299 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 11 Jun 2024 15:26:54 +0200 Subject: [PATCH 05/16] ImageMetaModel --- graph_weather/models/__init__.py | 2 +- graph_weather/models/fengwu_ghr/layers.py | 86 ++++++++++++++++------- tests/test_model.py | 27 +++++-- 3 files changed, 82 insertions(+), 33 deletions(-) diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 72d222a8..fadc1d52 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,6 +1,6 @@ """Models""" -from .fengwu_ghr.layers import MetaModel +from .fengwu_ghr.layers import MetaModel,ImageMetaModel from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 036d9048..e3da7e31 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -5,12 +5,38 @@ from einops.layers.torch import Rearrange from torch import nn + # helpers def pair(t): return t if isinstance(t, tuple) else (t, t) +from torch_geometric.nn import knn +from torch_geometric.utils import scatter + + +def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, + k: int = 3, num_workers: int = 1): + with torch.no_grad(): + assign_index = knn(pos_x, pos_y, k, + num_workers=num_workers) + y_idx, x_idx = assign_index[0], assign_index[1] + diff = pos_x[x_idx] - pos_y[y_idx] + squared_distance = (diff * diff).sum(dim=-1, keepdim=True) + weights = 1.0 / torch.clamp(squared_distance, min=1e-16) + + + # print((x[x_idx]*weights).shape) + # print(weights.shape) + den = scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum') + # print(den.shape) + y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum') + + + y = y / den + + return y def grid_interpolate(lat_lons: list, z: torch.Tensor, height, width, @@ -30,6 +56,7 @@ def grid_interpolate(lat_lons: list, z: torch.Tensor, z = torch.tensor(z) return z + def grid_extrapolate(lat_lons, z, height, width, method: str = "cubic"): @@ -37,8 +64,8 @@ def grid_extrapolate(lat_lons, z, yi = np.arange(0.5, height, 1)/height*180 z = rearrange(z, "b c h w -> h w b c") z = z.detach().numpy() - z= interpn((xi,yi),z, lat_lons, - bounds_error=False, + z = interpn((xi, yi), z, lat_lons, + bounds_error=False, method=method) z = rearrange(z, "n b c -> b n c") z = torch.tensor(z) @@ -122,27 +149,20 @@ def forward(self, x): x = ff(x) + x return self.norm(x) - -class MetaModel(nn.Module): - def __init__(self, lat_lons: list, *, +class ImageMetaModel(nn.Module): + def __init__(self, *, + image_size, patch_size, depth, heads, mlp_dim, - resolution=(721, 1440), - channels=3, dim_head=64, - interp_method='cubic'): + channels=3, dim_head=64): super().__init__() - image_height, image_width = pair(resolution) + image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) assert ( image_height % patch_height == 0 and image_width % patch_width == 0 ), "Image dimensions must be divisible by the patch size." - # interpolate - self.interpolate = lambda z: grid_interpolate( - lat_lons, z, image_height, image_width, - method=interp_method) - patch_dim = channels * patch_height * patch_width dim = patch_dim self.to_patch_embedding = nn.Sequential( @@ -173,27 +193,39 @@ def __init__(self, lat_lons: list, *, ) ) - # extrapolate - self.extrapolate = lambda z: grid_extrapolate( - lat_lons, z, image_height, image_width, - method=interp_method) - - def forward(self, x): device = x.device dtype = x.dtype - x = self.interpolate(x.to("cpu")) - x = x.to(device, dtype=dtype) - x = self.to_patch_embedding(x) x += self.pos_embedding.to(device, dtype=dtype) x = self.transformer(x) - x = self.reshaper(x) - x = self.extrapolate(x.to("cpu")) - x = x.to(device, dtype=dtype) - return x + +class MetaModel(nn.Module): + def __init__(self, lat_lons: list, *, + patch_size, depth, + heads, mlp_dim, + resolution=(721, 1440), + channels=3, dim_head=64, + interp_method='cubic'): + super().__init__() + resolution = pair(resolution) + b=3 + n=len(lat_lons) + d=7 + x=torch.randn((b,n,d)) + x=rearrange(x,"b n d -> n (b d)") + + pos_x= torch.tensor(lat_lons) + pos_y = torch.cartesian_prod( + torch.arange(0.5,resolution[0],1), + torch.arange(0.5,resolution[1],1) + ) + x = knn_interpolate(x,pos_x,pos_y) + x = rearrange(x,"m (b d) -> b m d", b=b,d=d) + print(x.shape) + diff --git a/tests/test_model.py b/tests/test_model.py index 474a3f89..95ec75bf 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -10,6 +10,7 @@ Encoder, Processor, MetaModel, + ImageMetaModel ) from graph_weather.models.losses import NormalizedMSELoss @@ -235,20 +236,36 @@ def test_normalized_loss(): loss, criterion.weights.expand_as(out.mean(-1)).mean()) +def test_image_meta_model(): + batch = 2 + channels = 3 + size = 900 + image = torch.randn((batch, channels, size, size)) + model = ImageMetaModel(image_size=size, + patch_size=10, + depth=1, heads=1, mlp_dim=7, + channels=channels) + + out = model(image) + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert out.size() == (batch, channels,size,size) + + def test_meta_model(): lat_lons = [] for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - batch =2 + batch = 2 channels = 3 model = MetaModel(lat_lons, resolution=4, patch_size=2, depth=1, heads=1, mlp_dim=7, channels=channels) - features = torch.randn((batch,len(lat_lons), channels)) + features = torch.randn((batch, len(lat_lons), channels)) out = model(features) - #assert not torch.isnan(out).any() - #assert not torch.isnan(out).any() - assert out.size() == (batch,len(lat_lons), channels) + # assert not torch.isnan(out).any() + # assert not torch.isnan(out).any() + assert out.size() == (batch, len(lat_lons), channels) From c57a27ec2358d5cd2ed0883d1f406efa52806a7b Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 11 Jun 2024 15:59:59 +0200 Subject: [PATCH 06/16] MetaModel initial --- graph_weather/models/fengwu_ghr/layers.py | 78 +++++++++++++++++------ 1 file changed, 60 insertions(+), 18 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index e3da7e31..0a4b7c3c 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,3 +1,6 @@ +from scipy.interpolate import griddata +from torch_geometric.nn import knn +from torch_geometric.utils import scatter import numpy as np from scipy.interpolate import griddata, interpn import torch @@ -12,9 +15,6 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) -from torch_geometric.nn import knn -from torch_geometric.utils import scatter - def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 3, num_workers: int = 1): @@ -26,18 +26,17 @@ def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) - # print((x[x_idx]*weights).shape) # print(weights.shape) den = scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum') # print(den.shape) y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum') - - + y = y / den return y + def grid_interpolate(lat_lons: list, z: torch.Tensor, height, width, method: str = "cubic"): @@ -149,6 +148,7 @@ def forward(self, x): x = ff(x) + x return self.norm(x) + class ImageMetaModel(nn.Module): def __init__(self, *, image_size, @@ -205,7 +205,50 @@ def forward(self, x): return x + class MetaModel(nn.Module): + def __init__(self, lat_lons: list, *, + patch_size, depth, + heads, mlp_dim, + resolution=(721, 1440), + channels=3, dim_head=64, + interp_method='cubic'): + super().__init__() + self.resolution = pair(resolution) + + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + torch.arange(0, self.resolution[0], 1), + torch.arange(0, self.resolution[1], 1) + ) + + self.image_model = ImageMetaModel(image_size=resolution, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels, + dim_head=dim_head) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, + w=self.resolution[0], + h=self.resolution[1]) + + x = self.image_model(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + + return x + + +class MetaModel2(nn.Module): def __init__(self, lat_lons: list, *, patch_size, depth, heads, mlp_dim, @@ -214,18 +257,17 @@ def __init__(self, lat_lons: list, *, interp_method='cubic'): super().__init__() resolution = pair(resolution) - b=3 - n=len(lat_lons) - d=7 - x=torch.randn((b,n,d)) - x=rearrange(x,"b n d -> n (b d)") - - pos_x= torch.tensor(lat_lons) + b = 3 + n = len(lat_lons) + d = 7 + x = torch.randn((b, n, d)) + x = rearrange(x, "b n d -> n (b d)") + + pos_x = torch.tensor(lat_lons) pos_y = torch.cartesian_prod( - torch.arange(0.5,resolution[0],1), - torch.arange(0.5,resolution[1],1) + torch.arange(0, resolution[0], 1), + torch.arange(0, resolution[1], 1) ) - x = knn_interpolate(x,pos_x,pos_y) - x = rearrange(x,"m (b d) -> b m d", b=b,d=d) + x = knn_interpolate(x, pos_x, pos_y) + x = rearrange(x, "m (b d) -> b m d", b=b, d=d) print(x.shape) - From 3d2a17d62c15d93f29a22fc54f8ebc7ad4052d9e Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Fri, 14 Jun 2024 16:44:08 +0200 Subject: [PATCH 07/16] tested metamodel --- graph_weather/models/fengwu_ghr/layers.py | 90 +++-------------------- tests/test_model.py | 19 +++-- 2 files changed, 23 insertions(+), 86 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 0a4b7c3c..dc530eac 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,8 +1,6 @@ from scipy.interpolate import griddata from torch_geometric.nn import knn from torch_geometric.utils import scatter -import numpy as np -from scipy.interpolate import griddata, interpn import torch from einops import rearrange from einops.layers.torch import Rearrange @@ -17,7 +15,7 @@ def pair(t): def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, - k: int = 3, num_workers: int = 1): + k: int = 4, num_workers: int = 1): with torch.no_grad(): assign_index = knn(pos_x, pos_y, k, num_workers=num_workers) @@ -26,10 +24,7 @@ def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) - # print((x[x_idx]*weights).shape) - # print(weights.shape) den = scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum') - # print(den.shape) y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum') y = y / den @@ -37,40 +32,6 @@ def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, return y -def grid_interpolate(lat_lons: list, z: torch.Tensor, - height, width, - method: str = "cubic"): - # TODO 1. CPU only - # 2. The mesh is a rectangle, not a sphere - - xi = np.arange(0.5, width, 1)/width*360 - yi = np.arange(0.5, height, 1)/height*180 - - xi, yi = np.meshgrid(xi, yi) - z = rearrange(z, "b n c -> n b c") - z = griddata( - lat_lons, z, (xi, yi), - fill_value=0, method=method) - z = rearrange(z, "h w b c -> b c h w") # hw ? - z = torch.tensor(z) - return z - - -def grid_extrapolate(lat_lons, z, - height, width, - method: str = "cubic"): - xi = np.arange(0.5, width, 1)/width*360 - yi = np.arange(0.5, height, 1)/height*180 - z = rearrange(z, "b c h w -> h w b c") - z = z.detach().numpy() - z = interpn((xi, yi), z, lat_lons, - bounds_error=False, - method=method) - z = rearrange(z, "n b c -> b n c") - z = torch.tensor(z) - return z - - def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" @@ -210,19 +171,19 @@ class MetaModel(nn.Module): def __init__(self, lat_lons: list, *, patch_size, depth, heads, mlp_dim, - resolution=(721, 1440), - channels=3, dim_head=64, - interp_method='cubic'): + image_size=(721, 1440), + channels=3, dim_head=64): super().__init__() - self.resolution = pair(resolution) + self.image_size = pair(image_size) self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - torch.arange(0, self.resolution[0], 1), - torch.arange(0, self.resolution[1], 1) + (torch.arange(-self.image_size[0]/2, + self.image_size[0]/2, 1)/self.image_size[0]*180).to(torch.long), + (torch.arange(0, self.image_size[1], 1)/self.image_size[1]*360).to(torch.long) ) - self.image_model = ImageMetaModel(image_size=resolution, + self.image_model = ImageMetaModel(image_size=image_size, patch_size=patch_size, depth=depth, heads=heads, @@ -235,39 +196,12 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, - w=self.resolution[0], - h=self.resolution[1]) - + x = rearrange(x, "(w h) (b c) -> b c w h", b=b, c=c, + w=self.image_size[0], + h=self.image_size[1]) x = self.image_model(x) - x = rearrange(x, "b c h w -> (h w) (b c)") + x = rearrange(x, "b c w h -> (w h) (b c)") x = knn_interpolate(x, self.pos_y, self.pos_x) x = rearrange(x, "n (b c) -> b n c", b=b, c=c) - return x - - -class MetaModel2(nn.Module): - def __init__(self, lat_lons: list, *, - patch_size, depth, - heads, mlp_dim, - resolution=(721, 1440), - channels=3, dim_head=64, - interp_method='cubic'): - super().__init__() - resolution = pair(resolution) - b = 3 - n = len(lat_lons) - d = 7 - x = torch.randn((b, n, d)) - x = rearrange(x, "b n d -> n (b d)") - - pos_x = torch.tensor(lat_lons) - pos_y = torch.cartesian_prod( - torch.arange(0, resolution[0], 1), - torch.arange(0, resolution[1], 1) - ) - x = knn_interpolate(x, pos_x, pos_y) - x = rearrange(x, "m (b d) -> b m d", b=b, d=d) - print(x.shape) diff --git a/tests/test_model.py b/tests/test_model.py index 95ec75bf..c290b118 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -239,19 +239,20 @@ def test_normalized_loss(): def test_image_meta_model(): batch = 2 channels = 3 - size = 900 + size = 4 + patch_size = 2 image = torch.randn((batch, channels, size, size)) model = ImageMetaModel(image_size=size, - patch_size=10, - depth=1, heads=1, mlp_dim=7, - channels=channels) + patch_size=patch_size, + channels=channels, + depth=1, heads=1, mlp_dim=7 + ) out = model(image) assert not torch.isnan(out).any() assert not torch.isnan(out).any() assert out.size() == (batch, channels,size,size) - def test_meta_model(): lat_lons = [] for lat in range(-90, 90, 5): @@ -260,12 +261,14 @@ def test_meta_model(): batch = 2 channels = 3 + image_size=20 + patch_size=4 model = MetaModel(lat_lons, - resolution=4, patch_size=2, + image_size=image_size, patch_size=patch_size, depth=1, heads=1, mlp_dim=7, channels=channels) features = torch.randn((batch, len(lat_lons), channels)) out = model(features) - # assert not torch.isnan(out).any() - # assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() assert out.size() == (batch, len(lat_lons), channels) From 87d1ffd22e34960bca767d71ca0325e368e59f19 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Jun 2024 17:16:17 +0000 Subject: [PATCH 08/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/__init__.py | 2 +- graph_weather/models/fengwu_ghr/layers.py | 80 ++++++++++++----------- tests/test_model.py | 43 ++++++------ 3 files changed, 65 insertions(+), 60 deletions(-) diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index fadc1d52..0083b16f 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,6 +1,6 @@ """Models""" -from .fengwu_ghr.layers import MetaModel,ImageMetaModel +from .fengwu_ghr.layers import ImageMetaModel, MetaModel from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 348f997b..06d15681 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,10 +1,9 @@ -from torch_geometric.nn import knn -from torch_geometric.utils import scatter import torch from einops import rearrange from einops.layers.torch import Rearrange from torch import nn - +from torch_geometric.nn import knn +from torch_geometric.utils import scatter # helpers @@ -13,18 +12,18 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) -def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, - k: int = 4, num_workers: int = 1): +def knn_interpolate( + x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 4, num_workers: int = 1 +): with torch.no_grad(): - assign_index = knn(pos_x, pos_y, k, - num_workers=num_workers) + assign_index = knn(pos_x, pos_y, k, num_workers=num_workers) y_idx, x_idx = assign_index[0], assign_index[1] diff = pos_x[x_idx] - pos_y[y_idx] squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) - den = scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum') - y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum') + den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum") + y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum") y = y / den @@ -77,8 +76,7 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange( - t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -97,8 +95,7 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim): for _ in range(depth): self.layers.append( nn.ModuleList( - [Attention(dim, heads=heads, dim_head=dim_head), - FeedForward(dim, mlp_dim)] + [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] ) ) @@ -110,11 +107,7 @@ def forward(self, x): class ImageMetaModel(nn.Module): - def __init__(self, *, - image_size, - patch_size, depth, - heads, mlp_dim, - channels=3, dim_head=64): + def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64): super().__init__() image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) @@ -127,8 +120,7 @@ def __init__(self, *, dim = patch_dim self.to_patch_embedding = nn.Sequential( Rearrange( - "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", - p_h=patch_height, p_w=patch_width + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width ), nn.LayerNorm(patch_dim), # TODO Do we need this? nn.Linear(patch_dim, dim), # TODO Do we need this? @@ -167,37 +159,49 @@ def forward(self, x): class MetaModel(nn.Module): - def __init__(self, lat_lons: list, *, - patch_size, depth, - heads, mlp_dim, - image_size=(721, 1440), - channels=3, dim_head=64): + def __init__( + self, + lat_lons: list, + *, + patch_size, + depth, + heads, + mlp_dim, + image_size=(721, 1440), + channels=3, + dim_head=64 + ): super().__init__() self.image_size = pair(image_size) self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - (torch.arange(-self.image_size[0]/2, - self.image_size[0]/2, 1)/self.image_size[0]*180).to(torch.long), - (torch.arange(0, self.image_size[1], 1)/self.image_size[1]*360).to(torch.long) + ( + torch.arange(-self.image_size[0] / 2, self.image_size[0] / 2, 1) + / self.image_size[0] + * 180 + ).to(torch.long), + (torch.arange(0, self.image_size[1], 1) / self.image_size[1] * 360).to(torch.long), ) - self.image_model = ImageMetaModel(image_size=image_size, - patch_size=patch_size, - depth=depth, - heads=heads, - mlp_dim=mlp_dim, - channels=channels, - dim_head=dim_head) + self.image_model = ImageMetaModel( + image_size=image_size, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels, + dim_head=dim_head, + ) def forward(self, x): b, n, c = x.shape x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange(x, "(w h) (b c) -> b c w h", b=b, c=c, - w=self.image_size[0], - h=self.image_size[1]) + x = rearrange( + x, "(w h) (b c) -> b c w h", b=b, c=c, w=self.image_size[0], h=self.image_size[1] + ) x = self.image_model(x) x = rearrange(x, "b c w h -> (w h) (b c)") diff --git a/tests/test_model.py b/tests/test_model.py index f7600943..5959349b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -10,7 +10,7 @@ Encoder, Processor, MetaModel, - ImageMetaModel + ImageMetaModel, ) from graph_weather.models.losses import NormalizedMSELoss from graph_weather.models.gencast.utils.noise import ( @@ -147,8 +147,7 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator( - output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -162,8 +161,7 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -204,8 +202,7 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -236,8 +233,7 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights - assert torch.isclose( - loss, criterion.weights.expand_as(out.mean(-1)).mean()) + assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) def test_image_meta_model(): @@ -246,16 +242,15 @@ def test_image_meta_model(): size = 4 patch_size = 2 image = torch.randn((batch, channels, size, size)) - model = ImageMetaModel(image_size=size, - patch_size=patch_size, - channels=channels, - depth=1, heads=1, mlp_dim=7 - ) + model = ImageMetaModel( + image_size=size, patch_size=patch_size, channels=channels, depth=1, heads=1, mlp_dim=7 + ) out = model(image) assert not torch.isnan(out).any() assert not torch.isnan(out).any() - assert out.size() == (batch, channels,size,size) + assert out.size() == (batch, channels, size, size) + def test_meta_model(): lat_lons = [] @@ -265,14 +260,20 @@ def test_meta_model(): batch = 2 channels = 3 - image_size=20 - patch_size=4 - model = MetaModel(lat_lons, - image_size=image_size, patch_size=patch_size, - depth=1, heads=1, mlp_dim=7, channels=channels) + image_size = 20 + patch_size = 4 + model = MetaModel( + lat_lons, + image_size=image_size, + patch_size=patch_size, + depth=1, + heads=1, + mlp_dim=7, + channels=channels, + ) features = torch.randn((batch, len(lat_lons), channels)) out = model(features) assert not torch.isnan(out).any() assert not torch.isnan(out).any() - assert out.size() == (batch, len(lat_lons), channels) + assert out.size() == (batch, len(lat_lons), channels) From 21d84c785ece635a6b753d2da1dc5d140c3bfca4 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Fri, 21 Jun 2024 17:58:13 +0200 Subject: [PATCH 09/16] wrapper meta model --- graph_weather/models/__init__.py | 2 +- graph_weather/models/fengwu_ghr/layers.py | 126 +++++++++++++++++----- tests/test_model.py | 76 +++++++++++-- 3 files changed, 171 insertions(+), 33 deletions(-) diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 0083b16f..ace964db 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,6 +1,6 @@ """Models""" -from .fengwu_ghr.layers import ImageMetaModel, MetaModel +from .fengwu_ghr.layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 06d15681..f5dbda57 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -76,7 +76,8 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange( + t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -95,7 +96,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim): for _ in range(depth): self.layers.append( nn.ModuleList( - [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] + [Attention(dim, heads=heads, dim_head=dim_head), + FeedForward(dim, mlp_dim)] ) ) @@ -107,20 +109,22 @@ def forward(self, x): class ImageMetaModel(nn.Module): - def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64): + def __init__(self, *, image_size, + patch_size, depth, heads, + mlp_dim, channels, dim_head): super().__init__() - image_height, image_width = pair(image_size) - patch_height, patch_width = pair(patch_size) + self.image_height, self.image_width = pair(image_size) + self.patch_height, self.patch_width = pair(patch_size) assert ( - image_height % patch_height == 0 and image_width % patch_width == 0 + self.image_height % self.patch_height == 0 and self.image_width % self.patch_width == 0 ), "Image dimensions must be divisible by the patch size." - patch_dim = channels * patch_height * patch_width + patch_dim = channels * self.patch_height * self.patch_width dim = patch_dim self.to_patch_embedding = nn.Sequential( Rearrange( - "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=self.patch_height, p_w=self.patch_width ), nn.LayerNorm(patch_dim), # TODO Do we need this? nn.Linear(patch_dim, dim), # TODO Do we need this? @@ -128,8 +132,8 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, ) self.pos_embedding = posemb_sincos_2d( - h=image_height // patch_height, - w=image_width // patch_width, + h=self.image_height // self.patch_height, + w=self.image_width // self.patch_width, dim=dim, ) @@ -138,10 +142,10 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, self.reshaper = nn.Sequential( Rearrange( "b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)", - h=image_height // patch_height, - w=image_width // patch_width, - p_h=patch_height, - p_w=patch_width, + h=self.image_height // self.patch_height, + w=self.image_width // self.patch_width, + p_h=self.patch_height, + p_w=self.patch_width, ) ) @@ -158,33 +162,53 @@ def forward(self, x): return x +class WrapperImageModel(nn.Module): + def __init__(self, image_meta_model: ImageMetaModel, + scale_factor): + super().__init__() + s_h, s_w = pair(scale_factor) + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", + s_h=s_h, s_w=s_w) + self.image_meta_model = image_meta_model + self.debatcher = Rearrange(" (b s_h s_w) c h w -> b c (h s_h) (w s_w)", + s_h=s_h, s_w=s_w) + + def forward(self, x): + x = self.batcher(x) + x = self.image_meta_model(x) + x = self.debatcher(x) + return x + + class MetaModel(nn.Module): def __init__( self, lat_lons: list, *, + image_size, patch_size, depth, heads, mlp_dim, - image_size=(721, 1440), - channels=3, + channels, dim_head=64 ): super().__init__() - self.image_size = pair(image_size) + self.i_h, self.i_w = pair(image_size) self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( ( - torch.arange(-self.image_size[0] / 2, self.image_size[0] / 2, 1) - / self.image_size[0] + torch.arange(-self.i_h / 2, + self.i_h / 2, 1) + / self.i_h * 180 ).to(torch.long), - (torch.arange(0, self.image_size[1], 1) / self.image_size[1] * 360).to(torch.long), + (torch.arange(0, self.i_w, 1) / + self.i_w * 360).to(torch.long), ) - self.image_model = ImageMetaModel( + self.image_meta_model = ImageMetaModel( image_size=image_size, patch_size=patch_size, depth=depth, @@ -200,11 +224,65 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) x = rearrange( - x, "(w h) (b c) -> b c w h", b=b, c=c, w=self.image_size[0], h=self.image_size[1] + x, "(h w) (b c) -> b c h w", b=b, c=c, + h=self.i_h, w=self.i_w ) - x = self.image_model(x) + x = self.image_meta_model(x) - x = rearrange(x, "b c w h -> (w h) (b c)") + x = rearrange(x, "b c h w -> (h w) (b c)") x = knn_interpolate(x, self.pos_y, self.pos_x) x = rearrange(x, "n (b c) -> b n c", b=b, c=c) return x + + +class WrapperMetaModel(nn.Module): + def __init__( + self, + lat_lons: list, + meta_model: MetaModel, + scale_factor + ): + super().__init__() + self.image_meta_model = meta_model.image_meta_model + + s_h, s_w = pair(scale_factor) + self.i_h, self.i_w = meta_model.i_h*s_h, meta_model.i_w*s_w + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + ( + torch.arange(-self.i_h / 2, + self.i_h / 2, 1) + / self.i_h + * 180 + ).to(torch.long), + (torch.arange(0, self.i_w, 1) / + self.i_w * 360).to(torch.long), + ) + + + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", + s_h=s_h, s_w=s_w) + + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", + s_h=s_h, s_w=s_w) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange( + x, "(h w) (b c) -> b c h w", b=b, c=c, + h=self.i_h, w=self.i_w + ) + + x = self.batcher(x) + x = self.image_meta_model(x) + x = self.debatcher(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + + + return x diff --git a/tests/test_model.py b/tests/test_model.py index 5959349b..e19e831e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,8 +9,10 @@ Decoder, Encoder, Processor, - MetaModel, ImageMetaModel, + MetaModel, + WrapperImageModel, + WrapperMetaModel ) from graph_weather.models.losses import NormalizedMSELoss from graph_weather.models.gencast.utils.noise import ( @@ -147,7 +149,8 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator( + output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -161,7 +164,8 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -202,7 +206,8 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -233,7 +238,8 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights - assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) + assert torch.isclose( + loss, criterion.weights.expand_as(out.mean(-1)).mean()) def test_image_meta_model(): @@ -243,13 +249,35 @@ def test_image_meta_model(): patch_size = 2 image = torch.randn((batch, channels, size, size)) model = ImageMetaModel( - image_size=size, patch_size=patch_size, channels=channels, depth=1, heads=1, mlp_dim=7 + image_size=size, patch_size=patch_size, + channels=channels, depth=1, heads=1, mlp_dim=7, + dim_head=64 ) out = model(image) assert not torch.isnan(out).any() assert not torch.isnan(out).any() - assert out.size() == (batch, channels, size, size) + assert out.size() == image.size() + + +def test_wrapper_image_meta_model(): + batch = 2 + channels = 3 + size = 4 + patch_size = 2 + model = ImageMetaModel( + image_size=size, patch_size=patch_size, + channels=channels, depth=1, heads=1, mlp_dim=7, + dim_head=64 + ) + scale_factor = 3 + big_image = torch.randn((batch, channels, + size*scale_factor, size*scale_factor)) + big_model = WrapperImageModel(model, scale_factor) + out = big_model(big_image) + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert out.size() == big_image.size() def test_meta_model(): @@ -270,10 +298,42 @@ def test_meta_model(): heads=1, mlp_dim=7, channels=channels, + dim_head=64 ) features = torch.randn((batch, len(lat_lons), channels)) out = model(features) assert not torch.isnan(out).any() assert not torch.isnan(out).any() - assert out.size() == (batch, len(lat_lons), channels) + assert out.size() == features.size() + + +def test_wrapper_meta_model(): + lat_lons = [] + for lat in range(-90, 90, 5): + for lon in range(0, 360, 5): + lat_lons.append((lat, lon)) + + batch = 2 + channels = 3 + image_size = 20 + patch_size = 4 + scale_factor=3 + model = MetaModel( + lat_lons, + image_size=image_size, + patch_size=patch_size, + depth=1, + heads=1, + mlp_dim=7, + channels=channels, + dim_head=64 + ) + + big_features = torch.randn((batch, len(lat_lons), channels)) + big_model = WrapperMetaModel(lat_lons, model, scale_factor) + out = big_model(big_features) + + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert out.size() == big_features.size() From 07a8d0f6078355cd47f19b0fa95ecdb6b2ae2be5 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Mon, 1 Jul 2024 16:30:29 +0200 Subject: [PATCH 10/16] RES --- graph_weather/models/fengwu_ghr/layers.py | 72 ++++++++++++++++++----- 1 file changed, 57 insertions(+), 15 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index f5dbda57..ad82fcad 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -89,33 +89,64 @@ def forward(self, x): class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, h=None, w=None, scale_factor=None): super().__init__() + self.depth = depth + self.res = res self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) - for _ in range(depth): + self.res_layers = nn.ModuleList([]) + for _ in range(self.depth): self.layers.append( nn.ModuleList( [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] ) ) + if self.res: + assert h is not None and w is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor" + s_h, s_w = pair(scale_factor) + self.res_layers.append( + nn.ModuleList( + [ # reshape to original shape window partition operation + # (b s_h s_w) (h w) d -> b (s_h h) (s_w w) d -> (b h w) (s_h s_w) d + Rearrange("(b s_h s_w) (h w) d -> (b h w) (s_h s_w) d", + h=h, w=w, s_h=s_h, s_w=s_w + ), + # TODO ????? + Attention(dim, heads=heads, dim_head=dim_head), + # restore shape + Rearrange("(b h w) (s_h s_w) d -> (b s_h s_w) (h w) d", + h=h, w=w, s_h=s_h, s_w=s_w + ), + ])) def forward(self, x): - for attn, ff in self.layers: + for i in range(self.depth): + attn, ff = self.layers[i] x = attn(x) + x x = ff(x) + x + if self.res: + reshape, loc_attn, restore = self.res_layers[i] + x = reshape(x) + x = loc_attn(x) + x + x = restore(x) return self.norm(x) class ImageMetaModel(nn.Module): def __init__(self, *, image_size, patch_size, depth, heads, - mlp_dim, channels, dim_head): + mlp_dim, channels, dim_head, + res=False, + scale_factor=None): super().__init__() self.image_height, self.image_width = pair(image_size) self.patch_height, self.patch_width = pair(patch_size) + s_h, s_w = pair(scale_factor) + if res: + assert scale_factor is not None, "If res=True, you must provide scale_factor" assert ( self.image_height % self.patch_height == 0 and self.image_width % self.patch_width == 0 ), "Image dimensions must be divisible by the patch size." @@ -137,7 +168,12 @@ def __init__(self, *, image_size, dim=dim, ) - self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, + res=res, + h=self.image_height // self.patch_height, + w=self.image_width // self.patch_width, + s_h=s_h, + s_w=s_w) self.reshaper = nn.Sequential( Rearrange( @@ -169,8 +205,13 @@ def __init__(self, image_meta_model: ImageMetaModel, s_h, s_w = pair(scale_factor) self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) - self.image_meta_model = image_meta_model - self.debatcher = Rearrange(" (b s_h s_w) c h w -> b c (h s_h) (w s_w)", + + imm_args = image_meta_model.vars().update( + {"res": True, "scale_factor": scale_factor}) + self.image_meta_model = ImageMetaModel(**imm_args) + self.image_meta_model.load(image_meta_model, strict=False) + + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) def forward(self, x): @@ -224,7 +265,7 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, + x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w ) x = self.image_meta_model(x) @@ -243,8 +284,6 @@ def __init__( scale_factor ): super().__init__() - self.image_meta_model = meta_model.image_meta_model - s_h, s_w = pair(scale_factor) self.i_h, self.i_w = meta_model.i_h*s_h, meta_model.i_w*s_w self.pos_x = torch.tensor(lat_lons) @@ -259,23 +298,27 @@ def __init__( self.i_w * 360).to(torch.long), ) - self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) + + imm_args = meta_model.image_meta_model.vars().update( + {"res": True, "scale_factor": scale_factor}) + self.image_meta_model = ImageMetaModel(**imm_args) + self.image_meta_model.load(meta_model.image_meta_model, strict=False) self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) def forward(self, x): b, n, c = x.shape - + x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, + x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w ) - + x = self.batcher(x) x = self.image_meta_model(x) x = self.debatcher(x) @@ -283,6 +326,5 @@ def forward(self, x): x = rearrange(x, "b c h w -> (h w) (b c)") x = knn_interpolate(x, self.pos_y, self.pos_x) x = rearrange(x, "n (b c) -> b n c", b=b, c=c) - return x From cd84968524f41ff654a16bcd1587ab2a1c59ae38 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 2 Jul 2024 11:15:32 +0200 Subject: [PATCH 11/16] load RES state_dict --- graph_weather/models/fengwu_ghr/layers.py | 45 ++++++++++++++++------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index ad82fcad..e42d7545 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -89,7 +89,7 @@ def forward(self, x): class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, h=None, w=None, scale_factor=None): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=None, scale_factor=None): super().__init__() self.depth = depth self.res = res @@ -104,7 +104,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, h=None, w=No ) ) if self.res: - assert h is not None and w is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor" + assert image_size is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor" + h, w = pair(image_size) s_h, s_w = pair(scale_factor) self.res_layers.append( nn.ModuleList( @@ -139,8 +140,20 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels, dim_head, res=False, - scale_factor=None): + scale_factor=None, + **kwargs): super().__init__() + #TODO this can probably be done better + self.image_size = image_size + self.patch_size = patch_size + self.depth = depth + self.heads = heads + self.mlp_dim = mlp_dim + self.channels = channels + self.dim_head = dim_head + self.res = res + self.scale_factor = scale_factor + self.image_height, self.image_width = pair(image_size) self.patch_height, self.patch_width = pair(patch_size) s_h, s_w = pair(scale_factor) @@ -170,10 +183,12 @@ def __init__(self, *, image_size, self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, res=res, - h=self.image_height // self.patch_height, - w=self.image_width // self.patch_width, - s_h=s_h, - s_w=s_w) + image_size=( + self.image_height // self.patch_height, + self.image_width // self.patch_width), + scale_factor=( + s_h, + s_w)) self.reshaper = nn.Sequential( Rearrange( @@ -205,12 +220,13 @@ def __init__(self, image_meta_model: ImageMetaModel, s_h, s_w = pair(scale_factor) self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) - - imm_args = image_meta_model.vars().update( + + imm_args = vars(image_meta_model) + imm_args.update( {"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load(image_meta_model, strict=False) - + self.image_meta_model.load_state_dict(image_meta_model.state_dict(), strict=False) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) @@ -301,11 +317,12 @@ def __init__( self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) - imm_args = meta_model.image_meta_model.vars().update( + imm_args = vars(meta_model.image_meta_model) + imm_args.update( {"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load(meta_model.image_meta_model, strict=False) - + self.image_meta_model.load_state_dict(meta_model.image_meta_model.state_dict(), strict=False) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) From b15110f19b253e22a27abdaa7b8db066d2ff0703 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jul 2024 09:25:20 +0000 Subject: [PATCH 12/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/fengwu_ghr/layers.py | 150 +++++++++++----------- tests/test_model.py | 43 ++++--- 2 files changed, 98 insertions(+), 95 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index f03235e3..bbe19a85 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -76,8 +76,7 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange( - t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -89,7 +88,9 @@ def forward(self, x): class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=None, scale_factor=None): + def __init__( + self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=None, scale_factor=None + ): super().__init__() self.depth = depth self.res = res @@ -99,28 +100,39 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=N for _ in range(self.depth): self.layers.append( nn.ModuleList( - [Attention(dim, heads=heads, dim_head=dim_head), - FeedForward(dim, mlp_dim)] + [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] ) ) if self.res: - assert image_size is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor" + assert ( + image_size is not None and scale_factor is not None + ), "If res=True, you must provide h, w and scale_factor" h, w = pair(image_size) s_h, s_w = pair(scale_factor) self.res_layers.append( nn.ModuleList( [ # reshape to original shape window partition operation # (b s_h s_w) (h w) d -> b (s_h h) (s_w w) d -> (b h w) (s_h s_w) d - Rearrange("(b s_h s_w) (h w) d -> (b h w) (s_h s_w) d", - h=h, w=w, s_h=s_h, s_w=s_w - ), + Rearrange( + "(b s_h s_w) (h w) d -> (b h w) (s_h s_w) d", + h=h, + w=w, + s_h=s_h, + s_w=s_w, + ), # TODO ????? Attention(dim, heads=heads, dim_head=dim_head), # restore shape - Rearrange("(b h w) (s_h s_w) d -> (b s_h s_w) (h w) d", - h=h, w=w, s_h=s_h, s_w=s_w - ), - ])) + Rearrange( + "(b h w) (s_h s_w) d -> (b s_h s_w) (h w) d", + h=h, + w=w, + s_h=s_h, + s_w=s_w, + ), + ] + ) + ) def forward(self, x): for i in range(self.depth): @@ -136,14 +148,22 @@ def forward(self, x): class ImageMetaModel(nn.Module): - def __init__(self, *, image_size, - patch_size, depth, heads, - mlp_dim, channels, dim_head, - res=False, - scale_factor=None, - **kwargs): + def __init__( + self, + *, + image_size, + patch_size, + depth, + heads, + mlp_dim, + channels, + dim_head, + res=False, + scale_factor=None, + **kwargs + ): super().__init__() - #TODO this can probably be done better + # TODO this can probably be done better self.image_size = image_size self.patch_size = patch_size self.depth = depth @@ -168,7 +188,9 @@ def __init__(self, *, image_size, dim = patch_dim self.to_patch_embedding = nn.Sequential( Rearrange( - "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=self.patch_height, p_w=self.patch_width + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", + p_h=self.patch_height, + p_w=self.patch_width, ), nn.LayerNorm(patch_dim), # TODO Do we need this? nn.Linear(patch_dim, dim), # TODO Do we need this? @@ -181,14 +203,19 @@ def __init__(self, *, image_size, dim=dim, ) - self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, - res=res, - image_size=( - self.image_height // self.patch_height, - self.image_width // self.patch_width), - scale_factor=( - s_h, - s_w)) + self.transformer = Transformer( + dim, + depth, + heads, + dim_head, + mlp_dim, + res=res, + image_size=( + self.image_height // self.patch_height, + self.image_width // self.patch_width, + ), + scale_factor=(s_h, s_w), + ) self.reshaper = nn.Sequential( Rearrange( @@ -203,6 +230,7 @@ def __init__(self, *, image_size, def forward(self, x): device = x.device dtype = x.dtype + def forward(self, x): device = x.device dtype = x.dtype @@ -219,21 +247,17 @@ def forward(self, x): class WrapperImageModel(nn.Module): - def __init__(self, image_meta_model: ImageMetaModel, - scale_factor): + def __init__(self, image_meta_model: ImageMetaModel, scale_factor): super().__init__() s_h, s_w = pair(scale_factor) - self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", - s_h=s_h, s_w=s_w) + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) imm_args = vars(image_meta_model) - imm_args.update( - {"res": True, "scale_factor": scale_factor}) + imm_args.update({"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) self.image_meta_model.load_state_dict(image_meta_model.state_dict(), strict=False) - self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", - s_h=s_h, s_w=s_w) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) def forward(self, x): x = self.batcher(x) @@ -260,14 +284,8 @@ def __init__( self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - ( - torch.arange(-self.i_h / 2, - self.i_h / 2, 1) - / self.i_h - * 180 - ).to(torch.long), - (torch.arange(0, self.i_w, 1) / - self.i_w * 360).to(torch.long), + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), ) self.image_meta_model = ImageMetaModel( @@ -285,10 +303,7 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, - h=self.i_h, w=self.i_w - ) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) x = self.image_meta_model(x) x = rearrange(x, "b c h w -> (h w) (b c)") @@ -298,48 +313,33 @@ def forward(self, x): class WrapperMetaModel(nn.Module): - def __init__( - self, - lat_lons: list, - meta_model: MetaModel, - scale_factor - ): + def __init__(self, lat_lons: list, meta_model: MetaModel, scale_factor): super().__init__() s_h, s_w = pair(scale_factor) - self.i_h, self.i_w = meta_model.i_h*s_h, meta_model.i_w*s_w + self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - ( - torch.arange(-self.i_h / 2, - self.i_h / 2, 1) - / self.i_h - * 180 - ).to(torch.long), - (torch.arange(0, self.i_w, 1) / - self.i_w * 360).to(torch.long), + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), ) - self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", - s_h=s_h, s_w=s_w) + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) imm_args = vars(meta_model.image_meta_model) - imm_args.update( - {"res": True, "scale_factor": scale_factor}) + imm_args.update({"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load_state_dict(meta_model.image_meta_model.state_dict(), strict=False) + self.image_meta_model.load_state_dict( + meta_model.image_meta_model.state_dict(), strict=False + ) - self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", - s_h=s_h, s_w=s_w) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) def forward(self, x): b, n, c = x.shape x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, - h=self.i_h, w=self.i_w - ) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) x = self.batcher(x) x = self.image_meta_model(x) diff --git a/tests/test_model.py b/tests/test_model.py index ac260b48..8f52116b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -12,7 +12,7 @@ ImageMetaModel, MetaModel, WrapperImageModel, - WrapperMetaModel + WrapperMetaModel, ) from graph_weather.models.losses import NormalizedMSELoss @@ -151,8 +151,7 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator( - output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -166,8 +165,7 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -208,8 +206,7 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -240,8 +237,7 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights - assert torch.isclose( - loss, criterion.weights.expand_as(out.mean(-1)).mean()) + assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) def test_image_meta_model(): @@ -251,9 +247,13 @@ def test_image_meta_model(): patch_size = 2 image = torch.randn((batch, channels, size, size)) model = ImageMetaModel( - image_size=size, patch_size=patch_size, - channels=channels, depth=1, heads=1, mlp_dim=7, - dim_head=64 + image_size=size, + patch_size=patch_size, + channels=channels, + depth=1, + heads=1, + mlp_dim=7, + dim_head=64, ) out = model(image) @@ -268,13 +268,16 @@ def test_wrapper_image_meta_model(): size = 4 patch_size = 2 model = ImageMetaModel( - image_size=size, patch_size=patch_size, - channels=channels, depth=1, heads=1, mlp_dim=7, - dim_head=64 + image_size=size, + patch_size=patch_size, + channels=channels, + depth=1, + heads=1, + mlp_dim=7, + dim_head=64, ) scale_factor = 3 - big_image = torch.randn((batch, channels, - size*scale_factor, size*scale_factor)) + big_image = torch.randn((batch, channels, size * scale_factor, size * scale_factor)) big_model = WrapperImageModel(model, scale_factor) out = big_model(big_image) assert not torch.isnan(out).any() @@ -300,7 +303,7 @@ def test_meta_model(): heads=1, mlp_dim=7, channels=channels, - dim_head=64 + dim_head=64, ) features = torch.randn((batch, len(lat_lons), channels)) @@ -320,7 +323,7 @@ def test_wrapper_meta_model(): channels = 3 image_size = 20 patch_size = 4 - scale_factor=3 + scale_factor = 3 model = MetaModel( lat_lons, image_size=image_size, @@ -329,7 +332,7 @@ def test_wrapper_meta_model(): heads=1, mlp_dim=7, channels=channels, - dim_head=64 + dim_head=64, ) big_features = torch.randn((batch, len(lat_lons), channels)) From 1146db9de02985298ba0539f3a7c661e759fe31e Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 2 Jul 2024 14:14:11 +0200 Subject: [PATCH 13/16] bug fix --- graph_weather/models/fengwu_ghr/layers.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index f03235e3..e42d7545 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -200,15 +200,10 @@ def __init__(self, *, image_size, ) ) - def forward(self, x): - device = x.device - dtype = x.dtype def forward(self, x): device = x.device dtype = x.dtype - x = self.to_patch_embedding(x) - x += self.pos_embedding.to(device, dtype=dtype) x = self.to_patch_embedding(x) x += self.pos_embedding.to(device, dtype=dtype) From 325fd0e9b57f03a7d225e37d6bf7bd0de79b4038 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 2 Jul 2024 15:46:45 +0200 Subject: [PATCH 14/16] bug fix --- tests/test_model.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index c7b24bee..e5604fce 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -358,4 +358,36 @@ def test_gencast_loss(): preds = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) noise_levels = torch.rand((batch_size, 1)) targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) - assert loss.forward(preds, targets, noise_levels) is not None + assert loss.forward(preds, noise_levels, targets) is not None + + +def test_gencast_denoiser(): + grid_lat = np.arange(-90, 90, 1) + grid_lon = np.arange(0, 360, 1) + input_features_dim = 10 + output_features_dim = 5 + batch_size = 3 + + denoiser = Denoiser( + grid_lon=grid_lon, + grid_lat=grid_lat, + input_features_dim=input_features_dim, + output_features_dim=output_features_dim, + hidden_dims=[16, 32], + num_blocks=3, + num_heads=4, + splits=0, + num_hops=1, + device=torch.device("cpu"), + ).eval() + + corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim)) + prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2 * input_features_dim)) + noise_levels = torch.randn((batch_size, 1)) + + with torch.no_grad(): + preds = denoiser( + corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels + ) + + assert not torch.isnan(preds).any() \ No newline at end of file From cfa9c3f7e073e07e796dc6de6cc1a26670af2d27 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jul 2024 13:48:05 +0000 Subject: [PATCH 15/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index e5604fce..9b43cb16 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -390,4 +390,4 @@ def test_gencast_denoiser(): corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels ) - assert not torch.isnan(preds).any() \ No newline at end of file + assert not torch.isnan(preds).any() From 2fadf974fe5087d974ed4caa1fd5c1d3bab9afe3 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Mon, 29 Jul 2024 12:28:26 +0200 Subject: [PATCH 16/16] env yml fix --- environment_cuda.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/environment_cuda.yml b/environment_cuda.yml index 9f76251a..3d06cbb9 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -6,12 +6,12 @@ channels: - conda-forge - defaults dependencies: - - pytorch-cuda=12.1 + - pytorch-cuda - numcodecs - pandas - pip - pyg - - python=3.12 + - python - pytorch - pytorch-cluster - pytorch-scatter