Skip to content

Commit

Permalink
Merge branch 'fengwu_ghr' of https://github.com/openclimatefix/graph_…
Browse files Browse the repository at this point in the history
…weather into fengwu_ghr
  • Loading branch information
rnwzd committed Jul 29, 2024
2 parents 257b353 + 0f46d7a commit 04d4776
Showing 1 changed file with 0 additions and 90 deletions.
90 changes: 0 additions & 90 deletions graph_weather/models/fengwu_ghr/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,93 +363,3 @@ def forward(self, x):
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)

return x

class MetaModel(nn.Module):
def __init__(
self,
lat_lons: list,
*,
image_size,
patch_size,
depth,
heads,
mlp_dim,
channels,
dim_head=64
):
super().__init__()
self.i_h, self.i_w = pair(image_size)

self.pos_x = torch.tensor(lat_lons)
self.pos_y = torch.cartesian_prod(
(torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long),
(torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long),
)

self.image_meta_model = ImageMetaModel(
image_size=image_size,
patch_size=patch_size,
depth=depth,
heads=heads,
mlp_dim=mlp_dim,
channels=channels,
dim_head=dim_head,
)

def forward(self, x):
b, n, c = x.shape

x = rearrange(x, "b n c -> n (b c)")
x = knn_interpolate(x, self.pos_x, self.pos_y)
x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w)
x = self.image_meta_model(x)

x = rearrange(x, "b c h w -> (h w) (b c)")
x = knn_interpolate(x, self.pos_y, self.pos_x)
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)
return x


class WrapperMetaModel(nn.Module):
def __init__(
self,
lat_lons: list,
meta_model: MetaModel,
scale_factor
):
super().__init__()
s_h, s_w = pair(scale_factor)
self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w
self.pos_x = torch.tensor(lat_lons)
self.pos_y = torch.cartesian_prod(
(torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long),
(torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long),
)

self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w)

imm_args = vars(meta_model.image_meta_model)
imm_args.update({"res": True, "scale_factor": scale_factor})
self.image_meta_model = ImageMetaModel(**imm_args)
self.image_meta_model.load_state_dict(
meta_model.image_meta_model.state_dict(), strict=False
)

self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w)

def forward(self, x):
b, n, c = x.shape

x = rearrange(x, "b n c -> n (b c)")
x = knn_interpolate(x, self.pos_x, self.pos_y)
x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w)

x = self.batcher(x)
x = self.image_meta_model(x)
x = self.debatcher(x)

x = rearrange(x, "b c h w -> (h w) (b c)")
x = knn_interpolate(x, self.pos_y, self.pos_x)
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)

return x

0 comments on commit 04d4776

Please sign in to comment.