From 36f2880452d2242df26278ad2127b43d29ee680f Mon Sep 17 00:00:00 2001 From: thucpn Date: Mon, 8 Apr 2024 15:40:08 +0700 Subject: [PATCH] feat: ask to use embedding model --- helpers/env-variables.ts | 10 ++-- questions.ts | 48 ++++++++++++------- .../src/controllers/engine/settings.ts | 6 ++- .../nextjs/app/api/chat/engine/settings.ts | 6 ++- 4 files changed, 46 insertions(+), 24 deletions(-) diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index f7aab3f5..b6b75260 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -146,6 +146,11 @@ export const createBackendEnvFile = async ( description: `The Llama Cloud API key.`, value: opts.llamaCloudKey, }, + { + name: "EMBEDDING_MODEL", + description: "Name of the embedding model to use.", + value: opts.embeddingModel, + }, // Add vector database environment variables ...(opts.vectorDb ? getVectorDBEnvs(opts.vectorDb) : []), ]; @@ -164,11 +169,6 @@ export const createBackendEnvFile = async ( description: "The port to start the backend app.", value: opts.port?.toString() || "8000", }, - { - name: "EMBEDDING_MODEL", - description: "Name of the embedding model to use.", - value: opts.embeddingModel, - }, { name: "EMBEDDING_DIM", description: "Dimension of the embedding model to use.", diff --git a/questions.ts b/questions.ts index 6532882b..5ecdff73 100644 --- a/questions.ts +++ b/questions.ts @@ -581,26 +581,40 @@ export const askQuestions = async ( } } - if (!program.embeddingModel && program.framework === "fastapi") { + if (!program.embeddingModel) { if (ciInfo.isCI) { program.embeddingModel = getPrefOrDefault("embeddingModel"); } else { - const { embeddingModel } = await prompts( - { - type: "select", - name: "embeddingModel", - message: "Which embedding model would you like to use?", - choices: await getAvailableModelChoices( - true, - program.openAiKey, - program.listServerModels, - ), - initial: 0, - }, - handlers, - ); - program.embeddingModel = embeddingModel; - preferences.embeddingModel = embeddingModel; + const { useEmbeddingModel } = await prompts({ + onState: onPromptState, + type: "toggle", + name: "useEmbeddingModel", + message: "Would you like to use an embedding model?", + initial: false, + active: "Yes", + inactive: "No", + }); + + let selectedEmbeddingModel = getPrefOrDefault("embeddingModel"); + if (useEmbeddingModel) { + const { embeddingModel } = await prompts( + { + type: "select", + name: "embeddingModel", + message: "Which embedding model would you like to use?", + choices: await getAvailableModelChoices( + true, + program.openAiKey, + program.listServerModels, + ), + initial: 0, + }, + handlers, + ); + selectedEmbeddingModel = embeddingModel; + } + program.embeddingModel = selectedEmbeddingModel; + preferences.embeddingModel = selectedEmbeddingModel; } } diff --git a/templates/types/streaming/express/src/controllers/engine/settings.ts b/templates/types/streaming/express/src/controllers/engine/settings.ts index 25c077a5..1efb1a24 100644 --- a/templates/types/streaming/express/src/controllers/engine/settings.ts +++ b/templates/types/streaming/express/src/controllers/engine/settings.ts @@ -1,7 +1,8 @@ -import { OpenAI, Settings } from "llamaindex"; +import { OpenAI, OpenAIEmbedding, Settings } from "llamaindex"; const CHUNK_SIZE = 512; const CHUNK_OVERLAP = 20; +const EMBEDDING_MODEL = process.env.EMBEDDING_MODEL; export const initSettings = async () => { Settings.llm = new OpenAI({ @@ -10,4 +11,7 @@ export const initSettings = async () => { }); Settings.chunkSize = CHUNK_SIZE; Settings.chunkOverlap = CHUNK_OVERLAP; + Settings.embedModel = new OpenAIEmbedding({ + model: EMBEDDING_MODEL, + }); }; diff --git a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts index 25c077a5..1efb1a24 100644 --- a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts +++ b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts @@ -1,7 +1,8 @@ -import { OpenAI, Settings } from "llamaindex"; +import { OpenAI, OpenAIEmbedding, Settings } from "llamaindex"; const CHUNK_SIZE = 512; const CHUNK_OVERLAP = 20; +const EMBEDDING_MODEL = process.env.EMBEDDING_MODEL; export const initSettings = async () => { Settings.llm = new OpenAI({ @@ -10,4 +11,7 @@ export const initSettings = async () => { }); Settings.chunkSize = CHUNK_SIZE; Settings.chunkOverlap = CHUNK_OVERLAP; + Settings.embedModel = new OpenAIEmbedding({ + model: EMBEDDING_MODEL, + }); };