Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiline prompts #1279

Merged
merged 14 commits into from
May 10, 2024
90 changes: 65 additions & 25 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,51 @@ def decode(fabric: L.Fabric, tokenizer: Tokenizer, token_stream: Iterator[torch.
return tokens_generated


def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature, top_k, top_p, stop_tokens):
prompt = prompt_style.apply(prompt=prompt)
encoded_prompt = tokenizer.encode(prompt, device=fabric.device)
y = generate(
model, encoded_prompt, model.max_seq_length, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens
)
fabric.print(">> Reply: ", end="")
t0 = time.perf_counter()
tokens_generated = decode(fabric, tokenizer, y)
t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
fabric.print(
f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec,"
f" {tokens_generated} tokens",
file=sys.stderr,
)
fabric.print()


def interact(multiline_prompts, model, tokenizer, prompt_style, fabric, temperature, top_k, top_p, stop_tokens):
while True:
try:
if not multiline_prompts:
prompt = input(">> Prompt: ")
else:
print(">> Prompt: (Type '!submit' on a new line to end input).")
prompt_lines = []
while True:
line = input()
if line.strip().lower() in ("!submit", "!quit", "!exit"):
break
prompt_lines.append(line)
prompt = "\n".join(prompt_lines)

except KeyboardInterrupt:
break

prompt = prompt.lower().strip()
if not prompt or prompt in ("!quit", "!exit"):
break

process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature, top_k, top_p, stop_tokens)


@torch.inference_mode()
def main(
*,
Expand All @@ -120,6 +165,7 @@ def main(
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
precision: Optional[str] = None,
compile: bool = False,
multiline_prompts: bool = False,
rasbt marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Starts a conversation with a tuned GPT model.

Expand Down Expand Up @@ -148,6 +194,7 @@ def main(
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
compile: Whether to use compilation to speed up token generation. Will increase startup time.
multiline_prompts: Whether to support multiline input prompts.
"""
precision = precision or get_default_supported_precision(training=False)

Expand Down Expand Up @@ -193,29 +240,22 @@ def main(
)
stop_tokens = prompt_style.stop_tokens(tokenizer)

print(f"Now chatting with {config.name}.\nTo exit, press 'Enter' on an empty prompt.\n")
if multiline_prompts:
exit_instruction = "To exit, enter '!quit' or '!exit' on an empty prompt and press 'Enter'."
else:
exit_instruction = "To exit, press 'Enter' on an empty prompt."

print(f"Now chatting with {config.name}.\n{exit_instruction}\n")
L.seed_everything(1234)
while True:
try:
prompt = input(">> Prompt: ")
except KeyboardInterrupt:
break
if prompt.lower().strip() in ("", "quit", "exit"):
break
prompt = prompt_style.apply(prompt=prompt)
encoded_prompt = tokenizer.encode(prompt, device=fabric.device)
y = generate(
model, encoded_prompt, model.max_seq_length, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens
)
fabric.print(">> Reply: ", end="")
t0 = time.perf_counter()
tokens_generated = decode(fabric, tokenizer, y)
t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
fabric.print(
f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec,"
f" {tokens_generated} tokens",
file=sys.stderr,
)
fabric.print()

interact(
multiline_prompts=multiline_prompts,
model=model,
tokenizer=tokenizer,
prompt_style=prompt_style,
fabric=fabric,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_tokens=stop_tokens
)
5 changes: 5 additions & 0 deletions tutorials/0_to_litgpt.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,12 @@ Time for inference: 1.26 sec total, 27.81 tokens/sec, 35 tokens

>> Prompt:
```
 

> [!TIP]
> Use `--multiline_prompts true` to support prompts that require multiple input lines.

<br>

&nbsp;
**More information and additional resources**
Expand Down
4 changes: 4 additions & 0 deletions tutorials/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ litgpt chat --checkpoint_dir checkpoints/stabilityai/stablelm-tuned-alpha-3b
This script can work with any checkpoint. For the best chat-like experience, we recommend using it with a checkpoints
fine-tuned for chatting such as `stabilityai/stablelm-tuned-alpha-3b` or `togethercomputer/RedPajama-INCITE-Chat-3B-v1`.

> [!TIP]
> Use `--multiline_prompts` to work with inputs that span multiple lines.


## Run a large model on one smaller device

Check out our [quantization tutorial](quantize.md).
Expand Down