Skip to content

Commit

Permalink
Fix GCloud model and ensure consistent API for LanguageModel
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 587772666
Change-Id: I6ecf64887c4b0221d9be5ada27272b2645359745
  • Loading branch information
jagapiou authored and copybara-github committed Dec 4, 2023
1 parent 7668e1b commit 6520408
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 70 deletions.
37 changes: 19 additions & 18 deletions concordia/language_model/gcloud_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
"""Google Cloud Language Model."""

from collections.abc import Collection, Sequence
import sys

from concordia.language_model import language_model
from concordia.utils import text
from google import auth
from typing_extensions import override
import vertexai
from vertexai.preview import language_models as vertex_models

DEFAULT_MAX_TOKENS = 50
MAX_MULTIPLE_CHOICE_ATTEMPTS = 20


Expand All @@ -34,7 +33,7 @@ def __init__(
project_id: str,
model_name: str = 'text-bison@001',
location: str = 'us-central1',
credentials: auth.credentials.Credentials = None
credentials: auth.credentials.Credentials | None = None,
) -> None:
"""Initializes a model instance using the Google Cloud language model API.
Expand All @@ -45,46 +44,46 @@ def __init__(
credentials: Custom credentials to use when making API calls. If not
provided credentials will be ascertained from the environment.
"""
if not credentials:
credentials = auth.default()[0]
if credentials is None:
credentials, _ = auth.default()
vertexai.init(
project=project_id, location=location, credentials=credentials)
project=project_id, location=location, credentials=credentials
)
self._model = vertex_models.TextGenerationModel.from_pretrained(model_name)

@override
def sample_text(
self,
prompt: str,
*,
timeout: float = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
max_characters: int = sys.maxsize,
terminators: Collection[str] = (),
temperature: float = 0.5,
max_tokens: int = language_model.DEFAULT_MAX_TOKENS,
max_characters: int = language_model.DEFAULT_MAX_CHARACTERS,
terminators: Collection[str] = language_model.DEFAULT_TERMINATORS,
temperature: float = language_model.DEFAULT_TEMPERATURE,
timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS,
seed: int | None = None,
) -> str:
"""See base class."""
if timeout is not None:
raise NotImplementedError('Unclear how to set timeout for cloud models.')
if seed is not None:
raise NotImplementedError('Unclear how to set seed for cloud models.')

max_tokens = min(max_tokens, max_characters)
sample = self._model.predict(
prompt,
temperature=temperature,
max_output_tokens=max_tokens,)
max_output_tokens=max_tokens,
)
return text.truncate(
sample.text, max_length=max_characters, delimiters=terminators
)

@override
def sample_choice(
self,
prompt: str,
responses: Sequence[str],
*,
seed: int | None = None,
) -> tuple[int, str, dict[str, float]]:
"""See base class."""
max_characters = max([len(response) for response in responses])

for _ in range(MAX_MULTIPLE_CHOICE_ATTEMPTS):
Expand All @@ -93,7 +92,8 @@ def sample_choice(
max_tokens=1,
max_characters=max_characters,
temperature=0.0,
seed=seed)
seed=seed,
)
try:
idx = responses.index(sample)
except ValueError:
Expand All @@ -103,4 +103,5 @@ def sample_choice(
return idx, responses[idx], debug

raise language_model.InvalidResponseError(
'Too many multiple choice attempts.')
'Too many multiple choice attempts.'
)
10 changes: 8 additions & 2 deletions concordia/language_model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
import sys
from typing import Any

DEFAULT_MAX_TOKENS = 50
DEFAULT_TEMPERATURE = 0.5
DEFAULT_MAX_CHARACTERS = sys.maxsize
DEFAULT_TERMINATORS = ()
DEFAULT_TIMEOUT_SECONDS = 60
DEFAULT_MAX_CHARACTERS = sys.maxsize
DEFAULT_MAX_TOKENS = 50


class InvalidResponseError(Exception):
Expand All @@ -43,6 +44,7 @@ def sample_text(
max_characters: int = DEFAULT_MAX_CHARACTERS,
terminators: Collection[str] = DEFAULT_TERMINATORS,
temperature: float = DEFAULT_TEMPERATURE,
timeout: float = DEFAULT_TIMEOUT_SECONDS,
seed: int | None = None,
) -> str:
"""Samples text from the model.
Expand All @@ -57,10 +59,14 @@ def sample_text(
terminators: the response will be terminated before any of these
characters.
temperature: temperature for the model.
timeout: timeout for the request.
seed: optional seed for the sampling. If None a random seed will be used.
Returns:
The sampled response (i.e. does not iclude the prompt).
Raises:
TimeoutError: if the operation times out.
"""
raise NotImplementedError

Expand Down
17 changes: 9 additions & 8 deletions concordia/language_model/retry_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

"""Wrapper to retry calls to an underlying language model."""

from collections.abc import Collection, Sequence
import copy
from typing import Any, Mapping, Tuple, Type
from collections.abc import Collection, Sequence, Mapping
from typing import Any, Type

from concordia.language_model import language_model
import retry
from typing_extensions import override


