Skip to content

Commit

Permalink
Add support for more language models.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 660766991
Change-Id: I5bda13a6de2cbf0247f074e4365b26b5f0d75c94
  • Loading branch information
jzleibo authored and copybara-github committed Aug 8, 2024
1 parent 9fd4fea commit c4f6d00
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 88 deletions.
16 changes: 8 additions & 8 deletions concordia/language_model/amazon_bedrock_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,33 +57,33 @@ class AmazonBedrockLanguageModel(language_model.LanguageModel):

def __init__(
self,
model_id: str,
model_name: str,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
):
"""Initializes the instance.
Args:
model_id: The language model to use. For more details, see
model_name: The language model to use. For more details, see
https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
measurements: The measurements object to log usage statistics to.
channel: The channel to write the statistics to.
"""
self._model_id = model_id
self._model_name = model_name
self._measurements = measurements
self._channel = channel
self._max_tokens_limit = self._get_max_tokens_limit(model_id)
self._max_tokens_limit = self._get_max_tokens_limit(model_name)

# AWS credentials are passed via environment variables, see:
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
self._client = boto3.client('bedrock-runtime')

def _get_max_tokens_limit(self, model_id: str) -> int:
def _get_max_tokens_limit(self, model_name: str) -> int:
"""Get the max tokens limit for the given model ID."""
for pattern, value in MODEL_MAX_OUTPUT_TOKENS_LIMITS.items():
if model_id.startswith(pattern):
if model_name.startswith(pattern):
return value
raise ValueError(f'Unknown model ID: {model_id}')
raise ValueError(f'Unknown model ID: {model_name}')

