-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Sebastian Raschka <[email protected]>
- Loading branch information
Showing
2 changed files
with
102 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
import pytest | ||
import warnings | ||
from pathlib import Path | ||
from litgpt.generate.base import next_token, batched_next_token | ||
from litgpt.api import LLM, GPT | ||
from litgpt.scripts.download import download_from_hub | ||
|
||
warnings.filterwarnings("ignore") | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires a GPU.") | ||
def test_batched_equivalence(tmp_path): | ||
|
||
model_name = "microsoft/phi-2" | ||
download_from_hub(repo_id=model_name, tokenizer_only=True, checkpoint_dir=tmp_path) | ||
|
||
device = "cuda:0" | ||
batch_size = 3 | ||
sample_kwargs = {"top_k": 1} | ||
|
||
llm: LLM = LLM.load( | ||
model_name, | ||
tokenizer_dir=Path(tmp_path / model_name), | ||
init="random", | ||
) | ||
model: GPT = llm.model | ||
model.set_kv_cache(batch_size=1, max_seq_length=50, device=device) | ||
|
||
input_pos_1 = torch.tensor( | ||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device=device | ||
) | ||
input_pos_2 = torch.tensor([10], dtype=torch.int64, device=device) | ||
|
||
x = torch.tensor( | ||
[43993, 25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410], | ||
device=device, | ||
dtype=torch.int64, | ||
) | ||
|
||
batch_x1 = torch.stack([x] * batch_size, dim=0) | ||
|
||
# Single token generation baseline | ||
tok_1 = next_token(model, input_pos_1, x.unsqueeze(0), **sample_kwargs) | ||
tok_2 = next_token(model, input_pos_2, tok_1.unsqueeze(0), **sample_kwargs) | ||
|
||
assert tok_1.ndim == 1 | ||
assert tok_2.ndim == 1 | ||
assert tok_1.size(0) == 1 | ||
assert tok_2.size(0) == 1 | ||
|
||
# Switch to batched generation | ||
model.clear_kv_cache() | ||
model.set_kv_cache(batch_size=batch_size, max_seq_length=50, device="cuda:0") | ||
|
||
toks_1: torch.Tensor = batched_next_token(model, input_pos_1, batch_x1, sample_kwargs) | ||
toks_2: torch.Tensor = batched_next_token(model, input_pos_2, toks_1, sample_kwargs) | ||
|
||
assert toks_1.ndim == 2 | ||
assert toks_2.ndim == 2 | ||
assert toks_1.size(0) == batch_size | ||
assert toks_2.size(0) == batch_size | ||
|
||
# Assert that single and batched next token generation are equivalent | ||
assert all(t == tok_1 for t in toks_1), f"{tok_1} != {toks_1}" | ||
assert all(t == tok_2 for t in toks_2), f"{tok_2} != {toks_2}" |