diff --git a/lib/bumblebee/utils/tokenizers.ex b/lib/bumblebee/utils/tokenizers.ex index bb78041f..505478dc 100644 --- a/lib/bumblebee/utils/tokenizers.ex +++ b/lib/bumblebee/utils/tokenizers.ex @@ -151,15 +151,15 @@ defmodule Bumblebee.Utils.Tokenizers do |> Nx.reshape({length(list), :auto}) end - def decode(tokenizer, [id | _] = ids) when is_number(id) do - case Tokenizer.decode(tokenizer, ids) do + def decode(tokenizer, [ids | _] = batch_ids) when is_list(ids) do + case Tokenizer.decode_batch(tokenizer, batch_ids) do {:ok, decoded} -> decoded {:error, term} -> raise "decoding failed with error: #{inspect(term)}" end end - def decode(tokenizer, batch_ids) do - case Tokenizer.decode_batch(tokenizer, batch_ids) do + def decode(tokenizer, ids) do + case Tokenizer.decode(tokenizer, ids) do {:ok, decoded} -> decoded {:error, term} -> raise "decoding failed with error: #{inspect(term)}" end