From 27397143db6c8d1c27dace6a317ab1c76a3077f6 Mon Sep 17 00:00:00 2001 From: "Huu Le (Lee)" <39040748+leehuwuj@users.noreply.github.com> Date: Mon, 1 Apr 2024 09:51:49 +0700 Subject: [PATCH] feat: Add database data source (MySQL and PostgreSQL) (#28) --- .changeset/seven-zebras-allow.md | 5 + helpers/python.ts | 60 +++++++-- helpers/types.ts | 11 +- questions.ts | 125 ++++++++++++------ .../components/loaders/python/__init__.py | 19 ++- templates/components/loaders/python/db.py | 26 ++++ 6 files changed, 187 insertions(+), 59 deletions(-) create mode 100644 .changeset/seven-zebras-allow.md create mode 100644 templates/components/loaders/python/db.py diff --git a/.changeset/seven-zebras-allow.md b/.changeset/seven-zebras-allow.md new file mode 100644 index 00000000..579a93fe --- /dev/null +++ b/.changeset/seven-zebras-allow.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Use databases as data source diff --git a/helpers/python.ts b/helpers/python.ts index 047c4876..c02cd8f1 100644 --- a/helpers/python.ts +++ b/helpers/python.ts @@ -9,6 +9,7 @@ import { templatesDir } from "./dir"; import { isPoetryAvailable, tryPoetryInstall } from "./poetry"; import { Tool } from "./tools"; import { + DbSourceConfig, InstallTemplateArgs, TemplateDataSource, TemplateVectorDB, @@ -65,17 +66,34 @@ const getAdditionalDependencies = ( // Add data source dependencies const dataSourceType = dataSource?.type; - if (dataSourceType === "file") { - // llama-index-readers-file (pdf, excel, csv) is already included in llama_index package - dependencies.push({ - name: "docx2txt", - version: "^0.8", - }); - } else if (dataSourceType === "web") { - dependencies.push({ - name: "llama-index-readers-web", - version: "^0.1.6", - }); + switch (dataSourceType) { + case "file": + dependencies.push({ + name: "docx2txt", + version: "^0.8", + }); + break; + case "web": + dependencies.push({ + name: "llama-index-readers-web", + version: "^0.1.6", + }); + break; + case "db": + dependencies.push({ + name: "llama-index-readers-database", + version: "^0.1.3", + }); + dependencies.push({ + name: "pymysql", + version: "^1.1.0", + extras: ["rsa"], + }); + dependencies.push({ + name: "psycopg2", + version: "^2.9.9", + }); + break; } // Add tools dependencies @@ -307,6 +325,26 @@ export const installPythonTemplate = async ({ node.commentBefore = ` use_llama_parse: Use LlamaParse if \`true\`. Needs a \`LLAMA_CLOUD_API_KEY\` from https://cloud.llamaindex.ai set as environment variable`; loaderConfig.set("file", node); } + + // DB loader config + const dbLoaders = dataSources.filter((ds) => ds.type === "db"); + if (dbLoaders.length > 0) { + const dbLoaderConfig = new Document({}); + const configEntries = dbLoaders.map((ds) => { + const dsConfig = ds.config as DbSourceConfig; + return { + uri: dsConfig.uri, + queries: [dsConfig.queries], + }; + }); + + const node = dbLoaderConfig.createNode(configEntries); + node.commentBefore = ` The configuration for the database loader, only supports MySQL and PostgreSQL databases for now. + uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db + query: The query to fetch data from the database. E.g.: SELECT * FROM table`; + loaderConfig.set("db", node); + } + // Write loaders config if (Object.keys(loaderConfig).length > 0) { const loaderConfigPath = path.join(root, "config/loaders.yaml"); diff --git a/helpers/types.ts b/helpers/types.ts index d94993bd..f7d0f414 100644 --- a/helpers/types.ts +++ b/helpers/types.ts @@ -14,7 +14,7 @@ export type TemplateDataSource = { type: TemplateDataSourceType; config: TemplateDataSourceConfig; }; -export type TemplateDataSourceType = "file" | "web"; +export type TemplateDataSourceType = "file" | "web" | "db"; export type TemplateObservability = "none" | "opentelemetry"; // Config for both file and folder export type FileSourceConfig = { @@ -25,8 +25,15 @@ export type WebSourceConfig = { prefix?: string; depth?: number; }; +export type DbSourceConfig = { + uri?: string; + queries?: string; +}; -export type TemplateDataSourceConfig = FileSourceConfig | WebSourceConfig; +export type TemplateDataSourceConfig = + | FileSourceConfig + | WebSourceConfig + | DbSourceConfig; export type CommunityProjectConfig = { owner: string; diff --git a/questions.ts b/questions.ts index 8354880a..231be855 100644 --- a/questions.ts +++ b/questions.ts @@ -159,6 +159,10 @@ export const getDataSourceChoices = ( title: "Use website content (requires Chrome)", value: "web", }); + choices.push({ + title: "Use data from a database (Mysql, PostgreSQL)", + value: "db", + }); } return choices; }; @@ -629,52 +633,93 @@ export const askQuestions = async ( // user doesn't want another data source or any data source break; } - if (selectedSource === "exampleFile") { - program.dataSources.push(EXAMPLE_FILE); - } else if (selectedSource === "file" || selectedSource === "folder") { - // Select local data source - const selectedPaths = await selectLocalContextData(selectedSource); - for (const p of selectedPaths) { + switch (selectedSource) { + case "exampleFile": { + program.dataSources.push(EXAMPLE_FILE); + break; + } + case "file": + case "folder": { + const selectedPaths = await selectLocalContextData(selectedSource); + for (const p of selectedPaths) { + program.dataSources.push({ + type: "file", + config: { + path: p, + }, + }); + } + break; + } + case "web": { + const { baseUrl } = await prompts( + { + type: "text", + name: "baseUrl", + message: "Please provide base URL of the website: ", + initial: "https://www.llamaindex.ai", + validate: (value: string) => { + if (!value.includes("://")) { + value = `https://${value}`; + } + const urlObj = new URL(value); + if ( + urlObj.protocol !== "https:" && + urlObj.protocol !== "http:" + ) { + return `URL=${value} has invalid protocol, only allow http or https`; + } + return true; + }, + }, + handlers, + ); + program.dataSources.push({ - type: "file", + type: "web", config: { - path: p, + baseUrl, + prefix: baseUrl, + depth: 1, }, }); + break; } - } else if (selectedSource === "web") { - // Selected web data source - const { baseUrl } = await prompts( - { - type: "text", - name: "baseUrl", - message: "Please provide base URL of the website: ", - initial: "https://www.llamaindex.ai", - validate: (value: string) => { - if (!value.includes("://")) { - value = `https://${value}`; - } - const urlObj = new URL(value); - if ( - urlObj.protocol !== "https:" && - urlObj.protocol !== "http:" - ) { - return `URL=${value} has invalid protocol, only allow http or https`; - } - return true; + case "db": { + const dbPrompts: prompts.PromptObject[] = [ + { + type: "text", + name: "uri", + message: + "Please enter the connection string (URI) for the database.", + initial: "mysql+pymysql://user:pass@localhost:3306/mydb", + validate: (value: string) => { + if (!value) { + return "Please provide a valid connection string"; + } else if ( + !( + value.startsWith("mysql+pymysql://") || + value.startsWith("postgresql+psycopg://") + ) + ) { + return "The connection string must start with 'mysql+pymysql://' for MySQL or 'postgresql+psycopg://' for PostgreSQL"; + } + return true; + }, }, - }, - handlers, - ); - - program.dataSources.push({ - type: "web", - config: { - baseUrl, - prefix: baseUrl, - depth: 1, - }, - }); + // Only ask for a query, user can provide more complex queries in the config file later + { + type: (prev) => (prev ? "text" : null), + name: "queries", + message: "Please enter the SQL query to fetch data:", + initial: "SELECT * FROM mytable", + }, + ]; + program.dataSources.push({ + type: "db", + config: await prompts(dbPrompts, handlers), + }); + } } } } diff --git a/templates/components/loaders/python/__init__.py b/templates/components/loaders/python/__init__.py index 662c65a9..d17df8e0 100644 --- a/templates/components/loaders/python/__init__.py +++ b/templates/components/loaders/python/__init__.py @@ -5,6 +5,7 @@ from typing import Dict from app.engine.loaders.file import FileLoaderConfig, get_file_documents from app.engine.loaders.web import WebLoaderConfig, get_web_documents +from app.engine.loaders.db import DBLoaderConfig, get_db_documents logger = logging.getLogger(__name__) @@ -22,11 +23,17 @@ def get_documents(): logger.info( f"Loading documents from loader: {loader_type}, config: {loader_config}" ) - if loader_type == "file": - document = get_file_documents(FileLoaderConfig(**loader_config)) - documents.extend(document) - elif loader_type == "web": - document = get_web_documents(WebLoaderConfig(**loader_config)) - documents.extend(document) + match loader_type: + case "file": + document = get_file_documents(FileLoaderConfig(**loader_config)) + case "web": + document = get_web_documents(WebLoaderConfig(**loader_config)) + case "db": + document = get_db_documents( + configs=[DBLoaderConfig(**cfg) for cfg in loader_config] + ) + case _: + raise ValueError(f"Invalid loader type: {loader_type}") + documents.extend(document) return documents diff --git a/templates/components/loaders/python/db.py b/templates/components/loaders/python/db.py new file mode 100644 index 00000000..d5c9ffde --- /dev/null +++ b/templates/components/loaders/python/db.py @@ -0,0 +1,26 @@ +import os +import logging +from typing import List +from pydantic import BaseModel, validator +from llama_index.core.indices.vector_store import VectorStoreIndex + +logger = logging.getLogger(__name__) + + +class DBLoaderConfig(BaseModel): + uri: str + queries: List[str] + + +def get_db_documents(configs: list[DBLoaderConfig]): + from llama_index.readers.database import DatabaseReader + + docs = [] + for entry in configs: + loader = DatabaseReader(uri=entry.uri) + for query in entry.queries: + logger.info(f"Loading data from database with query: {query}") + documents = loader.load_data(query=query) + docs.extend(documents) + + return documents