diff --git a/.changeset/modern-cars-travel.md b/.changeset/modern-cars-travel.md new file mode 100644 index 00000000..369abcad --- /dev/null +++ b/.changeset/modern-cars-travel.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +feat: implement artifact tool in TS diff --git a/e2e/shared/multiagent_template.spec.ts b/e2e/shared/multiagent_template.spec.ts index d76d320c..c330b2c9 100644 --- a/e2e/shared/multiagent_template.spec.ts +++ b/e2e/shared/multiagent_template.spec.ts @@ -66,7 +66,7 @@ test.describe(`Test multiagent template ${templateFramework} ${dataSource} ${tem page, }) => { await page.goto(`http://localhost:${port}`); - await page.fill("form input", userMessage); + await page.fill("form textarea", userMessage); const responsePromise = page.waitForResponse((res) => res.url().includes("/api/chat"), diff --git a/e2e/shared/streaming_template.spec.ts b/e2e/shared/streaming_template.spec.ts index 74c7eb4e..91183a91 100644 --- a/e2e/shared/streaming_template.spec.ts +++ b/e2e/shared/streaming_template.spec.ts @@ -72,7 +72,7 @@ test.describe(`Test streaming template ${templateFramework} ${dataSource} ${temp }) => { test.skip(templatePostInstallAction !== "runApp"); await page.goto(`http://localhost:${port}`); - await page.fill("form input", userMessage); + await page.fill("form textarea", userMessage); const [response] = await Promise.all([ page.waitForResponse( (res) => { diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index 49e0ab87..5777701c 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -397,12 +397,6 @@ const getEngineEnvs = (): EnvVar[] => { description: "The number of similar embeddings to return when retrieving documents.", }, - { - name: "STREAM_TIMEOUT", - description: - "The time in milliseconds to wait for the stream to return a response.", - value: "60000", - }, ]; }; diff --git a/helpers/tools.ts b/helpers/tools.ts index 97bde8b6..b65957e7 100644 --- a/helpers/tools.ts +++ b/helpers/tools.ts @@ -162,6 +162,26 @@ For better results, you can specify the region parameter to get results from a s }, ], }, + { + display: "Artifact Code Generator", + name: "artifact", + dependencies: [], + supportedFrameworks: ["express", "nextjs"], + type: ToolType.LOCAL, + envVars: [ + { + name: "E2B_API_KEY", + description: + "E2B_API_KEY key is required to run artifact code generator tool. Get it here: https://e2b.dev/docs/getting-started/api-key", + }, + { + name: TOOL_SYSTEM_PROMPT_ENV_VAR, + description: "System prompt for artifact code generator tool.", + value: + "You are a code assistant that can generate and execute code using its tools. Don't generate code yourself, use the provided tools instead. Do not show the code or sandbox url in chat, just describe the steps to build the application based on the code that is generated by your tools. Do not describe how to run the code, just the steps to build the application.", + }, + ], + }, { display: "OpenAPI action", name: "openapi_action.OpenAPIActionToolSpec", diff --git a/templates/components/engines/typescript/agent/tools/code-generator.ts b/templates/components/engines/typescript/agent/tools/code-generator.ts new file mode 100644 index 00000000..eedcfa51 --- /dev/null +++ b/templates/components/engines/typescript/agent/tools/code-generator.ts @@ -0,0 +1,129 @@ +import type { JSONSchemaType } from "ajv"; +import { + BaseTool, + ChatMessage, + JSONValue, + Settings, + ToolMetadata, +} from "llamaindex"; + +// prompt based on https://github.com/e2b-dev/ai-artifacts +const CODE_GENERATION_PROMPT = `You are a skilled software engineer. You do not make mistakes. Generate an artifact. You can install additional dependencies. You can use one of the following templates:\n + +1. code-interpreter-multilang: "Runs code as a Jupyter notebook cell. Strong data analysis angle. Can use complex visualisation to explain results.". File: script.py. Dependencies installed: python, jupyter, numpy, pandas, matplotlib, seaborn, plotly. Port: none. + +2. nextjs-developer: "A Next.js 13+ app that reloads automatically. Using the pages router.". File: pages/index.tsx. Dependencies installed: nextjs@14.2.5, typescript, @types/node, @types/react, @types/react-dom, postcss, tailwindcss, shadcn. Port: 3000. + +3. vue-developer: "A Vue.js 3+ app that reloads automatically. Only when asked specifically for a Vue app.". File: app.vue. Dependencies installed: vue@latest, nuxt@3.13.0, tailwindcss. Port: 3000. + +4. streamlit-developer: "A streamlit app that reloads automatically.". File: app.py. Dependencies installed: streamlit, pandas, numpy, matplotlib, request, seaborn, plotly. Port: 8501. + +5. gradio-developer: "A gradio app. Gradio Blocks/Interface should be called demo.". File: app.py. Dependencies installed: gradio, pandas, numpy, matplotlib, request, seaborn, plotly. Port: 7860. + +Provide detail information about the artifact you're about to generate in the following JSON format with the following keys: + +commentary: Describe what you're about to do and the steps you want to take for generating the artifact in great detail. +template: Name of the template used to generate the artifact. +title: Short title of the artifact. Max 3 words. +description: Short description of the artifact. Max 1 sentence. +additional_dependencies: Additional dependencies required by the artifact. Do not include dependencies that are already included in the template. +has_additional_dependencies: Detect if additional dependencies that are not included in the template are required by the artifact. +install_dependencies_command: Command to install additional dependencies required by the artifact. +port: Port number used by the resulted artifact. Null when no ports are exposed. +file_path: Relative path to the file, including the file name. +code: Code generated by the artifact. Only runnable code is allowed. + +Make sure to use the correct syntax for the programming language you're using. Make sure to generate only one code file. If you need to use CSS, make sure to include the CSS in the code file using Tailwind CSS syntax. +`; + +// detail information to execute code +export type CodeArtifact = { + commentary: string; + template: string; + title: string; + description: string; + additional_dependencies: string[]; + has_additional_dependencies: boolean; + install_dependencies_command: string; + port: number | null; + file_path: string; + code: string; +}; + +export type CodeGeneratorParameter = { + requirement: string; + oldCode?: string; +}; + +export type CodeGeneratorToolParams = { + metadata?: ToolMetadata>; +}; + +const DEFAULT_META_DATA: ToolMetadata> = + { + name: "artifact", + description: `Generate a code artifact based on the input. Don't call this tool if the user has not asked for code generation. E.g. if the user asks to write a description or specification, don't call this tool.`, + parameters: { + type: "object", + properties: { + requirement: { + type: "string", + description: "The description of the application you want to build.", + }, + oldCode: { + type: "string", + description: "The existing code to be modified", + nullable: true, + }, + }, + required: ["requirement"], + }, + }; + +export class CodeGeneratorTool implements BaseTool { + metadata: ToolMetadata>; + + constructor(params?: CodeGeneratorToolParams) { + this.metadata = params?.metadata || DEFAULT_META_DATA; + } + + async call(input: CodeGeneratorParameter) { + try { + const artifact = await this.generateArtifact( + input.requirement, + input.oldCode, + ); + return artifact as JSONValue; + } catch (error) { + return { isError: true }; + } + } + + // Generate artifact (code, environment, dependencies, etc.) + async generateArtifact( + query: string, + oldCode?: string, + ): Promise { + const userMessage = ` + ${query} + ${oldCode ? `The existing code is: \n\`\`\`${oldCode}\`\`\`` : ""} + `; + const messages: ChatMessage[] = [ + { role: "system", content: CODE_GENERATION_PROMPT }, + { role: "user", content: userMessage }, + ]; + try { + const response = await Settings.llm.chat({ messages }); + const content = response.message.content.toString(); + const jsonContent = content + .replace(/^```json\s*|\s*```$/g, "") + .replace(/^`+|`+$/g, "") + .trim(); + const artifact = JSON.parse(jsonContent) as CodeArtifact; + return artifact; + } catch (error) { + console.log("Failed to generate artifact", error); + throw error; + } + } +} diff --git a/templates/components/engines/typescript/agent/tools/index.ts b/templates/components/engines/typescript/agent/tools/index.ts index b29af048..062e2eb0 100644 --- a/templates/components/engines/typescript/agent/tools/index.ts +++ b/templates/components/engines/typescript/agent/tools/index.ts @@ -1,5 +1,6 @@ import { BaseToolWithCall } from "llamaindex"; import { ToolsFactory } from "llamaindex/tools/ToolsFactory"; +import { CodeGeneratorTool, CodeGeneratorToolParams } from "./code-generator"; import { DocumentGenerator, DocumentGeneratorParams, @@ -47,6 +48,9 @@ const toolFactory: Record = { img_gen: async (config: unknown) => { return [new ImgGeneratorTool(config as ImgGeneratorToolParams)]; }, + artifact: async (config: unknown) => { + return [new CodeGeneratorTool(config as CodeGeneratorToolParams)]; + }, document_generator: async (config: unknown) => { return [new DocumentGenerator(config as DocumentGeneratorParams)]; }, diff --git a/templates/components/llamaindex/typescript/streaming/annotations.ts b/templates/components/llamaindex/typescript/streaming/annotations.ts index 211886a1..13842c7a 100644 --- a/templates/components/llamaindex/typescript/streaming/annotations.ts +++ b/templates/components/llamaindex/typescript/streaming/annotations.ts @@ -1,4 +1,4 @@ -import { JSONValue } from "ai"; +import { JSONValue, Message } from "ai"; import { MessageContent, MessageContentDetail } from "llamaindex"; export type DocumentFileType = "csv" | "pdf" | "txt" | "docx"; @@ -21,13 +21,20 @@ type Annotation = { data: object; }; -export function retrieveDocumentIds(annotations?: JSONValue[]): string[] { - if (!annotations) return []; +export function isValidMessages(messages: Message[]): boolean { + const lastMessage = + messages && messages.length > 0 ? messages[messages.length - 1] : null; + return lastMessage !== null && lastMessage.role === "user"; +} + +export function retrieveDocumentIds(messages: Message[]): string[] { + // retrieve document Ids from the annotations of all messages (if any) + const annotations = getAllAnnotations(messages); + if (annotations.length === 0) return []; const ids: string[] = []; - for (const annotation of annotations) { - const { type, data } = getValidAnnotation(annotation); + for (const { type, data } of annotations) { if ( type === "document_file" && "files" in data && @@ -37,9 +44,7 @@ export function retrieveDocumentIds(annotations?: JSONValue[]): string[] { for (const file of files) { if (Array.isArray(file.content.value)) { // it's an array, so it's an array of doc IDs - for (const id of file.content.value) { - ids.push(id); - } + ids.push(...file.content.value); } } } @@ -48,24 +53,69 @@ export function retrieveDocumentIds(annotations?: JSONValue[]): string[] { return ids; } -export function convertMessageContent( - content: string, - annotations?: JSONValue[], -): MessageContent { - if (!annotations) return content; +export function retrieveMessageContent(messages: Message[]): MessageContent { + const userMessage = messages[messages.length - 1]; return [ { type: "text", - text: content, + text: userMessage.content, }, - ...convertAnnotations(annotations), + ...retrieveLatestArtifact(messages), + ...convertAnnotations(messages), ]; } -function convertAnnotations(annotations: JSONValue[]): MessageContentDetail[] { +function getAllAnnotations(messages: Message[]): Annotation[] { + return messages.flatMap((message) => + (message.annotations ?? []).map((annotation) => + getValidAnnotation(annotation), + ), + ); +} + +// get latest artifact from annotations to append to the user message +function retrieveLatestArtifact(messages: Message[]): MessageContentDetail[] { + const annotations = getAllAnnotations(messages); + if (annotations.length === 0) return []; + + for (const { type, data } of annotations.reverse()) { + if ( + type === "tools" && + "toolCall" in data && + "toolOutput" in data && + typeof data.toolCall === "object" && + typeof data.toolOutput === "object" && + data.toolCall !== null && + data.toolOutput !== null && + "name" in data.toolCall && + data.toolCall.name === "artifact" + ) { + const toolOutput = data.toolOutput as { output?: { code?: string } }; + if (toolOutput.output?.code) { + return [ + { + type: "text", + text: `The existing code is:\n\`\`\`\n${toolOutput.output.code}\n\`\`\``, + }, + ]; + } + } + } + return []; +} + +function convertAnnotations(messages: Message[]): MessageContentDetail[] { + // annotations from the last user message that has annotations + const annotations: Annotation[] = + messages + .slice() + .reverse() + .find((message) => message.role === "user" && message.annotations) + ?.annotations?.map(getValidAnnotation) || []; + if (annotations.length === 0) return []; + const content: MessageContentDetail[] = []; - annotations.forEach((annotation: JSONValue) => { - const { type, data } = getValidAnnotation(annotation); + annotations.forEach(({ type, data }) => { // convert image if (type === "image" && "url" in data && typeof data.url === "string") { content.push({ diff --git a/templates/components/llamaindex/typescript/streaming/events.ts b/templates/components/llamaindex/typescript/streaming/events.ts index 0df964a2..c14af55d 100644 --- a/templates/components/llamaindex/typescript/streaming/events.ts +++ b/templates/components/llamaindex/typescript/streaming/events.ts @@ -69,15 +69,6 @@ export function appendToolData( }); } -export function createStreamTimeout(stream: StreamData) { - const timeout = Number(process.env.STREAM_TIMEOUT ?? 1000 * 60 * 5); // default to 5 minutes - const t = setTimeout(() => { - appendEventData(stream, `Stream timed out after ${timeout / 1000} seconds`); - stream.close(); - }, timeout); - return t; -} - export function createCallbackManager(stream: StreamData) { const callbackManager = new CallbackManager(); diff --git a/templates/types/streaming/express/index.ts b/templates/types/streaming/express/index.ts index 801b8d02..c0fc67b5 100644 --- a/templates/types/streaming/express/index.ts +++ b/templates/types/streaming/express/index.ts @@ -2,6 +2,7 @@ import cors from "cors"; import "dotenv/config"; import express, { Express, Request, Response } from "express"; +import { sandbox } from "./src/controllers/sandbox.controller"; import { initObservability } from "./src/observability"; import chatRouter from "./src/routes/chat.route"; @@ -40,6 +41,7 @@ app.get("/", (req: Request, res: Response) => { }); app.use("/api/chat", chatRouter); +app.use("/api/sandbox", sandbox); app.listen(port, () => { console.log(`⚡️[server]: Server is running at http://localhost:${port}`); diff --git a/templates/types/streaming/express/package.json b/templates/types/streaming/express/package.json index 39d23f85..a5eca116 100644 --- a/templates/types/streaming/express/package.json +++ b/templates/types/streaming/express/package.json @@ -24,7 +24,7 @@ "llamaindex": "0.6.2", "pdf2json": "3.0.5", "ajv": "^8.12.0", - "@e2b/code-interpreter": "^0.0.5", + "@e2b/code-interpreter": "0.0.9-beta.3", "got": "^14.4.1", "@apidevtools/swagger-parser": "^10.1.0", "formdata-node": "^6.0.3", diff --git a/templates/types/streaming/express/src/controllers/chat.controller.ts b/templates/types/streaming/express/src/controllers/chat.controller.ts index 9e4901b1..a8220657 100644 --- a/templates/types/streaming/express/src/controllers/chat.controller.ts +++ b/templates/types/streaming/express/src/controllers/chat.controller.ts @@ -1,64 +1,34 @@ -import { - JSONValue, - LlamaIndexAdapter, - Message, - StreamData, - streamToResponse, -} from "ai"; +import { LlamaIndexAdapter, Message, StreamData, streamToResponse } from "ai"; import { Request, Response } from "express"; import { ChatMessage, Settings } from "llamaindex"; import { createChatEngine } from "./engine/chat"; import { - convertMessageContent, + isValidMessages, retrieveDocumentIds, + retrieveMessageContent, } from "./llamaindex/streaming/annotations"; -import { - createCallbackManager, - createStreamTimeout, -} from "./llamaindex/streaming/events"; +import { createCallbackManager } from "./llamaindex/streaming/events"; import { generateNextQuestions } from "./llamaindex/streaming/suggestion"; export const chat = async (req: Request, res: Response) => { // Init Vercel AI StreamData and timeout const vercelStreamData = new StreamData(); - const streamTimeout = createStreamTimeout(vercelStreamData); try { const { messages, data }: { messages: Message[]; data?: any } = req.body; - const userMessage = messages.pop(); - if (!messages || !userMessage || userMessage.role !== "user") { + if (!isValidMessages(messages)) { return res.status(400).json({ error: "messages are required in the request body and the last message must be from the user", }); } - let annotations = userMessage.annotations; - if (!annotations) { - // the user didn't send any new annotations with the last message - // so use the annotations from the last user message that has annotations - // REASON: GPT4 doesn't consider MessageContentDetail from previous messages, only strings - annotations = messages - .slice() - .reverse() - .find( - (message) => message.role === "user" && message.annotations, - )?.annotations; - } - - // retrieve document Ids from the annotations of all messages (if any) and create chat engine with index - const allAnnotations: JSONValue[] = [...messages, userMessage].flatMap( - (message) => { - return message.annotations ?? []; - }, - ); - const ids = retrieveDocumentIds(allAnnotations); + // retrieve document ids from the annotations of all messages (if any) + const ids = retrieveDocumentIds(messages); + // create chat engine with index using the document ids const chatEngine = await createChatEngine(ids, data); - // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format - const userMessageContent = convertMessageContent( - userMessage.content, - annotations, - ); + // retrieve user message content from Vercel/AI format + const userMessageContent = retrieveMessageContent(messages); // Setup callbacks const callbackManager = createCallbackManager(vercelStreamData); @@ -96,7 +66,5 @@ export const chat = async (req: Request, res: Response) => { return res.status(500).json({ detail: (error as Error).message, }); - } finally { - clearTimeout(streamTimeout); } }; diff --git a/templates/types/streaming/express/src/controllers/sandbox.controller.ts b/templates/types/streaming/express/src/controllers/sandbox.controller.ts new file mode 100644 index 00000000..6013d138 --- /dev/null +++ b/templates/types/streaming/express/src/controllers/sandbox.controller.ts @@ -0,0 +1,140 @@ +/* + * Copyright 2023 FoundryLabs, Inc. + * Portions of this file are copied from the e2b project (https://github.com/e2b-dev/ai-artifacts) + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { + CodeInterpreter, + ExecutionError, + Result, + Sandbox, +} from "@e2b/code-interpreter"; +import { Request, Response } from "express"; +import { saveDocument } from "./llamaindex/documents/helper"; + +type CodeArtifact = { + commentary: string; + template: string; + title: string; + description: string; + additional_dependencies: string[]; + has_additional_dependencies: boolean; + install_dependencies_command: string; + port: number | null; + file_path: string; + code: string; +}; + +const sandboxTimeout = 10 * 60 * 1000; // 10 minute in ms + +export const maxDuration = 60; + +export type ExecutionResult = { + template: string; + stdout: string[]; + stderr: string[]; + runtimeError?: ExecutionError; + outputUrls: Array<{ url: string; filename: string }>; + url: string; +}; + +export const sandbox = async (req: Request, res: Response) => { + const { artifact }: { artifact: CodeArtifact } = req.body; + + let sbx: Sandbox | CodeInterpreter | undefined = undefined; + + // Create a interpreter or a sandbox + if (artifact.template === "code-interpreter-multilang") { + sbx = await CodeInterpreter.create({ + metadata: { template: artifact.template }, + timeoutMs: sandboxTimeout, + }); + console.log("Created code interpreter", sbx.sandboxID); + } else { + sbx = await Sandbox.create(artifact.template, { + metadata: { template: artifact.template, userID: "default" }, + timeoutMs: sandboxTimeout, + }); + console.log("Created sandbox", sbx.sandboxID); + } + + // Install packages + if (artifact.has_additional_dependencies) { + if (sbx instanceof CodeInterpreter) { + await sbx.notebook.execCell(artifact.install_dependencies_command); + console.log( + `Installed dependencies: ${artifact.additional_dependencies.join(", ")} in code interpreter ${sbx.sandboxID}`, + ); + } else if (sbx instanceof Sandbox) { + await sbx.commands.run(artifact.install_dependencies_command); + console.log( + `Installed dependencies: ${artifact.additional_dependencies.join(", ")} in sandbox ${sbx.sandboxID}`, + ); + } + } + + // Copy code to fs + if (artifact.code && Array.isArray(artifact.code)) { + artifact.code.forEach(async (file) => { + await sbx.files.write(file.file_path, file.file_content); + console.log(`Copied file to ${file.file_path} in ${sbx.sandboxID}`); + }); + } else { + await sbx.files.write(artifact.file_path, artifact.code); + console.log(`Copied file to ${artifact.file_path} in ${sbx.sandboxID}`); + } + + // Execute code or return a URL to the running sandbox + if (artifact.template === "code-interpreter-multilang") { + const result = await (sbx as CodeInterpreter).notebook.execCell( + artifact.code || "", + ); + await (sbx as CodeInterpreter).close(); + const outputUrls = await downloadCellResults(result.results); + + return res.status(200).json({ + template: artifact.template, + stdout: result.logs.stdout, + stderr: result.logs.stderr, + runtimeError: result.error, + outputUrls: outputUrls, + }); + } else { + return res.status(200).json({ + template: artifact.template, + url: `https://${sbx?.getHost(artifact.port || 80)}`, + }); + } +}; + +async function downloadCellResults( + cellResults?: Result[], +): Promise> { + if (!cellResults) return []; + const results = await Promise.all( + cellResults.map(async (res) => { + const formats = res.formats(); // available formats in the result + const formatResults = await Promise.all( + formats.map(async (ext) => { + const filename = `${crypto.randomUUID()}.${ext}`; + const base64 = res[ext as keyof Result]; + const buffer = Buffer.from(base64, "base64"); + const fileurl = await saveDocument(filename, buffer); + return { url: fileurl, filename }; + }), + ); + return formatResults; + }), + ); + return results.flat(); +} diff --git a/templates/types/streaming/nextjs/app/api/chat/route.ts b/templates/types/streaming/nextjs/app/api/chat/route.ts index fbb4774c..397ea326 100644 --- a/templates/types/streaming/nextjs/app/api/chat/route.ts +++ b/templates/types/streaming/nextjs/app/api/chat/route.ts @@ -1,17 +1,15 @@ import { initObservability } from "@/app/observability"; -import { JSONValue, LlamaIndexAdapter, Message, StreamData } from "ai"; +import { LlamaIndexAdapter, Message, StreamData } from "ai"; import { ChatMessage, Settings } from "llamaindex"; import { NextRequest, NextResponse } from "next/server"; import { createChatEngine } from "./engine/chat"; import { initSettings } from "./engine/settings"; import { - convertMessageContent, + isValidMessages, retrieveDocumentIds, + retrieveMessageContent, } from "./llamaindex/streaming/annotations"; -import { - createCallbackManager, - createStreamTimeout, -} from "./llamaindex/streaming/events"; +import { createCallbackManager } from "./llamaindex/streaming/events"; import { generateNextQuestions } from "./llamaindex/streaming/suggestion"; initObservability(); @@ -23,13 +21,11 @@ export const dynamic = "force-dynamic"; export async function POST(request: NextRequest) { // Init Vercel AI StreamData and timeout const vercelStreamData = new StreamData(); - const streamTimeout = createStreamTimeout(vercelStreamData); try { const body = await request.json(); const { messages, data }: { messages: Message[]; data?: any } = body; - const userMessage = messages.pop(); - if (!messages || !userMessage || userMessage.role !== "user") { + if (!isValidMessages(messages)) { return NextResponse.json( { error: @@ -39,33 +35,13 @@ export async function POST(request: NextRequest) { ); } - let annotations = userMessage.annotations; - if (!annotations) { - // the user didn't send any new annotations with the last message - // so use the annotations from the last user message that has annotations - // REASON: GPT4 doesn't consider MessageContentDetail from previous messages, only strings - annotations = messages - .slice() - .reverse() - .find( - (message) => message.role === "user" && message.annotations, - )?.annotations; - } - - // retrieve document Ids from the annotations of all messages (if any) and create chat engine with index - const allAnnotations: JSONValue[] = [...messages, userMessage].flatMap( - (message) => { - return message.annotations ?? []; - }, - ); - const ids = retrieveDocumentIds(allAnnotations); + // retrieve document ids from the annotations of all messages (if any) + const ids = retrieveDocumentIds(messages); + // create chat engine with index using the document ids const chatEngine = await createChatEngine(ids, data); - // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format - const userMessageContent = convertMessageContent( - userMessage.content, - annotations, - ); + // retrieve user message content from Vercel/AI format + const userMessageContent = retrieveMessageContent(messages); // Setup callbacks const callbackManager = createCallbackManager(vercelStreamData); @@ -110,7 +86,5 @@ export async function POST(request: NextRequest) { status: 500, }, ); - } finally { - clearTimeout(streamTimeout); } } diff --git a/templates/types/streaming/nextjs/app/api/sandbox/route.ts b/templates/types/streaming/nextjs/app/api/sandbox/route.ts new file mode 100644 index 00000000..cfc20087 --- /dev/null +++ b/templates/types/streaming/nextjs/app/api/sandbox/route.ts @@ -0,0 +1,142 @@ +/* + * Copyright 2023 FoundryLabs, Inc. + * Portions of this file are copied from the e2b project (https://github.com/e2b-dev/ai-artifacts) + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { + CodeInterpreter, + ExecutionError, + Result, + Sandbox, +} from "@e2b/code-interpreter"; +import { saveDocument } from "../chat/llamaindex/documents/helper"; + +type CodeArtifact = { + commentary: string; + template: string; + title: string; + description: string; + additional_dependencies: string[]; + has_additional_dependencies: boolean; + install_dependencies_command: string; + port: number | null; + file_path: string; + code: string; +}; + +const sandboxTimeout = 10 * 60 * 1000; // 10 minute in ms + +export const maxDuration = 60; + +export type ExecutionResult = { + template: string; + stdout: string[]; + stderr: string[]; + runtimeError?: ExecutionError; + outputUrls: Array<{ url: string; filename: string }>; + url: string; +}; + +export async function POST(req: Request) { + const { artifact }: { artifact: CodeArtifact } = await req.json(); + + let sbx: Sandbox | CodeInterpreter | undefined = undefined; + + // Create a interpreter or a sandbox + if (artifact.template === "code-interpreter-multilang") { + sbx = await CodeInterpreter.create({ + metadata: { template: artifact.template }, + timeoutMs: sandboxTimeout, + }); + console.log("Created code interpreter", sbx.sandboxID); + } else { + sbx = await Sandbox.create(artifact.template, { + metadata: { template: artifact.template, userID: "default" }, + timeoutMs: sandboxTimeout, + }); + console.log("Created sandbox", sbx.sandboxID); + } + + // Install packages + if (artifact.has_additional_dependencies) { + if (sbx instanceof CodeInterpreter) { + await sbx.notebook.execCell(artifact.install_dependencies_command); + console.log( + `Installed dependencies: ${artifact.additional_dependencies.join(", ")} in code interpreter ${sbx.sandboxID}`, + ); + } else if (sbx instanceof Sandbox) { + await sbx.commands.run(artifact.install_dependencies_command); + console.log( + `Installed dependencies: ${artifact.additional_dependencies.join(", ")} in sandbox ${sbx.sandboxID}`, + ); + } + } + + // Copy code to fs + if (artifact.code && Array.isArray(artifact.code)) { + artifact.code.forEach(async (file) => { + await sbx.files.write(file.file_path, file.file_content); + console.log(`Copied file to ${file.file_path} in ${sbx.sandboxID}`); + }); + } else { + await sbx.files.write(artifact.file_path, artifact.code); + console.log(`Copied file to ${artifact.file_path} in ${sbx.sandboxID}`); + } + + // Execute code or return a URL to the running sandbox + if (artifact.template === "code-interpreter-multilang") { + const result = await (sbx as CodeInterpreter).notebook.execCell( + artifact.code || "", + ); + await (sbx as CodeInterpreter).close(); + const outputUrls = await downloadCellResults(result.results); + return new Response( + JSON.stringify({ + template: artifact.template, + stdout: result.logs.stdout, + stderr: result.logs.stderr, + runtimeError: result.error, + outputUrls: outputUrls, + }), + ); + } else { + return new Response( + JSON.stringify({ + template: artifact.template, + url: `https://${sbx?.getHost(artifact.port || 80)}`, + }), + ); + } +} + +async function downloadCellResults( + cellResults?: Result[], +): Promise> { + if (!cellResults) return []; + const results = await Promise.all( + cellResults.map(async (res) => { + const formats = res.formats(); // available formats in the result + const formatResults = await Promise.all( + formats.map(async (ext) => { + const filename = `${crypto.randomUUID()}.${ext}`; + const base64 = res[ext as keyof Result]; + const buffer = Buffer.from(base64, "base64"); + const fileurl = await saveDocument(filename, buffer); + return { url: fileurl, filename }; + }), + ); + return formatResults; + }), + ); + return results.flat(); +} diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx index 9d1cb44e..326cc969 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx @@ -1,8 +1,9 @@ import { JSONValue } from "ai"; +import React from "react"; import { Button } from "../button"; import { DocumentPreview } from "../document-preview"; import FileUploader from "../file-uploader"; -import { Input } from "../input"; +import { Textarea } from "../textarea"; import UploadImagePreview from "../upload-image-preview"; import { ChatHandler } from "./chat.interface"; import { useFile } from "./hooks/use-file"; @@ -54,6 +55,7 @@ export default function ChatInput( }; const onSubmit = (e: React.FormEvent) => { + e.preventDefault(); const annotations = getAnnotations(); if (annotations.length) { handleSubmitWithAnnotations(e, annotations); @@ -76,6 +78,13 @@ export default function ChatInput( } }; + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter" && !e.shiftKey) { + e.preventDefault(); + onSubmit(e as unknown as React.FormEvent); + } + }; + return (
)}
- ; + case "artifact": + return ( + + ); default: return null; } diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/index.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/index.tsx index 375b1d4c..47ec2ba8 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/index.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/index.tsx @@ -37,11 +37,13 @@ function ChatMessageContent({ isLoading, append, isLastMessage, + artifactVersion, }: { message: Message; isLoading: boolean; append: Pick["append"]; isLastMessage: boolean; + artifactVersion: number | undefined; }) { const annotations = message.annotations as MessageAnnotation[] | undefined; if (!annotations?.length) return ; @@ -104,7 +106,9 @@ function ChatMessageContent({ }, { order: -1, - component: toolData[0] ? : null, + component: toolData[0] ? ( + + ) : null, }, { order: 0, @@ -142,11 +146,13 @@ export default function ChatMessage({ isLoading, append, isLastMessage, + artifactVersion, }: { chatMessage: Message; isLoading: boolean; append: Pick["append"]; isLastMessage: boolean; + artifactVersion: number | undefined; }) { const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 2000 }); return ( @@ -158,6 +164,7 @@ export default function ChatMessage({ isLoading={isLoading} append={append} isLastMessage={isLastMessage} + artifactVersion={artifactVersion} /> +
+ + + Code + Preview + + +
+ +
+
+ + {runtimeError && } + + {sandboxUrl && } + {outputUrls && } + +
+ + ); +} + +function RunTimeError({ + runtimeError, +}: { + runtimeError: { name: string; value: string; tracebackRaw: string[] }; +}) { + const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 1000 }); + const contentToCopy = `Fix this error:\n${runtimeError.name}\n${runtimeError.value}\n${runtimeError.tracebackRaw.join("\n")}`; + return ( + + + Runtime Error: + + + +
+

{runtimeError.name}

+

{runtimeError.value}

+ {runtimeError.tracebackRaw.map((trace, index) => ( +
+              {trace}
+            
+ ))} +
+ +
+
+ ); +} + +function CodeSandboxPreview({ url }: { url: string }) { + const [loading, setLoading] = useState(true); + const iframeRef = useRef(null); + + useEffect(() => { + if (!loading && iframeRef.current) { + iframeRef.current.focus(); + } + }, [loading]); + + return ( + <> +