Skip to content

Commit

Permalink
feat: Add database data source (MySQL and PostgreSQL) (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
leehuwuj authored Apr 1, 2024
1 parent 665c26c commit 2739714
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 59 deletions.
5 changes: 5 additions & 0 deletions .changeset/seven-zebras-allow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"create-llama": patch
---

Use databases as data source
60 changes: 49 additions & 11 deletions helpers/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { templatesDir } from "./dir";
import { isPoetryAvailable, tryPoetryInstall } from "./poetry";
import { Tool } from "./tools";
import {
DbSourceConfig,
InstallTemplateArgs,
TemplateDataSource,
TemplateVectorDB,
Expand Down Expand Up @@ -65,17 +66,34 @@ const getAdditionalDependencies = (

// Add data source dependencies
const dataSourceType = dataSource?.type;
if (dataSourceType === "file") {
// llama-index-readers-file (pdf, excel, csv) is already included in llama_index package
dependencies.push({
name: "docx2txt",
version: "^0.8",
});
} else if (dataSourceType === "web") {
dependencies.push({
name: "llama-index-readers-web",
version: "^0.1.6",
});
switch (dataSourceType) {
case "file":
dependencies.push({
name: "docx2txt",
version: "^0.8",
});
break;
case "web":
dependencies.push({
name: "llama-index-readers-web",
version: "^0.1.6",
});
break;
case "db":
dependencies.push({
name: "llama-index-readers-database",
version: "^0.1.3",
});
dependencies.push({
name: "pymysql",
version: "^1.1.0",
extras: ["rsa"],
});
dependencies.push({
name: "psycopg2",
version: "^2.9.9",
});
break;
}

// Add tools dependencies
Expand Down Expand Up @@ -307,6 +325,26 @@ export const installPythonTemplate = async ({
node.commentBefore = ` use_llama_parse: Use LlamaParse if \`true\`. Needs a \`LLAMA_CLOUD_API_KEY\` from https://cloud.llamaindex.ai set as environment variable`;
loaderConfig.set("file", node);
}

// DB loader config
const dbLoaders = dataSources.filter((ds) => ds.type === "db");
if (dbLoaders.length > 0) {
const dbLoaderConfig = new Document({});
const configEntries = dbLoaders.map((ds) => {
const dsConfig = ds.config as DbSourceConfig;
return {
uri: dsConfig.uri,
queries: [dsConfig.queries],
};
});

const node = dbLoaderConfig.createNode(configEntries);
node.commentBefore = ` The configuration for the database loader, only supports MySQL and PostgreSQL databases for now.
uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
query: The query to fetch data from the database. E.g.: SELECT * FROM table`;
loaderConfig.set("db", node);
}

// Write loaders config
if (Object.keys(loaderConfig).length > 0) {
const loaderConfigPath = path.join(root, "config/loaders.yaml");
Expand Down
11 changes: 9 additions & 2 deletions helpers/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export type TemplateDataSource = {
type: TemplateDataSourceType;
config: TemplateDataSourceConfig;
};
export type TemplateDataSourceType = "file" | "web";
export type TemplateDataSourceType = "file" | "web" | "db";
export type TemplateObservability = "none" | "opentelemetry";
// Config for both file and folder
export type FileSourceConfig = {
Expand All @@ -25,8 +25,15 @@ export type WebSourceConfig = {
prefix?: string;
depth?: number;
};
export type DbSourceConfig = {
uri?: string;
queries?: string;
};

export type TemplateDataSourceConfig = FileSourceConfig | WebSourceConfig;
export type TemplateDataSourceConfig =
| FileSourceConfig
| WebSourceConfig
| DbSourceConfig;

export type CommunityProjectConfig = {
owner: string;
Expand Down
125 changes: 85 additions & 40 deletions questions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ export const getDataSourceChoices = (
title: "Use website content (requires Chrome)",
value: "web",
});
choices.push({
title: "Use data from a database (Mysql, PostgreSQL)",
value: "db",
});
}
return choices;
};
Expand Down Expand Up @@ -629,52 +633,93 @@ export const askQuestions = async (
// user doesn't want another data source or any data source
break;
}
if (selectedSource === "exampleFile") {
program.dataSources.push(EXAMPLE_FILE);
} else if (selectedSource === "file" || selectedSource === "folder") {
// Select local data source
const selectedPaths = await selectLocalContextData(selectedSource);
for (const p of selectedPaths) {
switch (selectedSource) {
case "exampleFile": {
program.dataSources.push(EXAMPLE_FILE);
break;
}
case "file":
case "folder": {
const selectedPaths = await selectLocalContextData(selectedSource);
for (const p of selectedPaths) {
program.dataSources.push({
type: "file",
config: {
path: p,
},
});
}
break;
}
case "web": {
const { baseUrl } = await prompts(
{
type: "text",
name: "baseUrl",
message: "Please provide base URL of the website: ",
initial: "https://www.llamaindex.ai",
validate: (value: string) => {
if (!value.includes("://")) {
value = `https://${value}`;
}
const urlObj = new URL(value);
if (
urlObj.protocol !== "https:" &&
urlObj.protocol !== "http:"
) {
return `URL=${value} has invalid protocol, only allow http or https`;
}
return true;
},
},
handlers,
);

program.dataSources.push({
type: "file",
type: "web",
config: {
path: p,
baseUrl,
prefix: baseUrl,
depth: 1,
},
});
break;
}
} else if (selectedSource === "web") {
// Selected web data source
const { baseUrl } = await prompts(
{
type: "text",
name: "baseUrl",
message: "Please provide base URL of the website: ",
initial: "https://www.llamaindex.ai",
validate: (value: string) => {
if (!value.includes("://")) {
value = `https://${value}`;
}
const urlObj = new URL(value);
if (
urlObj.protocol !== "https:" &&
urlObj.protocol !== "http:"
) {
return `URL=${value} has invalid protocol, only allow http or https`;
}
return true;
case "db": {
const dbPrompts: prompts.PromptObject<string>[] = [
{
type: "text",
name: "uri",
message:
"Please enter the connection string (URI) for the database.",
initial: "mysql+pymysql://user:pass@localhost:3306/mydb",
validate: (value: string) => {
if (!value) {
return "Please provide a valid connection string";
} else if (
!(
value.startsWith("mysql+pymysql://") ||
value.startsWith("postgresql+psycopg://")
)
) {
return "The connection string must start with 'mysql+pymysql://' for MySQL or 'postgresql+psycopg://' for PostgreSQL";
}
return true;
},
},
},
handlers,
);

program.dataSources.push({
type: "web",
config: {
baseUrl,
prefix: baseUrl,
depth: 1,
},
});
// Only ask for a query, user can provide more complex queries in the config file later
{
type: (prev) => (prev ? "text" : null),
name: "queries",
message: "Please enter the SQL query to fetch data:",
initial: "SELECT * FROM mytable",
},
];
program.dataSources.push({
type: "db",
config: await prompts(dbPrompts, handlers),
});
}
}
}
}
Expand Down
19 changes: 13 additions & 6 deletions templates/components/loaders/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
from app.engine.loaders.db import DBLoaderConfig, get_db_documents

logger = logging.getLogger(__name__)

Expand All @@ -22,11 +23,17 @@ def get_documents():
logger.info(
f"Loading documents from loader: {loader_type}, config: {loader_config}"
)
if loader_type == "file":
document = get_file_documents(FileLoaderConfig(**loader_config))
documents.extend(document)
elif loader_type == "web":
document = get_web_documents(WebLoaderConfig(**loader_config))
documents.extend(document)
match loader_type:
case "file":
document = get_file_documents(FileLoaderConfig(**loader_config))
case "web":
document = get_web_documents(WebLoaderConfig(**loader_config))
case "db":
document = get_db_documents(
configs=[DBLoaderConfig(**cfg) for cfg in loader_config]
)
case _:
raise ValueError(f"Invalid loader type: {loader_type}")
documents.extend(document)

return documents
26 changes: 26 additions & 0 deletions templates/components/loaders/python/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import logging
from typing import List
from pydantic import BaseModel, validator
from llama_index.core.indices.vector_store import VectorStoreIndex

logger = logging.getLogger(__name__)


class DBLoaderConfig(BaseModel):
uri: str
queries: List[str]


def get_db_documents(configs: list[DBLoaderConfig]):
from llama_index.readers.database import DatabaseReader

docs = []
for entry in configs:
loader = DatabaseReader(uri=entry.uri)
for query in entry.queries:
logger.info(f"Loading data from database with query: {query}")
documents = loader.load_data(query=query)
docs.extend(documents)

return documents

0 comments on commit 2739714

Please sign in to comment.