Skip to content

Commit

Permalink
Add phi model
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Mar 1, 2024
1 parent 4c8bb21 commit 9e2441d
Show file tree
Hide file tree
Showing 5 changed files with 453 additions and 5 deletions.
3 changes: 3 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ defmodule Bumblebee do
"MistralModel" => {Bumblebee.Text.Mistral, :base},
"MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling},
"MistralForSequenceClassification" => {Bumblebee.Text.Mistral, :for_sequence_classification},
"PhiModel" => {Bumblebee.Text.Phi, :base},
"PhiForCausalLM" => {Bumblebee.Text.Phi, :for_causal_language_modeling},
"PhiForSequenceClassification" => {Bumblebee.Text.Phi, :for_sequence_classification},
"ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification},
"ResNetModel" => {Bumblebee.Vision.ResNet, :base},
"RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
Expand Down
8 changes: 4 additions & 4 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,11 @@ defmodule Bumblebee.Layers.Transformer do
* `:max_positions` - the maximum number of distinct positions
* `:rotary_embedding_base` - base for computing rotary embedding frequency. Defaults
to `10_000`.
* `:base` - base for computing rotary embedding frequency. Defaults
to `10_000`.
* `:rotary_percentage` - percentage of hidden dimensions to allocate to rotary embeddings.
Defaults to `1.0`.
* `:percentage` - percentage of hidden dimensions to allocate to rotary embeddings.
Defaults to `1.0`.
* `:name` - the prefix for layer names
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/text/mistral.ex
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ defmodule Bumblebee.Text.Mistral do

gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false)

hidden_state = Axon.multiply(intermediate, Axon.activation(gate, activation))
hidden_state = Axon.multiply(intermediate, Layers.activation(gate, activation))

Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false)
end
Expand Down
Loading

0 comments on commit 9e2441d

Please sign in to comment.