diff --git a/environment_cuda.yml b/environment_cuda.yml index 9f76251a..3d06cbb9 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -6,12 +6,12 @@ channels: - conda-forge - defaults dependencies: - - pytorch-cuda=12.1 + - pytorch-cuda - numcodecs - pandas - pip - pyg - - python=3.12 + - python - pytorch - pytorch-cluster - pytorch-scatter diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index d129d2dd..2d032ab8 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,3 +1,6 @@ +from scipy.interpolate import griddata +from torch_geometric.nn import knn +from torch_geometric.utils import scatter import torch from einops import rearrange from einops.layers.torch import Rearrange @@ -12,6 +15,22 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) +def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, + k: int = 4, num_workers: int = 1): + with torch.no_grad(): + assign_index = knn(pos_x, pos_y, k, num_workers=num_workers) + y_idx, x_idx = assign_index[0], assign_index[1] + diff = pos_x[x_idx] - pos_y[y_idx] + squared_distance = (diff * diff).sum(dim=-1, keepdim=True) + weights = 1.0 / torch.clamp(squared_distance, min=1e-16) + + den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum") + y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum") + + y = y / den + + return y + def knn_interpolate( x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 4, num_workers: int = 1 ): @@ -344,3 +363,93 @@ def forward(self, x): x = rearrange(x, "n (b c) -> b n c", b=b, c=c) return x + +class MetaModel(nn.Module): + def __init__( + self, + lat_lons: list, + *, + image_size, + patch_size, + depth, + heads, + mlp_dim, + channels, + dim_head=64 + ): + super().__init__() + self.i_h, self.i_w = pair(image_size) + + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), + ) + + self.image_meta_model = ImageMetaModel( + image_size=image_size, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels, + dim_head=dim_head, + ) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) + x = self.image_meta_model(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + return x + + +class WrapperMetaModel(nn.Module): + def __init__( + self, + lat_lons: list, + meta_model: MetaModel, + scale_factor + ): + super().__init__() + s_h, s_w = pair(scale_factor) + self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), + ) + + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) + + imm_args = vars(meta_model.image_meta_model) + imm_args.update({"res": True, "scale_factor": scale_factor}) + self.image_meta_model = ImageMetaModel(**imm_args) + self.image_meta_model.load_state_dict( + meta_model.image_meta_model.state_dict(), strict=False + ) + + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) + + x = self.batcher(x) + x = self.image_meta_model(x) + x = self.debatcher(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + + return x