diff --git a/create-app.ts b/create-app.ts index 3379ceec..8683ee9a 100644 --- a/create-app.ts +++ b/create-app.ts @@ -41,8 +41,9 @@ export async function createApp({ vectorDb, externalPort, postInstallAction, - dataSource, + dataSources, tools, + useLlamaParse, observability, }: InstallAppArgs): Promise { const root = path.resolve(appPath); @@ -89,8 +90,9 @@ export async function createApp({ vectorDb, externalPort, postInstallAction, - dataSource, + dataSources, tools, + useLlamaParse, observability, }; diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index 06d27197..660b17a2 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -107,7 +107,7 @@ export const createBackendEnvFile = async ( model?: string; embeddingModel?: string; framework?: TemplateFramework; - dataSource?: TemplateDataSource; + dataSources?: TemplateDataSource[]; port?: number; }, ) => { @@ -126,19 +126,13 @@ export const createBackendEnvFile = async ( description: "The OpenAI API key to use.", value: opts.openAiKey, }, - + { + name: "LLAMA_CLOUD_API_KEY", + description: `The Llama Cloud API key.`, + value: opts.llamaCloudKey, + }, // Add vector database environment variables ...(opts.vectorDb ? getVectorDBEnvs(opts.vectorDb) : []), - // Add LlamaCloud API key - ...(opts.llamaCloudKey - ? [ - { - name: "LLAMA_CLOUD_API_KEY", - description: `The Llama Cloud API key.`, - value: opts.llamaCloudKey, - }, - ] - : []), ]; let envVars: EnvVar[] = []; if (opts.framework === "fastapi") { diff --git a/helpers/index.ts b/helpers/index.ts index 6a217346..de28a18f 100644 --- a/helpers/index.ts +++ b/helpers/index.ts @@ -1,10 +1,9 @@ -import { copy } from "./copy"; import { callPackageManager } from "./install"; -import fs from "fs/promises"; import path from "path"; import { cyan } from "picocolors"; +import fsExtra from "fs-extra"; import { templatesDir } from "./dir"; import { createBackendEnvFile, createFrontendEnvFile } from "./env-variables"; import { PackageManager } from "./get-pkg-manager"; @@ -27,8 +26,8 @@ async function generateContextData( packageManager?: PackageManager, openAiKey?: string, vectorDb?: TemplateVectorDB, - dataSource?: TemplateDataSource, llamaCloudKey?: string, + useLlamaParse?: boolean, ) { if (packageManager) { const runGenerate = `${cyan( @@ -37,8 +36,7 @@ async function generateContextData( : `${packageManager} run generate`, )}`; const openAiKeyConfigured = openAiKey || process.env["OPENAI_API_KEY"]; - const llamaCloudKeyConfigured = (dataSource?.config as FileSourceConfig) - ?.useLlamaParse + const llamaCloudKeyConfigured = useLlamaParse ? llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"] : true; const hasVectorDb = vectorDb && vectorDb !== "none"; @@ -76,47 +74,16 @@ async function generateContextData( const copyContextData = async ( root: string, - dataSource?: TemplateDataSource, + dataSources: TemplateDataSource[], ) => { - const destPath = path.join(root, "data"); - - const dataSourceConfig = dataSource?.config as FileSourceConfig; - - // Copy file - if (dataSource?.type === "file") { - if (dataSourceConfig.paths) { - await fs.mkdir(destPath, { recursive: true }); - console.log( - "Copying data from files:", - dataSourceConfig.paths.toString(), - ); - for (const p of dataSourceConfig.paths) { - await fs.copyFile(p, path.join(destPath, path.basename(p))); - } - } else { - console.log("Missing file path in config"); - process.exit(1); - } - return; - } - - // Copy folder - if (dataSource?.type === "folder") { - // Example data does not have path config, set the default path - const srcPaths = dataSourceConfig.paths ?? [ - path.join(templatesDir, "components", "data"), - ]; - console.log("Copying data from folders: ", srcPaths); - for (const p of srcPaths) { - const folderName = path.basename(p); - const destFolderPath = path.join(destPath, folderName); - await fs.mkdir(destFolderPath, { recursive: true }); - await copy("**", destFolderPath, { - parents: true, - cwd: p, - }); - } - return; + for (const dataSource of dataSources) { + const dataSourceConfig = dataSource?.config as FileSourceConfig; + // Copy local data + const dataPath = + dataSourceConfig.path ?? path.join(templatesDir, "components", "data"); + const destPath = path.join(root, "data", path.basename(dataPath)); + console.log("Copying data from path:", dataPath); + await fsExtra.copy(dataPath, destPath); } }; @@ -166,12 +133,13 @@ export const installTemplate = async ( model: props.model, embeddingModel: props.embeddingModel, framework: props.framework, - dataSource: props.dataSource, + dataSources: props.dataSources, port: props.externalPort, }); if (props.engine === "context") { - await copyContextData(props.root, props.dataSource); + console.log("\nGenerating context data...\n"); + await copyContextData(props.root, props.dataSources); if ( props.postInstallAction === "runApp" || props.postInstallAction === "dependencies" @@ -181,14 +149,14 @@ export const installTemplate = async ( props.packageManager, props.openAiKey, props.vectorDb, - props.dataSource, props.llamaCloudKey, + props.useLlamaParse, ); } } } else { // this is a frontend for a full-stack app, create .env file with model information - createFrontendEnvFile(props.root, { + await createFrontendEnvFile(props.root, { model: props.model, customApiPath: props.customApiPath, }); diff --git a/helpers/python.ts b/helpers/python.ts index 9687f371..b4427c24 100644 --- a/helpers/python.ts +++ b/helpers/python.ts @@ -8,7 +8,6 @@ import { templatesDir } from "./dir"; import { isPoetryAvailable, tryPoetryInstall } from "./poetry"; import { Tool } from "./tools"; import { - FileSourceConfig, InstallTemplateArgs, TemplateDataSource, TemplateVectorDB, @@ -65,7 +64,7 @@ const getAdditionalDependencies = ( // Add data source dependencies const dataSourceType = dataSource?.type; - if (dataSourceType === "file" || dataSourceType === "folder") { + if (dataSourceType === "file") { // llama-index-readers-file (pdf, excel, csv) is already included in llama_index package dependencies.push({ name: "docx2txt", @@ -180,9 +179,10 @@ export const installPythonTemplate = async ({ framework, engine, vectorDb, - dataSource, + dataSources, tools, postInstallAction, + useLlamaParse, }: Pick< InstallTemplateArgs, | "root" @@ -190,8 +190,9 @@ export const installPythonTemplate = async ({ | "template" | "engine" | "vectorDb" - | "dataSource" + | "dataSources" | "tools" + | "useLlamaParse" | "postInstallAction" >) => { console.log("\nInitializing Python project with template:", template, "\n"); @@ -256,51 +257,52 @@ export const installPythonTemplate = async ({ }); } - // Write loader configs - if (dataSource?.type === "web") { - const config = dataSource.config as WebSourceConfig[]; - const webLoaderConfig = config.map((c) => { - return { - base_url: c.baseUrl, - prefix: c.prefix || c.baseUrl, - depth: c.depth || 1, - }; - }); - const loaderConfigPath = path.join(root, "config/loaders.json"); - await fs.mkdir(path.join(root, "config"), { recursive: true }); - await fs.writeFile( - loaderConfigPath, - JSON.stringify( - { - web: webLoaderConfig, - }, - null, - 2, - ), - ); - } + if (dataSources.length > 0) { + const loaderConfigs: Record = {}; + const loaderPath = path.join(enginePath, "loaders"); - const dataSourceType = dataSource?.type; - if (dataSourceType !== undefined && dataSourceType !== "none") { - let loaderFolder: string; - if (dataSourceType === "file" || dataSourceType === "folder") { - const dataSourceConfig = dataSource?.config as FileSourceConfig; - loaderFolder = dataSourceConfig.useLlamaParse ? "llama_parse" : "file"; - } else { - loaderFolder = dataSourceType; - } - await copy("**", enginePath, { + // Copy loaders to enginePath + await copy("**", loaderPath, { parents: true, - cwd: path.join(compPath, "loaders", "python", loaderFolder), + cwd: path.join(compPath, "loaders", "python"), }); + + // Generate loaders config + // Web loader config + if (dataSources.some((ds) => ds.type === "web")) { + const webLoaderConfig = dataSources + .filter((ds) => ds.type === "web") + .map((ds) => { + const dsConfig = ds.config as WebSourceConfig; + return { + base_url: dsConfig.baseUrl, + prefix: dsConfig.prefix, + depth: dsConfig.depth, + }; + }); + loaderConfigs["web"] = webLoaderConfig; + } + // File loader config + if (dataSources.some((ds) => ds.type === "file")) { + loaderConfigs["file"] = { + use_llama_parse: useLlamaParse, + }; + } + // Write loaders config + if (Object.keys(loaderConfigs).length > 0) { + const loaderConfigPath = path.join(root, "config/loaders.json"); + await fs.mkdir(path.join(root, "config"), { recursive: true }); + await fs.writeFile( + loaderConfigPath, + JSON.stringify(loaderConfigs, null, 2), + ); + } } } - const addOnDependencies = getAdditionalDependencies( - vectorDb, - dataSource, - tools, - ); + const addOnDependencies = dataSources + .map((ds) => getAdditionalDependencies(vectorDb, ds, tools)) + .flat(); await addDependencies(root, addOnDependencies); if (postInstallAction === "runApp" || postInstallAction === "dependencies") { diff --git a/helpers/types.ts b/helpers/types.ts index d093c099..0c5a30f0 100644 --- a/helpers/types.ts +++ b/helpers/types.ts @@ -15,12 +15,11 @@ export type TemplateDataSource = { type: TemplateDataSourceType; config: TemplateDataSourceConfig; }; -export type TemplateDataSourceType = "none" | "file" | "folder" | "web"; +export type TemplateDataSourceType = "file" | "web"; export type TemplateObservability = "none" | "opentelemetry"; // Config for both file and folder export type FileSourceConfig = { - paths?: string[]; - useLlamaParse?: boolean; + path?: string; }; export type WebSourceConfig = { baseUrl?: string; @@ -28,7 +27,7 @@ export type WebSourceConfig = { depth?: number; }; -export type TemplateDataSourceConfig = FileSourceConfig | WebSourceConfig[]; +export type TemplateDataSourceConfig = FileSourceConfig | WebSourceConfig; export type CommunityProjectConfig = { owner: string; @@ -46,11 +45,12 @@ export interface InstallTemplateArgs { framework: TemplateFramework; engine: TemplateEngine; ui: TemplateUI; - dataSource?: TemplateDataSource; + dataSources: TemplateDataSource[]; eslint: boolean; customApiPath?: string; openAiKey?: string; llamaCloudKey?: string; + useLlamaParse?: boolean; model: string; embeddingModel: string; communityProjectConfig?: CommunityProjectConfig; diff --git a/helpers/typescript.ts b/helpers/typescript.ts index 4fdbca98..4aefb531 100644 --- a/helpers/typescript.ts +++ b/helpers/typescript.ts @@ -6,7 +6,7 @@ import { copy } from "../helpers/copy"; import { callPackageManager } from "../helpers/install"; import { templatesDir } from "./dir"; import { PackageManager } from "./get-pkg-manager"; -import { FileSourceConfig, InstallTemplateArgs } from "./types"; +import { InstallTemplateArgs } from "./types"; const rename = (name: string) => { switch (name) { @@ -65,7 +65,8 @@ export const installTSTemplate = async ({ backend, observability, tools, - dataSource, + dataSources, + useLlamaParse, }: InstallTemplateArgs & { backend: boolean }) => { console.log(bold(`Using ${packageManager}.`)); @@ -173,15 +174,10 @@ export const installTSTemplate = async ({ }); // copy loader component - const dataSourceType = dataSource?.type; - if (dataSourceType && dataSourceType !== "none") { + const dataSourceType = dataSources[0]?.type; + if (dataSourceType) { let loaderFolder: string; - if (dataSourceType === "file" || dataSourceType === "folder") { - const dataSourceConfig = dataSource?.config as FileSourceConfig; - loaderFolder = dataSourceConfig.useLlamaParse ? "llama_parse" : "file"; - } else { - loaderFolder = dataSourceType; - } + loaderFolder = useLlamaParse ? "llama_parse" : dataSourceType; await copy("**", enginePath, { parents: true, cwd: path.join(compPath, "loaders", "typescript", loaderFolder), diff --git a/index.ts b/index.ts index 15f453ed..2bd8cf3c 100644 --- a/index.ts +++ b/index.ts @@ -302,8 +302,9 @@ async function run(): Promise { vectorDb: program.vectorDb, externalPort: program.externalPort, postInstallAction: program.postInstallAction, - dataSource: program.dataSource, + dataSources: program.dataSources, tools: program.tools, + useLlamaParse: program.useLlamaParse, observability: program.observability, }); conf.set("preferences", preferences); diff --git a/package.json b/package.json index b93df497..92bc5bb4 100644 --- a/package.json +++ b/package.json @@ -43,6 +43,7 @@ "@types/prompts": "2.0.1", "@types/tar": "6.1.5", "@types/validate-npm-package-name": "3.0.0", + "@types/fs-extra": "11.0.4", "@vercel/ncc": "0.38.1", "async-retry": "1.3.1", "async-sema": "3.0.1", @@ -68,7 +69,8 @@ "prettier-plugin-organize-imports": "^3.2.4", "typescript": "^5.3.3", "eslint-config-prettier": "^8.10.0", - "ora": "^8.0.1" + "ora": "^8.0.1", + "fs-extra": "11.2.0" }, "engines": { "node": ">=16.14.0" diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 82d4b7c3..4dbb3836 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -20,6 +20,9 @@ devDependencies: '@types/cross-spawn': specifier: 6.0.0 version: 6.0.0 + '@types/fs-extra': + specifier: 11.0.4 + version: 11.0.4 '@types/node': specifier: ^20.11.7 version: 20.11.26 @@ -62,6 +65,9 @@ devDependencies: fast-glob: specifier: 3.3.1 version: 3.3.1 + fs-extra: + specifier: 11.2.0 + version: 11.2.0 got: specifier: 10.7.0 version: 10.7.0 @@ -489,10 +495,23 @@ packages: '@types/node': 20.11.26 dev: true + /@types/fs-extra@11.0.4: + resolution: {integrity: sha512-yTbItCNreRooED33qjunPthRcSjERP1r4MqCZc7wv0u2sUkzTFp45tgUfS5+r7FrZPdmCCNflLhVSP/o+SemsQ==} + dependencies: + '@types/jsonfile': 6.1.4 + '@types/node': 20.11.26 + dev: true + /@types/http-cache-semantics@4.0.4: resolution: {integrity: sha512-1m0bIFVc7eJWyve9S0RnuRgcQqF/Xd5QsUZAZeQFr1Q3/p9JWoQQEqmVy+DPTNpGXwhgIetAoYF8JSc33q29QA==} dev: true + /@types/jsonfile@6.1.4: + resolution: {integrity: sha512-D5qGUYwjvnNNextdU59/+fI+spnwtTFmyQP0h+PfIOSkNfpU6AOICUOkm4i0OnSk+NyjdPJrxCDro0sJsWlRpQ==} + dependencies: + '@types/node': 20.11.26 + dev: true + /@types/keyv@3.1.4: resolution: {integrity: sha512-BQ5aZNSCpj7D6K2ksrRCTmKRLEpnPvWDiLPfoGyhZ++8YtiK9d/3DBKPJgry359X/P1PfruyYwvnvwFjuEiEIg==} dependencies: @@ -1437,6 +1456,15 @@ packages: signal-exit: 4.1.0 dev: true + /fs-extra@11.2.0: + resolution: {integrity: sha512-PmDi3uwK5nFuXh7XDTlVnS17xJS7vW36is2+w3xcv8SVxiB4NyATf4ctkVY5bkSjX0Y4nbvZCq1/EjtEyr9ktw==} + engines: {node: '>=14.14'} + dependencies: + graceful-fs: 4.2.11 + jsonfile: 6.1.0 + universalify: 2.0.1 + dev: true + /fs-extra@7.0.1: resolution: {integrity: sha512-YJDaCJZEnBmcbw13fvdAM9AwNOJwOzrE4pqMqBq5nFiEqXUqHwlK4B+3pUw6JNvfSPtX05xFHtYy/1ni01eGCw==} engines: {node: '>=6 <7 || >=8'} @@ -1982,6 +2010,14 @@ packages: graceful-fs: 4.2.11 dev: true + /jsonfile@6.1.0: + resolution: {integrity: sha512-5dgndWOriYSm5cnYaJNhalLNDKOqFwyDB/rr1E9ZsGciGvKPs8R2xYGCacuf3z6K1YKDz182fd+fY3cn3pMqXQ==} + dependencies: + universalify: 2.0.1 + optionalDependencies: + graceful-fs: 4.2.11 + dev: true + /keyv@4.5.4: resolution: {integrity: sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==} dependencies: @@ -3185,6 +3221,11 @@ packages: engines: {node: '>= 4.0.0'} dev: true + /universalify@2.0.1: + resolution: {integrity: sha512-gptHNQghINnc/vTGIk0SOFGFNXw7JVrlRUtConJRlvaw6DuX0wO5Jeko9sWrMBhh+PsYAZ7oXAiOnf/UKogyiw==} + engines: {node: '>= 10.0.0'} + dev: true + /update-check@1.5.4: resolution: {integrity: sha512-5YHsflzHP4t1G+8WGPlvKbJEbAJGCgw+Em+dGR1KmBUbr1J36SJBqlHLjR7oob7sco5hWHGQVcr9B2poIVDDTQ==} dependencies: diff --git a/questions.ts b/questions.ts index 0651f3d1..6fc06da0 100644 --- a/questions.ts +++ b/questions.ts @@ -8,10 +8,9 @@ import { blue, green, red } from "picocolors"; import prompts from "prompts"; import { InstallAppArgs } from "./create-app"; import { - FileSourceConfig, + TemplateDataSource, TemplateDataSourceType, TemplateFramework, - WebSourceConfig, } from "./helpers"; import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant"; import { templatesDir } from "./helpers/dir"; @@ -26,7 +25,6 @@ export type QuestionArgs = Omit< "appPath" | "packageManager" > & { files?: string; - llamaParse?: boolean; listServerModels?: boolean; }; const supportedContextFileTypes = [ @@ -78,15 +76,13 @@ const defaults: QuestionArgs = { frontend: false, openAiKey: "", llamaCloudKey: "", + useLlamaParse: false, model: "gpt-3.5-turbo", embeddingModel: "text-embedding-ada-002", communityProjectConfig: undefined, llamapack: "", postInstallAction: "dependencies", - dataSource: { - type: "none", - config: {}, - }, + dataSources: [], tools: [], }; @@ -124,27 +120,42 @@ const getVectorDbChoices = (framework: TemplateFramework) => { return displayedChoices; }; -const getDataSourceChoices = (framework: TemplateFramework) => { - const choices = [ - { - title: "No data, just a simple chat", - value: "simple", - }, - { title: "Use an example PDF", value: "exampleFile" }, - ]; - if (process.platform === "win32" || process.platform === "darwin") { +export const getDataSourceChoices = ( + framework: TemplateFramework, + selectedDataSource: TemplateDataSource[], +) => { + const choices = []; + if (selectedDataSource.length > 0) { choices.push({ - title: `Use local files (${supportedContextFileTypes.join(", ")})`, - value: "localFile", + title: "No", + value: "no", + }); + } + if (selectedDataSource === undefined || selectedDataSource.length === 0) { + choices.push({ + title: "No data, just a simple chat", + value: "none", }); choices.push({ + title: "Use an example PDF", + value: "exampleFile", + }); + } + + choices.push( + { + title: `Use local files (${supportedContextFileTypes.join(", ")})`, + value: "file", + }, + { title: process.platform === "win32" ? "Use a local folder" : "Use local folders", - value: "localFolder", - }); - } + value: "folder", + }, + ); + if (framework === "fastapi") { choices.push({ title: "Use website content (requires Chrome)", @@ -182,9 +193,10 @@ const selectLocalContextData = async (type: TemplateDataSourceType) => { process.platform === "win32" ? selectedPath.split("\r\n") : selectedPath.split(", "); + for (const p of paths) { if ( - type == "file" && + fs.statSync(p).isFile() && !supportedContextFileTypes.includes(path.extname(p)) ) { console.log( @@ -320,9 +332,7 @@ export const askQuestions = async ( const openAiKeyConfigured = program.openAiKey || process.env["OPENAI_API_KEY"]; // If using LlamaParse, require LlamaCloud API key - const llamaCloudKeyConfigured = ( - program.dataSource?.config as FileSourceConfig - )?.useLlamaParse + const llamaCloudKeyConfigured = program.useLlamaParse ? program.llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"] : true; const hasVectorDb = program.vectorDb && program.vectorDb !== "none"; @@ -620,127 +630,149 @@ export const askQuestions = async ( if (program.files) { // If user specified files option, then the program should use context engine - program.engine == "context"; - if (!fs.existsSync(program.files)) { - console.log("File or folder not found"); - process.exit(1); - } else { - program.dataSource = { - type: fs.lstatSync(program.files).isDirectory() ? "folder" : "file", + program.engine = "context"; + program.files.split(",").forEach((filePath) => { + program.dataSources.push({ + type: "file", config: { - paths: program.files.split(","), + path: filePath, }, - }; - } + }); + }); } if (!program.engine) { if (ciInfo.isCI) { program.engine = getPrefOrDefault("engine"); + program.dataSources = getPrefOrDefault("dataSources"); } else { - const { dataSource } = await prompts( - { - type: "select", - name: "dataSource", - message: "Which data source would you like to use?", - choices: getDataSourceChoices(program.framework), - initial: 1, - }, - handlers, - ); - // Initialize with default config - program.dataSource = getPrefOrDefault("dataSource"); - if (program.dataSource) { - switch (dataSource) { - case "simple": - program.engine = "simple"; - program.dataSource = { type: "none", config: {} }; - break; - case "exampleFile": - program.engine = "context"; - // Treat example as a folder data source with no config - program.dataSource = { type: "folder", config: {} }; - break; - case "localFile": - program.engine = "context"; - program.dataSource = { + program.dataSources = []; + while (true) { + const { selectedSource } = await prompts( + { + type: "select", + name: "selectedSource", + message: + program.dataSources.length === 0 + ? "Which data source would you like to use?" + : "Would you like to add another data source?", + choices: getDataSourceChoices( + program.framework, + program.dataSources, + ), + initial: 0, + }, + handlers, + ); + + if (selectedSource === "no") { + break; + } + + if (selectedSource === "none") { + // Selected simple chat + program.dataSources = []; + // Stop asking for another data source + break; + } + + if (selectedSource === "exampleFile") { + program.dataSources.push({ + type: "file", + config: {}, + }); + } else if (selectedSource === "file" || selectedSource === "folder") { + // Select local data source + const selectedPaths = await selectLocalContextData(selectedSource); + for (const p of selectedPaths) { + program.dataSources.push({ type: "file", config: { - paths: await selectLocalContextData("file"), + path: p, }, - }; - break; - case "localFolder": - program.engine = "context"; - program.dataSource = { - type: "folder", - config: { - paths: await selectLocalContextData("folder"), + }); + } + } else if (selectedSource === "web") { + // Selected web data source + const { baseUrl } = await prompts( + { + type: "text", + name: "baseUrl", + message: "Please provide base URL of the website: ", + initial: "https://www.llamaindex.ai", + validate: (value: string) => { + if (!value.includes("://")) { + value = `https://${value}`; + } + const urlObj = new URL(value); + if ( + urlObj.protocol !== "https:" && + urlObj.protocol !== "http:" + ) { + return `URL=${value} has invalid protocol, only allow http or https`; + } + return true; }, - }; - break; - case "web": - program.engine = "context"; - program.dataSource.type = "web"; - break; + }, + handlers, + ); + + program.dataSources.push({ + type: "web", + config: { + baseUrl, + prefix: baseUrl, + depth: 1, + }, + }); } } + + if (program.dataSources.length === 0) { + program.engine = "simple"; + } else { + program.engine = "context"; + } } - } else if (!program.dataSource) { + } else if (!program.dataSources) { // Handle a case when engine is specified but dataSource is not if (program.engine === "context") { - program.dataSource = { - type: "folder", - config: {}, - }; + program.dataSources = [ + { + type: "file", + config: {}, + }, + ]; } else if (program.engine === "simple") { - program.dataSource = { - type: "none", - config: {}, - }; + program.dataSources = []; } } + // Asking for LlamaParse if user selected file or folder data source if ( - program.dataSource?.type === "file" || - program.dataSource?.type === "folder" + program.dataSources.some((ds) => ds.type === "file") && + !program.useLlamaParse ) { if (ciInfo.isCI) { + program.useLlamaParse = getPrefOrDefault("useLlamaParse"); program.llamaCloudKey = getPrefOrDefault("llamaCloudKey"); } else { - const dataSourceConfig = program.dataSource.config as FileSourceConfig; - dataSourceConfig.useLlamaParse = program.llamaParse; - - // Is pdf file selected as data source or is it a folder data source - const askingLlamaParse = - dataSourceConfig.useLlamaParse === undefined && - (program.dataSource.type === "folder" || - (program.dataSource.type === "file" && - dataSourceConfig.paths?.some((p) => path.extname(p) === ".pdf"))); - - // Ask if user wants to use LlamaParse - if (askingLlamaParse) { - const { useLlamaParse } = await prompts( - { - type: "toggle", - name: "useLlamaParse", - message: - "Would you like to use LlamaParse (improved parser for RAG - requires API key)?", - initial: true, - active: "yes", - inactive: "no", - }, - handlers, - ); - dataSourceConfig.useLlamaParse = useLlamaParse; - program.dataSource.config = dataSourceConfig; - } + const { useLlamaParse } = await prompts( + { + type: "toggle", + name: "useLlamaParse", + message: + "Would you like to use LlamaParse (improved parser for RAG - requires API key)?", + initial: true, + active: "yes", + inactive: "no", + }, + handlers, + ); + program.useLlamaParse = useLlamaParse; // Ask for LlamaCloud API key - if ( - dataSourceConfig.useLlamaParse && - program.llamaCloudKey === undefined - ) { + if (useLlamaParse && program.llamaCloudKey === undefined) { const { llamaCloudKey } = await prompts( { type: "text", @@ -755,56 +787,6 @@ export const askQuestions = async ( } } - if (program.dataSource?.type === "web" && program.framework === "fastapi") { - program.dataSource.config = []; - - while (true) { - const questions: any[] = [ - { - type: "text", - name: "baseUrl", - message: "Please provide base URL of the website: ", - initial: "https://www.llamaindex.ai", - validate: (value: string) => { - if (!value.includes("://")) { - value = `https://${value}`; - } - const urlObj = new URL(value); - if (urlObj.protocol !== "https:" && urlObj.protocol !== "http:") { - return `URL=${value} has invalid protocol, only allow http or https`; - } - // Check duplicated URL - if ( - (program.dataSource?.config as WebSourceConfig[]).some( - (c) => c.baseUrl === value, - ) - ) { - return `URL=${value} is already added. Please provide a different URL.`; - } - return true; - }, - }, - { - type: "toggle", - name: "shouldContinue", - message: "Would you like to add another website?", - initial: false, - active: "Yes", - inactive: "No", - }, - ]; - let { shouldContinue, baseUrl } = await prompts(questions, handlers); - program.dataSource.config.push({ - baseUrl: baseUrl, - prefix: baseUrl, - depth: 1, - }); - if (!shouldContinue) { - break; - } - } - } - if (program.engine !== "simple" && !program.vectorDb) { if (ciInfo.isCI) { program.vectorDb = getPrefOrDefault("vectorDb"); diff --git a/templates/components/loaders/python/__init__.py b/templates/components/loaders/python/__init__.py new file mode 100644 index 00000000..b8c0f5a1 --- /dev/null +++ b/templates/components/loaders/python/__init__.py @@ -0,0 +1,33 @@ +import os +import json +import importlib +import logging +from typing import Dict +from app.engine.loaders.file import FileLoaderConfig, get_file_documents +from app.engine.loaders.web import WebLoaderConfig, get_web_documents + +logger = logging.getLogger(__name__) + + +def load_configs(): + with open("config/loaders.json") as f: + configs = json.load(f) + return configs + + +def get_documents(): + documents = [] + config = load_configs() + for loader_type, loader_config in config.items(): + logger.info( + f"Loading documents from loader: {loader_type}, config: {loader_config}" + ) + if loader_type == "file": + document = get_file_documents(FileLoaderConfig(**loader_config)) + documents.extend(document) + elif loader_type == "web": + for entry in loader_config: + document = get_web_documents(WebLoaderConfig(**entry)) + documents.extend(document) + + return documents diff --git a/templates/components/loaders/python/file.py b/templates/components/loaders/python/file.py new file mode 100644 index 00000000..a814b0d0 --- /dev/null +++ b/templates/components/loaders/python/file.py @@ -0,0 +1,37 @@ +import os +from llama_parse import LlamaParse +from pydantic import BaseModel, validator + + +class FileLoaderConfig(BaseModel): + data_dir: str = "data" + use_llama_parse: bool = False + + @validator("data_dir") + def data_dir_must_exist(cls, v): + if not os.path.isdir(v): + raise ValueError(f"Directory '{v}' does not exist") + return v + + +def llama_parse_parser(): + if os.getenv("LLAMA_CLOUD_API_KEY") is None: + raise ValueError( + "LLAMA_CLOUD_API_KEY environment variable is not set. " + "Please set it in .env file or in your shell environment then run again!" + ) + parser = LlamaParse(result_type="markdown", verbose=True, language="en") + return parser + + +def get_file_documents(config: FileLoaderConfig): + from llama_index.core.readers import SimpleDirectoryReader + + reader = SimpleDirectoryReader( + config.data_dir, + recursive=True, + ) + if config.use_llama_parse: + parser = llama_parse_parser() + reader.file_extractor = {".pdf": parser} + return reader.load_data() diff --git a/templates/components/loaders/python/file/loader.py b/templates/components/loaders/python/file/loader.py deleted file mode 100644 index 40923709..00000000 --- a/templates/components/loaders/python/file/loader.py +++ /dev/null @@ -1,10 +0,0 @@ -from llama_index.core.readers import SimpleDirectoryReader - -DATA_DIR = "data" # directory containing the documents - - -def get_documents(): - return SimpleDirectoryReader( - DATA_DIR, - recursive=True, - ).load_data() diff --git a/templates/components/loaders/python/llama_parse/loader.py b/templates/components/loaders/python/llama_parse/loader.py deleted file mode 100644 index efaf3421..00000000 --- a/templates/components/loaders/python/llama_parse/loader.py +++ /dev/null @@ -1,19 +0,0 @@ -import os -from llama_parse import LlamaParse -from llama_index.core import SimpleDirectoryReader - -DATA_DIR = "data" # directory containing the documents - - -def get_documents(): - if os.getenv("LLAMA_CLOUD_API_KEY") is None: - raise ValueError( - "LLAMA_CLOUD_API_KEY environment variable is not set. " - "Please set it in .env file or in your shell environment then run again!" - ) - parser = LlamaParse(result_type="markdown", verbose=True, language="en") - - reader = SimpleDirectoryReader( - DATA_DIR, recursive=True, file_extractor={".pdf": parser} - ) - return reader.load_data() diff --git a/templates/components/loaders/python/web.py b/templates/components/loaders/python/web.py new file mode 100644 index 00000000..bca9aaf7 --- /dev/null +++ b/templates/components/loaders/python/web.py @@ -0,0 +1,19 @@ +import os +import json +from pydantic import BaseModel, Field + + +class WebLoaderConfig(BaseModel): + base_url: str + prefix: str + max_depth: int = Field(default=1, ge=0) + + +def get_web_documents(config: WebLoaderConfig): + from llama_index.readers.web import WholeSiteReader + + scraper = WholeSiteReader( + prefix=config.prefix, + max_depth=config.max_depth, + ) + return scraper.load_data(config.base_url) diff --git a/templates/components/loaders/python/web/loader.py b/templates/components/loaders/python/web/loader.py deleted file mode 100644 index 096e3c97..00000000 --- a/templates/components/loaders/python/web/loader.py +++ /dev/null @@ -1,31 +0,0 @@ -import os -import json -from pydantic import BaseModel, Field -from llama_index.readers.web import WholeSiteReader - - -class WebLoaderConfig(BaseModel): - base_url: str - prefix: str - max_depth: int = Field(default=1, ge=0) - - -def load_configs(): - with open("config/loaders.json") as f: - configs = json.load(f) - web_config = configs.get("web", None) - if web_config is None: - raise ValueError("No web config found in loaders.json") - return [WebLoaderConfig(**config) for config in web_config] - - -def get_documents(): - web_config = load_configs() - documents = [] - for entry in web_config: - scraper = WholeSiteReader( - prefix=entry.prefix, - max_depth=entry.max_depth, - ) - documents.extend(scraper.load_data(entry.base_url)) - return documents diff --git a/templates/components/vectordbs/python/milvus/generate.py b/templates/components/vectordbs/python/milvus/generate.py index b980d9f8..69ee4899 100644 --- a/templates/components/vectordbs/python/milvus/generate.py +++ b/templates/components/vectordbs/python/milvus/generate.py @@ -8,7 +8,7 @@ from llama_index.core.indices import VectorStoreIndex from llama_index.vector_stores.milvus import MilvusVectorStore from app.settings import init_settings -from app.engine.loader import get_documents +from app.engine.loaders import get_documents logging.basicConfig(level=logging.INFO) logger = logging.getLogger() diff --git a/templates/components/vectordbs/python/mongo/generate.py b/templates/components/vectordbs/python/mongo/generate.py index 69d52071..ddc32c5a 100644 --- a/templates/components/vectordbs/python/mongo/generate.py +++ b/templates/components/vectordbs/python/mongo/generate.py @@ -8,7 +8,7 @@ from llama_index.core.indices import VectorStoreIndex from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch from app.settings import init_settings -from app.engine.loader import get_documents +from app.engine.loaders import get_documents logging.basicConfig(level=logging.INFO) logger = logging.getLogger() diff --git a/templates/components/vectordbs/python/none/generate.py b/templates/components/vectordbs/python/none/generate.py index 3c8055f3..78fe57be 100644 --- a/templates/components/vectordbs/python/none/generate.py +++ b/templates/components/vectordbs/python/none/generate.py @@ -7,7 +7,7 @@ VectorStoreIndex, ) from app.engine.constants import STORAGE_DIR -from app.engine.loader import get_documents +from app.engine.loaders import get_documents from app.settings import init_settings diff --git a/templates/components/vectordbs/python/pg/generate.py b/templates/components/vectordbs/python/pg/generate.py index 608beb2e..5cc93244 100644 --- a/templates/components/vectordbs/python/pg/generate.py +++ b/templates/components/vectordbs/python/pg/generate.py @@ -6,7 +6,7 @@ from llama_index.core.indices import VectorStoreIndex from llama_index.core.storage import StorageContext -from app.engine.loader import get_documents +from app.engine.loaders import get_documents from app.settings import init_settings from app.engine.utils import init_pg_vector_store_from_env diff --git a/templates/components/vectordbs/python/pinecone/generate.py b/templates/components/vectordbs/python/pinecone/generate.py index 4e14648b..c7ad55ea 100644 --- a/templates/components/vectordbs/python/pinecone/generate.py +++ b/templates/components/vectordbs/python/pinecone/generate.py @@ -8,7 +8,7 @@ from llama_index.core.indices import VectorStoreIndex from llama_index.vector_stores.pinecone import PineconeVectorStore from app.settings import init_settings -from app.engine.loader import get_documents +from app.engine.loaders import get_documents logging.basicConfig(level=logging.INFO) logger = logging.getLogger()