Skip to content

Commit

Permalink
Add an AutoTokenizer (stanford-crfm#1994)
Browse files Browse the repository at this point in the history
  • Loading branch information
JosselinSomervilleRoberts authored Nov 20, 2023
1 parent f264d65 commit 32342d9
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 128 deletions.
17 changes: 14 additions & 3 deletions scripts/compute_request_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Optional, Dict
from helm.proxy.clients.auto_client import AutoClient
from helm.benchmark.model_deployment_registry import ModelDeployment, get_model_deployment
from helm.proxy.tokenizers.auto_tokenizer import AutoTokenizer
from helm.common.request import Request
from helm.common.tokenization_request import TokenizationRequest

Expand Down Expand Up @@ -81,6 +82,7 @@ class RequestLimits:

def figure_out_max_prompt_length(
client: AutoClient,
auto_tokenizer: AutoTokenizer,
model_deployment_name: str,
model_name: str,
tokenizer_name: str,
Expand All @@ -89,7 +91,7 @@ def figure_out_max_prompt_length(
prefix: str = "",
suffix: str = "",
) -> RequestLimits:
tokenizer = client._get_tokenizer(tokenizer_name)
tokenizer = auto_tokenizer._get_tokenizer(tokenizer_name)
num_tokens_prefix = get_number_of_tokens(prefix, tokenizer, tokenizer_name)
num_tokens_suffix = get_number_of_tokens(suffix, tokenizer, tokenizer_name)

Expand Down Expand Up @@ -177,14 +179,15 @@ def figure_out_max_prompt_length_plus_tokens(

def check_limits(
client: AutoClient,
auto_tokenizer: AutoTokenizer,
model_deployment_name: str,
model_name: str,
tokenizer_name: str,
limits: RequestLimits,
prefix: str = "",
suffix: str = "",
) -> bool:
tokenizer = client._get_tokenizer(tokenizer_name)
tokenizer = auto_tokenizer._get_tokenizer(tokenizer_name)
result: bool = True

# Check the max_prompt_length
Expand Down Expand Up @@ -342,6 +345,7 @@ def main():
print(f"cache_path: {cache_path}")

client = AutoClient(credentials=credentials, cache_path=cache_path)
auto_tokenizer = AutoTokenizer(credentials=credentials, cache_path=cache_path)
print("client successfully created")

print("Making short request...")
Expand All @@ -368,7 +372,13 @@ def main():

print("========== Figure out max_prompt_length ==========")
limits: RequestLimits = figure_out_max_prompt_length(
client, args.model_deployment_name, args.model_name, args.tokenizer_name, prefix=args.prefix, suffix=args.suffix
client,
auto_tokenizer,
args.model_deployment_name,
args.model_name,
args.tokenizer_name,
prefix=args.prefix,
suffix=args.suffix,
)
print(f"max_prompt_length: {limits.max_prompt_length}")
print("===================================================")
Expand All @@ -393,6 +403,7 @@ def main():
print("========== Check the limits ==========")
result: bool = check_limits(
client,
auto_tokenizer,
args.model_deployment_name,
args.model_name,
args.tokenizer_name,
Expand Down
53 changes: 28 additions & 25 deletions scripts/efficiency/generate_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
DecodeRequestResult,
TokenizationToken,
)
from helm.proxy.clients.client import Client
from helm.proxy.clients.auto_client import AutoClient
from helm.proxy.tokenizers.tokenizer import Tokenizer
from helm.proxy.tokenizers.auto_tokenizer import AutoTokenizer
from helm.proxy.services.service import (
CACHE_DIR,
)
Expand All @@ -40,25 +40,28 @@
}


def _count_prompt_tokens(client: Client, prompt: str, tokenizer: str):
request: TokenizationRequest = TokenizationRequest(text=prompt, tokenizer=tokenizer)
result: TokenizationRequestResult = client.tokenize(request)
def _count_prompt_tokens(tokenizer: Tokenizer, prompt: str, tokenizer_name: str):
request: TokenizationRequest = TokenizationRequest(text=prompt, tokenizer=tokenizer_name)
result: TokenizationRequestResult = tokenizer.tokenize(request)
return len(result.tokens)


def get_client(base_path: str = "prod_env"):
def get_tokenizer(base_path: str = "prod_env") -> AutoTokenizer:
credentials = get_credentials(base_path)
cache_path = os.path.join(base_path, CACHE_DIR)
ensure_directory_exists(cache_path)

# TODO: Pass mongo_uri to AutoClient
client = AutoClient(credentials, cache_path)
tokenizer = AutoTokenizer(credentials, cache_path)

return client
return tokenizer


def tokenize_text(
client: AutoClient, tokenizer: str, output_path: str = "synthetic_efficiency_instances", base_path: str = "prod_env"
tokenizer: AutoTokenizer,
tokenizer_name: str,
output_path: str = "synthetic_efficiency_instances",
base_path: str = "prod_env",
) -> Tuple[Dict[str, List[TokenizationToken]], Dict[str, List[str]]]:
"""Tokenizes each book using the requested tokenizer service."""
sources = {
Expand All @@ -72,7 +75,7 @@ def tokenize_text(
tokens: Dict[str, List[TokenizationToken]] = {}
text_chunks: Dict[str, List[str]] = {}

tokenizer_organization: str = tokenizer.split("/")[0]
tokenizer_organization: str = tokenizer_name.split("/")[0]
ai21_tokenizer: bool = tokenizer_organization == "ai21"

# Extract tokens from book sources
Expand All @@ -96,9 +99,9 @@ def tokenize_text(
batch = " ".join(text[i * batch_size : (i + 1) * batch_size])
while True:
request: TokenizationRequest = TokenizationRequest(
text=batch, tokenizer=tokenizer, encode=(not ai21_tokenizer)
text=batch, tokenizer=tokenizer_name, encode=(not ai21_tokenizer)
)
result: TokenizationRequestResult = client.tokenize(request)
result: TokenizationRequestResult = tokenizer.tokenize(request)
tokens_ = frozenset([token.value for token in result.tokens])
if tokens_ not in seen_tokens:
seen_tokens.add(tokens_)
Expand All @@ -116,15 +119,15 @@ def tokenize_text(
def generate_synthetic_efficiency_instances(
tokens: Dict[str, List[TokenizationToken]],
text_chunks: Dict[str, List[str]],
client: Client,
tokenizer: Tokenizer,
num_instances: int,
num_prompt_tokens: int,
tokenizer: str,
tokenizer_name: str,
output_path: str = "synthetic_efficiency_instances",
base_path: str = "prod_env",
):
"""Generates the synthetic efficiency instances given the tokenized book sources."""
tokenizer_organization: str = tokenizer.split("/")[0]
tokenizer_organization: str = tokenizer_name.split("/")[0]
ai21_tokenizer: bool = tokenizer_organization == "ai21"

books = list(tokens.keys())
Expand Down Expand Up @@ -155,13 +158,13 @@ def generate_synthetic_efficiency_instances(
prompt = "".join(per_instance_tokens)
else:
decode_request: DecodeRequest = DecodeRequest(tokens=per_instance_tokens) # type: ignore
decode_result: DecodeRequestResult = client.decode(decode_request)
decode_result: DecodeRequestResult = tokenizer.decode(decode_request)
prompt = decode_result.text

if prompt == "":
num_generated_tokens = 0
else:
num_generated_tokens = _count_prompt_tokens(client, prompt, tokenizer)
num_generated_tokens = _count_prompt_tokens(tokenizer, prompt, tokenizer_name)
if num_generated_tokens != num_prompt_tokens:
temp_num_tokens = num_generated_tokens
while temp_num_tokens < num_prompt_tokens:
Expand Down Expand Up @@ -190,7 +193,7 @@ def generate_synthetic_efficiency_instances(
if not finished:
print(
f"Requested {num_prompt_tokens}, got {num_generated_tokens} for "
f"book {books[j]}, instance #{orig_i}, tokenizer={tokenizer}, "
f"book {books[j]}, instance #{orig_i}, tokenizer={tokenizer_name}, "
"trying again with a new span of text..."
)
attempt_num += 1
Expand All @@ -199,15 +202,15 @@ def generate_synthetic_efficiency_instances(

for i, prompt in enumerate(prompts):
for k, v in TOKENIZER_REPLACEMENTS.items():
tokenizer = tokenizer.replace(k, v)
name = f"num_prompt_tokens={num_prompt_tokens}," f"tokenizer={tokenizer.replace('/', '_')}," f"id={i}.txt"
tokenizer_name = tokenizer_name.replace(k, v)
name = f"num_prompt_tokens={num_prompt_tokens}," f"tokenizer={tokenizer_name.replace('/', '_')}," f"id={i}.txt"
write(os.path.join(output_path, name), prompt)


if __name__ == "__main__":
client = get_client()
tokenizer = get_tokenizer()

for tokenizer in [
for tokenizer_name in [
"huggingface/gpt2",
"ai21/j1",
"cohere/cohere",
Expand All @@ -221,13 +224,13 @@ def generate_synthetic_efficiency_instances(
"EleutherAI/gpt-neox-20b",
"EleutherAI/gpt-j-6B",
]:
tokens, text_chunks = tokenize_text(tokenizer=tokenizer, client=client)
tokens, text_chunks = tokenize_text(tokenizer=tokenizer, tokenizer_name=tokenizer_name)
for num_prompt_tokens in NUM_INPUT_TOKENS:
generate_synthetic_efficiency_instances(
tokens=tokens,
text_chunks=text_chunks,
client=client,
tokenizer=tokenizer,
num_instances=30,
num_prompt_tokens=num_prompt_tokens,
tokenizer=tokenizer,
tokenizer_name=tokenizer_name,
)
4 changes: 3 additions & 1 deletion src/helm/benchmark/test_model_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from helm.benchmark.window_services.window_service_factory import WindowServiceFactory
from helm.proxy.clients.auto_client import AutoClient
from helm.proxy.tokenizers.auto_tokenizer import AutoTokenizer
from collections import defaultdict


Expand Down Expand Up @@ -1393,6 +1394,7 @@ class TestModelProperties:
@pytest.mark.parametrize("model", ALL_MODEL_DEPLOYMENTS)
def test_models_has_window_service(self, model: ModelMetadata):
auto_client = AutoClient(defaultdict(str), "", "")
auto_tokenizer = AutoTokenizer(defaultdict(str), "", "")
model_deployments = {
model_deployment.name: model_deployment for model_deployment in _BUILT_IN_MODEL_DEPLOYMENTS
}
Expand All @@ -1413,7 +1415,7 @@ def test_models_has_window_service(self, model: ModelMetadata):
client = auto_client._get_client(deployment_name)
window_service = WindowServiceFactory.get_window_service(deployment_name, tokenizer_service)
tokenizer_name = window_service.tokenizer_name
tokenizer = auto_client._get_tokenizer(tokenizer_name)
tokenizer = auto_tokenizer._get_tokenizer(tokenizer_name)

client_class_name = _full_class_name(client)
tokenizer_class_name = _full_class_name(tokenizer)
Expand Down
14 changes: 14 additions & 0 deletions src/helm/common/cache_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Functions used for caching."""

import os

from helm.common.cache import CacheConfig, MongoCacheConfig, SqliteCacheConfig


def build_cache_config(cache_path: str, mongo_uri: str, organization: str) -> CacheConfig:
if mongo_uri:
return MongoCacheConfig(mongo_uri, collection_name=organization)

client_cache_path: str = os.path.join(cache_path, f"{organization}.sqlite")
# TODO: Allow setting CacheConfig.follower_cache_path from a command line flag.
return SqliteCacheConfig(client_cache_path)
28 changes: 28 additions & 0 deletions src/helm/common/credentials_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Functions used for credentials."""

from typing import Any, Mapping, Optional

from helm.common.hierarchical_logger import hlog


def provide_api_key(
credentials: Mapping[str, Any], host_organization: str, model: Optional[str] = None
) -> Optional[str]:
api_key_name = host_organization + "ApiKey"
if api_key_name in credentials:
hlog(f"Using host_organization api key defined in credentials.conf: {api_key_name}")
return credentials[api_key_name]
if "deployments" not in credentials:
hlog(
"WARNING: Could not find key 'deployments' in credentials.conf, "
f"therefore the API key {api_key_name} should be specified."
)
return None
deployment_api_keys = credentials["deployments"]
if model is None:
hlog(f"WARNING: Could not find key '{host_organization}' in credentials.conf and no model provided")
return None
if model not in deployment_api_keys:
hlog(f"WARNING: Could not find key '{model}' under key 'deployments' in credentials.conf")
return None
return deployment_api_keys[model]
Loading

0 comments on commit 32342d9

Please sign in to comment.