diff --git a/concordia/language_model/amazon_bedrock_model.py b/concordia/language_model/amazon_bedrock_model.py index 3447445d..1effa9b4 100644 --- a/concordia/language_model/amazon_bedrock_model.py +++ b/concordia/language_model/amazon_bedrock_model.py @@ -58,6 +58,7 @@ class AmazonBedrockLanguageModel(language_model.LanguageModel): def __init__( self, model_name: str, + *, measurements: measurements_lib.Measurements | None = None, channel: str = language_model.DEFAULT_STATS_CHANNEL, ): diff --git a/concordia/language_model/cloud_vertex_model.py b/concordia/language_model/cloud_vertex_model.py index c1de7016..c0dd3432 100644 --- a/concordia/language_model/cloud_vertex_model.py +++ b/concordia/language_model/cloud_vertex_model.py @@ -68,6 +68,7 @@ class VertexLanguageModel(language_model.LanguageModel): def __init__( self, model_name: str = 'gemini-pro', + *, harm_block_threshold: HarmBlockThreshold = HarmBlockThreshold.BLOCK_NONE, measurements: measurements_lib.Measurements | None = None, channel: str = language_model.DEFAULT_STATS_CHANNEL, diff --git a/concordia/language_model/google_aistudio_model.py b/concordia/language_model/google_aistudio_model.py index 9d3bee09..ec5230fc 100644 --- a/concordia/language_model/google_aistudio_model.py +++ b/concordia/language_model/google_aistudio_model.py @@ -16,6 +16,7 @@ from collections.abc import Collection, Mapping, Sequence import copy +import os import time from concordia.language_model import language_model @@ -97,8 +98,9 @@ class GoogleAIStudioLanguageModel(language_model.LanguageModel): def __init__( self, - api_key: str, model_name: str = 'gemini-1.5-pro-latest', + *, + api_key: str | None = None, safety_settings: Sequence[Mapping[str, str]] = DEFAULT_SAFETY_SETTINGS, measurements: measurements_lib.Measurements | None = None, channel: str = language_model.DEFAULT_STATS_CHANNEL, @@ -107,14 +109,17 @@ def __init__( """Initializes a model API instance using Google AI Studio. Args: - api_key: The API key to use when accessing the Google AI Studio API model_name: which language model to use. For more details, see https://aistudio.google.com/ + api_key: The API key to use when accessing the Google AI Studio API, if + None will use the GOOGLE_API_KEY environment variable. safety_settings: See https://ai.google.dev/gemini-api/docs/safety-guidance measurements: The measurements object to log usage statistics to channel: The channel to write the statistics to sleep_periodically: Whether to sleep between API calls to avoid rate limit """ + if api_key is None: + api_key = os.environ['GOOGLE_API_KEY'] self._api_key = api_key self._model_name = model_name self._safety_settings = safety_settings diff --git a/concordia/language_model/gpt_model.py b/concordia/language_model/gpt_model.py index 745a62a6..93d4ab50 100644 --- a/concordia/language_model/gpt_model.py +++ b/concordia/language_model/gpt_model.py @@ -16,6 +16,8 @@ """Language Model that uses OpenAI's GPT models.""" from collections.abc import Collection, Sequence +import os + from concordia.language_model import language_model from concordia.utils import measurements as measurements_lib from concordia.utils import sampling @@ -30,27 +32,29 @@ class GptLanguageModel(language_model.LanguageModel): def __init__( self, - api_key: str, model_name: str, + *, + api_key: str | None = None, measurements: measurements_lib.Measurements | None = None, channel: str = language_model.DEFAULT_STATS_CHANNEL, ): """Initializes the instance. Args: - api_key: The API key to use when accessing the OpenAI API. model_name: The language model to use. For more details, see https://platform.openai.com/docs/guides/text-generation/which-model-should-i-use. + api_key: The API key to use when accessing the OpenAI API. If None, will + use the OPENAI_API_KEY environment variable. measurements: The measurements object to log usage statistics to. channel: The channel to write the statistics to. """ + if api_key is None: + api_key = os.environ['OPENAI_API_KEY'] self._api_key = api_key self._model_name = model_name self._measurements = measurements self._channel = channel - self._client = openai.OpenAI( - api_key=api_key, - ) + self._client = openai.OpenAI(api_key=self._api_key) @override def sample_text( diff --git a/concordia/language_model/mistral_model.py b/concordia/language_model/mistral_model.py index bfb597bb..e4861948 100644 --- a/concordia/language_model/mistral_model.py +++ b/concordia/language_model/mistral_model.py @@ -15,14 +15,14 @@ """Language Model wrapper for Mistral models.""" from collections.abc import Collection, Sequence +import os import time + from concordia.language_model import language_model from concordia.utils import measurements as measurements_lib from concordia.utils import sampling -from mistralai import Mistral -from mistralai.models import AssistantMessage -from mistralai.models import SystemMessage -from mistralai.models import UserMessage +import mistralai +from mistralai import models from typing_extensions import override _MAX_MULTIPLE_CHOICE_ATTEMPTS = 20 @@ -39,8 +39,9 @@ class MistralLanguageModel(language_model.LanguageModel): def __init__( self, - api_key: str, model_name: str, + *, + api_key: str | None = None, use_codestral_for_choices: bool = False, measurements: measurements_lib.Measurements | None = None, channel: str = language_model.DEFAULT_STATS_CHANNEL, @@ -48,19 +49,22 @@ def __init__( """Initializes the instance. Args: - api_key: The API key to use when accessing the OpenAI API. model_name: The language model to use. For more details, see https://docs.mistral.ai/getting-started/models/. + api_key: The API key to use when accessing the OpenAI API, if None will + use the MISTRAL_API_KEY environment variable. use_codestral_for_choices: When enabled, use codestral for multiple choice questions. Otherwise, use the model specified in the param `model_name`. measurements: The measurements object to log usage statistics to. channel: The channel to write the statistics to. """ + if api_key is None: + api_key = os.environ['MISTRAL_API_KEY'] self._api_key = api_key self._text_model_name = model_name self._measurements = measurements self._channel = channel - self._client = Mistral(api_key=api_key) + self._client = mistralai.Mistral(api_key=api_key) self._choice_model_name = self._text_model_name if use_codestral_for_choices: @@ -123,21 +127,27 @@ def _chat_text( ) -> str: del terminators messages = [ - SystemMessage(role='system', - content=('You always continue sentences provided ' + - 'by the user and you never repeat what ' + - 'the user already said.')), - UserMessage(role='user', - content='Question: Is Jake a turtle?\nAnswer: Jake is '), - AssistantMessage(role='assistant', - content='not a turtle.'), - UserMessage(role='user', - content=('Question: What is Priya doing right ' - 'now?\nAnswer: Priya is currently ')), - AssistantMessage(role='assistant', - content='sleeping.'), - UserMessage(role='user', - content=prompt) + models.SystemMessage( + role='system', + content=( + 'You always continue sentences provided by the user and you ' + 'never repeat what the user already said.' + ) + ), + models.UserMessage( + role='user', + content='Question: Is Jake a turtle?\nAnswer: Jake is ', + ), + models.AssistantMessage(role='assistant', content='not a turtle.'), + models.UserMessage( + role='user', + content=( + 'Question: What is Priya doing right now?\n' + 'Answer: Priya is currently ' + ), + ), + models.AssistantMessage(role='assistant', content='sleeping.'), + models.UserMessage(role='user', content=prompt), ] for attempts in range(_MAX_CHAT_ATTEMPTS): if attempts > 0: diff --git a/concordia/language_model/pytorch_gemma_model.py b/concordia/language_model/pytorch_gemma_model.py index f055af38..dd59721d 100644 --- a/concordia/language_model/pytorch_gemma_model.py +++ b/concordia/language_model/pytorch_gemma_model.py @@ -31,9 +31,8 @@ class PyTorchGemmaLanguageModel(language_model.LanguageModel): def __init__( self, - *, - # The default model is the 2 billion parameter instruction-tuned Gemma. model_name: str = 'google/gemma-2b-it', + *, measurements: measurements_lib.Measurements | None = None, channel: str = language_model.DEFAULT_STATS_CHANNEL, ) -> None: