Skip to content

Commit

Permalink
init flamingo embeds new weights
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-awadalla committed Sep 17, 2023
1 parent 4875822 commit 8f2f040
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions open_flamingo/src/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ def __init__(
_weight=self.lang_model.get_input_embeddings().weight,
pad_token_id=self.pad_token_id,
)
input_embeds.additional_fc.weight.data.normal_(
mean=0.0, std=self.lang_model.config.initializer_range
)
if hasattr(input_embeds, "additional_embedding"):
input_embeds.additional_embedding.weight.data.normal_(
mean=0.0, std=self.lang_model.config.initializer_range
)

self.lang_model.set_input_embeddings(input_embeds)

Expand All @@ -76,9 +77,10 @@ def __init__(
_weight=self.lang_model.get_output_embeddings().weight,
_bias=self.lang_model.get_output_embeddings().bias,
)
out_embeds.additional_fc.weight.data.normal_(
mean=0.0, std=self.lang_model.config.initializer_range
)
if hasattr(out_embeds, "additional_fc"):
out_embeds.additional_fc.weight.data.normal_(
mean=0.0, std=self.lang_model.config.initializer_range
)
self.lang_model.set_output_embeddings(out_embeds)

# gradient checkpointing
Expand Down

0 comments on commit 8f2f040

Please sign in to comment.