From a72c96b93dfb395b1b5d40ce8a2c9e3fd9d9182f Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Wed, 20 Mar 2024 18:13:33 -0700 Subject: [PATCH] fix resampler projection --- open_flamingo/src/helpers.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/open_flamingo/src/helpers.py b/open_flamingo/src/helpers.py index 8db95307..66c54b04 100644 --- a/open_flamingo/src/helpers.py +++ b/open_flamingo/src/helpers.py @@ -105,7 +105,7 @@ def __init__( """ Perceiver module which takes in image features and outputs image tokens. Args: - dim (int): final dimension of the incoming image features + dim (int): dimension of the incoming image features dim_inner (int, optional): final dimension to project the incoming image features to; also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim. depth (int, optional): number of layers. Defaults to 6. @@ -124,17 +124,17 @@ def __init__( else: projection = None dim_inner = dim - super().__init__(dim_media=dim_inner, num_tokens_per_media=num_latents) + super().__init__(dim_media=dim, num_tokens_per_media=num_latents) self.projection = projection - self.latents = nn.Parameter(torch.randn(num_latents, dim_inner)) + self.latents = nn.Parameter(torch.randn(num_latents, dim)) # positional embeddings self.frame_embs = ( - nn.Parameter(torch.randn(max_num_frames, dim_inner)) + nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None ) self.media_time_embs = ( - nn.Parameter(torch.randn(max_num_media, 1, dim_inner)) + nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None ) @@ -145,14 +145,14 @@ def __init__( nn.ModuleList( [ PerceiverAttention( - dim=dim_inner, dim_head=dim_head, heads=heads + dim=dim, dim_head=dim_head, heads=heads ), - FeedForward(dim=dim_inner, mult=ff_mult), + FeedForward(dim=dim, mult=ff_mult), ] ) ) - self.norm = nn.LayerNorm(dim_inner) + self.norm = nn.LayerNorm(dim) def forward(self, x): """ @@ -164,9 +164,6 @@ def forward(self, x): """ b, T, F, v = x.shape[:4] - if exists(self.projection): - x = self.projection(x) - # frame and media time embeddings if exists(self.frame_embs): frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) @@ -182,7 +179,11 @@ def forward(self, x): for attn, ff in self.layers: latents = attn(x, latents) + latents latents = ff(latents) + latents - return self.norm(latents) + + if exists(self.projection): + return self.projection(self.norm(latents)) + else: + self.norm(latents) class LinearPatchProjection(VisionTokenizer): """Linear projection from patch features to image tokens."""