Skip to content

Commit

Permalink
test_wrapper_meta_model
Browse files Browse the repository at this point in the history
  • Loading branch information
rnwzd committed Jul 29, 2024
1 parent 04d4776 commit f72c610
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,35 @@ def test_meta_model():
assert out.size() == features.size()


def test_wrapper_meta_model():
lat_lons = []
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
lat_lons.append((lat, lon))

batch = 2
channels = 3
image_size = 20
patch_size = 4
scale_factor=3
model = MetaModel(
lat_lons,
image_size=image_size,
patch_size=patch_size,
depth=1,
heads=1,
mlp_dim=7,
channels=channels,
dim_head=64
)

big_features = torch.randn((batch, len(lat_lons), channels))
big_model = WrapperMetaModel(lat_lons, model, scale_factor)
out = big_model(big_features)

assert not torch.isnan(out).any()


def test_gencast_noise():
num_lon = 360
num_lat = 180
Expand Down

0 comments on commit f72c610

Please sign in to comment.