Skip to content

Commit

Permalink
Improve mistral model wrapper with better error handling. Also added …
Browse files Browse the repository at this point in the history
…optional functionality to allow the use of different models for choice and text.

PiperOrigin-RevId: 657516383
Change-Id: Iaa5aa4249b56a539566b4c566fae149cfe310853
  • Loading branch information
jzleibo authored and copybara-github committed Jul 30, 2024
1 parent 153610a commit adf8136
Showing 1 changed file with 43 additions and 15 deletions.
58 changes: 43 additions & 15 deletions concordia/language_model/mistral_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.


"""Language Model wrapper for Mistral models."""

from collections.abc import Collection, Sequence
import time
from concordia.language_model import language_model
from concordia.utils import measurements as measurements_lib
from concordia.utils import sampling
from mistralai.client import MistralClient
from mistralai.exceptions import MistralException
from mistralai.models.chat_completion import ChatMessage
from typing_extensions import override

_MAX_MULTIPLE_CHOICE_ATTEMPTS = 20
_MAX_CHAT_ATTEMPTS = 20

# At least one Mistral model supports completion mode.
COMPLETION_MODELS = (
Expand All @@ -38,6 +40,7 @@ def __init__(
self,
api_key: str,
model_name: str,
use_codestral_for_choices: bool = False,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
):
Expand All @@ -47,18 +50,28 @@ def __init__(
api_key: The API key to use when accessing the OpenAI API.
model_name: The language model to use. For more details, see
https://docs.mistral.ai/getting-started/models/.
use_codestral_for_choices: When enabled, use codestral for multiple choice
questions. Otherwise, use the model specified in the param `model_name`.
measurements: The measurements object to log usage statistics to.
channel: The channel to write the statistics to.
"""
self._api_key = api_key
self._model_name = model_name
self._text_model_name = model_name
self._measurements = measurements
self._channel = channel
self._client = MistralClient(api_key=api_key)

self._completion = False
if self._model_name in COMPLETION_MODELS:
self._completion = True
self._choice_model_name = self._text_model_name
if use_codestral_for_choices:
self._choice_model_name = 'codestral-latest'

self._completion_for_text = False
if self._text_model_name in COMPLETION_MODELS:
self._completion_for_text = True

self._completion_for_choice = False
if self._choice_model_name in COMPLETION_MODELS:
self._completion_for_choice = True

def _complete_text(
self,
Expand All @@ -76,7 +89,7 @@ def _complete_text(
terminators = ('\n\n',)

response = self._client.completion(
model=self._model_name,
model=self._choice_model_name,
prompt=prompt,
suffix=suffix,
temperature=temperature,
Expand Down Expand Up @@ -125,14 +138,29 @@ def _chat_text(
ChatMessage(role='user',
content=prompt)
]
response = self._client.chat(
model=self._model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
random_seed=seed,
for attempts in range(_MAX_CHAT_ATTEMPTS):
if attempts > 0:
print('Sleeping for 10 seconds...')
time.sleep(10)
# Increase temperature after the first failed attempt.
temperature = sampling.dynamically_adjust_temperature(
attempts, _MAX_CHAT_ATTEMPTS)
try:
response = self._client.chat(
model=self._text_model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
random_seed=seed,
)
except MistralException:
continue
else:
return response.choices[0].message.content

raise language_model.InvalidResponseError(
(f'Too many chat attempts.\n Prompt: {prompt}')
)
return response.choices[0].message.content

@override
def sample_text(
Expand All @@ -147,7 +175,7 @@ def sample_text(
) -> str:
del timeout

if self._completion:
if self._completion_for_text:
response = self._complete_text(
prompt=prompt,
suffix='.\n',
Expand Down Expand Up @@ -194,7 +222,7 @@ def sample_choice(
temperature = sampling.dynamically_adjust_temperature(
attempts, _MAX_MULTIPLE_CHOICE_ATTEMPTS)

if self._completion:
if self._completion_for_choice:
sample = self._complete_text(
prompt=prompt,
suffix=')',
Expand Down

0 comments on commit adf8136

Please sign in to comment.