Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] HybridCache not subscriptable #1047

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions guidance/models/transformers/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def __init__(self, model, tokenizer, compute_log_probs: bool, chat_template=None
self.model = model.__class__.__name__
self.device = self.model_obj.device # otherwise note the current device

self._past_key_values = None
self._past_key_values: Union[transformers_package.Cache, tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]], None] = None
self._cached_logits = None
self._cached_token_ids: list[int] = []

Expand Down Expand Up @@ -454,13 +454,42 @@ def get_logits(self, token_ids):

# reset the cache length according to that number of positions
past_key_values = self._past_key_values
past_length = past_key_values[0][0].size(-2) if past_key_values is not None else 0
if past_length > num_cached:
# note we recompute the last token because we don't bother to handle the special case of just computing logits
max_cache_shape = None
if past_key_values is None:
past_length = 0
elif isinstance(past_key_values, tuple):
past_length = past_key_values[0][0].size(-2)
elif isinstance(past_key_values, transformers_package.Cache):
# TODO: use model's `cache_position` as this may be deprecated in a future version
# https://github.com/huggingface/transformers/blob/70b07d97cf2c5f61fff55700b65528a1b6845cd2/src/transformers/cache_utils.py#L64
past_length = past_key_values.get_seq_length()
# TODO: use `get_max_cache_shape` as `get_max_length` will be deprecated in a future version
# (`get_max_cache_shape` is not yet available so we can't use it yet)
# https://github.com/huggingface/transformers/blob/70b07d97cf2c5f61fff55700b65528a1b6845cd2/src/transformers/cache_utils.py#L67
max_cache_shape = past_key_values.get_max_length()
else:
raise TypeError(f"Unknown type of past_key_values: {type(past_key_values)}")

if max_cache_shape is not None and len(token_ids) > max_cache_shape:
warnings.warn("Cache is too small. Resetting.")
# TODO: this seems to get set to the length of the first sequence we pass for models using
# StaticCache or HybridCache. We need to initialize our own cache with a large enough size
# if we want to continue generation with the same cache.
self._past_key_values = None
past_length = 0
Copy link
Collaborator Author

@hudson-ai hudson-ai Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is necessary for SlidingWindowCache?

elif past_length > num_cached:
past_length = max(0, num_cached - 1)
self._past_key_values = tuple(
tuple(p[..., :past_length, :] for p in v) for v in past_key_values
)
if isinstance(past_key_values, tuple):
self._past_key_values = tuple(
tuple(p[..., :past_length, :] for p in v) for v in past_key_values
)
elif isinstance(past_key_values, transformers_package.Cache) and hasattr(past_key_values, "crop"):
self._past_key_values.crop(past_length)
else:
warnings.warn(f"Cropping unsupported for cache type: {type(self._past_key_values)}. Resetting cache.")
self._past_key_values = None
past_length = 0

cache_token_ids[past_length:] = []

# call the model
Expand Down
Loading