You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
diff --git a/t5/hf_t5.py b/t5/hf_t5.py
index 98c6da8..23d9644 100644
--- a/t5/hf_t5.py+++ b/t5/hf_t5.py@@ -23,11 +23,11 @@ def embed(t5_model: str):
def generate(t5_model: str):
- prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."+ prompt = "def print_hello_world():<extra_id_0>"
tokenizer = AutoTokenizer.from_pretrained(t5_model)
torch_model = AutoModelForSeq2SeqLM.from_pretrained(t5_model)
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
- outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512)+ outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=10)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Thanks for flagging. Indeed the way we do streaming decode in the T5 example is not correct for most tokenizers (you can't typically decode each new token individually as we do here). It should either be a proper streaming decoder or we just eat the quadratic cost and redecode the entire prefix.
Will mark this as a bug, should be a fairly simple fix.
The
hf_t5.py
can do correct output with changes:It seems that the tokenizer does not work well with streaming decoding.
The text was updated successfully, but these errors were encountered: