Skip to content

Commit

Permalink
fengwu_ghr: initial
Browse files Browse the repository at this point in the history
fengwu_ghr: fixes

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Interpolate initial

ImageMetaModel

MetaModel initial

tested metamodel

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

wrapper meta model

RES

load RES state_dict

bug fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

env yml fix
  • Loading branch information
rnwzd committed Jul 29, 2024
1 parent 743cf97 commit 257b353
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 2 deletions.
4 changes: 2 additions & 2 deletions environment_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ channels:
- conda-forge
- defaults
dependencies:
- pytorch-cuda=12.1
- pytorch-cuda
- numcodecs
- pandas
- pip
- pyg
- python=3.12
- python
- pytorch
- pytorch-cluster
- pytorch-scatter
Expand Down
109 changes: 109 additions & 0 deletions graph_weather/models/fengwu_ghr/layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from scipy.interpolate import griddata
from torch_geometric.nn import knn
from torch_geometric.utils import scatter
import torch
from einops import rearrange
from einops.layers.torch import Rearrange
Expand All @@ -12,6 +15,22 @@ def pair(t):
return t if isinstance(t, tuple) else (t, t)


def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor,
k: int = 4, num_workers: int = 1):
with torch.no_grad():
assign_index = knn(pos_x, pos_y, k, num_workers=num_workers)
y_idx, x_idx = assign_index[0], assign_index[1]
diff = pos_x[x_idx] - pos_y[y_idx]
squared_distance = (diff * diff).sum(dim=-1, keepdim=True)
weights = 1.0 / torch.clamp(squared_distance, min=1e-16)

den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum")
y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum")

y = y / den

return y

def knn_interpolate(
x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 4, num_workers: int = 1
):
Expand Down Expand Up @@ -344,3 +363,93 @@ def forward(self, x):
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)

return x

class MetaModel(nn.Module):
def __init__(
self,
lat_lons: list,
*,
image_size,
patch_size,
depth,
heads,
mlp_dim,
channels,
dim_head=64
):
super().__init__()
self.i_h, self.i_w = pair(image_size)

self.pos_x = torch.tensor(lat_lons)
self.pos_y = torch.cartesian_prod(
(torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long),
(torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long),
)

self.image_meta_model = ImageMetaModel(
image_size=image_size,
patch_size=patch_size,
depth=depth,
heads=heads,
mlp_dim=mlp_dim,
channels=channels,
dim_head=dim_head,
)

def forward(self, x):
b, n, c = x.shape

x = rearrange(x, "b n c -> n (b c)")
x = knn_interpolate(x, self.pos_x, self.pos_y)
x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w)
x = self.image_meta_model(x)

x = rearrange(x, "b c h w -> (h w) (b c)")
x = knn_interpolate(x, self.pos_y, self.pos_x)
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)
return x


class WrapperMetaModel(nn.Module):
def __init__(
self,
lat_lons: list,
meta_model: MetaModel,
scale_factor
):
super().__init__()
s_h, s_w = pair(scale_factor)
self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w
self.pos_x = torch.tensor(lat_lons)
self.pos_y = torch.cartesian_prod(
(torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long),
(torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long),
)

self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w)

imm_args = vars(meta_model.image_meta_model)
imm_args.update({"res": True, "scale_factor": scale_factor})
self.image_meta_model = ImageMetaModel(**imm_args)
self.image_meta_model.load_state_dict(
meta_model.image_meta_model.state_dict(), strict=False
)

self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w)

def forward(self, x):
b, n, c = x.shape

x = rearrange(x, "b n c -> n (b c)")
x = knn_interpolate(x, self.pos_x, self.pos_y)
x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w)

x = self.batcher(x)
x = self.image_meta_model(x)
x = self.debatcher(x)

x = rearrange(x, "b c h w -> (h w) (b c)")
x = knn_interpolate(x, self.pos_y, self.pos_x)
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)

return x

0 comments on commit 257b353

Please sign in to comment.