Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add database data source #28

Merged
merged 6 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/seven-zebras-allow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"create-llama": patch
---

Use databases as data source
60 changes: 49 additions & 11 deletions helpers/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { templatesDir } from "./dir";
import { isPoetryAvailable, tryPoetryInstall } from "./poetry";
import { Tool } from "./tools";
import {
DbSourceConfig,
InstallTemplateArgs,
TemplateDataSource,
TemplateVectorDB,
Expand Down Expand Up @@ -65,17 +66,34 @@ const getAdditionalDependencies = (

// Add data source dependencies
const dataSourceType = dataSource?.type;
if (dataSourceType === "file") {
// llama-index-readers-file (pdf, excel, csv) is already included in llama_index package
dependencies.push({
name: "docx2txt",
version: "^0.8",
});
} else if (dataSourceType === "web") {
dependencies.push({
name: "llama-index-readers-web",
version: "^0.1.6",
});
switch (dataSourceType) {
case "file":
dependencies.push({
name: "docx2txt",
version: "^0.8",
});
break;
case "web":
dependencies.push({
name: "llama-index-readers-web",
version: "^0.1.6",
});
break;
case "db":
dependencies.push({
name: "llama-index-readers-database",
version: "^0.1.3",
});
dependencies.push({
name: "pymysql",
version: "^1.1.0",
extras: ["rsa"],
});
dependencies.push({
marcusschiesser marked this conversation as resolved.
Show resolved Hide resolved
name: "psycopg2",
version: "^2.9.9",
});
break;
}

// Add tools dependencies
Expand Down Expand Up @@ -307,6 +325,26 @@ export const installPythonTemplate = async ({
node.commentBefore = ` use_llama_parse: Use LlamaParse if \`true\`. Needs a \`LLAMA_CLOUD_API_KEY\` from https://cloud.llamaindex.ai set as environment variable`;
loaderConfig.set("file", node);
}

// DB loader config
const dbLoaders = dataSources.filter((ds) => ds.type === "db");
if (dbLoaders.length > 0) {
const dbLoaderConfig = new Document({});
const configEntries = dbLoaders.map((ds) => {
const dsConfig = ds.config as DbSourceConfig;
return {
uri: dsConfig.uri,
queries: [dsConfig.queries],
};
});

const node = dbLoaderConfig.createNode(configEntries);
node.commentBefore = ` The configuration for the database loader, only supports MySQL and PostgreSQL databases for now.
uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
query: The query to fetch data from the database. E.g.: SELECT * FROM table`;
loaderConfig.set("db", node);
}

// Write loaders config
if (Object.keys(loaderConfig).length > 0) {
const loaderConfigPath = path.join(root, "config/loaders.yaml");
Expand Down
11 changes: 9 additions & 2 deletions helpers/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export type TemplateDataSource = {
type: TemplateDataSourceType;
config: TemplateDataSourceConfig;
};
export type TemplateDataSourceType = "file" | "web";
export type TemplateDataSourceType = "file" | "web" | "db";
export type TemplateObservability = "none" | "opentelemetry";
// Config for both file and folder
export type FileSourceConfig = {
Expand All @@ -25,8 +25,15 @@ export type WebSourceConfig = {
prefix?: string;
depth?: number;
};
export type DbSourceConfig = {
uri?: string;
queries?: string;
};

export type TemplateDataSourceConfig = FileSourceConfig | WebSourceConfig;
export type TemplateDataSourceConfig =
| FileSourceConfig
| WebSourceConfig
| DbSourceConfig;

export type CommunityProjectConfig = {
owner: string;
Expand Down
125 changes: 85 additions & 40 deletions questions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ export const getDataSourceChoices = (
title: "Use website content (requires Chrome)",
value: "web",
});
choices.push({
title: "Use data from a database (Mysql, PostgreSQL)",
value: "db",
});
}
return choices;
};
Expand Down Expand Up @@ -629,52 +633,93 @@ export const askQuestions = async (
// user doesn't want another data source or any data source
break;
}
if (selectedSource === "exampleFile") {
program.dataSources.push(EXAMPLE_FILE);
} else if (selectedSource === "file" || selectedSource === "folder") {
// Select local data source
const selectedPaths = await selectLocalContextData(selectedSource);
for (const p of selectedPaths) {
switch (selectedSource) {
case "exampleFile": {
program.dataSources.push(EXAMPLE_FILE);
break;
}
case "file":
case "folder": {
const selectedPaths = await selectLocalContextData(selectedSource);
for (const p of selectedPaths) {
program.dataSources.push({
type: "file",
config: {
path: p,
},
});
}
break;
}
case "web": {
const { baseUrl } = await prompts(
{
type: "text",
name: "baseUrl",
message: "Please provide base URL of the website: ",
initial: "https://www.llamaindex.ai",
validate: (value: string) => {
if (!value.includes("://")) {
value = `https://${value}`;
}
const urlObj = new URL(value);
if (
urlObj.protocol !== "https:" &&
urlObj.protocol !== "http:"
) {
return `URL=${value} has invalid protocol, only allow http or https`;
}
return true;
},
},
handlers,
);

program.dataSources.push({
type: "file",
type: "web",
config: {
path: p,
baseUrl,
prefix: baseUrl,
depth: 1,
},
});
break;
}
} else if (selectedSource === "web") {
// Selected web data source
const { baseUrl } = await prompts(
{
type: "text",
name: "baseUrl",
message: "Please provide base URL of the website: ",
initial: "https://www.llamaindex.ai",
validate: (value: string) => {
if (!value.includes("://")) {
value = `https://${value}`;
}
const urlObj = new URL(value);
if (
urlObj.protocol !== "https:" &&
urlObj.protocol !== "http:"
) {
return `URL=${value} has invalid protocol, only allow http or https`;
}
return true;
case "db": {
marcusschiesser marked this conversation as resolved.
Show resolved Hide resolved
const dbPrompts: prompts.PromptObject<string>[] = [
{
type: "text",
name: "uri",
message:
"Please enter the connection string (URI) for the database.",
initial: "mysql+pymysql://user:pass@localhost:3306/mydb",
validate: (value: string) => {
if (!value) {
return "Please provide a valid connection string";
} else if (
!(
value.startsWith("mysql+pymysql://") ||
value.startsWith("postgresql+psycopg://")
)
) {
return "The connection string must start with 'mysql+pymysql://' for MySQL or 'postgresql+psycopg://' for PostgreSQL";
}
return true;
},
},
},
handlers,
);

program.dataSources.push({
type: "web",
config: {
baseUrl,
prefix: baseUrl,
depth: 1,
},
});
// Only ask for a query, user can provide more complex queries in the config file later
{
type: (prev) => (prev ? "text" : null),
name: "queries",
message: "Please enter the SQL query to fetch data:",
initial: "SELECT * FROM mytable",
},
];
program.dataSources.push({
type: "db",
config: await prompts(dbPrompts, handlers),
});
}
}
}
}
Expand Down
19 changes: 13 additions & 6 deletions templates/components/loaders/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
from app.engine.loaders.db import DBLoaderConfig, get_db_documents

logger = logging.getLogger(__name__)

Expand All @@ -22,11 +23,17 @@ def get_documents():
logger.info(
f"Loading documents from loader: {loader_type}, config: {loader_config}"
)
if loader_type == "file":
document = get_file_documents(FileLoaderConfig(**loader_config))
documents.extend(document)
elif loader_type == "web":
document = get_web_documents(WebLoaderConfig(**loader_config))
documents.extend(document)
match loader_type:
case "file":
document = get_file_documents(FileLoaderConfig(**loader_config))
case "web":
document = get_web_documents(WebLoaderConfig(**loader_config))
case "db":
document = get_db_documents(
configs=[DBLoaderConfig(**cfg) for cfg in loader_config]
)
case _:
raise ValueError(f"Invalid loader type: {loader_type}")
documents.extend(document)

return documents
26 changes: 26 additions & 0 deletions templates/components/loaders/python/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import logging
from typing import List
from pydantic import BaseModel, validator
from llama_index.core.indices.vector_store import VectorStoreIndex

logger = logging.getLogger(__name__)


class DBLoaderConfig(BaseModel):
uri: str
queries: List[str]


def get_db_documents(configs: list[DBLoaderConfig]):
from llama_index.readers.database import DatabaseReader

docs = []
for entry in configs:
loader = DatabaseReader(uri=entry.uri)
for query in entry.queries:
logger.info(f"Loading data from database with query: {query}")
documents = loader.load_data(query=query)
docs.extend(documents)

return documents
Loading