Skip to content

Commit

Permalink
Fix Assistant cache
Browse files Browse the repository at this point in the history
  • Loading branch information
artitw committed Feb 12, 2024
1 parent 5e95d04 commit 8ced9d5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="text2text",
version="1.4.2",
version="1.4.3",
author="artitw",
author_email="[email protected]",
description="Text2Text: Crosslingual NLP/G toolkit",
Expand Down
17 changes: 6 additions & 11 deletions text2text/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,12 @@ def chat_completion(self, messages=[{"role": "user", "content": "hello"}], strea

input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)

attention_mask = None
past_key_values = None
for i in range(1,len(messages)):
past_input_string = tokenizer.apply_chat_template(messages[:-i], tokenize=False)
past_key_values = cache.get(past_input_string, None)
if past_key_values:
seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1)
attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device)
break

if attention_mask == None:
past_input_string = tokenizer.apply_chat_template(messages[:-1], tokenize=False)
past_key_values = cache.get(past_input_string, None)
if past_key_values:
seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1)
attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device)
else:
attention_mask = torch.ones_like(input_ids)

results = model.generate(
Expand Down

0 comments on commit 8ced9d5

Please sign in to comment.