Skip to content

Commit

Permalink
Add DINOv2 model (#334)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonatan Kłosko <[email protected]>
  • Loading branch information
joelpaulkoch and jonatanklosko committed Feb 21, 2024
1 parent 5b3d7ac commit f739e0a
Show file tree
Hide file tree
Showing 11 changed files with 952 additions and 73 deletions.
6 changes: 5 additions & 1 deletion lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ defmodule Bumblebee do
{Bumblebee.Vision.Deit, :for_image_classification_with_teacher},
"DeiTForMaskedImageModeling" => {Bumblebee.Vision.Deit, :for_masked_image_modeling},
"DeiTModel" => {Bumblebee.Vision.Deit, :base},
"Dinov2Model" => {Bumblebee.Vision.DinoV2, :base},
"Dinov2Backbone" => {Bumblebee.Vision.DinoV2, :backbone},
"Dinov2ForImageClassification" => {Bumblebee.Vision.DinoV2, :for_image_classification},
"DistilBertModel" => {Bumblebee.Text.Distilbert, :base},
"DistilBertForMaskedLM" => {Bumblebee.Text.Distilbert, :for_masked_language_modeling},
"DistilBertForSequenceClassification" =>
Expand Down Expand Up @@ -203,7 +206,8 @@ defmodule Bumblebee do
}

@transformers_image_processor_type_to_featurizer %{
"BlipImageProcessor" => Bumblebee.Vision.BlipFeaturizer
"BlipImageProcessor" => Bumblebee.Vision.BlipFeaturizer,
"BitImageProcessor" => Bumblebee.Vision.BitFeaturizer
}

@model_type_to_featurizer %{
Expand Down
8 changes: 3 additions & 5 deletions lib/bumblebee/diffusion/layers/unet.ex
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ defmodule Bumblebee.Diffusion.Layers.UNet do
epsilon: 1.0e-5
],
dropout_rate: dropout,
ffn: &ffn_geglu(&1, hidden_size, dropout: dropout, name: &2),
ffn: &ffn_geglu(&1, 4 * hidden_size, hidden_size, dropout: dropout, name: &2),
block_type: :norm_first,
name: join(name, "blocks")
)
Expand All @@ -347,12 +347,10 @@ defmodule Bumblebee.Diffusion.Layers.UNet do
end

# A feed-forward network with GEGLU nonlinearity as in https://arxiv.org/abs/2002.05202
defp ffn_geglu(x, size, opts) do
defp ffn_geglu(x, intermediate_size, output_size, opts) do
name = opts[:name]
dropout = opts[:dropout] || 0.0

intermediate_size = 4 * size

{x, gate} =
x
|> Axon.dense(intermediate_size * 2, name: join(name, "intermediate"))
Expand All @@ -362,6 +360,6 @@ defmodule Bumblebee.Diffusion.Layers.UNet do

x
|> Axon.dropout(rate: dropout, name: join(name, "dropout"))
|> Axon.dense(size, name: join(name, "output"))
|> Axon.dense(output_size, name: join(name, "output"))
end
end
104 changes: 43 additions & 61 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ defmodule Bumblebee.Layers.Transformer do
* `:parallel` - block with attention and FFN independently (in parallel).
This type doesn't support cross-attention
Alternatively a custom 3-arity function may be given. The function
receives the input hidden state, a map with block steps and a
name to prefix any additional layers.
* `:scale_attention_weights` - whether to scale query in the traditional style of
multi-headed attention. Defaults to `true`
Expand Down Expand Up @@ -469,17 +473,25 @@ defmodule Bumblebee.Layers.Transformer do

ffn = &ffn_fun.(&1, join(name, "ffn"))

block_impl =
case block_type do
type when is_atom(type) -> &block_impl(type, &1, &2, &3)
fun when is_function(fun) -> fun
end

