From 01548880d4bd5d60b58fc1a22956432b89ca59b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Thu, 3 Aug 2023 17:15:24 +0200 Subject: [PATCH] Fix decoding an empty sequence --- lib/bumblebee/utils/tokenizers.ex | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/bumblebee/utils/tokenizers.ex b/lib/bumblebee/utils/tokenizers.ex index bb78041f..505478dc 100644 --- a/lib/bumblebee/utils/tokenizers.ex +++ b/lib/bumblebee/utils/tokenizers.ex @@ -151,15 +151,15 @@ defmodule Bumblebee.Utils.Tokenizers do |> Nx.reshape({length(list), :auto}) end - def decode(tokenizer, [id | _] = ids) when is_number(id) do - case Tokenizer.decode(tokenizer, ids) do + def decode(tokenizer, [ids | _] = batch_ids) when is_list(ids) do + case Tokenizer.decode_batch(tokenizer, batch_ids) do {:ok, decoded} -> decoded {:error, term} -> raise "decoding failed with error: #{inspect(term)}" end end - def decode(tokenizer, batch_ids) do - case Tokenizer.decode_batch(tokenizer, batch_ids) do + def decode(tokenizer, ids) do + case Tokenizer.decode(tokenizer, ids) do {:ok, decoded} -> decoded {:error, term} -> raise "decoding failed with error: #{inspect(term)}" end