@override
def sample_text(
Expand Down Expand Up @@ -147,7 +147,7 @@ def sample_text(
del inference_config['stopSequences']

response = self._client.converse(
modelId=self._model_id,
modelId=self._model_name,
system=system,
messages=messages,
inferenceConfig=inference_config,
Expand Down
140 changes: 140 additions & 0 deletions concordia/language_model/langchain_ollama_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Langchain-based language model using ollama to run on the local machine."""

from collections.abc import Collection, Sequence

from concordia.language_model import language_model
from concordia.utils import measurements as measurements_lib
from concordia.utils import sampling
from langchain import llms

from typing_extensions import override

_MAX_MULTIPLE_CHOICE_ATTEMPTS = 20
_DEFAULT_TEMPERATURE = 0.5
_DEFAULT_TERMINATORS = ()
_DEFAULT_SYSTEM_MESSAGE = (
'Continue the user\'s sentences. Never repeat their starts. For example, '
'when you see \'Bob is\', you should continue the sentence after '
'the word \'is\'. Here are some more examples: \'Question: Is Jake a '
'turtle?\nAnswer: Jake is \' should be completed as \'not a turtle.\' and '
'\'Question: What is Priya doing right now?\nAnswer: Priya is currently \' '
'should be completed as \'working on repairing the sink.\'. Notice that '
'it is OK to be creative with how you finish the user\'s sentences. The '
'most important thing is to always continue in the same style as the user.'
)


class OllamaLanguageModel(language_model.LanguageModel):
"""Language Model that uses Ollama LLM models."""

def __init__(
self,
model_name: str,
*,
system_message: str = _DEFAULT_SYSTEM_MESSAGE,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
) -> None:
"""Initializes the instance.
Args:
model_name: The language model to use. For more details, see
https://github.com/ollama/ollama.
system_message: System message to prefix to requests when prompting the
model.
measurements: The measurements object to log usage statistics to.
channel: The channel to write the statistics to.
"""
self._model_name = model_name
self._system_message = system_message
self._terminators = []
if 'llama3' in self._model_name:
self._terminators.extend(['<|eot_id|>'])
self._client = llms.Ollama(model=model_name, stop=self._terminators)

self._measurements = measurements
self._channel = channel

@override
def sample_text(
self,
prompt: str,
*,
max_tokens: int = language_model.DEFAULT_MAX_TOKENS,
terminators: Collection[str] = _DEFAULT_TERMINATORS,
temperature: float = _DEFAULT_TEMPERATURE,
timeout: float = -1,
seed: int | None = None,
) -> str:
del max_tokens, timeout, seed # Unused.

prompt_with_system_message = f'{self._system_message}\n\n{prompt}'

terminators = (self._terminators.extend(terminators)
if terminators is not None else self._terminators)

response = self._client(
prompt_with_system_message,
stop=terminators,
temperature=temperature,
)

if self._measurements is not None:
self._measurements.publish_datum(
self._channel,
{'raw_text_length': len(response)})

return response

@override
def sample_choice(
self,
prompt: str,
responses: Sequence[str],
*,
seed: int | None = None,
) -> tuple[int, str, dict[str, float]]:
prompt_with_system_message = f'{self._system_message}\n\n{prompt}'
sample = ''
answer = ''
for attempts in range(_MAX_MULTIPLE_CHOICE_ATTEMPTS):
# Increase temperature after the first failed attempt.
temperature = sampling.dynamically_adjust_temperature(
attempts, _MAX_MULTIPLE_CHOICE_ATTEMPTS)

sample = self.sample_text(
prompt_with_system_message,
temperature=temperature,
seed=seed,
)
answer = sampling.extract_choice_response(sample)
try:
idx = responses.index(answer)
except ValueError:
continue
else:
if self._measurements is not None:
self._measurements.publish_datum(
self._channel, {'choices_calls': attempts}
)
debug = {}
return idx, responses[idx], debug

raise language_model.InvalidResponseError(
(f'Too many multiple choice attempts.\nLast attempt: {sample}, ' +
f'extracted: {answer}')
)
66 changes: 47 additions & 19 deletions concordia/language_model/ollama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,27 @@
"""Ollama Language Model, a wrapper for models running on the local machine."""

from collections.abc import Collection, Sequence
import json

from concordia.language_model import language_model
from concordia.utils import measurements as measurements_lib
from concordia.utils import sampling
from langchain import llms

import ollama
from typing_extensions import override


_MAX_MULTIPLE_CHOICE_ATTEMPTS = 20
_DEFAULT_TEMPERATURE = 0.5
_DEFAULT_TERMINATORS = ()
_DEFAULT_SYSTEM_MESSAGE = (
'Continue the user\'s sentences. Never repeat their starts. For example, '
'when you see \'Bob is\', you should continue the sentence after '
'the word \'is\'.'
'the word \'is\'. Here are some more examples: \'Question: Is Jake a '
'turtle?\nAnswer: Jake is \' should be completed as \'not a turtle.\' and '
'\'Question: What is Priya doing right now?\nAnswer: Priya is currently \' '
'should be completed as \'working on repairing the sink.\'. Notice that '
'it is OK to be creative with how you finish the user\'s sentences. The '
'most important thing is to always continue in the same style as the user.'
)


Expand All @@ -55,11 +61,9 @@ def __init__(
channel: The channel to write the statistics to.
"""
self._model_name = model_name
self._client = ollama.Client()
self._system_message = system_message
self._terminators = []
if 'llama3' in self._model_name:
self._terminators.extend(['<|eot_id|>'])
self._client = llms.Ollama(model=model_name, stop=self._terminators)

self._measurements = measurements
self._channel = channel
Expand All @@ -75,25 +79,26 @@ def sample_text(
timeout: float = -1,
seed: int | None = None,
) -> str:
del max_tokens, timeout, seed # Unused.
del max_tokens, timeout, seed, temperature # Unused.

prompt_with_system_message = f'{self._system_message}\n\n{prompt}'

terminators = (self._terminators.extend(terminators)
if terminators is not None else self._terminators)
terminators = self._terminators + list(terminators)

response = self._client(
prompt_with_system_message,
stop=terminators,
temperature=temperature,
response = self._client.generate(
model=self._model_name,
prompt=prompt_with_system_message,
options={'stop': terminators},
keep_alive='10m',
)
result = response['response']

if self._measurements is not None:
self._measurements.publish_datum(
self._channel,
{'raw_text_length': len(response)})
{'raw_text_length': len(result)})

return response
return result

@override
def sample_choice(
Expand All @@ -103,19 +108,42 @@ def sample_choice(
*,
seed: int | None = None,
) -> tuple[int, str, dict[str, float]]:
del seed # Unused.
prompt_with_system_message = f'{self._system_message}\n\n{prompt}'
template = {'choice': '', 'single sentence explanation': ''}
sample = ''
answer = ''
for attempts in range(_MAX_MULTIPLE_CHOICE_ATTEMPTS):
# Increase temperature after the first failed attempt.
temperature = sampling.dynamically_adjust_temperature(
attempts, _MAX_MULTIPLE_CHOICE_ATTEMPTS)

sample = self.sample_text(
prompt_with_system_message,
temperature=temperature,
seed=seed,
response = self._client.generate(
model=self._model_name,
prompt=(f'{prompt_with_system_message}.\n'
f'Use the following json template: {json.dumps(template)}.'),
options={'stop': (), 'temperature': temperature},
format='json',
keep_alive='10m',
)
json_data = response
try:
json_data_response = json.loads(json_data['response'])
except json.JSONDecodeError:
continue
sample_or_none = json_data_response.get('choice', None)
if sample_or_none is None:
if isinstance(json_data_response, dict) and json_data_response:
sample = next(iter(json_data_response.values()))
elif isinstance(json_data_response, str) and json_data_response:
sample = sample_or_none.strip()
else:
continue
else:
sample = sample_or_none
if isinstance(sample, str) and sample:
sample = sample.strip()

answer = sampling.extract_choice_response(sample)
try:
idx = responses.index(answer)
Expand Down
Loading

0 comments on commit c4f6d00

Please sign in to comment.