Skip to content

Commit

Permalink
fix resampler projection
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-awadalla authored Mar 21, 2024
1 parent 292afa1 commit a72c96b
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions open_flamingo/src/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
"""
Perceiver module which takes in image features and outputs image tokens.
Args:
dim (int): final dimension of the incoming image features
dim (int): dimension of the incoming image features
dim_inner (int, optional): final dimension to project the incoming image features to;
also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
depth (int, optional): number of layers. Defaults to 6.
Expand All @@ -124,17 +124,17 @@ def __init__(
else:
projection = None
dim_inner = dim
super().__init__(dim_media=dim_inner, num_tokens_per_media=num_latents)
super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
self.projection = projection
self.latents = nn.Parameter(torch.randn(num_latents, dim_inner))
self.latents = nn.Parameter(torch.randn(num_latents, dim))
# positional embeddings
self.frame_embs = (
nn.Parameter(torch.randn(max_num_frames, dim_inner))
nn.Parameter(torch.randn(max_num_frames, dim))
if exists(max_num_frames)
else None
)
self.media_time_embs = (
nn.Parameter(torch.randn(max_num_media, 1, dim_inner))
nn.Parameter(torch.randn(max_num_media, 1, dim))
if exists(max_num_media)
else None
)
Expand All @@ -145,14 +145,14 @@ def __init__(
nn.ModuleList(
[
PerceiverAttention(
dim=dim_inner, dim_head=dim_head, heads=heads
dim=dim, dim_head=dim_head, heads=heads
),
FeedForward(dim=dim_inner, mult=ff_mult),
FeedForward(dim=dim, mult=ff_mult),
]
)
)

self.norm = nn.LayerNorm(dim_inner)
self.norm = nn.LayerNorm(dim)

def forward(self, x):
"""
Expand All @@ -164,9 +164,6 @@ def forward(self, x):
"""
b, T, F, v = x.shape[:4]

if exists(self.projection):
x = self.projection(x)

# frame and media time embeddings
if exists(self.frame_embs):
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
Expand All @@ -182,7 +179,11 @@ def forward(self, x):
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
return self.norm(latents)

if exists(self.projection):
return self.projection(self.norm(latents))
else:
self.norm(latents)

class LinearPatchProjection(VisionTokenizer):
"""Linear projection from patch features to image tokens."""
Expand Down

0 comments on commit a72c96b

Please sign in to comment.