-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace transformers list of logits warpers by a fused logic warper (#…
…234) * doc(tgi): typo * feat(generate): add fused logits warper * feat(generate): use FusedLogitsWarper
- Loading branch information
Showing
5 changed files
with
199 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,88 @@ | ||
from dataclasses import dataclass | ||
from typing import Tuple | ||
|
||
import torch | ||
from transformers.generation import LogitsWarper | ||
from transformers import GenerationConfig | ||
|
||
|
||
@dataclass | ||
class FusedLogitsWarper: | ||
""" | ||
A class that performs top-k then top-p filtering, optionally applying a temperature. | ||
Top-k filtering only keeps the `k` tokens with the best scores. | ||
Top-p filtering only keeps the top tokens whose cumulated probability is above `p`. | ||
The filtered tokens are returned as a list of indices, along with the corresponding subset of | ||
the original logits. | ||
If only top-k filtering is active, the filtered tokens are sorted by descending order. | ||
If top-p filtering is active, the filtered tokens are sorted by ascending order. | ||
Args: | ||
temperature (`float`): | ||
Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases | ||
randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely | ||
token. | ||
top_k (`int`): | ||
The number of highest probability vocabulary tokens to keep for top-k-filtering. | ||
top_p (`float`): | ||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | ||
higher are kept for generation. | ||
""" | ||
|
||
temperature: float = 1.0 | ||
top_k: int = 0 | ||
top_p: float = 1.0 | ||
|
||
@classmethod | ||
def from_config(cls, generation_config: GenerationConfig) -> "FusedLogitsWarper": | ||
"""Instantiate a fused warper from a generation configuration. | ||
Args: | ||
generation_config (`~transformers.generation.GenerationConfig`): | ||
The generation configuration to be used as base parametrization for the fused warper. | ||
Returns: | ||
a `FusedLogitsWarper` or None if neither top-k nor top-p are configured. | ||
""" | ||
if generation_config.do_sample and generation_config.top_k == 0 and generation_config.top_p == 1.0: | ||
raise ValueError("Multinomial sampling requires at least top-k or top-p to be specified.") | ||
return cls(generation_config.temperature, generation_config.top_k, generation_config.top_p) | ||
|
||
def __call__(self, logits: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]: | ||
if self.temperature != 1.0: | ||
logits = logits / self.temperature | ||
|
||
do_top_k = self.top_k > 0 and self.top_k < logits.shape[-1] | ||
do_top_p = self.top_p < 1.0 and self.top_p > 0.0 | ||
|
||
if not do_top_k and not do_top_p: | ||
return logits, None | ||
|
||
class FastTopKLogitsWarper(LogitsWarper): | ||
r"""Returns [batch_size, top_k] scores and indices instead of [batch_size, vocab_size] scores.""" | ||
if do_top_k: | ||
sorted_logits, sorted_indices = torch.topk(logits, self.top_k) | ||
else: | ||
# Warning: not applying top-k filtering leads to this very slow sort operation | ||
sorted_logits, sorted_indices = torch.sort(logits) | ||
|
||
def __init__(self, top_k: int): | ||
self.top_k = top_k | ||
if do_top_p: | ||
if do_top_k: | ||
# logits have been sorted in descending order, so we need to flip them | ||
sorted_logits = torch.flip(sorted_logits, [-1]) | ||
sorted_indices = torch.flip(sorted_indices, [-1]) | ||
# We always keep the best logits and those whose cumulative probability is strictly higher than top_p | ||
cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) | ||
keep_mask = cum_probs > (1 - self.top_p) | ||
keep_mask[:, -1] = True | ||
# Set rejected logits to -inf so that they are ignored in downstream comparisons | ||
sorted_logits[~keep_mask] = float("-Inf") | ||
# Clip the [batch_size, vocab_size] logits tensor to speed-up downstream ops | ||
keep_by_batch = torch.sum(keep_mask, dim=-1) | ||
keep = torch.amax(keep_by_batch) | ||
sorted_logits = sorted_logits[:, -keep:] | ||
sorted_indices = sorted_indices[:, -keep:] | ||
|
||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | ||
top_k = min(self.top_k, scores.size(-1)) # Safety check | ||
# Remove all tokens with a probability less than the last token of the top-k | ||
return torch.topk(scores, top_k) | ||
return sorted_logits, sorted_indices |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import pytest | ||
import torch | ||
|
||
from optimum.neuron.generation import FusedLogitsWarper | ||
|
||
|
||
def test_temperature(): | ||
logits = torch.rand([10, 10, 10]) | ||
temperature = 0.9 | ||
warper = FusedLogitsWarper(temperature=temperature) | ||
warped_logits, warped_indices = warper(logits) | ||
assert warped_indices is None | ||
assert torch.allclose(warped_logits * temperature, logits) | ||
|
||
|
||
def shuffle_logits(logits): | ||
shuffled_logits = torch.empty_like(logits) | ||
batch_size, vocab_size = logits.shape | ||
for i in range(batch_size): | ||
shuffled_indices = torch.randperm(vocab_size) | ||
shuffled_logits[i] = logits[i, shuffled_indices] | ||
return shuffled_logits | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", [1, 2, 10]) | ||
@pytest.mark.parametrize("vocab_size", [10, 1000, 50000]) | ||
def test_top_k(batch_size, vocab_size): | ||
# Create sorted logits by descending order for easier comparison, then shuffle them | ||
sorted_logits = ( | ||
torch.arange(start=vocab_size, end=0, step=-1, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1) | ||
) | ||
shuffled_logits = shuffle_logits(sorted_logits) | ||
|
||
top_k = vocab_size // 2 | ||
|
||
warper = FusedLogitsWarper(top_k=top_k) | ||
|
||
filtered_logits, indices = warper(shuffled_logits) | ||
|
||
assert filtered_logits.shape[-1] == top_k | ||
assert indices.shape[-1] == top_k | ||
|
||
for i in range(batch_size): | ||
# Verify indices are correct | ||
assert torch.equal(shuffled_logits[i, indices[i]], filtered_logits[i]) | ||
# Since the original logits were sorted, filtered logits should match the start of the sequence | ||
assert torch.equal(filtered_logits[i], sorted_logits[i, :top_k]) | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", [1, 2, 10]) | ||
@pytest.mark.parametrize("vocab_size", [30, 1000, 50000]) | ||
def test_top_p(batch_size, vocab_size): | ||
# Create normalized logits | ||
norm_logits = torch.zeros(batch_size, vocab_size, dtype=torch.float) | ||
# We have 4 buckets, each corresponding to 0.25 of the total weights | ||
# With populations corresponding to 0.4, 0.3, 0.2 and 0.1 percent of the vocab_size | ||
buckets = [0.4, 0.3, 0.2, 0.1] | ||
bucket_weight = 1.0 / len(buckets) | ||
index = 0 | ||
for bucket in buckets: | ||
bucket_size = int(bucket * vocab_size) | ||
norm_logits[:, index : index + bucket_size] = bucket_weight / bucket_size | ||
index += bucket_size | ||
# Sanity check: the sum of the normalized logits should be one | ||
assert torch.allclose(torch.sum(norm_logits, axis=-1), torch.ones(batch_size)) | ||
|
||
# The first bucket cumulated sum is 0.25: set top_p to 75 % to exclude it | ||
warper = FusedLogitsWarper(top_p=0.75) | ||
|
||
# top_p will apply a softmax, so we need to take the log of our normalized logits | ||
sorted_logits = torch.log(norm_logits) | ||
shuffled_logits = shuffle_logits(sorted_logits) | ||
|
||
filtered_logits, indices = warper(shuffled_logits) | ||
|
||
# We expect all logits but the first bucket | ||
expected_n_logits = int((1.0 - buckets[0]) * vocab_size) | ||
assert filtered_logits.shape[-1] == expected_n_logits | ||
assert indices.shape[-1] == expected_n_logits | ||
|
||
for i in range(batch_size): | ||
# Verify indices are correct | ||
assert torch.equal(shuffled_logits[i, indices[i]], filtered_logits[i]) | ||
# Since the original logits were sorted, filtered logits should match the end of the sequence | ||
assert torch.equal(filtered_logits[i], sorted_logits[i, -expected_n_logits:]) | ||
|
||
|
||
def test_top_k_top_p(): | ||
warper = FusedLogitsWarper(top_k=3, top_p=0.8) | ||
|
||
# Prepare logits with normalized top-3, with distributions | ||
# so that cumulative prob > top_p requires 3, 2, and 1 logits resp. | ||
norm_top3_logits = torch.tensor( | ||
[[0.01, 0.01, 0.25, 0.25, 0.5], [0.01, 0.01, 0.2, 0.2, 0.6], [0.01, 0.01, 0.1, 0.1, 0.8]] | ||
) | ||
|
||
# Top_p will apply a softmax, so take the log | ||
sorted_logits = torch.log(norm_top3_logits) | ||
shuffled_logits = shuffle_logits(sorted_logits) | ||
|
||
filtered_logits, indices = warper(shuffled_logits) | ||
|
||
assert filtered_logits.shape[-1] == 3 | ||
assert torch.all(filtered_logits[0, :] == sorted_logits[0, -3:]) | ||
assert filtered_logits[1, 0] == float("-Inf") | ||
assert torch.all(filtered_logits[1, 1:] == sorted_logits[1, -2:]) | ||
assert torch.all(filtered_logits[2, :2] == float("-Inf")) | ||
assert filtered_logits[2, -1] == sorted_logits[2, -1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters