Skip to content

Commit

Permalink
Add measurements for calls to the language model. These should help w…
Browse files Browse the repository at this point in the history
…ith estimating the cost of running an experiment.

PiperOrigin-RevId: 588115287
Change-Id: If214b0f6d820f8204fcedb2fa0a9febf522d3655
  • Loading branch information
duenez authored and copybara-github committed Dec 5, 2023
1 parent 0aa31ee commit 1577c65
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
17 changes: 17 additions & 0 deletions concordia/language_model/gcloud_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections.abc import Collection, Sequence

from concordia.language_model import language_model
from concordia.utils import measurements as measurements_lib
from concordia.utils import text
from google import auth
from typing_extensions import override
Expand All @@ -34,6 +35,8 @@ def __init__(
model_name: str = 'text-bison@001',
location: str = 'us-central1',
credentials: auth.credentials.Credentials | None = None,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
) -> None:
"""Initializes a model instance using the Google Cloud language model API.
Expand All @@ -43,13 +46,17 @@ def __init__(
location: The location to use when making API calls.
credentials: Custom credentials to use when making API calls. If not
provided credentials will be ascertained from the environment.
measurements: The measurements object to log usage statistics to.
channel: The channel to write the statistics to.
"""
if credentials is None:
credentials, _ = auth.default()
vertexai.init(
project=project_id, location=location, credentials=credentials
)
self._model = vertex_models.TextGenerationModel.from_pretrained(model_name)
self._measurements = measurements
self._channel = channel

@override
def sample_text(
Expand All @@ -72,6 +79,10 @@ def sample_text(
temperature=temperature,
max_output_tokens=max_tokens,
)
if self._measurements is not None:
self._measurements.publish_datum(
self._channel,
{'raw_text_length': len(sample)})
return text.truncate(
sample.text, max_length=max_characters, delimiters=terminators
)
Expand All @@ -86,6 +97,7 @@ def sample_choice(
) -> tuple[int, str, dict[str, float]]:
max_characters = max([len(response) for response in responses])

attempts = 1
for _ in range(MAX_MULTIPLE_CHOICE_ATTEMPTS):
sample = self.sample_text(
prompt,
Expand All @@ -97,8 +109,13 @@ def sample_choice(
try:
idx = responses.index(sample)
except ValueError:
attempts += 1
continue
else:
if self._measurements is not None:
self._measurements.publish_datum(
self._channel,
{'choices_calls': attempts})
debug = {}
return idx, responses[idx], debug

Expand Down
2 changes: 2 additions & 0 deletions concordia/language_model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
DEFAULT_MAX_CHARACTERS = sys.maxsize
DEFAULT_MAX_TOKENS = 50

DEFAULT_STATS_CHANNEL = 'language_model_stats'


class InvalidResponseError(Exception):
"""Exception to throw when exceeding max attempts to get a choice."""
Expand Down

0 comments on commit 1577c65

Please sign in to comment.