Skip to content

Commit

Permalink
Use default API keys.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 661182893
Change-Id: I3e734467208bfed9c4880a3134e2886bae29b07b
  • Loading branch information
jagapiou authored and copybara-github committed Aug 9, 2024
1 parent dba1105 commit ac10c34
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 31 deletions.
1 change: 1 addition & 0 deletions concordia/language_model/amazon_bedrock_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class AmazonBedrockLanguageModel(language_model.LanguageModel):
def __init__(
self,
model_name: str,
*,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
):
Expand Down
1 change: 1 addition & 0 deletions concordia/language_model/cloud_vertex_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class VertexLanguageModel(language_model.LanguageModel):
def __init__(
self,
model_name: str = 'gemini-pro',
*,
harm_block_threshold: HarmBlockThreshold = HarmBlockThreshold.BLOCK_NONE,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
Expand Down
9 changes: 7 additions & 2 deletions concordia/language_model/google_aistudio_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from collections.abc import Collection, Mapping, Sequence
import copy
import os
import time

from concordia.language_model import language_model
Expand Down Expand Up @@ -97,8 +98,9 @@ class GoogleAIStudioLanguageModel(language_model.LanguageModel):

def __init__(
self,
api_key: str,
model_name: str = 'gemini-1.5-pro-latest',
*,
api_key: str | None = None,
safety_settings: Sequence[Mapping[str, str]] = DEFAULT_SAFETY_SETTINGS,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
Expand All @@ -107,14 +109,17 @@ def __init__(
"""Initializes a model API instance using Google AI Studio.
Args:
api_key: The API key to use when accessing the Google AI Studio API
model_name: which language model to use. For more details, see
https://aistudio.google.com/
api_key: The API key to use when accessing the Google AI Studio API, if
None will use the GOOGLE_API_KEY environment variable.
safety_settings: See https://ai.google.dev/gemini-api/docs/safety-guidance
measurements: The measurements object to log usage statistics to
channel: The channel to write the statistics to
sleep_periodically: Whether to sleep between API calls to avoid rate limit
"""
if api_key is None:
api_key = os.environ['GOOGLE_API_KEY']
self._api_key = api_key
self._model_name = model_name
self._safety_settings = safety_settings
Expand Down
14 changes: 9 additions & 5 deletions concordia/language_model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"""Language Model that uses OpenAI's GPT models."""

from collections.abc import Collection, Sequence
import os

from concordia.language_model import language_model
from concordia.utils import measurements as measurements_lib
from concordia.utils import sampling
Expand All @@ -30,27 +32,29 @@ class GptLanguageModel(language_model.LanguageModel):

def __init__(
self,
api_key: str,
model_name: str,
*,
api_key: str | None = None,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
):
"""Initializes the instance.
Args:
api_key: The API key to use when accessing the OpenAI API.
model_name: The language model to use. For more details, see
https://platform.openai.com/docs/guides/text-generation/which-model-should-i-use.
api_key: The API key to use when accessing the OpenAI API. If None, will
use the OPENAI_API_KEY environment variable.
measurements: The measurements object to log usage statistics to.
channel: The channel to write the statistics to.
"""
if api_key is None:
api_key = os.environ['OPENAI_API_KEY']
self._api_key = api_key
self._model_name = model_name
self._measurements = measurements
self._channel = channel
self._client = openai.OpenAI(
api_key=api_key,
)
self._client = openai.OpenAI(api_key=self._api_key)

@override
def sample_text(
Expand Down
54 changes: 32 additions & 22 deletions concordia/language_model/mistral_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
"""Language Model wrapper for Mistral models."""

from collections.abc import Collection, Sequence
import os
import time

from concordia.language_model import language_model
from concordia.utils import measurements as measurements_lib
from concordia.utils import sampling
from mistralai import Mistral
from mistralai.models import AssistantMessage
from mistralai.models import SystemMessage
from mistralai.models import UserMessage
import mistralai
from mistralai import models
from typing_extensions import override

_MAX_MULTIPLE_CHOICE_ATTEMPTS = 20
Expand All @@ -39,28 +39,32 @@ class MistralLanguageModel(language_model.LanguageModel):

def __init__(
self,
api_key: str,
model_name: str,
*,
api_key: str | None = None,
use_codestral_for_choices: bool = False,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
):
"""Initializes the instance.
Args:
api_key: The API key to use when accessing the OpenAI API.
model_name: The language model to use. For more details, see
https://docs.mistral.ai/getting-started/models/.
api_key: The API key to use when accessing the OpenAI API, if None will
use the MISTRAL_API_KEY environment variable.
use_codestral_for_choices: When enabled, use codestral for multiple choice
questions. Otherwise, use the model specified in the param `model_name`.
measurements: The measurements object to log usage statistics to.
channel: The channel to write the statistics to.
"""
if api_key is None:
api_key = os.environ['MISTRAL_API_KEY']
self._api_key = api_key
self._text_model_name = model_name
self._measurements = measurements
self._channel = channel
self._client = Mistral(api_key=api_key)
self._client = mistralai.Mistral(api_key=api_key)

self._choice_model_name = self._text_model_name
if use_codestral_for_choices:
Expand Down Expand Up @@ -123,21 +127,27 @@ def _chat_text(
) -> str:
del terminators
messages = [
SystemMessage(role='system',
content=('You always continue sentences provided ' +
'by the user and you never repeat what ' +
'the user already said.')),
UserMessage(role='user',
content='Question: Is Jake a turtle?\nAnswer: Jake is '),
AssistantMessage(role='assistant',
content='not a turtle.'),
UserMessage(role='user',
content=('Question: What is Priya doing right '
'now?\nAnswer: Priya is currently ')),
AssistantMessage(role='assistant',
content='sleeping.'),
UserMessage(role='user',
content=prompt)
models.SystemMessage(
role='system',
content=(
'You always continue sentences provided by the user and you '
'never repeat what the user already said.'
)
),
models.UserMessage(
role='user',
content='Question: Is Jake a turtle?\nAnswer: Jake is ',
),
models.AssistantMessage(role='assistant', content='not a turtle.'),
models.UserMessage(
role='user',
content=(
'Question: What is Priya doing right now?\n'
'Answer: Priya is currently '
),
),
models.AssistantMessage(role='assistant', content='sleeping.'),
models.UserMessage(role='user', content=prompt),
]
for attempts in range(_MAX_CHAT_ATTEMPTS):
if attempts > 0:
Expand Down
3 changes: 1 addition & 2 deletions concordia/language_model/pytorch_gemma_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ class PyTorchGemmaLanguageModel(language_model.LanguageModel):

def __init__(
self,
*,
# The default model is the 2 billion parameter instruction-tuned Gemma.
model_name: str = 'google/gemma-2b-it',
*,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
) -> None:
Expand Down

0 comments on commit ac10c34

Please sign in to comment.