Skip to content

Commit

Permalink
Add batched_next_token() and batched_sample() (#1693)
Browse files Browse the repository at this point in the history
Co-authored-by: Sebastian Raschka <[email protected]>
  • Loading branch information
apaz-cli and rasbt authored Aug 28, 2024
1 parent f655f01 commit 3c0c479
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 2 deletions.
39 changes: 37 additions & 2 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,43 @@ def sample(

def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
logits = model(x, input_pos)
next = sample(logits, **kwargs)
return next.to(dtype=x.dtype)
_next = sample(logits, **kwargs).to(dtype=torch.int64)
return _next

def batched_sample(logits: list[torch.Tensor], kwargs: list[dict]) -> torch.Tensor:
assert len(logits) == len(kwargs), "logits and kwargs must have the same length."
return torch.stack([sample(l, **sample_args).to(dtype=torch.int64) for sample_args, l in zip(kwargs, logits)], dim=0)

def batched_next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, kwargs: Union[dict, list[dict]]) -> torch.Tensor:
# Where:
# input_pos is a 1d tensor of shape [seq_length...]
# x is context tokens to add to the kvcache.
# For prefill, x is a 2d tensor of shape [batch_size, prompt_length].
# For subsequent tokens, x is a 2d tensor of shape [batch_size, 1].
# kwargs is a list of dictionaries, each containing the keyword arguments for the sample function.
# If one dictionary is passed, it's repeated for each sample in the batch.

# In the future, we would like input_pos to be a 2d tensor of shape [batch_size, seq_length].
# That way, we can support prompts of different sizes.
# This means making the rope cache and kvcache forward() work with batches. Currently, they do not.
# This is relatively complicated, given the current implementation. It will require some rewriting.
# Relevant thread: https://discuss.pytorch.org/t/batched-index-select/9115
# We will also need the same with tensor.index_copy_(). These do not work for batches, and the replacement
# is somewhat nontrivial. Until then, we can only accept prompts that are all the same length.
# After this problem is resolved, there will be another problem. That being, continuous batched prefill.
# If you have any ideas on this, let me know. I don't think that padding input_pos is viable.

_kwargs = kwargs if isinstance(kwargs, list) else [kwargs] * x.size(0)

# Run the model on the batch.
logits_stack = model(x, input_pos)

# Unbind the logits stack into a list of logits.
logits_list = [logits_stack] if logits_stack.ndim == 1 else logits_stack.unbind(0)
logits_list = [l.unsqueeze(0) for l in logits_list]

# Return the next token for each sample in the batch.
return batched_sample(logits_list, kwargs=_kwargs)


@torch.inference_mode()
Expand Down
65 changes: 65 additions & 0 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import pytest
import warnings
from pathlib import Path
from litgpt.generate.base import next_token, batched_next_token
from litgpt.api import LLM, GPT
from litgpt.scripts.download import download_from_hub

warnings.filterwarnings("ignore")

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires a GPU.")
def test_batched_equivalence(tmp_path):

model_name = "microsoft/phi-2"
download_from_hub(repo_id=model_name, tokenizer_only=True, checkpoint_dir=tmp_path)

device = "cuda:0"
batch_size = 3
sample_kwargs = {"top_k": 1}

llm: LLM = LLM.load(
model_name,
tokenizer_dir=Path(tmp_path / model_name),
init="random",
)
model: GPT = llm.model
model.set_kv_cache(batch_size=1, max_seq_length=50, device=device)

input_pos_1 = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device=device
)
input_pos_2 = torch.tensor([10], dtype=torch.int64, device=device)

x = torch.tensor(
[43993, 25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410],
device=device,
dtype=torch.int64,
)

batch_x1 = torch.stack([x] * batch_size, dim=0)

# Single token generation baseline
tok_1 = next_token(model, input_pos_1, x.unsqueeze(0), **sample_kwargs)
tok_2 = next_token(model, input_pos_2, tok_1.unsqueeze(0), **sample_kwargs)

assert tok_1.ndim == 1
assert tok_2.ndim == 1
assert tok_1.size(0) == 1
assert tok_2.size(0) == 1

# Switch to batched generation
model.clear_kv_cache()
model.set_kv_cache(batch_size=batch_size, max_seq_length=50, device="cuda:0")

toks_1: torch.Tensor = batched_next_token(model, input_pos_1, batch_x1, sample_kwargs)
toks_2: torch.Tensor = batched_next_token(model, input_pos_2, toks_1, sample_kwargs)

assert toks_1.ndim == 2
assert toks_2.ndim == 2
assert toks_1.size(0) == batch_size
assert toks_2.size(0) == batch_size

# Assert that single and batched next token generation are equivalent
assert all(t == tok_1 for t in toks_1), f"{tok_1} != {toks_1}"
assert all(t == tok_2 for t in toks_2), f"{tok_2} != {toks_2}"

0 comments on commit 3c0c479

Please sign in to comment.