{hidden_state, attention_info, cross_attention_info} =
block_impl(
block_type,
block_impl.(
hidden_state,
self_attention_norm,
self_attention,
cross_attention_maybe,
cross_attention_norm,
cross_attention,
output_norm,
ffn
%{
self_attention_norm: self_attention_norm,
self_attention: self_attention,
cross_attention_maybe: cross_attention_maybe,
cross_attention_norm: cross_attention_norm,
cross_attention: cross_attention,
output_norm: output_norm,
ffn: ffn
},
name
)

{attention, self_attention_cache, attention_relative_bias} = attention_info
Expand All @@ -495,36 +507,26 @@ defmodule Bumblebee.Layers.Transformer do
{hidden_state, attention, cross_attention, block_cache, attention_relative_bias}
end

defp block_impl(
:standard,
hidden_state,
self_attention_norm,
self_attention,
cross_attention_maybe,
cross_attention_norm,
cross_attention,
output_norm,
ffn
) do
defp block_impl(:standard, hidden_state, steps, _name) do
shortcut = hidden_state

{hidden_state, attention_info} = self_attention.(hidden_state)
{hidden_state, attention_info} = steps.self_attention.(hidden_state)

hidden_state =
hidden_state
|> Axon.add(shortcut)
|> self_attention_norm.()
|> steps.self_attention_norm.()

{hidden_state, cross_attention_info} =
cross_attention_maybe.(hidden_state, fn hidden_state ->
steps.cross_attention_maybe.(hidden_state, fn hidden_state ->
shortcut = hidden_state

{hidden_state, cross_attention_info} = cross_attention.(hidden_state)
{hidden_state, cross_attention_info} = steps.cross_attention.(hidden_state)

hidden_state =
hidden_state
|> Axon.add(shortcut)
|> cross_attention_norm.()
|> steps.cross_attention_norm.()

{hidden_state, cross_attention_info}
end)
Expand All @@ -533,41 +535,31 @@ defmodule Bumblebee.Layers.Transformer do

hidden_state =
hidden_state
|> ffn.()
|> steps.ffn.()
|> Axon.add(shortcut)
|> output_norm.()
|> steps.output_norm.()

{hidden_state, attention_info, cross_attention_info}
end

defp block_impl(
:norm_first,
hidden_state,
self_attention_norm,
self_attention,
cross_attention_maybe,
cross_attention_norm,
cross_attention,
output_norm,
ffn
) do
defp block_impl(:norm_first, hidden_state, steps, _name) do
shortcut = hidden_state

{hidden_state, attention_info} =
hidden_state
|> self_attention_norm.()
|> self_attention.()
|> steps.self_attention_norm.()
|> steps.self_attention.()

hidden_state = Axon.add(hidden_state, shortcut)

{hidden_state, cross_attention_info} =
cross_attention_maybe.(hidden_state, fn hidden_state ->
steps.cross_attention_maybe.(hidden_state, fn hidden_state ->
shortcut = hidden_state

{hidden_state, cross_attention_info} =
hidden_state
|> cross_attention_norm.()
|> cross_attention.()
|> steps.cross_attention_norm.()
|> steps.cross_attention.()

hidden_state = Axon.add(hidden_state, shortcut)

Expand All @@ -578,40 +570,30 @@ defmodule Bumblebee.Layers.Transformer do

hidden_state =
hidden_state
|> output_norm.()
|> ffn.()
|> steps.output_norm.()
|> steps.ffn.()
|> Axon.add(shortcut)

{hidden_state, attention_info, cross_attention_info}
end

defp block_impl(
:parallel,
hidden_state,
self_attention_norm,
self_attention,
cross_attention_maybe,
_cross_attention_norm,
_cross_attention,
output_norm,
ffn
) do
defp block_impl(:parallel, hidden_state, steps, _name) do
shortcut = hidden_state

{attention_hidden_state, attention_info} =
hidden_state
|> self_attention_norm.()
|> self_attention.()
|> steps.self_attention_norm.()
|> steps.self_attention.()

{_hidden_state, cross_attention_info} =
cross_attention_maybe.(hidden_state, fn _hidden_state ->
steps.cross_attention_maybe.(hidden_state, fn _hidden_state ->
raise "cross attention not supported"
end)

ffn_hidden_state =
hidden_state
|> output_norm.()
|> ffn.()
|> steps.output_norm.()
|> steps.ffn.()

hidden_state = Axon.add([shortcut, attention_hidden_state, ffn_hidden_state])

Expand Down
Loading

0 comments on commit f739e0a

Please sign in to comment.