Skip to content

Commit

Permalink
Merge branch 'main' into eval-quickstart
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Jun 30, 2023
2 parents b08ef44 + 8746a73 commit 2187c67
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
7 changes: 7 additions & 0 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ class StreamingTextDataset(StreamingDataset):
smaller epoch size. Defaults to ``None``.
predownload (int, optional): Target number of samples ahead to download the shards of while
iterating. Defaults to ``100_000``.
cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
shard cache. Before downloading a shard, the least recently used resident shard(s) may
be evicted (deleted from the local cache) in order to stay under the limit. Set to None
to disable shard eviction. Supports integer bytes as well as string human-readable
bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. Defaults to ``None``, which is interpreted as the number of nodes of the
Expand All @@ -77,6 +82,7 @@ def __init__(self,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: int = 100_000,
cache_limit: Optional[Union[int, str]] = None,
partition_algo: str = 'orig',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
Expand Down Expand Up @@ -118,6 +124,7 @@ def __init__(self,
keep_zip=keep_zip,
epoch_size=epoch_size,
predownload=predownload,
cache_limit=cache_limit,
partition_algo=partition_algo,
num_canonical_nodes=num_canonical_nodes,
batch_size=batch_size,
Expand Down
27 changes: 24 additions & 3 deletions scripts/inference/hf_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from argparse import ArgumentParser, ArgumentTypeError, Namespace
from contextlib import nullcontext

import numpy as np
import torch
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
pipeline)
Expand Down Expand Up @@ -318,17 +319,37 @@ def _generate(encoded_inp):

# Print generations
delimiter = '#' * 100
for prompt, gen in zip(batch, decoded_gen):
continuation = gen[len(prompt):]
# decode the encoded prompt to handle the case when the tokenizer
# trims extra spaces or does other pre-tokenization things
effective_prompts = tokenizer.batch_decode(encoded_inp['input_ids'],
skip_special_tokens=True)
for idx, (effective_prompt, prompt, gen) in enumerate(
zip(effective_prompts, batch, decoded_gen)):
continuation = gen[len(effective_prompt):]
print(delimiter)
print('\033[92m' + prompt + '\033[0m' + continuation)
if len(continuation) > 0:
print('\033[92m' + prompt + '\033[0m' + continuation)
else:
print('Warning. No non-special output tokens generated.')
print(
'This can happen if the generation only contains padding/eos tokens.'
)
print('Debug:')
full_generation = tokenizer.batch_decode(
encoded_gen, skip_special_tokens=False)[idx]
print('\033[92m' + 'Prompt:\n' + prompt + '\033[0m')
print('Full generation:\n' + full_generation)

print(delimiter)

# Print timing info
bs = len(batch)
# ensure that gen_tokens >= 1 in case model only generated padding tokens
gen_tokens = np.maximum(gen_tokens, np.ones_like(gen_tokens))
output_tokens = gen_tokens - input_tokens
total_input_tokens = input_tokens.sum()
total_output_tokens = output_tokens.sum()

encode_latency = 1000 * (encode_end - encode_start)
gen_latency = 1000 * (gen_end - gen_start)
decode_latency = 1000 * (decode_end - decode_start)
Expand Down

0 comments on commit 2187c67

Please sign in to comment.