Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AzureOpenAI API #88

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# MacOS metadata.
.DS_Store

# Pycharm metadata
.idea

# Byte-compiled Python code.
*.py[cod]
__pycache__/
Expand Down
69 changes: 69 additions & 0 deletions concordia/language_model/azure_gpt_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Language Model that uses OpenAI's GPT models using AZURE"""

import os
from openai import AzureOpenAI
from concordia.language_model import language_model
from concordia.utils import measurements as measurements_lib
from concordia.language_model.base_gpt_model import BaseGPTModel


class AzureGptLanguageModel(BaseGPTModel):
"""Language Model that uses OpenAI GPT models."""

def __init__(
self,
model_name: str,
*,
api_key: str | None = None,
azure_endpoint: str | None = None,
api_version: str | None = None,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
):
"""Initializes the instance.

Args:
model_name: The language model to use. For more details, see
https://platform.openai.com/docs/guides/text-generation/which-model-should-i-use.
api_key: The API key to use when accessing the OpenAI API. If None, will
use the OPENAI_API_KEY environment variable.
azure_endpoint: The Azure endpoint to use when accessing the OpenIA API.
If None, will use the AZURE_OPENAI_ENDPOINT environment variable.
api_version: The Azure api version to use when accessing the OpenIA API.
If None, will use the AZURE_OPENAI_API_VERSION environment variable.
measurements: The measurements object to log usage statistics to.
channel: The channel to write the statistics to.
"""
if api_key is None:
api_key = os.environ['AZURE_OPENAI_API_KEY']
if azure_endpoint is None:
azure_endpoint = os.environ['AZURE_OPENAI_ENDPOINT']
if api_version is None:
api_version = os.environ['AZURE_OPENAI_API_VERSION']

self._api_key = api_key
self._azure_endpoint = azure_endpoint
self._api_version = api_version
client = AzureOpenAI(api_key=self._api_key,
azure_endpoint=self._azure_endpoint,
api_version=self._api_version)

super().__init__(model_name=model_name,
client=client,
measurements=measurements,
channel=channel)
132 changes: 132 additions & 0 deletions concordia/language_model/base_gpt_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from collections.abc import Collection, Sequence
from concordia.language_model import language_model
from concordia.utils import measurements as measurements_lib
from concordia.utils import sampling
from openai import AzureOpenAI, OpenAI
from typing_extensions import override

_MAX_MULTIPLE_CHOICE_ATTEMPTS = 20


class BaseGPTModel(language_model.LanguageModel):
"""Base class for GPT models (OpenAI and Azure)"""

def __init__(
self,
model_name: str,
client: AzureOpenAI | OpenAI,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
):
"""Initializes the base instance."""
self._model_name = model_name
self._measurements = measurements
self._channel = channel
self._client = client

@override
def sample_text(
self,
prompt: str,
*,
max_tokens: int = language_model.DEFAULT_MAX_TOKENS,
terminators: Collection[str] = language_model.DEFAULT_TERMINATORS,
temperature: float = language_model.DEFAULT_TEMPERATURE,
timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS,
seed: int | None = None,
) -> str:
# Limit tokens to 4000 for GPT models
max_tokens = min(max_tokens, 4000)

messages = [
{'role': 'system',
'content': ('You always continue sentences provided ' +
'by the user and you never repeat what ' +
'the user already said.')},
{'role': 'user',
'content': 'Question: Is Jake a turtle?\nAnswer: Jake is '},
{'role': 'assistant',
'content': 'not a turtle.'},
{'role': 'user',
'content': ('Question: What is Priya doing right now?\nAnswer: ' +
'Priya is currently ')},
{'role': 'assistant',
'content': 'sleeping.'},
{'role': 'user',
'content': prompt}
]

response = self._client.chat.completions.create(
model=self._model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
stop=terminators,
seed=seed,
)

if self._measurements is not None:
self._measurements.publish_datum(
self._channel,
{'raw_text_length': len(response.choices[0].message.content)},
)
return response.choices[0].message.content

@override
def sample_choice(
self,
prompt: str,
responses: Sequence[str],
*,
seed: int | None = None,
) -> tuple[int, str, dict[str, float]]:
prompt = (
prompt
+ '\nRespond EXACTLY with one of the following strings:\n'
+ '\n'.join(responses) + '.'
)

