Skip to content

Commit

Permalink
Fengwu ghr (#114)
Browse files Browse the repository at this point in the history
* fengwu_ghr: initial

* fengwu_ghr: fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Interpolate initial

* ImageMetaModel

* MetaModel initial

* tested metamodel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rnwzd and pre-commit-ci[bot] authored Jun 17, 2024
1 parent b0bab00 commit ccaea16
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 19 deletions.
2 changes: 1 addition & 1 deletion graph_weather/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Models"""

from .fengwu_ghr.layers import MetaModel
from .fengwu_ghr.layers import ImageMetaModel, MetaModel
from .layers.assimilator_decoder import AssimilatorDecoder
from .layers.assimilator_encoder import AssimilatorEncoder
from .layers.decoder import Decoder
Expand Down
84 changes: 78 additions & 6 deletions graph_weather/models/fengwu_ghr/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import nn
from torch_geometric.nn import knn
from torch_geometric.utils import scatter

# helpers

Expand All @@ -10,6 +12,24 @@ def pair(t):
return t if isinstance(t, tuple) else (t, t)


def knn_interpolate(
x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 4, num_workers: int = 1
):
with torch.no_grad():
assign_index = knn(pos_x, pos_y, k, num_workers=num_workers)
y_idx, x_idx = assign_index[0], assign_index[1]
diff = pos_x[x_idx] - pos_y[y_idx]
squared_distance = (diff * diff).sum(dim=-1, keepdim=True)
weights = 1.0 / torch.clamp(squared_distance, min=1e-16)

den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum")
y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum")

y = y / den

return y


def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
Expand Down Expand Up @@ -86,7 +106,7 @@ def forward(self, x):
return self.norm(x)


class MetaModel(nn.Module):
class ImageMetaModel(nn.Module):
def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64):
super().__init__()
image_height, image_width = pair(image_size)
Expand Down Expand Up @@ -125,14 +145,66 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3,
)
)

def forward(self, img):
device = img.device
def forward(self, x):
device = x.device
dtype = x.dtype

x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)
x = self.to_patch_embedding(x)
x += self.pos_embedding.to(device, dtype=dtype)

x = self.transformer(x)

x = self.reshaper(x)

return x


class MetaModel(nn.Module):
def __init__(
self,
lat_lons: list,
*,
patch_size,
depth,
heads,
mlp_dim,
image_size=(721, 1440),
channels=3,
dim_head=64
):
super().__init__()
self.image_size = pair(image_size)

self.pos_x = torch.tensor(lat_lons)
self.pos_y = torch.cartesian_prod(
(
torch.arange(-self.image_size[0] / 2, self.image_size[0] / 2, 1)
/ self.image_size[0]
* 180
).to(torch.long),
(torch.arange(0, self.image_size[1], 1) / self.image_size[1] * 360).to(torch.long),
)

self.image_model = ImageMetaModel(
image_size=image_size,
patch_size=patch_size,
depth=depth,
heads=heads,
mlp_dim=mlp_dim,
channels=channels,
dim_head=dim_head,
)

def forward(self, x):
b, n, c = x.shape

x = rearrange(x, "b n c -> n (b c)")
x = knn_interpolate(x, self.pos_x, self.pos_y)
x = rearrange(
x, "(w h) (b c) -> b c w h", b=b, c=c, w=self.image_size[0], h=self.image_size[1]
)
x = self.image_model(x)

x = rearrange(x, "b c w h -> (w h) (b c)")
x = knn_interpolate(x, self.pos_y, self.pos_x)
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)
return x
47 changes: 35 additions & 12 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Encoder,
Processor,
MetaModel,
ImageMetaModel,
)
from graph_weather.models.losses import NormalizedMSELoss
from graph_weather.models.gencast.utils.noise import (
Expand Down Expand Up @@ -235,22 +236,44 @@ def test_normalized_loss():
assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean())


def test_gencast_noise():
num_lat = 32
num_samples = 5
target_residuals = np.zeros((2 * num_lat, num_lat, num_samples))
noise_level = sample_noise_level()
noise = generate_isotropic_noise(num_lat=num_lat, num_samples=target_residuals.shape[-1])
corrupted_residuals = target_residuals + noise_level * noise
assert corrupted_residuals.shape == target_residuals.shape
assert not np.isnan(corrupted_residuals).any()
def test_image_meta_model():
batch = 2
channels = 3
size = 4
patch_size = 2
image = torch.randn((batch, channels, size, size))
model = ImageMetaModel(
image_size=size, patch_size=patch_size, channels=channels, depth=1, heads=1, mlp_dim=7
)

out = model(image)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == (batch, channels, size, size)


def test_meta_model():
model = MetaModel(image_size=100, patch_size=10, depth=1, heads=1, mlp_dim=7, channels=3)
features = torch.randn((1, 3, 100, 100))
lat_lons = []
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
lat_lons.append((lat, lon))

batch = 2
channels = 3
image_size = 20
patch_size = 4
model = MetaModel(
lat_lons,
image_size=image_size,
patch_size=patch_size,
depth=1,
heads=1,
mlp_dim=7,
channels=channels,
)
features = torch.randn((batch, len(lat_lons), channels))

out = model(features)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == (1, 3, 100, 100)
assert out.size() == (batch, len(lat_lons), channels)

0 comments on commit ccaea16

Please sign in to comment.