Skip to content

Commit

Permalink
fixing LLM imports
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 661176481
Change-Id: I3fcce9ae534ce3f7260511b2e5491c71a6897324
  • Loading branch information
vezhnick authored and copybara-github committed Aug 9, 2024
1 parent 4562135 commit edfd15f
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions examples/tutorials/agent_development.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,14 @@
"import pathlib\n",
"import sys\n",
"\n",
"from concordia import language_model\n",
"from concordia.language_model import gpt_model\n",
"from concordia.language_model import mistral_model\n",
"from concordia.language_model import no_language_model\n",
"from concordia.language_model import amazon_bedrock_model\n",
"from concordia.language_model import google_aistudio_model\n",
"from concordia.language_model import langchain_ollama_model\n",
"from concordia.language_model import ollama_model\n",
"from concordia.language_model import pytorch_gemma_model\n",
"from concordia.utils import measurements as measurements_lib\n",
"import openai\n",
"import sentence_transformers"
Expand Down Expand Up @@ -163,35 +170,35 @@
" # simply replace the following with the correct initialization for the model\n",
" # you want to use.\n",
" if API_TYPE == 'amazon_bedrock':\n",
" model = language_model.amazon_bedrock_model.AmazonBedrockLanguageModel(\n",
" model = amazon_bedrock_model.AmazonBedrockLanguageModel(\n",
" model_name=MODEL_NAME)\n",
" elif API_TYPE == 'google_aistudio_model':\n",
" model = language_model.google_aistudio_model.GoogleAIStudioLanguageModel(\n",
" model = google_aistudio_model.GoogleAIStudioLanguageModel(\n",
" model_name=MODEL_NAME)\n",
" elif API_TYPE == 'langchain_ollama':\n",
" model = language_model.langchain_ollama_model.LangchainOllamaLanguageModel(\n",
" model = langchain_ollama_model.LangchainOllamaLanguageModel(\n",
" model_name=MODEL_NAME)\n",
" elif API_TYPE == 'mistral':\n",
" mistral_api_key = os.environ['MISTRAL_API_KEY']\n",
" if not mistral_api_key:\n",
" raise ValueError('Mistral api_key is required.')\n",
" model = language_model.mistral_model.MistralLanguageModel(api_key=mistral_api_key,\n",
" model = mistral_model.MistralLanguageModel(api_key=mistral_api_key,\n",
" model_name=MODEL_NAME)\n",
" elif API_TYPE == 'ollama':\n",
" model = language_model.ollama_model.OllamaLanguageModel(model_name=MODEL_NAME)\n",
" model = ollama_model.OllamaLanguageModel(model_name=MODEL_NAME)\n",
" elif API_TYPE == 'openai':\n",
" openai.api_key = os.environ['OPENAI_API_KEY']\n",
" if not openai.api_key:\n",
" raise ValueError('OpenAI api_key is required.')\n",
" model = language_model.gpt_model.GptLanguageModel(api_key=openai.api_key,\n",
" model = gpt_model.GptLanguageModel(api_key=openai.api_key,\n",
" model_name=MODEL_NAME)\n",
" elif API_TYPE == 'pytorch_gemma':\n",
" model = language_model.pytorch_gemma_model.PyTorchGemmaLanguageModel(\n",
" model = pytorch_gemma_model.PyTorchGemmaLanguageModel(\n",
" model_name=MODEL_NAME)\n",
" else:\n",
" raise ValueError(f'Unrecognized api type: {API_TYPE}')\n",
"else:\n",
" model = language_model.no_language_model.NoLanguageModel()"
" model = no_language_model.NoLanguageModel()"
]
},
{
Expand Down

0 comments on commit edfd15f

Please sign in to comment.