Skip to content

Commit

Permalink
feat: implement artifact tool in TS (#328)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Marcus Schiesser <[email protected]>
  • Loading branch information
thucpn and marcusschiesser authored Oct 3, 2024
1 parent 27a1b9f commit 5a7216e
Show file tree
Hide file tree
Showing 25 changed files with 1,070 additions and 122 deletions.
5 changes: 5 additions & 0 deletions .changeset/modern-cars-travel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"create-llama": patch
---

feat: implement artifact tool in TS
2 changes: 1 addition & 1 deletion e2e/shared/multiagent_template.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ test.describe(`Test multiagent template ${templateFramework} ${dataSource} ${tem
page,
}) => {
await page.goto(`http://localhost:${port}`);
await page.fill("form input", userMessage);
await page.fill("form textarea", userMessage);

const responsePromise = page.waitForResponse((res) =>
res.url().includes("/api/chat"),
Expand Down
2 changes: 1 addition & 1 deletion e2e/shared/streaming_template.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ test.describe(`Test streaming template ${templateFramework} ${dataSource} ${temp
}) => {
test.skip(templatePostInstallAction !== "runApp");
await page.goto(`http://localhost:${port}`);
await page.fill("form input", userMessage);
await page.fill("form textarea", userMessage);
const [response] = await Promise.all([
page.waitForResponse(
(res) => {
Expand Down
6 changes: 0 additions & 6 deletions helpers/env-variables.ts
Original file line number Diff line number Diff line change
Expand Up @@ -397,12 +397,6 @@ const getEngineEnvs = (): EnvVar[] => {
description:
"The number of similar embeddings to return when retrieving documents.",
},
{
name: "STREAM_TIMEOUT",
description:
"The time in milliseconds to wait for the stream to return a response.",
value: "60000",
},
];
};

Expand Down
20 changes: 20 additions & 0 deletions helpers/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,26 @@ For better results, you can specify the region parameter to get results from a s
},
],
},
{
display: "Artifact Code Generator",
name: "artifact",
dependencies: [],
supportedFrameworks: ["express", "nextjs"],
type: ToolType.LOCAL,
envVars: [
{
name: "E2B_API_KEY",
description:
"E2B_API_KEY key is required to run artifact code generator tool. Get it here: https://e2b.dev/docs/getting-started/api-key",
},
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for artifact code generator tool.",
value:
"You are a code assistant that can generate and execute code using its tools. Don't generate code yourself, use the provided tools instead. Do not show the code or sandbox url in chat, just describe the steps to build the application based on the code that is generated by your tools. Do not describe how to run the code, just the steps to build the application.",
},
],
},
{
display: "OpenAPI action",
name: "openapi_action.OpenAPIActionToolSpec",
Expand Down
129 changes: 129 additions & 0 deletions templates/components/engines/typescript/agent/tools/code-generator.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import type { JSONSchemaType } from "ajv";
import {
BaseTool,
ChatMessage,
JSONValue,
Settings,
ToolMetadata,
} from "llamaindex";

// prompt based on https://github.com/e2b-dev/ai-artifacts
const CODE_GENERATION_PROMPT = `You are a skilled software engineer. You do not make mistakes. Generate an artifact. You can install additional dependencies. You can use one of the following templates:\n
1. code-interpreter-multilang: "Runs code as a Jupyter notebook cell. Strong data analysis angle. Can use complex visualisation to explain results.". File: script.py. Dependencies installed: python, jupyter, numpy, pandas, matplotlib, seaborn, plotly. Port: none.
2. nextjs-developer: "A Next.js 13+ app that reloads automatically. Using the pages router.". File: pages/index.tsx. Dependencies installed: [email protected], typescript, @types/node, @types/react, @types/react-dom, postcss, tailwindcss, shadcn. Port: 3000.
3. vue-developer: "A Vue.js 3+ app that reloads automatically. Only when asked specifically for a Vue app.". File: app.vue. Dependencies installed: vue@latest, [email protected], tailwindcss. Port: 3000.
4. streamlit-developer: "A streamlit app that reloads automatically.". File: app.py. Dependencies installed: streamlit, pandas, numpy, matplotlib, request, seaborn, plotly. Port: 8501.
5. gradio-developer: "A gradio app. Gradio Blocks/Interface should be called demo.". File: app.py. Dependencies installed: gradio, pandas, numpy, matplotlib, request, seaborn, plotly. Port: 7860.
Provide detail information about the artifact you're about to generate in the following JSON format with the following keys:
commentary: Describe what you're about to do and the steps you want to take for generating the artifact in great detail.
template: Name of the template used to generate the artifact.
title: Short title of the artifact. Max 3 words.
description: Short description of the artifact. Max 1 sentence.
additional_dependencies: Additional dependencies required by the artifact. Do not include dependencies that are already included in the template.
has_additional_dependencies: Detect if additional dependencies that are not included in the template are required by the artifact.
install_dependencies_command: Command to install additional dependencies required by the artifact.
port: Port number used by the resulted artifact. Null when no ports are exposed.
file_path: Relative path to the file, including the file name.
code: Code generated by the artifact. Only runnable code is allowed.
Make sure to use the correct syntax for the programming language you're using. Make sure to generate only one code file. If you need to use CSS, make sure to include the CSS in the code file using Tailwind CSS syntax.
`;

// detail information to execute code
export type CodeArtifact = {
commentary: string;
template: string;
title: string;
description: string;
additional_dependencies: string[];
has_additional_dependencies: boolean;
install_dependencies_command: string;
port: number | null;
file_path: string;
code: string;
};

export type CodeGeneratorParameter = {
requirement: string;
oldCode?: string;
};

export type CodeGeneratorToolParams = {
metadata?: ToolMetadata<JSONSchemaType<CodeGeneratorParameter>>;
};

const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<CodeGeneratorParameter>> =
{
name: "artifact",
description: `Generate a code artifact based on the input. Don't call this tool if the user has not asked for code generation. E.g. if the user asks to write a description or specification, don't call this tool.`,
parameters: {
type: "object",
properties: {
requirement: {
type: "string",
description: "The description of the application you want to build.",
},
oldCode: {
type: "string",
description: "The existing code to be modified",
nullable: true,
},
},
required: ["requirement"],
},
};

export class CodeGeneratorTool implements BaseTool<CodeGeneratorParameter> {
metadata: ToolMetadata<JSONSchemaType<CodeGeneratorParameter>>;

constructor(params?: CodeGeneratorToolParams) {
this.metadata = params?.metadata || DEFAULT_META_DATA;
}

async call(input: CodeGeneratorParameter) {
try {
const artifact = await this.generateArtifact(
input.requirement,
input.oldCode,
);
return artifact as JSONValue;
} catch (error) {
return { isError: true };
}
}

// Generate artifact (code, environment, dependencies, etc.)
async generateArtifact(
query: string,
oldCode?: string,
): Promise<CodeArtifact> {
const userMessage = `
${query}
${oldCode ? `The existing code is: \n\`\`\`${oldCode}\`\`\`` : ""}
`;
const messages: ChatMessage[] = [
{ role: "system", content: CODE_GENERATION_PROMPT },
{ role: "user", content: userMessage },
];
try {
const response = await Settings.llm.chat({ messages });
const content = response.message.content.toString();
const jsonContent = content
.replace(/^```json\s*|\s*```$/g, "")
.replace(/^`+|`+$/g, "")
.trim();
const artifact = JSON.parse(jsonContent) as CodeArtifact;
return artifact;
} catch (error) {
console.log("Failed to generate artifact", error);
throw error;
}
}
}
4 changes: 4 additions & 0 deletions templates/components/engines/typescript/agent/tools/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { BaseToolWithCall } from "llamaindex";
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
import { CodeGeneratorTool, CodeGeneratorToolParams } from "./code-generator";
import {
DocumentGenerator,
DocumentGeneratorParams,
Expand Down Expand Up @@ -47,6 +48,9 @@ const toolFactory: Record<string, ToolCreator> = {
img_gen: async (config: unknown) => {
return [new ImgGeneratorTool(config as ImgGeneratorToolParams)];
},
artifact: async (config: unknown) => {
return [new CodeGeneratorTool(config as CodeGeneratorToolParams)];
},
document_generator: async (config: unknown) => {
return [new DocumentGenerator(config as DocumentGeneratorParams)];
},
Expand Down
86 changes: 68 additions & 18 deletions templates/components/llamaindex/typescript/streaming/annotations.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { JSONValue } from "ai";
import { JSONValue, Message } from "ai";
import { MessageContent, MessageContentDetail } from "llamaindex";

export type DocumentFileType = "csv" | "pdf" | "txt" | "docx";
Expand All @@ -21,13 +21,20 @@ type Annotation = {
data: object;
};

export function retrieveDocumentIds(annotations?: JSONValue[]): string[] {
if (!annotations) return [];
export function isValidMessages(messages: Message[]): boolean {
const lastMessage =
messages && messages.length > 0 ? messages[messages.length - 1] : null;
return lastMessage !== null && lastMessage.role === "user";
}

export function retrieveDocumentIds(messages: Message[]): string[] {
// retrieve document Ids from the annotations of all messages (if any)
const annotations = getAllAnnotations(messages);
if (annotations.length === 0) return [];

const ids: string[] = [];

for (const annotation of annotations) {
const { type, data } = getValidAnnotation(annotation);
for (const { type, data } of annotations) {
if (
type === "document_file" &&
"files" in data &&
Expand All @@ -37,9 +44,7 @@ export function retrieveDocumentIds(annotations?: JSONValue[]): string[] {
for (const file of files) {
if (Array.isArray(file.content.value)) {
// it's an array, so it's an array of doc IDs
for (const id of file.content.value) {
ids.push(id);
}
ids.push(...file.content.value);
}
}
}
Expand All @@ -48,24 +53,69 @@ export function retrieveDocumentIds(annotations?: JSONValue[]): string[] {
return ids;
}

export function convertMessageContent(
content: string,
annotations?: JSONValue[],
): MessageContent {
if (!annotations) return content;
export function retrieveMessageContent(messages: Message[]): MessageContent {
const userMessage = messages[messages.length - 1];
return [
{
type: "text",
text: content,
text: userMessage.content,
},
...convertAnnotations(annotations),
...retrieveLatestArtifact(messages),
...convertAnnotations(messages),
];
}

function convertAnnotations(annotations: JSONValue[]): MessageContentDetail[] {
function getAllAnnotations(messages: Message[]): Annotation[] {
return messages.flatMap((message) =>
(message.annotations ?? []).map((annotation) =>
getValidAnnotation(annotation),
),
);
}

// get latest artifact from annotations to append to the user message
function retrieveLatestArtifact(messages: Message[]): MessageContentDetail[] {
const annotations = getAllAnnotations(messages);
if (annotations.length === 0) return [];

for (const { type, data } of annotations.reverse()) {
if (
type === "tools" &&
"toolCall" in data &&
"toolOutput" in data &&
typeof data.toolCall === "object" &&
typeof data.toolOutput === "object" &&
data.toolCall !== null &&
data.toolOutput !== null &&
"name" in data.toolCall &&
data.toolCall.name === "artifact"
) {
const toolOutput = data.toolOutput as { output?: { code?: string } };
if (toolOutput.output?.code) {
return [
{
type: "text",
text: `The existing code is:\n\`\`\`\n${toolOutput.output.code}\n\`\`\``,
},
];
}
}
}
return [];
}

