Skip to content

Commit

Permalink
Replace transformers list of logits warpers by a fused logic warper (#…
Browse files Browse the repository at this point in the history
…234)

* doc(tgi): typo

* feat(generate): add fused logits warper

* feat(generate): use FusedLogitsWarper
  • Loading branch information
dacorvo authored Sep 19, 2023
1 parent b538488 commit 0cab527
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 29 deletions.
2 changes: 1 addition & 1 deletion optimum/neuron/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .logits_process import FastTopKLogitsWarper
from .logits_process import FusedLogitsWarper
from .token_selector import TokenSelector
from .utils import NeuronGenerationMixin
92 changes: 83 additions & 9 deletions optimum/neuron/generation/logits_process.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,88 @@
from dataclasses import dataclass
from typing import Tuple

import torch
from transformers.generation import LogitsWarper
from transformers import GenerationConfig


@dataclass
class FusedLogitsWarper:
"""
A class that performs top-k then top-p filtering, optionally applying a temperature.
Top-k filtering only keeps the `k` tokens with the best scores.
Top-p filtering only keeps the top tokens whose cumulated probability is above `p`.
The filtered tokens are returned as a list of indices, along with the corresponding subset of
the original logits.
If only top-k filtering is active, the filtered tokens are sorted by descending order.
If top-p filtering is active, the filtered tokens are sorted by ascending order.
Args:
temperature (`float`):
Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases
randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely
token.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
"""

temperature: float = 1.0
top_k: int = 0
top_p: float = 1.0

@classmethod
def from_config(cls, generation_config: GenerationConfig) -> "FusedLogitsWarper":
"""Instantiate a fused warper from a generation configuration.
Args:
generation_config (`~transformers.generation.GenerationConfig`):
The generation configuration to be used as base parametrization for the fused warper.
Returns:
a `FusedLogitsWarper` or None if neither top-k nor top-p are configured.
"""
if generation_config.do_sample and generation_config.top_k == 0 and generation_config.top_p == 1.0:
raise ValueError("Multinomial sampling requires at least top-k or top-p to be specified.")
return cls(generation_config.temperature, generation_config.top_k, generation_config.top_p)

