Skip to content

Commit

Permalink
Fixing hf_generate bug to account for pre-tokenization (#387)
Browse files Browse the repository at this point in the history
* adding fix to trim gen correctly before printing

* adding warning if total_output_tokens=0

* modifying debug statement to include special tokens

* formatting fixes from pre-commit

* updates from code review
  • Loading branch information
ksreenivasan committed Jun 29, 2023
1 parent 73783b9 commit 8746a73
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions scripts/inference/hf_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from argparse import ArgumentParser, ArgumentTypeError, Namespace
from contextlib import nullcontext

import numpy as np
import torch
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
pipeline)
Expand Down Expand Up @@ -318,17 +319,37 @@ def _generate(encoded_inp):

# Print generations
delimiter = '#' * 100
for prompt, gen in zip(batch, decoded_gen):
continuation = gen[len(prompt):]
# decode the encoded prompt to handle the case when the tokenizer
# trims extra spaces or does other pre-tokenization things
effective_prompts = tokenizer.batch_decode(encoded_inp['input_ids'],
skip_special_tokens=True)
for idx, (effective_prompt, prompt, gen) in enumerate(
zip(effective_prompts, batch, decoded_gen)):
continuation = gen[len(effective_prompt):]
print(delimiter)
print('\033[92m' + prompt + '\033[0m' + continuation)
if len(continuation) > 0:
print('\033[92m' + prompt + '\033[0m' + continuation)
else:
print('Warning. No non-special output tokens generated.')
print(
'This can happen if the generation only contains padding/eos tokens.'
)
print('Debug:')
full_generation = tokenizer.batch_decode(
encoded_gen, skip_special_tokens=False)[idx]
print('\033[92m' + 'Prompt:\n' + prompt + '\033[0m')
print('Full generation:\n' + full_generation)

print(delimiter)

# Print timing info
bs = len(batch)
# ensure that gen_tokens >= 1 in case model only generated padding tokens
gen_tokens = np.maximum(gen_tokens, np.ones_like(gen_tokens))
output_tokens = gen_tokens - input_tokens
total_input_tokens = input_tokens.sum()
total_output_tokens = output_tokens.sum()

encode_latency = 1000 * (encode_end - encode_start)
gen_latency = 1000 * (gen_end - gen_start)
decode_latency = 1000 * (decode_end - decode_start)
Expand Down

0 comments on commit 8746a73

Please sign in to comment.