From 48e01c9f30f9a87f94ab77909096a5c288fbe19b Mon Sep 17 00:00:00 2001 From: JosselinSomervilleRoberts Date: Tue, 7 Nov 2023 13:54:10 -0800 Subject: [PATCH] Lazy instantiate Aleph Alpha Client to pass regression test --- src/helm/proxy/clients/aleph_alpha_client.py | 3 --- src/helm/proxy/tokenizers/aleph_alpha_tokenizer.py | 4 +++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/helm/proxy/clients/aleph_alpha_client.py b/src/helm/proxy/clients/aleph_alpha_client.py index a988938ae3..ae7116cef2 100644 --- a/src/helm/proxy/clients/aleph_alpha_client.py +++ b/src/helm/proxy/clients/aleph_alpha_client.py @@ -2,8 +2,6 @@ import requests from typing import Any, Dict, List -from aleph_alpha_client import Client as AlephAlphaPythonClient - from helm.common.cache import CacheConfig from helm.common.request import wrap_request_time, Request, RequestResult, Sequence, Token from helm.proxy.tokenizers.tokenizer import Tokenizer @@ -16,7 +14,6 @@ class AlephAlphaClient(CachingClient): def __init__(self, api_key: str, tokenizer: Tokenizer, cache_config: CacheConfig): super().__init__(cache_config=cache_config, tokenizer=tokenizer) self.api_key: str = api_key - self._aleph_alpha_client = AlephAlphaPythonClient(token=api_key) def _send_request(self, endpoint: str, raw_request: Dict[str, Any]) -> Dict[str, Any]: response = requests.request( diff --git a/src/helm/proxy/tokenizers/aleph_alpha_tokenizer.py b/src/helm/proxy/tokenizers/aleph_alpha_tokenizer.py index a43c63b841..313cc0a4be 100644 --- a/src/helm/proxy/tokenizers/aleph_alpha_tokenizer.py +++ b/src/helm/proxy/tokenizers/aleph_alpha_tokenizer.py @@ -31,7 +31,7 @@ class AlephAlphaTokenizer(CachingTokenizer): def __init__(self, api_key: str, cache_config: CacheConfig) -> None: super().__init__(cache_config) self.api_key: str = api_key - self._aleph_alpha_client = AlephAlphaPythonClient(token=api_key) + self._aleph_alpha_client = AlephAlphaPythonClient(token=api_key) if api_key else None self._tokenizer_name_to_tokenizer: Dict[str, InternalTokenizer] = {} def _get_tokenizer(self, tokenizer_name: str) -> InternalTokenizer: @@ -40,6 +40,8 @@ def _get_tokenizer(self, tokenizer_name: str) -> InternalTokenizer: # Check if the tokenizer is cached if tokenizer_name not in self._tokenizer_name_to_tokenizer: + if self._aleph_alpha_client is None: + raise ValueError("Aleph Alpha API key not set.") self._tokenizer_name_to_tokenizer[tokenizer_name] = self._aleph_alpha_client.tokenizer(tokenizer_name) hlog(f"Initialized tokenizer: {tokenizer_name}") return self._tokenizer_name_to_tokenizer[tokenizer_name]