function convertAnnotations(messages: Message[]): MessageContentDetail[] {
// annotations from the last user message that has annotations
const annotations: Annotation[] =
messages
.slice()
.reverse()
.find((message) => message.role === "user" && message.annotations)
?.annotations?.map(getValidAnnotation) || [];
if (annotations.length === 0) return [];

const content: MessageContentDetail[] = [];
annotations.forEach((annotation: JSONValue) => {
const { type, data } = getValidAnnotation(annotation);
annotations.forEach(({ type, data }) => {
// convert image
if (type === "image" && "url" in data && typeof data.url === "string") {
content.push({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,6 @@ export function appendToolData(
});
}

export function createStreamTimeout(stream: StreamData) {
const timeout = Number(process.env.STREAM_TIMEOUT ?? 1000 * 60 * 5); // default to 5 minutes
const t = setTimeout(() => {
appendEventData(stream, `Stream timed out after ${timeout / 1000} seconds`);
stream.close();
}, timeout);
return t;
}

export function createCallbackManager(stream: StreamData) {
const callbackManager = new CallbackManager();

Expand Down
2 changes: 2 additions & 0 deletions templates/types/streaming/express/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import cors from "cors";
import "dotenv/config";
import express, { Express, Request, Response } from "express";
import { sandbox } from "./src/controllers/sandbox.controller";
import { initObservability } from "./src/observability";
import chatRouter from "./src/routes/chat.route";

Expand Down Expand Up @@ -40,6 +41,7 @@ app.get("/", (req: Request, res: Response) => {
});

app.use("/api/chat", chatRouter);
app.use("/api/sandbox", sandbox);

app.listen(port, () => {
console.log(`⚡️[server]: Server is running at http://localhost:${port}`);
Expand Down
2 changes: 1 addition & 1 deletion templates/types/streaming/express/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"llamaindex": "0.6.2",
"pdf2json": "3.0.5",
"ajv": "^8.12.0",
"@e2b/code-interpreter": "^0.0.5",
"@e2b/code-interpreter": "0.0.9-beta.3",
"got": "^14.4.1",
"@apidevtools/swagger-parser": "^10.1.0",
"formdata-node": "^6.0.3",
Expand Down
Loading

0 comments on commit 5a7216e

Please sign in to comment.