Skip to content

Commit

Permalink
Use vectorized implementation for batched choice
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Aug 1, 2023
1 parent 827f77d commit a756115
Showing 1 changed file with 7 additions and 25 deletions.
32 changes: 7 additions & 25 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -855,31 +855,13 @@ defmodule Bumblebee.Text.Generation do

keys = Nx.Random.split(key, parts: batch_size)

tokens =
for i <- 0..(batch_size - 1) do
probabilities = scores[i]
{token, _} = Nx.Random.choice(keys[i], vocab, probabilities, samples: 1)
token
end

Nx.concatenate(tokens, axis: 0)
end

# TODO: once vectorization is in
# deftransformp batched_choice(key, scores) do
# {batch_size, vocab_size} = Nx.shape(scores)
key = Nx.vectorize(keys, :batch)
probabilities = Nx.vectorize(scores, :batch)

# vocab = Nx.iota({vocab_size})
{tokens, _} = Nx.Random.choice(key, vocab, probabilities, samples: 1)

# keys = Nx.Random.split(key, parts: batch_size)

# key = Nx.vectorize(keys, :batch)
# probabilities = Nx.vectorize(scores, :batch)

# {tokens, _} = Nx.Random.choice(key, vocab, probabilities, samples: 1)

# tokens
# |> Nx.squeeze()
# |> Nx.devectorize()
# end
tokens
|> Nx.squeeze()
|> Nx.devectorize()
end
end

0 comments on commit a756115

Please sign in to comment.