def __call__(self, logits: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
if self.temperature != 1.0:
logits = logits / self.temperature

do_top_k = self.top_k > 0 and self.top_k < logits.shape[-1]
do_top_p = self.top_p < 1.0 and self.top_p > 0.0

if not do_top_k and not do_top_p:
return logits, None

class FastTopKLogitsWarper(LogitsWarper):
r"""Returns [batch_size, top_k] scores and indices instead of [batch_size, vocab_size] scores."""
if do_top_k:
sorted_logits, sorted_indices = torch.topk(logits, self.top_k)
else:
# Warning: not applying top-k filtering leads to this very slow sort operation
sorted_logits, sorted_indices = torch.sort(logits)

def __init__(self, top_k: int):
self.top_k = top_k
if do_top_p:
if do_top_k:
# logits have been sorted in descending order, so we need to flip them
sorted_logits = torch.flip(sorted_logits, [-1])
sorted_indices = torch.flip(sorted_indices, [-1])
# We always keep the best logits and those whose cumulative probability is strictly higher than top_p
cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
keep_mask = cum_probs > (1 - self.top_p)
keep_mask[:, -1] = True
# Set rejected logits to -inf so that they are ignored in downstream comparisons
sorted_logits[~keep_mask] = float("-Inf")
# Clip the [batch_size, vocab_size] logits tensor to speed-up downstream ops
keep_by_batch = torch.sum(keep_mask, dim=-1)
keep = torch.amax(keep_by_batch)
sorted_logits = sorted_logits[:, -keep:]
sorted_indices = sorted_indices[:, -keep:]

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
top_k = min(self.top_k, scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
return torch.topk(scores, top_k)
return sorted_logits, sorted_indices
24 changes: 6 additions & 18 deletions optimum/neuron/generation/token_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
GenerationMixin,
LogitsProcessorList,
StoppingCriteriaList,
TopKLogitsWarper,
)
from transformers.generation.utils import GenerationMode

from .logits_process import FastTopKLogitsWarper
from .logits_process import FusedLogitsWarper


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -51,13 +50,6 @@ def __init__(
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.logits_warper = logits_warper
if self.mode == GenerationMode.SAMPLE:
assert len(self.logits_warper) > 0
last_warper = self.logits_warper[-1]
self.fast_topk = isinstance(last_warper, TopKLogitsWarper)
if self.fast_topk:
# Replace the last warping operation by a faster alternative
self.logits_warper[-1] = FastTopKLogitsWarper(last_warper.top_k)

@classmethod
def create(
Expand Down Expand Up @@ -134,7 +126,7 @@ def create(

logits_warper = None
if generation_mode == GenerationMode.SAMPLE:
logits_warper = model._get_logits_warper(generation_config)
logits_warper = FusedLogitsWarper.from_config(generation_config)

return cls(
mode=generation_mode,
Expand Down Expand Up @@ -164,16 +156,12 @@ def select(self, input_ids: torch.LongTensor, logits: torch.Tensor) -> torch.Lon
return torch.argmax(scores, dim=-1)

def _sample(self, scores: torch.Tensor) -> torch.LongTensor:
if self.fast_topk:
# Get [batch_size, top_k] scores and indices instead of [batch_size, vocab_size] scores
scores, next_token_indices = self.logits_warper(None, scores)
else:
scores = self.logits_warper(None, scores)
# Get [batch_size, kept] scores and indices instead of [batch_size, vocab_size] scores
scores, next_token_indices = self.logits_warper(scores)

# sample
probs = torch.nn.functional.softmax(scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1)
if self.fast_topk:
# Convert the topk relative tokens to actual vocabulary tokens
next_tokens = torch.gather(next_token_indices, 1, next_tokens)
# Convert the filtered tokens to actual vocabulary tokens
next_tokens = torch.gather(next_token_indices, 1, next_tokens)
return next_tokens.squeeze(1)
108 changes: 108 additions & 0 deletions tests/generation/test_fused_logits_warper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import pytest
import torch

from optimum.neuron.generation import FusedLogitsWarper


def test_temperature():
logits = torch.rand([10, 10, 10])
temperature = 0.9
warper = FusedLogitsWarper(temperature=temperature)
warped_logits, warped_indices = warper(logits)
assert warped_indices is None
assert torch.allclose(warped_logits * temperature, logits)


def shuffle_logits(logits):
shuffled_logits = torch.empty_like(logits)
batch_size, vocab_size = logits.shape
for i in range(batch_size):
shuffled_indices = torch.randperm(vocab_size)
shuffled_logits[i] = logits[i, shuffled_indices]
return shuffled_logits


@pytest.mark.parametrize("batch_size", [1, 2, 10])
@pytest.mark.parametrize("vocab_size", [10, 1000, 50000])
def test_top_k(batch_size, vocab_size):
# Create sorted logits by descending order for easier comparison, then shuffle them
sorted_logits = (
torch.arange(start=vocab_size, end=0, step=-1, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1)
)
shuffled_logits = shuffle_logits(sorted_logits)

top_k = vocab_size // 2

warper = FusedLogitsWarper(top_k=top_k)

filtered_logits, indices = warper(shuffled_logits)

assert filtered_logits.shape[-1] == top_k
assert indices.shape[-1] == top_k

for i in range(batch_size):
# Verify indices are correct
assert torch.equal(shuffled_logits[i, indices[i]], filtered_logits[i])
# Since the original logits were sorted, filtered logits should match the start of the sequence
assert torch.equal(filtered_logits[i], sorted_logits[i, :top_k])


@pytest.mark.parametrize("batch_size", [1, 2, 10])
@pytest.mark.parametrize("vocab_size", [30, 1000, 50000])
def test_top_p(batch_size, vocab_size):
# Create normalized logits
norm_logits = torch.zeros(batch_size, vocab_size, dtype=torch.float)
# We have 4 buckets, each corresponding to 0.25 of the total weights
# With populations corresponding to 0.4, 0.3, 0.2 and 0.1 percent of the vocab_size
buckets = [0.4, 0.3, 0.2, 0.1]
bucket_weight = 1.0 / len(buckets)
index = 0
for bucket in buckets:
bucket_size = int(bucket * vocab_size)
norm_logits[:, index : index + bucket_size] = bucket_weight / bucket_size
index += bucket_size
# Sanity check: the sum of the normalized logits should be one
assert torch.allclose(torch.sum(norm_logits, axis=-1), torch.ones(batch_size))

# The first bucket cumulated sum is 0.25: set top_p to 75 % to exclude it
warper = FusedLogitsWarper(top_p=0.75)

# top_p will apply a softmax, so we need to take the log of our normalized logits
sorted_logits = torch.log(norm_logits)
shuffled_logits = shuffle_logits(sorted_logits)

filtered_logits, indices = warper(shuffled_logits)

# We expect all logits but the first bucket
expected_n_logits = int((1.0 - buckets[0]) * vocab_size)
assert filtered_logits.shape[-1] == expected_n_logits
assert indices.shape[-1] == expected_n_logits

for i in range(batch_size):
# Verify indices are correct
assert torch.equal(shuffled_logits[i, indices[i]], filtered_logits[i])
# Since the original logits were sorted, filtered logits should match the end of the sequence
assert torch.equal(filtered_logits[i], sorted_logits[i, -expected_n_logits:])


def test_top_k_top_p():
warper = FusedLogitsWarper(top_k=3, top_p=0.8)

# Prepare logits with normalized top-3, with distributions
# so that cumulative prob > top_p requires 3, 2, and 1 logits resp.
norm_top3_logits = torch.tensor(
[[0.01, 0.01, 0.25, 0.25, 0.5], [0.01, 0.01, 0.2, 0.2, 0.6], [0.01, 0.01, 0.1, 0.1, 0.8]]
)

# Top_p will apply a softmax, so take the log
sorted_logits = torch.log(norm_top3_logits)
shuffled_logits = shuffle_logits(sorted_logits)

filtered_logits, indices = warper(shuffled_logits)

assert filtered_logits.shape[-1] == 3
assert torch.all(filtered_logits[0, :] == sorted_logits[0, -3:])
assert filtered_logits[1, 0] == float("-Inf")
assert torch.all(filtered_logits[1, 1:] == sorted_logits[1, -2:])
assert torch.all(filtered_logits[2, :2] == float("-Inf"))
assert filtered_logits[2, -1] == sorted_logits[2, -1]
2 changes: 1 addition & 1 deletion text-generation-inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ You can query the model using either the `/generate` or `/generate_stream` route
```
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}' \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```

Expand Down

0 comments on commit 0cab527

Please sign in to comment.