diff --git a/llmfoundry/models/inference_api_wrapper/gemini_chat.py b/llmfoundry/models/inference_api_wrapper/gemini_chat.py index ce6b07ecaa..5fab95ac36 100644 --- a/llmfoundry/models/inference_api_wrapper/gemini_chat.py +++ b/llmfoundry/models/inference_api_wrapper/gemini_chat.py @@ -7,8 +7,8 @@ from time import sleep from typing import Any, List, Optional, Union -import google.generativeai as google_genai from composer.core.types import Batch +from composer.utils.import_helpers import MissingConditionalImportError from omegaconf import DictConfig from openai.types.chat.chat_completion import ChatCompletion from transformers import AutoTokenizer @@ -37,6 +37,14 @@ def __init__(self, om_model_config: DictConfig, api_key = om_model_config.pop('api_key', None) if api_key is None: api_key = os.environ.get('GEMINI_API_KEY') + try: + import google.generativeai as google_genai + except ImportError as e: + # TODO: should google-generativeai be grouped with openai in setup.py? + raise MissingConditionalImportError( + extra_deps_group='openai', + conda_package='google-generativeai', + conda_channel='conda-forge') from e google_genai.configure(api_key=api_key) super().__init__(om_model_config, tokenizer) self.model_cfg = om_model_config