Skip to content

Commit

Permalink
feat: default to agent chat engine, use context chat for no tools and…
Browse files Browse the repository at this point in the history
… a datasource only
  • Loading branch information
marcusschiesser committed Apr 2, 2024
1 parent 11cd67c commit b6e16ea
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 64 deletions.
93 changes: 42 additions & 51 deletions helpers/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -223,58 +223,39 @@ export const installPythonTemplate = async ({
const compPath = path.join(templatesDir, "components");
const enginePath = path.join(root, "app", "engine");

if (dataSources.length > 0) {
// copy vector db component
const vectorDbDirName = vectorDb ?? "none";
const VectorDBPath = path.join(
compPath,
"vectordbs",
"python",
vectorDbDirName,
);
await copy("**", enginePath, {
parents: true,
cwd: VectorDBPath,
});
// Copy selected vector DB
await copy("**", enginePath, {
parents: true,
cwd: path.join(compPath, "vectordbs", "python", vectorDb ?? "none"),
});

// Copy loaders to enginePath
const loaderPath = path.join(enginePath, "loaders");
await copy("**", loaderPath, {
parents: true,
cwd: path.join(compPath, "loaders", "python"),
});
// Copy all loaders to enginePath
const loaderPath = path.join(enginePath, "loaders");
await copy("**", loaderPath, {
parents: true,
cwd: path.join(compPath, "loaders", "python"),
});

// Generate loaders config
await writeLoadersConfig(root, dataSources, useLlamaParse);
}
// write configuration for loaders
await writeLoadersConfig(root, dataSources, useLlamaParse);

// Copy engine code
if (tools && tools.length > 0) {
console.log("\nUsing agent chat engine\n");
await copy("**", enginePath, {
parents: true,
cwd: path.join(compPath, "engines", "python", "agent"),
});
// Write tool configs
const configContent: Record<string, any> = {};
tools.forEach((tool) => {
configContent[tool.name] = tool.config ?? {};
});
const configFilePath = path.join(root, "config/tools.yaml");
await fs.mkdir(path.join(root, "config"), { recursive: true });
await fs.writeFile(configFilePath, yaml.stringify(configContent));
} else if (dataSources.length > 0) {
console.log("\nUsing context chat engine\n");
await copy("**", enginePath, {
parents: true,
cwd: path.join(compPath, "engines", "python", "chat"),
});
// Select engine code based on data sources and tools
let engine;
tools = tools ?? [];
if (dataSources.length > 0 && tools.length === 0) {
console.log("\nNo tools selected - use optimized context chat engine\n");
engine = "chat";
} else {
console.log(
"\nUsing simple chat as neither a datasource nor tools are selected\n",
);
engine = "agent";
await writeToolsConfig(root, tools);
}

// Copy engine code
await copy("**", enginePath, {
parents: true,
cwd: path.join(compPath, "engines", "python", engine),
});

const addOnDependencies = dataSources
.map((ds) => getAdditionalDependencies(vectorDb, ds, tools))
.flat();
Expand All @@ -290,11 +271,23 @@ export const installPythonTemplate = async ({
});
};

async function writeToolsConfig(root: string, tools: Tool[]) {
if (tools.length === 0) return; // no tools selected, no config need
const configContent: Record<string, any> = {};
tools.forEach((tool) => {
configContent[tool.name] = tool.config ?? {};
});
const configFilePath = path.join(root, "config", "tools.yaml");
await fs.mkdir(path.join(root, "config"), { recursive: true });
await fs.writeFile(configFilePath, yaml.stringify(configContent));
}

async function writeLoadersConfig(
root: string,
dataSources: TemplateDataSource[],
useLlamaParse?: boolean,
) {
if (dataSources.length === 0) return; // no datasources, no config needed
const loaderConfig = new Document({});
// Web loader config
if (dataSources.some((ds) => ds.type === "web")) {
Expand Down Expand Up @@ -361,9 +354,7 @@ async function writeLoadersConfig(
}

// Write loaders config
if (Object.keys(loaderConfig).length > 0) {
const loaderConfigPath = path.join(root, "config/loaders.yaml");
await fs.mkdir(path.join(root, "config"), { recursive: true });
await fs.writeFile(loaderConfigPath, yaml.stringify(loaderConfig));
}
const loaderConfigPath = path.join(root, "config/loaders.yaml");
await fs.mkdir(path.join(root, "config"), { recursive: true });
await fs.writeFile(loaderConfigPath, yaml.stringify(loaderConfig));
}
9 changes: 5 additions & 4 deletions templates/components/engines/python/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ def get_chat_engine():
top_k = os.getenv("TOP_K", "3")
tools = []

# Add query tool
# Add query tool if index exists
index = get_index()
query_engine = index.as_query_engine(similarity_top_k=int(top_k))
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
tools.append(query_engine_tool)
if index is not None:
query_engine = index.as_query_engine(similarity_top_k=int(top_k))
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
tools.append(query_engine_tool)

# Add additional tools
tools += ToolFactory.from_env()
Expand Down
10 changes: 6 additions & 4 deletions templates/components/engines/python/agent/tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import yaml
import importlib

Expand Down Expand Up @@ -26,8 +27,9 @@ def create_tool(tool_name: str, **kwargs) -> list[FunctionTool]:
@staticmethod
def from_env() -> list[FunctionTool]:
tools = []
with open("config/tools.yaml", "r") as f:
tool_configs = yaml.safe_load(f)
for name, config in tool_configs.items():
tools += ToolFactory.create_tool(name, **config)
if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r") as f:
tool_configs = yaml.safe_load(f)
for name, config in tool_configs.items():
tools += ToolFactory.create_tool(name, **config)
return tools
8 changes: 7 additions & 1 deletion templates/components/engines/python/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ def get_chat_engine():
system_prompt = os.getenv("SYSTEM_PROMPT")
top_k = os.getenv("TOP_K", 3)

return get_index().as_chat_engine(
index = get_index()
if index is None:
raise Exception(
"StorageContext is empty - call 'python app/engine/generate.py' to generate the storage first"
)

return index.as_chat_engine(
similarity_top_k=int(top_k),
system_prompt=system_prompt,
chat_mode="condense_plus_context",
Expand Down
5 changes: 1 addition & 4 deletions templates/components/vectordbs/python/none/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
def get_index():
# check if storage already exists
if not os.path.exists(STORAGE_DIR):
raise Exception(
"StorageContext is empty - call 'python app/engine/generate.py' to generate the storage first"
)

return None
# load the existing index
logger.info(f"Loading index from {STORAGE_DIR}...")
storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
Expand Down

0 comments on commit b6e16ea

Please sign in to comment.