sample = ''
answer = ''
for attempts in range(_MAX_MULTIPLE_CHOICE_ATTEMPTS):
temperature = sampling.dynamically_adjust_temperature(
attempts, _MAX_MULTIPLE_CHOICE_ATTEMPTS)

sample = self.sample_text(
prompt,
temperature=temperature,
seed=seed,
)
answer = sampling.extract_choice_response(sample)
try:
idx = responses.index(answer)
except ValueError:
continue
else:
if self._measurements is not None:
self._measurements.publish_datum(
self._channel, {'choices_calls': attempts}
)
debug = {}
return idx, responses[idx], debug

raise language_model.InvalidResponseError(
(f'Too many multiple choice attempts.\nLast attempt: {sample}, ' +
f'extracted: {answer}')
)
115 changes: 9 additions & 106 deletions concordia/language_model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,14 @@

"""Language Model that uses OpenAI's GPT models."""

from collections.abc import Collection, Sequence
import os

import openai
from concordia.language_model import language_model
from concordia.utils import measurements as measurements_lib
from concordia.utils import sampling
import openai
from typing_extensions import override

_MAX_MULTIPLE_CHOICE_ATTEMPTS = 20
from concordia.language_model.base_gpt_model import BaseGPTModel


class GptLanguageModel(language_model.LanguageModel):
class GptLanguageModel(BaseGPTModel):
"""Language Model that uses OpenAI GPT models."""

def __init__(
Expand All @@ -49,102 +44,10 @@ def __init__(
channel: The channel to write the statistics to.
"""
if api_key is None:
api_key = os.environ['OPENAI_API_KEY']
api_key = os.environ['OPENAI_API_KEY']
self._api_key = api_key
self._model_name = model_name
self._measurements = measurements
self._channel = channel
self._client = openai.OpenAI(api_key=self._api_key)

@override
def sample_text(
self,
prompt: str,
*,
max_tokens: int = language_model.DEFAULT_MAX_TOKENS,
terminators: Collection[str] = language_model.DEFAULT_TERMINATORS,
temperature: float = language_model.DEFAULT_TEMPERATURE,
timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS,
seed: int | None = None,
) -> str:
# gpt models do not support `max_tokens` > 4096.
max_tokens = min(max_tokens, 4000)

messages = [
{'role': 'system',
'content': ('You always continue sentences provided ' +
'by the user and you never repeat what ' +
'the user already said.')},
{'role': 'user',
'content': 'Question: Is Jake a turtle?\nAnswer: Jake is '},
{'role': 'assistant',
'content': 'not a turtle.'},
{'role': 'user',
'content': ('Question: What is Priya doing right now?\nAnswer: ' +
'Priya is currently ')},
{'role': 'assistant',
'content': 'sleeping.'},
{'role': 'user',
'content': prompt}
]

response = self._client.chat.completions.create(
model=self._model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
stop=terminators,
seed=seed,
)

if self._measurements is not None:
self._measurements.publish_datum(
self._channel,
{'raw_text_length': len(response.choices[0].message.content)},
)
return response.choices[0].message.content

@override
def sample_choice(
self,
prompt: str,
responses: Sequence[str],
*,
seed: int | None = None,
) -> tuple[int, str, dict[str, float]]:
prompt = (
prompt
+ '\nRespond EXACTLY with one of the following strings:\n'
+ '\n'.join(responses) + '.'
)

sample = ''
answer = ''
for attempts in range(_MAX_MULTIPLE_CHOICE_ATTEMPTS):
# Increase temperature after the first failed attempt.
temperature = sampling.dynamically_adjust_temperature(
attempts, _MAX_MULTIPLE_CHOICE_ATTEMPTS)

sample = self.sample_text(
prompt,
temperature=temperature,
seed=seed,
)
answer = sampling.extract_choice_response(sample)
try:
idx = responses.index(answer)
except ValueError:
continue
else:
if self._measurements is not None:
self._measurements.publish_datum(
self._channel, {'choices_calls': attempts}
)
debug = {}
return idx, responses[idx], debug

raise language_model.InvalidResponseError(
(f'Too many multiple choice attempts.\nLast attempt: {sample}, ' +
f'extracted: {answer}')
)
client = openai.OpenAI(api_key=self._api_key)
super().__init__(model_name=model_name,
client=client,
measurements=measurements,
channel=channel)