class RetryLanguageModel(language_model.LanguageModel):
Expand All @@ -29,9 +29,9 @@ def __init__(
self,
model: language_model.LanguageModel,
retry_on_exceptions: Collection[Type[Exception]] = (Exception,),
retry_tries: float = 3.,
retry_tries: int = 3,
retry_delay: float = 2.,
jitter: Tuple[float, float] = (0.0, 1.0),
jitter: tuple[float, float] = (0.0, 1.0),
) -> None:
"""Wrap the underlying language model with retries on given exceptions.
Expand All @@ -43,11 +43,12 @@ def __init__(
jitter: tuple of minimum and maximum jitter to add to the retry.
"""
self._model = model
self._retry_on_exceptions = copy.deepcopy(retry_on_exceptions)
self._retry_on_exceptions = tuple(retry_on_exceptions)
self._retry_tries = retry_tries
self._retry_delay = retry_delay
self._jitter = jitter

@override
def sample_text(
self,
prompt: str,
Expand All @@ -56,9 +57,9 @@ def sample_text(
max_characters: int = language_model.DEFAULT_MAX_CHARACTERS,
terminators: Collection[str] = language_model.DEFAULT_TERMINATORS,
temperature: float = language_model.DEFAULT_TEMPERATURE,
timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS,
seed: int | None = None,
) -> str:
"""See base class."""
@retry.retry(self._retry_on_exceptions, tries=self._retry_tries,
delay=self._retry_delay, jitter=self._jitter)
def _sample_text(model, prompt, *, max_tokens=max_tokens,
Expand All @@ -72,14 +73,14 @@ def _sample_text(model, prompt, *, max_tokens=max_tokens,
max_characters=max_characters, terminators=terminators,
temperature=temperature, seed=seed)

@override
def sample_choice(
self,
prompt: str,
responses: Sequence[str],
*,
seed: int | None = None,
) -> tuple[int, str, Mapping[str, Any]]:
"""See base class."""
@retry.retry(self._retry_on_exceptions, tries=self._retry_tries,
delay=self._retry_delay, jitter=self._jitter)
def _sample_choice(model, prompt, responses, *, seed):
Expand Down
40 changes: 8 additions & 32 deletions concordia/language_model/sax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@

from collections.abc import Collection, Sequence
import concurrent.futures
import sys

from concordia.language_model import language_model
from concordia.utils import text
import numpy as np
from saxml.client.python import sax
from scipy import special
from typing_extensions import override

DEFAULT_MAX_TOKENS = 50
DEFAULT_TIMEOUT_SECONDS = 60
DEFAULT_NUM_CONNECTIONS = 3


Expand All @@ -55,31 +53,18 @@ def __init__(
self._model = sax.Model(path, options).LM()
self._deterministic_multiple_choice = deterministic_multiple_choice

@override
def sample_text(
self,
prompt: str,
*,
timeout: float = DEFAULT_TIMEOUT_SECONDS,
max_tokens: int = DEFAULT_MAX_TOKENS,
max_characters: int = sys.maxsize,
terminators: Collection[str] = (),
temperature: float = 0.5,
max_tokens: int = language_model.DEFAULT_MAX_TOKENS,
max_characters: int = language_model.DEFAULT_MAX_CHARACTERS,
terminators: Collection[str] = language_model.DEFAULT_TERMINATORS,
temperature: float = language_model.DEFAULT_TEMPERATURE,
timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS,
seed: int | None = None,
) -> str:
"""Samples a string from the model.
Args:
prompt: the prompt to generate a response for.
timeout: timeout for the request.
max_tokens: maximum number of tokens to generate.
max_characters: maximum number of characters to generate.
terminators: delimiters to use in the generated response.
temperature: temperature for the model.
seed: seed for the random number generator.
Returns:
A string of the generated response.
"""
if seed is not None:
raise NotImplementedError('Unclear how to set seed for sax models.')
max_tokens = min(max_tokens, max_characters)
Expand All @@ -92,23 +77,14 @@ def sample_text(
sample, max_length=max_characters, delimiters=terminators
)

@override
def sample_choice(
self,
prompt: str,
responses: Sequence[str],
*,
seed: int | None = None,
) -> tuple[int, str, dict[str, float]]:
"""Samples a response from the model.
Args:
prompt: the prompt to generate a response for.
responses: the responses to sample.
seed: seed for the random number generator.
Returns:
A tuple of (index, response, debug).
"""
scores = self._score_responses(prompt, responses)
probs = special.softmax(scores)
entropy = probs @ np.log(probs)
Expand Down
20 changes: 10 additions & 10 deletions concordia/tests/mock_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
"""A mock Language Model."""

from collections.abc import Collection, Sequence
import sys

from concordia.language_model import language_model
from typing_extensions import override


class MockModel(language_model.LanguageModel):
Expand All @@ -32,36 +32,36 @@ def __init__(
"""
self._response = response

@override
def sample_text(
self,
prompt: str,
*,
timeout: float = 0,
max_tokens: int = 0,
max_characters: int = sys.maxsize,
terminators: Collection[str] = (),
temperature: float = 0.5,
max_tokens: int = language_model.DEFAULT_MAX_TOKENS,
max_characters: int = language_model.DEFAULT_MAX_CHARACTERS,
terminators: Collection[str] = language_model.DEFAULT_TERMINATORS,
temperature: float = language_model.DEFAULT_TEMPERATURE,
timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS,
seed: int | None = None,
) -> str:
"""See base class."""
del (
prompt,
timeout,
max_tokens,
max_characters,
terminators,
seed,
temperature,
timeout,
seed,
)
return self._response

@override
def sample_choice(
self,
prompt: str,
responses: Sequence[str],
*,
seed: int | None = None,
) -> tuple[int, str, dict[str, float]]:
"""See base class."""
del prompt, seed
return 0, responses[0], {}

0 comments on commit 6520408

Please sign in to comment.