Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/top p sampling #1360

Merged
merged 19 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def generate(
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
stop_tokens: Tuple[List[int], ...] = (),
) -> Iterator[torch.Tensor]:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as possible.
Expand All @@ -33,7 +34,12 @@ def generate(
prompt: Tensor of shape (T) with indices of the prompt sequence.
max_returned_tokens: The maximum number of tokens to return (given plus generated).
temperature: Scales the predicted logits by 1 / temperature
top_k: If specified, only sample among the tokens with the k highest probabilities
top_k: If specified, only sample among the tokens with the k highest probabilities.
top_p: The cumulative probability threshold to consider in the sampling process.
belerico marked this conversation as resolved.
Show resolved Hide resolved
In top-p sampling the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top-p`.
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
stop_tokens: If specified, stop generating any more token once one of this list is generated.
"""
T = prompt.size(0)
Expand All @@ -51,7 +57,7 @@ def generate(
tokens = []
token = prompt
for t in range(1, max_returned_tokens - T + 1):
token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k)
token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p)
tokens.append(token)
# check the stop condition
if any((l := len(st)) <= len(tokens) and all(a == b for a, b in zip(tokens[-l:], st)) for st in stop_tokens):
Expand Down Expand Up @@ -99,6 +105,7 @@ def decode(fabric: L.Fabric, tokenizer: Tokenizer, token_stream: Iterator[torch.
def main(
*,
top_k: Optional[int] = 200,
top_p: Optional[float] = None,
temperature: float = 0.8,
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-tuned-alpha-3b"),
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
Expand All @@ -109,6 +116,11 @@ def main(

Args:
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: The cumulative probability threshold to consider in the sampling process.
In top-p sampling the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top-p`.
belerico marked this conversation as resolved.
Show resolved Hide resolved
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
checkpoint_dir: The checkpoint directory to load.
Expand Down Expand Up @@ -175,7 +187,7 @@ def main(
prompt = prompt_style.apply(prompt=prompt)
encoded_prompt = tokenizer.encode(prompt, device=fabric.device)
y = generate(
model, encoded_prompt, model.max_seq_length, temperature=temperature, top_k=top_k, stop_tokens=stop_tokens
model, encoded_prompt, model.max_seq_length, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens
)
fabric.print(">> Reply: ", end="")
t0 = time.perf_counter()
Expand Down
8 changes: 7 additions & 1 deletion litgpt/generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def main(
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
max_new_tokens: int = 100,
top_k: Optional[int] = 50,
top_p: Optional[float] = None,
temperature: float = 0.8,
precision: Optional[str] = None,
) -> None:
Expand All @@ -42,6 +43,11 @@ def main(
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: The cumulative probability threshold to consider in the sampling process.
In top-p sampling the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top-p`.
belerico marked this conversation as resolved.
Show resolved Hide resolved
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
precision: Indicates the Fabric precision setting to use.
Expand Down Expand Up @@ -97,7 +103,7 @@ def main(

L.seed_everything(1234)
t0 = time.perf_counter()
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id)
t = time.perf_counter() - t0

output = tokenizer.decode(y)
Expand Down
6 changes: 6 additions & 0 deletions litgpt/generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def main(
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
max_new_tokens: int = 100,
top_k: Optional[int] = 50,
top_p: Optional[float] = None,
temperature: float = 0.8,
precision: Optional[str] = None,
) -> None:
Expand All @@ -42,6 +43,11 @@ def main(
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: The cumulative probability threshold to consider in the sampling process.
In top-p sampling the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top-p`.
belerico marked this conversation as resolved.
Show resolved Hide resolved
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
precision: Indicates the Fabric precision setting to use.
Expand Down
36 changes: 32 additions & 4 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,28 @@ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
return torch.multinomial(probs, num_samples=1)


def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None) -> torch.Tensor:
def sample(
logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None
) -> torch.Tensor:
logits = logits[0, -1]
# optionally crop the logits to only the top k options
if top_k is not None:
v, i = torch.topk(logits, min(top_k, logits.size(-1)))
# do not use `torch.where` as in nanogpt because it will repeat top-k collisions
logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
# optionally crop the logits to smallest set of logits with a cumulative probability above top_p
if top_p is not None:
belerico marked this conversation as resolved.
Show resolved Hide resolved
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Example:
# sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
# sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least 1 token always to prevent the case where no token is selected
# In this case the most probable one is always kept
sorted_indices_to_remove[-1:] = 0
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, float("-inf"))
# optionally scale the logits and sample from a probability distribution
if temperature > 0.0:
probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
Expand All @@ -52,6 +67,7 @@ def generate(
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
eos_id: Optional[int] = None,
) -> torch.Tensor:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
Expand All @@ -64,6 +80,11 @@ def generate(
max_returned_tokens: The maximum number of tokens to return (given plus generated).
temperature: Scales the predicted logits by 1 / temperature.
top_k: If specified, only sample among the tokens with the k highest probabilities.
top_p: The cumulative probability threshold to consider in the sampling process.
In top-p sampling the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top-p`.
belerico marked this conversation as resolved.
Show resolved Hide resolved
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
eos_id: If specified, stop generating any more token once the <eos> token is triggered.
"""
T = prompt.size(0)
Expand All @@ -78,11 +99,13 @@ def generate(
tokens = [prompt]
input_pos = torch.tensor([T], device=device)
token = next_token(
model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k
model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p
).clone()
tokens.append(token)
for _ in range(2, max_returned_tokens - T + 1):
token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k).clone()
token = next_token(
model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p
).clone()
tokens.append(token)
if token == eos_id:
break
Expand All @@ -97,6 +120,7 @@ def main(
num_samples: int = 1,
max_new_tokens: int = 50,
top_k: Optional[int] = 50,
top_p: Optional[float] = None,
temperature: float = 0.8,
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
Expand All @@ -110,6 +134,10 @@ def main(
num_samples: The number of text samples to generate.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: The cumulative probability threshold to consider in the sampling process.
In top-p sampling the smallest set of tokens whose cumulative probability doesn't
exceed the threshold `top_p` is selected. For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
checkpoint_dir: The checkpoint directory to load.
Expand Down Expand Up @@ -175,7 +203,7 @@ def main(
L.seed_everything(1234)
for i in range(num_samples):
t0 = time.perf_counter()
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
y = generate( model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id)
rasbt marked this conversation as resolved.
Show resolved Hide resolved
t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
Expand Down
8 changes: 7 additions & 1 deletion litgpt/generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def main(
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
max_new_tokens: int = 100,
top_k: Optional[int] = 50,
top_p: Optional[float] = None,
temperature: float = 0.8,
precision: Optional[str] = None,
) -> None:
Expand All @@ -41,6 +42,11 @@ def main(
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: The cumulative probability threshold to consider in the sampling process.
In top-p sampling the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top-p`.
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
precision: Indicates the Fabric precision setting to use.
Expand Down Expand Up @@ -93,7 +99,7 @@ def main(

L.seed_everything(1234)
t0 = time.perf_counter()
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id)
t = time.perf_counter() - t0

output = tokenizer.decode(y)
Expand Down
8 changes: 7 additions & 1 deletion litgpt/generate/sequentially.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def main(
num_samples: int = 1,
max_new_tokens: int = 50,
top_k: Optional[int] = 50,
top_p: Optional[float] = None,
temperature: float = 0.8,
checkpoint_dir: Path = Path("checkpoints/mistralai/Mistral-7B-Instruct-v0.1"),
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq"]] = None,
Expand All @@ -130,6 +131,11 @@ def main(
num_samples: The number of text samples to generate.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: The cumulative probability threshold to consider in the sampling process.
In top-p sampling the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top-p`.
belerico marked this conversation as resolved.
Show resolved Hide resolved
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
checkpoint_dir: The checkpoint directory to load.
Expand Down Expand Up @@ -206,7 +212,7 @@ def main(
for i in range(num_samples):
t0 = time.perf_counter()
y = generate_base.generate(
model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id
model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id
)
t = time.perf_counter() - t0
for block in model.transformer.h:
Expand Down
6 changes: 6 additions & 0 deletions litgpt/generate/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def main(
num_samples: int = 1,
max_new_tokens: int = 50,
top_k: Optional[int] = 50,
top_p: Optional[float] = None,
temperature: float = 0.8,
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq"]] = None,
Expand All @@ -108,6 +109,11 @@ def main(
num_samples: The number of text samples to generate.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: The cumulative probability threshold to consider in the sampling process.
In top-p sampling the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top-p`.
belerico marked this conversation as resolved.
Show resolved Hide resolved
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
checkpoint_dir: The checkpoint directory to load.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like):
num_samples = 2
out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
generate.main(temperature=2.0, top_k=2, num_samples=num_samples, checkpoint_dir=fake_checkpoint_dir)
generate.main(temperature=2.0, top_k=2, top_p=0.9, num_samples=num_samples, checkpoint_dir=fake_checkpoint_dir)

assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_sample(temperature):
[[85, 79, 57, 68, 50], [89, 46, 72, 45, 32], [68, 96, 68, 24, 36]],
]
)
token = sample(logits, temperature=temperature)
token = sample(logits.float(), temperature=temperature, top_p=0.8)
belerico marked this conversation as resolved.
Show resolved Hide resolved

assert token.shape == (1,)
# sample is batch size 1 only for now - this should be [0, 1] once batched generation is supported
Expand Down
Loading