diff --git a/concordia/language_model/gcloud_model.py b/concordia/language_model/gcloud_model.py index d7d2db2b..9c118ffe 100644 --- a/concordia/language_model/gcloud_model.py +++ b/concordia/language_model/gcloud_model.py @@ -16,6 +16,7 @@ from collections.abc import Collection, Sequence from concordia.language_model import language_model +from concordia.utils import measurements as measurements_lib from concordia.utils import text from google import auth from typing_extensions import override @@ -34,6 +35,8 @@ def __init__( model_name: str = 'text-bison@001', location: str = 'us-central1', credentials: auth.credentials.Credentials | None = None, + measurements: measurements_lib.Measurements | None = None, + channel: str = language_model.DEFAULT_STATS_CHANNEL, ) -> None: """Initializes a model instance using the Google Cloud language model API. @@ -43,6 +46,8 @@ def __init__( location: The location to use when making API calls. credentials: Custom credentials to use when making API calls. If not provided credentials will be ascertained from the environment. + measurements: The measurements object to log usage statistics to. + channel: The channel to write the statistics to. """ if credentials is None: credentials, _ = auth.default() @@ -50,6 +55,8 @@ def __init__( project=project_id, location=location, credentials=credentials ) self._model = vertex_models.TextGenerationModel.from_pretrained(model_name) + self._measurements = measurements + self._channel = channel @override def sample_text( @@ -72,6 +79,10 @@ def sample_text( temperature=temperature, max_output_tokens=max_tokens, ) + if self._measurements is not None: + self._measurements.publish_datum( + self._channel, + {'raw_text_length': len(sample)}) return text.truncate( sample.text, max_length=max_characters, delimiters=terminators ) @@ -86,6 +97,7 @@ def sample_choice( ) -> tuple[int, str, dict[str, float]]: max_characters = max([len(response) for response in responses]) + attempts = 1 for _ in range(MAX_MULTIPLE_CHOICE_ATTEMPTS): sample = self.sample_text( prompt, @@ -97,8 +109,13 @@ def sample_choice( try: idx = responses.index(sample) except ValueError: + attempts += 1 continue else: + if self._measurements is not None: + self._measurements.publish_datum( + self._channel, + {'choices_calls': attempts}) debug = {} return idx, responses[idx], debug diff --git a/concordia/language_model/language_model.py b/concordia/language_model/language_model.py index d2509ec0..9f86519c 100644 --- a/concordia/language_model/language_model.py +++ b/concordia/language_model/language_model.py @@ -26,6 +26,8 @@ DEFAULT_MAX_CHARACTERS = sys.maxsize DEFAULT_MAX_TOKENS = 50 +DEFAULT_STATS_CHANNEL = 'language_model_stats' + class InvalidResponseError(Exception): """Exception to throw when exceeding max attempts to get a choice."""