From 571469d05b311d596a8312bd234400fceba9381f Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Thu, 14 Nov 2024 13:37:05 +0100 Subject: [PATCH] split data and assistant stream processing (#3673) --- content/docs/06-advanced/04-caching.mdx | 4 +- .../01-migration-guide-4-0.mdx | 8 + .../app/api/use-chat-cache/route.ts | 4 +- packages/ai/core/generate-text/stream-text.ts | 21 +- packages/ai/streams/assistant-response.ts | 18 +- packages/ai/streams/index.ts | 26 +- packages/ai/streams/stream-data.ts | 8 +- packages/react/src/use-assistant.ts | 4 +- packages/react/src/use-assistant.ui.test.tsx | 22 +- packages/react/src/use-chat.ui.test.tsx | 34 +-- packages/solid/src/use-chat.ui.test.tsx | 32 +-- packages/svelte/src/use-assistant.ts | 4 +- .../src/assistant-stream-parts.test.ts | 20 ++ .../ui-utils/src/assistant-stream-parts.ts | 220 ++++++++++++++++ ...arts.test.ts => data-stream-parts.test.ts} | 58 ++--- .../{stream-parts.ts => data-stream-parts.ts} | 181 +++----------- packages/ui-utils/src/index.ts | 15 +- .../src/process-assistant-stream.test.ts | 234 ++++++++++++++++++ .../ui-utils/src/process-assistant-stream.ts | 66 +++++ .../process-data-procotol-response.test.ts | 54 ++-- .../ui-utils/src/process-data-stream.test.ts | 47 +--- packages/ui-utils/src/process-data-stream.ts | 6 +- .../src/test/create-data-protocol-stream.ts | 4 +- packages/vue/src/use-assistant.ts | 6 +- packages/vue/src/use-assistant.ui.test.tsx | 22 +- packages/vue/src/use-chat.ui.test.tsx | 38 +-- 26 files changed, 789 insertions(+), 367 deletions(-) create mode 100644 packages/ui-utils/src/assistant-stream-parts.test.ts create mode 100644 packages/ui-utils/src/assistant-stream-parts.ts rename packages/ui-utils/src/{stream-parts.test.ts => data-stream-parts.test.ts} (80%) rename packages/ui-utils/src/{stream-parts.ts => data-stream-parts.ts} (69%) create mode 100644 packages/ui-utils/src/process-assistant-stream.test.ts create mode 100644 packages/ui-utils/src/process-assistant-stream.ts diff --git a/content/docs/06-advanced/04-caching.mdx b/content/docs/06-advanced/04-caching.mdx index 59d6d02d666d..546dda794449 100644 --- a/content/docs/06-advanced/04-caching.mdx +++ b/content/docs/06-advanced/04-caching.mdx @@ -18,7 +18,7 @@ This example uses [Vercel KV](https://vercel.com/storage/kv) and Next.js to cach ```tsx filename="app/api/chat/route.ts" import { openai } from '@ai-sdk/openai'; -import { formatStreamPart, streamText } from 'ai'; +import { formatDataStreamPart, streamText } from 'ai'; import kv from '@vercel/kv'; // Allow streaming responses up to 30 seconds @@ -36,7 +36,7 @@ export async function POST(req: Request) { // Check if we have a cached response const cached = await kv.get(key); if (cached != null) { - return new Response(formatStreamPart('text', cached), { + return new Response(formatDataStreamPart('text', cached), { status: 200, headers: { 'Content-Type': 'text/plain' }, }); diff --git a/content/docs/08-troubleshooting/01-migration-guide/01-migration-guide-4-0.mdx b/content/docs/08-troubleshooting/01-migration-guide/01-migration-guide-4-0.mdx index 3690d62c9e43..e4c74c0098ef 100644 --- a/content/docs/08-troubleshooting/01-migration-guide/01-migration-guide-4-0.mdx +++ b/content/docs/08-troubleshooting/01-migration-guide/01-migration-guide-4-0.mdx @@ -159,6 +159,14 @@ The following methods have been removed from the `streamText` result: - `pipeAIStreamToResponse` - `toAIStreamResponse` +### Renamed "formatStreamPart" to "formatDataStreamPart" + +The `formatStreamPart` function has been renamed to `formatDataStreamPart`. + +### Renamed "parseStreamPart" to "parseDataStreamPart" + +The `parseStreamPart` function has been renamed to `parseDataStreamPart`. + ### Removed `TokenUsage`, `CompletionTokenUsage` and `EmbeddingTokenUsage` types The `TokenUsage`, `CompletionTokenUsage` and `EmbeddingTokenUsage` types have been removed. diff --git a/examples/next-openai/app/api/use-chat-cache/route.ts b/examples/next-openai/app/api/use-chat-cache/route.ts index e713d3a856c9..6670ac385d0a 100644 --- a/examples/next-openai/app/api/use-chat-cache/route.ts +++ b/examples/next-openai/app/api/use-chat-cache/route.ts @@ -1,5 +1,5 @@ import { openai } from '@ai-sdk/openai'; -import { formatStreamPart, streamText } from 'ai'; +import { formatDataStreamPart, streamText } from 'ai'; // Allow streaming responses up to 30 seconds export const maxDuration = 30; @@ -16,7 +16,7 @@ export async function POST(req: Request) { // Check if we have a cached response const cached = cache.get(key); if (cached != null) { - return new Response(formatStreamPart('text', cached), { + return new Response(formatDataStreamPart('text', cached), { status: 200, headers: { 'Content-Type': 'text/plain' }, }); diff --git a/packages/ai/core/generate-text/stream-text.ts b/packages/ai/core/generate-text/stream-text.ts index 1065bb987bbb..a3a39209ee20 100644 --- a/packages/ai/core/generate-text/stream-text.ts +++ b/packages/ai/core/generate-text/stream-text.ts @@ -1,6 +1,5 @@ import { createIdGenerator } from '@ai-sdk/provider-utils'; -import { formatStreamPart } from '@ai-sdk/ui-utils'; -import { Span } from '@opentelemetry/api'; +import { formatDataStreamPart } from '@ai-sdk/ui-utils'; import { ServerResponse } from 'node:http'; import { InvalidArgumentError } from '../../errors/invalid-argument-error'; import { StreamData } from '../../streams/stream-data'; @@ -21,13 +20,11 @@ import { selectTelemetryAttributes } from '../telemetry/select-telemetry-attribu import { TelemetrySettings } from '../telemetry/telemetry-settings'; import { CoreTool } from '../tool'; import { - CallWarning, CoreToolChoice, FinishReason, LanguageModel, LogProbs, } from '../types/language-model'; -import { LanguageModelRequestMetadata } from '../types/language-model-request-metadata'; import { ProviderMetadata } from '../types/provider-metadata'; import { LanguageModelUsage } from '../types/usage'; import { @@ -1074,13 +1071,13 @@ However, the LLM results are expected to be small enough to not cause issues. const chunkType = chunk.type; switch (chunkType) { case 'text-delta': { - controller.enqueue(formatStreamPart('text', chunk.textDelta)); + controller.enqueue(formatDataStreamPart('text', chunk.textDelta)); break; } case 'tool-call-streaming-start': { controller.enqueue( - formatStreamPart('tool_call_streaming_start', { + formatDataStreamPart('tool_call_streaming_start', { toolCallId: chunk.toolCallId, toolName: chunk.toolName, }), @@ -1090,7 +1087,7 @@ However, the LLM results are expected to be small enough to not cause issues. case 'tool-call-delta': { controller.enqueue( - formatStreamPart('tool_call_delta', { + formatDataStreamPart('tool_call_delta', { toolCallId: chunk.toolCallId, argsTextDelta: chunk.argsTextDelta, }), @@ -1100,7 +1097,7 @@ However, the LLM results are expected to be small enough to not cause issues. case 'tool-call': { controller.enqueue( - formatStreamPart('tool_call', { + formatDataStreamPart('tool_call', { toolCallId: chunk.toolCallId, toolName: chunk.toolName, args: chunk.args, @@ -1111,7 +1108,7 @@ However, the LLM results are expected to be small enough to not cause issues. case 'tool-result': { controller.enqueue( - formatStreamPart('tool_result', { + formatDataStreamPart('tool_result', { toolCallId: chunk.toolCallId, result: chunk.result, }), @@ -1121,14 +1118,14 @@ However, the LLM results are expected to be small enough to not cause issues. case 'error': { controller.enqueue( - formatStreamPart('error', getErrorMessage(chunk.error)), + formatDataStreamPart('error', getErrorMessage(chunk.error)), ); break; } case 'step-finish': { controller.enqueue( - formatStreamPart('finish_step', { + formatDataStreamPart('finish_step', { finishReason: chunk.finishReason, usage: sendUsage ? { @@ -1144,7 +1141,7 @@ However, the LLM results are expected to be small enough to not cause issues. case 'finish': { controller.enqueue( - formatStreamPart('finish_message', { + formatDataStreamPart('finish_message', { finishReason: chunk.finishReason, usage: sendUsage ? { diff --git a/packages/ai/streams/assistant-response.ts b/packages/ai/streams/assistant-response.ts index 6048e1c438e9..1a22b8d98c4c 100644 --- a/packages/ai/streams/assistant-response.ts +++ b/packages/ai/streams/assistant-response.ts @@ -1,7 +1,7 @@ import { AssistantMessage, DataMessage, - formatStreamPart, + formatAssistantStreamPart, } from '@ai-sdk/ui-utils'; /** @@ -54,19 +54,23 @@ export function AssistantResponse( const sendMessage = (message: AssistantMessage) => { controller.enqueue( - textEncoder.encode(formatStreamPart('assistant_message', message)), + textEncoder.encode( + formatAssistantStreamPart('assistant_message', message), + ), ); }; const sendDataMessage = (message: DataMessage) => { controller.enqueue( - textEncoder.encode(formatStreamPart('data_message', message)), + textEncoder.encode( + formatAssistantStreamPart('data_message', message), + ), ); }; const sendError = (errorMessage: string) => { controller.enqueue( - textEncoder.encode(formatStreamPart('error', errorMessage)), + textEncoder.encode(formatAssistantStreamPart('error', errorMessage)), ); }; @@ -78,7 +82,7 @@ export function AssistantResponse( case 'thread.message.created': { controller.enqueue( textEncoder.encode( - formatStreamPart('assistant_message', { + formatAssistantStreamPart('assistant_message', { id: value.data.id, role: 'assistant', content: [{ type: 'text', text: { value: '' } }], @@ -94,7 +98,7 @@ export function AssistantResponse( if (content?.type === 'text' && content.text?.value != null) { controller.enqueue( textEncoder.encode( - formatStreamPart('text', content.text.value), + formatAssistantStreamPart('text', content.text.value), ), ); } @@ -116,7 +120,7 @@ export function AssistantResponse( // send the threadId and messageId as the first message: controller.enqueue( textEncoder.encode( - formatStreamPart('assistant_control_data', { + formatAssistantStreamPart('assistant_control_data', { threadId, messageId, }), diff --git a/packages/ai/streams/index.ts b/packages/ai/streams/index.ts index 08a5eec285ca..f82a7cff6b74 100644 --- a/packages/ai/streams/index.ts +++ b/packages/ai/streams/index.ts @@ -1,26 +1,28 @@ // forwarding exports from ui-utils: export { - formatStreamPart, - parseStreamPart, + formatAssistantStreamPart, + formatDataStreamPart, + parseAssistantStreamPart, + parseDataStreamPart, + processDataProtocolResponse, processDataStream, processTextStream, - processDataProtocolResponse, } from '@ai-sdk/ui-utils'; export type { - AssistantStatus, - UseAssistantOptions, - Message, - CreateMessage, - DataMessage, AssistantMessage, - JSONValue, + AssistantStatus, + Attachment, ChatRequest, ChatRequestOptions, - ToolInvocation, - StreamPart, + CreateMessage, + DataMessage, + DataStreamPart, IdGenerator, + JSONValue, + Message, RequestOptions, - Attachment, + ToolInvocation, + UseAssistantOptions, } from '@ai-sdk/ui-utils'; export { generateId } from '@ai-sdk/provider-utils'; diff --git a/packages/ai/streams/stream-data.ts b/packages/ai/streams/stream-data.ts index 2554afe0d3cd..e7d2139bc8fb 100644 --- a/packages/ai/streams/stream-data.ts +++ b/packages/ai/streams/stream-data.ts @@ -1,4 +1,4 @@ -import { JSONValue, formatStreamPart } from '@ai-sdk/ui-utils'; +import { JSONValue, formatDataStreamPart } from '@ai-sdk/ui-utils'; import { HANGING_STREAM_WARNING_TIME_MS } from '../util/constants'; /** @@ -66,7 +66,7 @@ export class StreamData { } this.controller.enqueue( - this.encoder.encode(formatStreamPart('data', [value])), + this.encoder.encode(formatDataStreamPart('data', [value])), ); } @@ -80,7 +80,7 @@ export class StreamData { } this.controller.enqueue( - this.encoder.encode(formatStreamPart('message_annotations', [value])), + this.encoder.encode(formatDataStreamPart('message_annotations', [value])), ); } } @@ -95,7 +95,7 @@ export function createStreamDataTransformer() { return new TransformStream({ transform: async (chunk, controller) => { const message = decoder.decode(chunk); - controller.enqueue(encoder.encode(formatStreamPart('text', message))); + controller.enqueue(encoder.encode(formatDataStreamPart('text', message))); }, }); } diff --git a/packages/react/src/use-assistant.ts b/packages/react/src/use-assistant.ts index 568f4cc7f486..ca70372a6235 100644 --- a/packages/react/src/use-assistant.ts +++ b/packages/react/src/use-assistant.ts @@ -5,7 +5,7 @@ import { Message, UseAssistantOptions, generateId, - processDataStream, + processAssistantStream, } from '@ai-sdk/ui-utils'; import { useCallback, useRef, useState } from 'react'; @@ -176,7 +176,7 @@ export function useAssistant({ throw new Error('The response body is empty.'); } - await processDataStream({ + await processAssistantStream({ stream: response.body, onStreamPart: async ({ type, value }) => { switch (type) { diff --git a/packages/react/src/use-assistant.ui.test.tsx b/packages/react/src/use-assistant.ui.test.tsx index 37dd58405b49..4e3cbfcae851 100644 --- a/packages/react/src/use-assistant.ui.test.tsx +++ b/packages/react/src/use-assistant.ui.test.tsx @@ -1,4 +1,4 @@ -import { formatStreamPart } from '@ai-sdk/ui-utils'; +import { formatAssistantStreamPart } from '@ai-sdk/ui-utils'; import { mockFetchDataStream, mockFetchDataStreamWithGenerator, @@ -49,11 +49,11 @@ describe('stream data stream', () => { const { requestBody } = mockFetchDataStream({ url: 'https://example.com/api/assistant', chunks: [ - formatStreamPart('assistant_control_data', { + formatAssistantStreamPart('assistant_control_data', { threadId: 't0', messageId: 'm0', }), - formatStreamPart('assistant_message', { + formatAssistantStreamPart('assistant_message', { id: 'm0', role: 'assistant', content: [{ type: 'text', text: { value: '' } }], @@ -109,14 +109,14 @@ describe('stream data stream', () => { const encoder = new TextEncoder(); yield encoder.encode( - formatStreamPart('assistant_control_data', { + formatAssistantStreamPart('assistant_control_data', { threadId: 't0', messageId: 'm1', }), ); yield encoder.encode( - formatStreamPart('assistant_message', { + formatAssistantStreamPart('assistant_message', { id: 'm1', role: 'assistant', content: [{ type: 'text', text: { value: '' } }], @@ -203,11 +203,11 @@ describe('thread management', () => { const { requestBody } = mockFetchDataStream({ url: 'https://example.com/api/assistant', chunks: [ - formatStreamPart('assistant_control_data', { + formatAssistantStreamPart('assistant_control_data', { threadId: 't0', messageId: 'm0', }), - formatStreamPart('assistant_message', { + formatAssistantStreamPart('assistant_message', { id: 'm0', role: 'assistant', content: [{ type: 'text', text: { value: '' } }], @@ -250,11 +250,11 @@ describe('thread management', () => { const { requestBody } = mockFetchDataStream({ url: 'https://example.com/api/assistant', chunks: [ - formatStreamPart('assistant_control_data', { + formatAssistantStreamPart('assistant_control_data', { threadId: 't1', messageId: 'm0', }), - formatStreamPart('assistant_message', { + formatAssistantStreamPart('assistant_message', { id: 'm0', role: 'assistant', content: [{ type: 'text', text: { value: '' } }], @@ -297,11 +297,11 @@ describe('thread management', () => { const { requestBody } = mockFetchDataStream({ url: 'https://example.com/api/assistant', chunks: [ - formatStreamPart('assistant_control_data', { + formatAssistantStreamPart('assistant_control_data', { threadId: 't3', messageId: 'm0', }), - formatStreamPart('assistant_message', { + formatAssistantStreamPart('assistant_message', { id: 'm0', role: 'assistant', content: [{ type: 'text', text: { value: '' } }], diff --git a/packages/react/src/use-chat.ui.test.tsx b/packages/react/src/use-chat.ui.test.tsx index 56d377d74dbe..df0b5a45c346 100644 --- a/packages/react/src/use-chat.ui.test.tsx +++ b/packages/react/src/use-chat.ui.test.tsx @@ -1,7 +1,7 @@ /* eslint-disable @next/next/no-img-element */ import { withTestServer } from '@ai-sdk/provider-utils/test'; import { - formatStreamPart, + formatDataStreamPart, generateId, getTextFromDataUrl, Message, @@ -240,11 +240,11 @@ describe('data protocol stream', () => { url: '/api/chat', type: 'stream-values', content: [ - formatStreamPart('text', 'Hello'), - formatStreamPart('text', ','), - formatStreamPart('text', ' world'), - formatStreamPart('text', '.'), - formatStreamPart('finish_message', { + formatDataStreamPart('text', 'Hello'), + formatDataStreamPart('text', ','), + formatDataStreamPart('text', ' world'), + formatDataStreamPart('text', '.'), + formatDataStreamPart('finish_message', { finishReason: 'stop', usage: { completionTokens: 1, promptTokens: 3 }, }), @@ -749,7 +749,7 @@ describe('onToolCall', () => { url: '/api/chat', type: 'stream-values', content: [ - formatStreamPart('tool_call', { + formatDataStreamPart('tool_call', { toolCallId: 'tool-call-0', toolName: 'test-tool', args: { testArg: 'test-value' }, @@ -813,7 +813,7 @@ describe('tool invocations', () => { await userEvent.click(screen.getByTestId('do-append')); streamController.enqueue( - formatStreamPart('tool_call_streaming_start', { + formatDataStreamPart('tool_call_streaming_start', { toolCallId: 'tool-call-0', toolName: 'test-tool', }), @@ -826,7 +826,7 @@ describe('tool invocations', () => { }); streamController.enqueue( - formatStreamPart('tool_call_delta', { + formatDataStreamPart('tool_call_delta', { toolCallId: 'tool-call-0', argsTextDelta: '{"testArg":"t', }), @@ -839,7 +839,7 @@ describe('tool invocations', () => { }); streamController.enqueue( - formatStreamPart('tool_call_delta', { + formatDataStreamPart('tool_call_delta', { toolCallId: 'tool-call-0', argsTextDelta: 'est-value"}}', }), @@ -852,7 +852,7 @@ describe('tool invocations', () => { }); streamController.enqueue( - formatStreamPart('tool_call', { + formatDataStreamPart('tool_call', { toolCallId: 'tool-call-0', toolName: 'test-tool', args: { testArg: 'test-value' }, @@ -866,7 +866,7 @@ describe('tool invocations', () => { }); streamController.enqueue( - formatStreamPart('tool_result', { + formatDataStreamPart('tool_result', { toolCallId: 'tool-call-0', result: 'test-result', }), @@ -890,7 +890,7 @@ describe('tool invocations', () => { await userEvent.click(screen.getByTestId('do-append')); streamController.enqueue( - formatStreamPart('tool_call', { + formatDataStreamPart('tool_call', { toolCallId: 'tool-call-0', toolName: 'test-tool', args: { testArg: 'test-value' }, @@ -904,7 +904,7 @@ describe('tool invocations', () => { }); streamController.enqueue( - formatStreamPart('tool_result', { + formatDataStreamPart('tool_result', { toolCallId: 'tool-call-0', result: 'test-result', }), @@ -973,7 +973,7 @@ describe('maxSteps', () => { url: '/api/chat', type: 'stream-values', content: [ - formatStreamPart('tool_call', { + formatDataStreamPart('tool_call', { toolCallId: 'tool-call-0', toolName: 'test-tool', args: { testArg: 'test-value' }, @@ -983,7 +983,7 @@ describe('maxSteps', () => { { url: '/api/chat', type: 'stream-values', - content: [formatStreamPart('text', 'final result')], + content: [formatDataStreamPart('text', 'final result')], }, ], async () => { @@ -1058,7 +1058,7 @@ describe('maxSteps', () => { url: '/api/chat', type: 'stream-values', content: [ - formatStreamPart('tool_call', { + formatDataStreamPart('tool_call', { toolCallId: 'tool-call-0', toolName: 'test-tool', args: { testArg: 'test-value' }, diff --git a/packages/solid/src/use-chat.ui.test.tsx b/packages/solid/src/use-chat.ui.test.tsx index 897d9633d2fc..a7e989c10488 100644 --- a/packages/solid/src/use-chat.ui.test.tsx +++ b/packages/solid/src/use-chat.ui.test.tsx @@ -1,6 +1,6 @@ /** @jsxImportSource solid-js */ import { withTestServer } from '@ai-sdk/provider-utils/test'; -import { formatStreamPart, Message } from '@ai-sdk/ui-utils'; +import { formatDataStreamPart, Message } from '@ai-sdk/ui-utils'; import { mockFetchDataStream } from '@ai-sdk/ui-utils/test'; import { cleanup, findByText, render, screen } from '@solidjs/testing-library'; import '@testing-library/jest-dom'; @@ -216,11 +216,11 @@ describe('data protocol stream', () => { url: '/api/chat', type: 'stream-values', content: [ - formatStreamPart('text', 'Hello'), - formatStreamPart('text', ','), - formatStreamPart('text', ' world'), - formatStreamPart('text', '.'), - formatStreamPart('finish_message', { + formatDataStreamPart('text', 'Hello'), + formatDataStreamPart('text', ','), + formatDataStreamPart('text', ' world'), + formatDataStreamPart('text', '.'), + formatDataStreamPart('finish_message', { finishReason: 'stop', usage: { completionTokens: 1, promptTokens: 3 }, }), @@ -443,7 +443,7 @@ describe('onToolCall', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: [ - formatStreamPart('tool_call', { + formatDataStreamPart('tool_call', { toolCallId: 'tool-call-0', toolName: 'test-tool', args: { testArg: 'test-value' }, @@ -467,7 +467,7 @@ describe('maxSteps', () => { async onToolCall({ toolCall }) { mockFetchDataStream({ url: 'https://example.com/api/chat', - chunks: [formatStreamPart('text', 'final result')], + chunks: [formatDataStreamPart('text', 'final result')], }); return `test-tool-response: ${toolCall.toolName} ${ @@ -508,7 +508,7 @@ describe('maxSteps', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: [ - formatStreamPart('tool_call', { + formatDataStreamPart('tool_call', { toolCallId: 'tool-call-0', toolName: 'test-tool', args: { testArg: 'test-value' }, @@ -529,7 +529,7 @@ describe('maxSteps', () => { async onToolCall({ toolCall }) { mockFetchDataStream({ url: 'https://example.com/api/chat', - chunks: [formatStreamPart('error', 'some failure')], + chunks: [formatDataStreamPart('error', 'some failure')], maxCalls: 1, }); @@ -583,7 +583,7 @@ describe('maxSteps', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: [ - formatStreamPart('tool_call', { + formatDataStreamPart('tool_call', { toolCallId: 'tool-call-0', toolName: 'test-tool', args: { testArg: 'test-value' }, @@ -643,7 +643,7 @@ describe('form actions', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: ['Hello', ',', ' world', '.'].map(token => - formatStreamPart('text', token), + formatDataStreamPart('text', token), ), }); @@ -665,7 +665,7 @@ describe('form actions', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: ['How', ' can', ' I', ' help', ' you', '?'].map(token => - formatStreamPart('text', token), + formatDataStreamPart('text', token), ), }); @@ -724,7 +724,7 @@ describe('form actions (with options)', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: ['Hello', ',', ' world', '.'].map(token => - formatStreamPart('text', token), + formatDataStreamPart('text', token), ), }); @@ -746,7 +746,7 @@ describe('form actions (with options)', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: ['How', ' can', ' I', ' help', ' you', '?'].map(token => - formatStreamPart('text', token), + formatDataStreamPart('text', token), ), }); @@ -761,7 +761,7 @@ describe('form actions (with options)', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: ['The', ' sky', ' is', ' blue.'].map(token => - formatStreamPart('text', token), + formatDataStreamPart('text', token), ), }); diff --git a/packages/svelte/src/use-assistant.ts b/packages/svelte/src/use-assistant.ts index aa061eca9e91..98341ca282c6 100644 --- a/packages/svelte/src/use-assistant.ts +++ b/packages/svelte/src/use-assistant.ts @@ -5,7 +5,7 @@ import type { Message, UseAssistantOptions, } from '@ai-sdk/ui-utils'; -import { generateId, processDataStream } from '@ai-sdk/ui-utils'; +import { generateId, processAssistantStream } from '@ai-sdk/ui-utils'; import { Readable, Writable, get, writable } from 'svelte/store'; // use function to allow for mocking in tests: @@ -143,7 +143,7 @@ export function useAssistant({ throw new Error('The response body is empty.'); } - await processDataStream({ + await processAssistantStream({ stream: response.body, onStreamPart: async ({ type, value }) => { switch (type) { diff --git a/packages/ui-utils/src/assistant-stream-parts.test.ts b/packages/ui-utils/src/assistant-stream-parts.test.ts new file mode 100644 index 000000000000..11e5559622e0 --- /dev/null +++ b/packages/ui-utils/src/assistant-stream-parts.test.ts @@ -0,0 +1,20 @@ +import { + formatAssistantStreamPart, + parseAssistantStreamPart, +} from './assistant-stream-parts'; + +describe('text stream part', () => { + it('should format a text stream part', () => { + expect(formatAssistantStreamPart('text', 'value\nvalue')).toEqual( + '0:"value\\nvalue"\n', + ); + }); + + it('should parse a text line', () => { + const input = '0:"Hello, world!"'; + expect(parseAssistantStreamPart(input)).toEqual({ + type: 'text', + value: 'Hello, world!', + }); + }); +}); diff --git a/packages/ui-utils/src/assistant-stream-parts.ts b/packages/ui-utils/src/assistant-stream-parts.ts new file mode 100644 index 000000000000..959378b94f3e --- /dev/null +++ b/packages/ui-utils/src/assistant-stream-parts.ts @@ -0,0 +1,220 @@ +import { AssistantMessage, DataMessage, JSONValue } from './types'; + +export type AssistantStreamString = + `${(typeof StreamStringPrefixes)[keyof typeof StreamStringPrefixes]}:${string}\n`; + +export interface AssistantStreamPart< + CODE extends string, + NAME extends string, + TYPE, +> { + code: CODE; + name: NAME; + parse: (value: JSONValue) => { type: NAME; value: TYPE }; +} + +const textStreamPart: AssistantStreamPart<'0', 'text', string> = { + code: '0', + name: 'text', + parse: (value: JSONValue) => { + if (typeof value !== 'string') { + throw new Error('"text" parts expect a string value.'); + } + return { type: 'text', value }; + }, +}; + +const errorStreamPart: AssistantStreamPart<'3', 'error', string> = { + code: '3', + name: 'error', + parse: (value: JSONValue) => { + if (typeof value !== 'string') { + throw new Error('"error" parts expect a string value.'); + } + return { type: 'error', value }; + }, +}; + +const assistantMessageStreamPart: AssistantStreamPart< + '4', + 'assistant_message', + AssistantMessage +> = { + code: '4', + name: 'assistant_message', + parse: (value: JSONValue) => { + if ( + value == null || + typeof value !== 'object' || + !('id' in value) || + !('role' in value) || + !('content' in value) || + typeof value.id !== 'string' || + typeof value.role !== 'string' || + value.role !== 'assistant' || + !Array.isArray(value.content) || + !value.content.every( + item => + item != null && + typeof item === 'object' && + 'type' in item && + item.type === 'text' && + 'text' in item && + item.text != null && + typeof item.text === 'object' && + 'value' in item.text && + typeof item.text.value === 'string', + ) + ) { + throw new Error( + '"assistant_message" parts expect an object with an "id", "role", and "content" property.', + ); + } + + return { + type: 'assistant_message', + value: value as AssistantMessage, + }; + }, +}; + +const assistantControlDataStreamPart: AssistantStreamPart< + '5', + 'assistant_control_data', + { + threadId: string; + messageId: string; + } +> = { + code: '5', + name: 'assistant_control_data', + parse: (value: JSONValue) => { + if ( + value == null || + typeof value !== 'object' || + !('threadId' in value) || + !('messageId' in value) || + typeof value.threadId !== 'string' || + typeof value.messageId !== 'string' + ) { + throw new Error( + '"assistant_control_data" parts expect an object with a "threadId" and "messageId" property.', + ); + } + + return { + type: 'assistant_control_data', + value: { + threadId: value.threadId, + messageId: value.messageId, + }, + }; + }, +}; + +const dataMessageStreamPart: AssistantStreamPart< + '6', + 'data_message', + DataMessage +> = { + code: '6', + name: 'data_message', + parse: (value: JSONValue) => { + if ( + value == null || + typeof value !== 'object' || + !('role' in value) || + !('data' in value) || + typeof value.role !== 'string' || + value.role !== 'data' + ) { + throw new Error( + '"data_message" parts expect an object with a "role" and "data" property.', + ); + } + + return { + type: 'data_message', + value: value as DataMessage, + }; + }, +}; + +const assistantStreamParts = [ + textStreamPart, + errorStreamPart, + assistantMessageStreamPart, + assistantControlDataStreamPart, + dataMessageStreamPart, +] as const; + +type AssistantStreamParts = + | typeof textStreamPart + | typeof errorStreamPart + | typeof assistantMessageStreamPart + | typeof assistantControlDataStreamPart + | typeof dataMessageStreamPart; + +type AssistantStreamPartValueType = { + [P in AssistantStreamParts as P['name']]: ReturnType['value']; +}; + +export type AssistantStreamPartType = + | ReturnType + | ReturnType + | ReturnType + | ReturnType + | ReturnType; + +export const assistantStreamPartsByCode = { + [textStreamPart.code]: textStreamPart, + [errorStreamPart.code]: errorStreamPart, + [assistantMessageStreamPart.code]: assistantMessageStreamPart, + [assistantControlDataStreamPart.code]: assistantControlDataStreamPart, + [dataMessageStreamPart.code]: dataMessageStreamPart, +} as const; + +export const StreamStringPrefixes = { + [textStreamPart.name]: textStreamPart.code, + [errorStreamPart.name]: errorStreamPart.code, + [assistantMessageStreamPart.name]: assistantMessageStreamPart.code, + [assistantControlDataStreamPart.name]: assistantControlDataStreamPart.code, + [dataMessageStreamPart.name]: dataMessageStreamPart.code, +} as const; + +export const validCodes = assistantStreamParts.map(part => part.code); + +export const parseAssistantStreamPart = ( + line: string, +): AssistantStreamPartType => { + const firstSeparatorIndex = line.indexOf(':'); + + if (firstSeparatorIndex === -1) { + throw new Error('Failed to parse stream string. No separator found.'); + } + + const prefix = line.slice(0, firstSeparatorIndex); + + if (!validCodes.includes(prefix as keyof typeof assistantStreamPartsByCode)) { + throw new Error(`Failed to parse stream string. Invalid code ${prefix}.`); + } + + const code = prefix as keyof typeof assistantStreamPartsByCode; + + const textValue = line.slice(firstSeparatorIndex + 1); + const jsonValue: JSONValue = JSON.parse(textValue); + + return assistantStreamPartsByCode[code].parse(jsonValue); +}; + +export function formatAssistantStreamPart< + T extends keyof AssistantStreamPartValueType, +>(type: T, value: AssistantStreamPartValueType[T]): AssistantStreamString { + const streamPart = assistantStreamParts.find(part => part.name === type); + + if (!streamPart) { + throw new Error(`Invalid stream part type: ${type}`); + } + + return `${streamPart.code}:${JSON.stringify(value)}\n`; +} diff --git a/packages/ui-utils/src/stream-parts.test.ts b/packages/ui-utils/src/data-stream-parts.test.ts similarity index 80% rename from packages/ui-utils/src/stream-parts.test.ts rename to packages/ui-utils/src/data-stream-parts.test.ts index cb1615eb89bf..1ca7aeaf5a3c 100644 --- a/packages/ui-utils/src/stream-parts.test.ts +++ b/packages/ui-utils/src/data-stream-parts.test.ts @@ -2,18 +2,18 @@ import { ToolCall as CoreToolCall, ToolResult as CoreToolResult, } from '@ai-sdk/provider-utils'; -import { formatStreamPart, parseStreamPart } from './stream-parts'; +import { formatDataStreamPart, parseDataStreamPart } from './data-stream-parts'; -describe('stream-parts', () => { - describe('formatStreamPart', () => { +describe('data-stream-parts', () => { + describe('formatDataStreamPart', () => { it('should escape newlines in text', () => { - expect(formatStreamPart('text', 'value\nvalue')).toEqual( + expect(formatDataStreamPart('text', 'value\nvalue')).toEqual( '0:"value\\nvalue"\n', ); }); it('should escape newlines in data objects', () => { - expect(formatStreamPart('data', [{ test: 'value\nvalue' }])).toEqual( + expect(formatDataStreamPart('data', [{ test: 'value\nvalue' }])).toEqual( '2:[{"test":"value\\nvalue"}]\n', ); }); @@ -23,7 +23,7 @@ describe('stream-parts', () => { it('should parse a text line', () => { const input = '0:"Hello, world!"'; - expect(parseStreamPart(input)).toEqual({ + expect(parseDataStreamPart(input)).toEqual({ type: 'text', value: 'Hello, world!', }); @@ -32,7 +32,7 @@ describe('stream-parts', () => { it('should parse a data line', () => { const input = '2:[{"test":"value"}]'; const expectedOutput = { type: 'data', value: [{ test: 'value' }] }; - expect(parseStreamPart(input)).toEqual(expectedOutput); + expect(parseDataStreamPart(input)).toEqual(expectedOutput); }); it('should parse a message data line', () => { @@ -41,22 +41,22 @@ describe('stream-parts', () => { type: 'message_annotations', value: [{ test: 'value' }], }; - expect(parseStreamPart(input)).toEqual(expectedOutput); + expect(parseDataStreamPart(input)).toEqual(expectedOutput); }); it('should throw an error if the input does not contain a colon separator', () => { const input = 'invalid stream string'; - expect(() => parseStreamPart(input)).toThrow(); + expect(() => parseDataStreamPart(input)).toThrow(); }); it('should throw an error if the input contains an invalid type', () => { const input = '55:test'; - expect(() => parseStreamPart(input)).toThrow(); + expect(() => parseDataStreamPart(input)).toThrow(); }); it("should throw error if the input's JSON is invalid", () => { const input = '0:{"test":"value"'; - expect(() => parseStreamPart(input)).toThrow(); + expect(() => parseDataStreamPart(input)).toThrow(); }); }); }); @@ -69,7 +69,7 @@ describe('tool_call stream part', () => { args: { test: 'value' }, }; - expect(formatStreamPart('tool_call', toolCall)).toEqual( + expect(formatDataStreamPart('tool_call', toolCall)).toEqual( `9:${JSON.stringify(toolCall)}\n`, ); }); @@ -83,7 +83,7 @@ describe('tool_call stream part', () => { const input = `9:${JSON.stringify(toolCall)}`; - expect(parseStreamPart(input)).toEqual({ + expect(parseDataStreamPart(input)).toEqual({ type: 'tool_call', value: toolCall, }); @@ -100,7 +100,7 @@ describe('tool_result stream part', () => { result: 'result', }; - expect(formatStreamPart('tool_result', toolResult)).toEqual( + expect(formatDataStreamPart('tool_result', toolResult)).toEqual( `a:${JSON.stringify(toolResult)}\n`, ); }); @@ -113,7 +113,7 @@ describe('tool_result stream part', () => { const input = `a:${JSON.stringify(toolResult)}`; - expect(parseStreamPart(input)).toEqual({ + expect(parseDataStreamPart(input)).toEqual({ type: 'tool_result', value: toolResult, }); @@ -123,7 +123,7 @@ describe('tool_result stream part', () => { describe('tool_call_streaming_start stream part', () => { it('should format a tool_call_streaming_start stream part', () => { expect( - formatStreamPart('tool_call_streaming_start', { + formatDataStreamPart('tool_call_streaming_start', { toolCallId: 'tc_0', toolName: 'example_tool', }), @@ -133,7 +133,7 @@ describe('tool_call_streaming_start stream part', () => { it('should parse a tool_call_streaming_start stream part', () => { const input = `b:{"toolCallId":"tc_0","toolName":"example_tool"}`; - expect(parseStreamPart(input)).toEqual({ + expect(parseDataStreamPart(input)).toEqual({ type: 'tool_call_streaming_start', value: { toolCallId: 'tc_0', toolName: 'example_tool' }, }); @@ -143,7 +143,7 @@ describe('tool_call_streaming_start stream part', () => { describe('tool_call_delta stream part', () => { it('should format a tool_call_delta stream part', () => { expect( - formatStreamPart('tool_call_delta', { + formatDataStreamPart('tool_call_delta', { toolCallId: 'tc_0', argsTextDelta: 'delta', }), @@ -152,7 +152,7 @@ describe('tool_call_delta stream part', () => { it('should parse a tool_call_delta stream part', () => { const input = `c:{"toolCallId":"tc_0","argsTextDelta":"delta"}`; - expect(parseStreamPart(input)).toEqual({ + expect(parseDataStreamPart(input)).toEqual({ type: 'tool_call_delta', value: { toolCallId: 'tc_0', argsTextDelta: 'delta' }, }); @@ -162,7 +162,7 @@ describe('tool_call_delta stream part', () => { describe('finish_message stream part', () => { it('should format a finish_message stream part', () => { expect( - formatStreamPart('finish_message', { + formatDataStreamPart('finish_message', { finishReason: 'stop', usage: { promptTokens: 10, completionTokens: 20 }, }), @@ -173,7 +173,7 @@ describe('finish_message stream part', () => { it('should format a finish_message stream part without usage information', () => { expect( - formatStreamPart('finish_message', { + formatDataStreamPart('finish_message', { finishReason: 'stop', }), ).toEqual(`d:{"finishReason":"stop"}\n`); @@ -181,7 +181,7 @@ describe('finish_message stream part', () => { it('should parse a finish_message stream part', () => { const input = `d:{"finishReason":"stop","usage":{"promptTokens":10,"completionTokens":20}}`; - expect(parseStreamPart(input)).toEqual({ + expect(parseDataStreamPart(input)).toEqual({ type: 'finish_message', value: { finishReason: 'stop', @@ -192,7 +192,7 @@ describe('finish_message stream part', () => { it('should parse a finish_message with null completion and prompt tokens', () => { const input = `d:{"finishReason":"stop","usage":{"promptTokens":null,"completionTokens":null}}`; - expect(parseStreamPart(input)).toEqual({ + expect(parseDataStreamPart(input)).toEqual({ type: 'finish_message', value: { finishReason: 'stop', @@ -203,7 +203,7 @@ describe('finish_message stream part', () => { it('should parse a finish_message without usage information', () => { const input = `d:{"finishReason":"stop"}`; - expect(parseStreamPart(input)).toEqual({ + expect(parseDataStreamPart(input)).toEqual({ type: 'finish_message', value: { finishReason: 'stop', @@ -215,7 +215,7 @@ describe('finish_message stream part', () => { describe('finish_step stream part', () => { it('should format a finish_step stream part', () => { expect( - formatStreamPart('finish_step', { + formatDataStreamPart('finish_step', { finishReason: 'stop', usage: { promptTokens: 10, completionTokens: 20 }, isContinued: false, @@ -227,7 +227,7 @@ describe('finish_step stream part', () => { it('should format a finish_step stream part without usage or continue information ', () => { expect( - formatStreamPart('finish_step', { + formatDataStreamPart('finish_step', { finishReason: 'stop', isContinued: false, }), @@ -236,7 +236,7 @@ describe('finish_step stream part', () => { it('should parse a finish_step stream part', () => { const input = `e:{"finishReason":"stop","usage":{"promptTokens":10,"completionTokens":20},"isContinued":true}`; - expect(parseStreamPart(input)).toEqual({ + expect(parseDataStreamPart(input)).toEqual({ type: 'finish_step', value: { finishReason: 'stop', @@ -248,7 +248,7 @@ describe('finish_step stream part', () => { it('should parse a finish_step with null completion and prompt tokens', () => { const input = `e:{"finishReason":"stop","usage":{"promptTokens":null,"completionTokens":null}}`; - expect(parseStreamPart(input)).toEqual({ + expect(parseDataStreamPart(input)).toEqual({ type: 'finish_step', value: { finishReason: 'stop', @@ -260,7 +260,7 @@ describe('finish_step stream part', () => { it('should parse a finish_step without usage information', () => { const input = `e:{"finishReason":"stop","usage":null}`; - expect(parseStreamPart(input)).toEqual({ + expect(parseDataStreamPart(input)).toEqual({ type: 'finish_step', value: { finishReason: 'stop', diff --git a/packages/ui-utils/src/stream-parts.ts b/packages/ui-utils/src/data-stream-parts.ts similarity index 69% rename from packages/ui-utils/src/stream-parts.ts rename to packages/ui-utils/src/data-stream-parts.ts index d40b4bfdf169..6ff653d1f44c 100644 --- a/packages/ui-utils/src/stream-parts.ts +++ b/packages/ui-utils/src/data-stream-parts.ts @@ -3,18 +3,22 @@ import { ToolCall as CoreToolCall, ToolResult as CoreToolResult, } from '@ai-sdk/provider-utils'; -import { AssistantMessage, DataMessage, JSONValue } from './types'; +import { JSONValue } from './types'; -export type StreamString = - `${(typeof StreamStringPrefixes)[keyof typeof StreamStringPrefixes]}:${string}\n`; +export type DataStreamString = + `${(typeof DataStreamStringPrefixes)[keyof typeof DataStreamStringPrefixes]}:${string}\n`; -export interface StreamPart { +export interface DataStreamPart< + CODE extends string, + NAME extends string, + TYPE, +> { code: CODE; name: NAME; parse: (value: JSONValue) => { type: NAME; value: TYPE }; } -const textStreamPart: StreamPart<'0', 'text', string> = { +const textStreamPart: DataStreamPart<'0', 'text', string> = { code: '0', name: 'text', parse: (value: JSONValue) => { @@ -25,7 +29,7 @@ const textStreamPart: StreamPart<'0', 'text', string> = { }, }; -const dataStreamPart: StreamPart<'2', 'data', Array> = { +const dataStreamPart: DataStreamPart<'2', 'data', Array> = { code: '2', name: 'data', parse: (value: JSONValue) => { @@ -37,7 +41,7 @@ const dataStreamPart: StreamPart<'2', 'data', Array> = { }, }; -const errorStreamPart: StreamPart<'3', 'error', string> = { +const errorStreamPart: DataStreamPart<'3', 'error', string> = { code: '3', name: 'error', parse: (value: JSONValue) => { @@ -48,108 +52,7 @@ const errorStreamPart: StreamPart<'3', 'error', string> = { }, }; -const assistantMessageStreamPart: StreamPart< - '4', - 'assistant_message', - AssistantMessage -> = { - code: '4', - name: 'assistant_message', - parse: (value: JSONValue) => { - if ( - value == null || - typeof value !== 'object' || - !('id' in value) || - !('role' in value) || - !('content' in value) || - typeof value.id !== 'string' || - typeof value.role !== 'string' || - value.role !== 'assistant' || - !Array.isArray(value.content) || - !value.content.every( - item => - item != null && - typeof item === 'object' && - 'type' in item && - item.type === 'text' && - 'text' in item && - item.text != null && - typeof item.text === 'object' && - 'value' in item.text && - typeof item.text.value === 'string', - ) - ) { - throw new Error( - '"assistant_message" parts expect an object with an "id", "role", and "content" property.', - ); - } - - return { - type: 'assistant_message', - value: value as AssistantMessage, - }; - }, -}; - -const assistantControlDataStreamPart: StreamPart< - '5', - 'assistant_control_data', - { - threadId: string; - messageId: string; - } -> = { - code: '5', - name: 'assistant_control_data', - parse: (value: JSONValue) => { - if ( - value == null || - typeof value !== 'object' || - !('threadId' in value) || - !('messageId' in value) || - typeof value.threadId !== 'string' || - typeof value.messageId !== 'string' - ) { - throw new Error( - '"assistant_control_data" parts expect an object with a "threadId" and "messageId" property.', - ); - } - - return { - type: 'assistant_control_data', - value: { - threadId: value.threadId, - messageId: value.messageId, - }, - }; - }, -}; - -const dataMessageStreamPart: StreamPart<'6', 'data_message', DataMessage> = { - code: '6', - name: 'data_message', - parse: (value: JSONValue) => { - if ( - value == null || - typeof value !== 'object' || - !('role' in value) || - !('data' in value) || - typeof value.role !== 'string' || - value.role !== 'data' - ) { - throw new Error( - '"data_message" parts expect an object with a "role" and "data" property.', - ); - } - - return { - type: 'data_message', - value: value as DataMessage, - }; - }, -}; - -const messageAnnotationsStreamPart: StreamPart< +const messageAnnotationsStreamPart: DataStreamPart< '8', 'message_annotations', Array @@ -165,7 +68,7 @@ const messageAnnotationsStreamPart: StreamPart< }, }; -const toolCallStreamPart: StreamPart< +const toolCallStreamPart: DataStreamPart< '9', 'tool_call', CoreToolCall @@ -195,7 +98,7 @@ const toolCallStreamPart: StreamPart< }, }; -const toolResultStreamPart: StreamPart< +const toolResultStreamPart: DataStreamPart< 'a', 'tool_result', Omit, 'args' | 'toolName'> @@ -225,7 +128,7 @@ const toolResultStreamPart: StreamPart< }, }; -const toolCallStreamingStartStreamPart: StreamPart< +const toolCallStreamingStartStreamPart: DataStreamPart< 'b', 'tool_call_streaming_start', { toolCallId: string; toolName: string } @@ -253,7 +156,7 @@ const toolCallStreamingStartStreamPart: StreamPart< }, }; -const toolCallDeltaStreamPart: StreamPart< +const toolCallDeltaStreamPart: DataStreamPart< 'c', 'tool_call_delta', { toolCallId: string; argsTextDelta: string } @@ -284,7 +187,7 @@ const toolCallDeltaStreamPart: StreamPart< }, }; -const finishMessageStreamPart: StreamPart< +const finishMessageStreamPart: DataStreamPart< 'd', 'finish_message', { @@ -345,7 +248,7 @@ const finishMessageStreamPart: StreamPart< }, }; -const finishStepStreamPart: StreamPart< +const finishStepStreamPart: DataStreamPart< 'e', 'finish_step', { @@ -413,13 +316,10 @@ const finishStepStreamPart: StreamPart< }, }; -const streamParts = [ +const dataStreamParts = [ textStreamPart, dataStreamPart, errorStreamPart, - assistantMessageStreamPart, - assistantControlDataStreamPart, - dataMessageStreamPart, messageAnnotationsStreamPart, toolCallStreamPart, toolResultStreamPart, @@ -429,14 +329,10 @@ const streamParts = [ finishStepStreamPart, ] as const; -// union type of all stream parts -type StreamParts = +type DataStreamParts = | typeof textStreamPart | typeof dataStreamPart | typeof errorStreamPart - | typeof assistantMessageStreamPart - | typeof assistantControlDataStreamPart - | typeof dataMessageStreamPart | typeof messageAnnotationsStreamPart | typeof toolCallStreamPart | typeof toolResultStreamPart @@ -448,17 +344,14 @@ type StreamParts = /** * Maps the type of a stream part to its value type. */ -type StreamPartValueType = { - [P in StreamParts as P['name']]: ReturnType['value']; +type DataStreamPartValueType = { + [P in DataStreamParts as P['name']]: ReturnType['value']; }; -export type StreamPartType = +export type DataStreamPartType = | ReturnType | ReturnType | ReturnType - | ReturnType - | ReturnType - | ReturnType | ReturnType | ReturnType | ReturnType @@ -467,13 +360,10 @@ export type StreamPartType = | ReturnType | ReturnType; -export const streamPartsByCode = { +export const dataStreamPartsByCode = { [textStreamPart.code]: textStreamPart, [dataStreamPart.code]: dataStreamPart, [errorStreamPart.code]: errorStreamPart, - [assistantMessageStreamPart.code]: assistantMessageStreamPart, - [assistantControlDataStreamPart.code]: assistantControlDataStreamPart, - [dataMessageStreamPart.code]: dataMessageStreamPart, [messageAnnotationsStreamPart.code]: messageAnnotationsStreamPart, [toolCallStreamPart.code]: toolCallStreamPart, [toolResultStreamPart.code]: toolResultStreamPart, @@ -505,13 +395,10 @@ export const streamPartsByCode = { * 6: {"tool_call": {"id": "tool_0", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\\n\\"location\\": \\"Charlottesville, Virginia\\",\\n\\"format\\": \\"celsius\\"\\n}"}}} *``` */ -export const StreamStringPrefixes = { +export const DataStreamStringPrefixes = { [textStreamPart.name]: textStreamPart.code, [dataStreamPart.name]: dataStreamPart.code, [errorStreamPart.name]: errorStreamPart.code, - [assistantMessageStreamPart.name]: assistantMessageStreamPart.code, - [assistantControlDataStreamPart.name]: assistantControlDataStreamPart.code, - [dataMessageStreamPart.name]: dataMessageStreamPart.code, [messageAnnotationsStreamPart.name]: messageAnnotationsStreamPart.code, [toolCallStreamPart.name]: toolCallStreamPart.code, [toolResultStreamPart.name]: toolResultStreamPart.code, @@ -522,7 +409,7 @@ export const StreamStringPrefixes = { [finishStepStreamPart.name]: finishStepStreamPart.code, } as const; -export const validCodes = streamParts.map(part => part.code); +export const validCodes = dataStreamParts.map(part => part.code); /** Parses a stream part from a string. @@ -531,7 +418,7 @@ Parses a stream part from a string. @returns The parsed stream part. @throws An error if the string cannot be parsed. */ -export const parseStreamPart = (line: string): StreamPartType => { +export const parseDataStreamPart = (line: string): DataStreamPartType => { const firstSeparatorIndex = line.indexOf(':'); if (firstSeparatorIndex === -1) { @@ -540,16 +427,16 @@ export const parseStreamPart = (line: string): StreamPartType => { const prefix = line.slice(0, firstSeparatorIndex); - if (!validCodes.includes(prefix as keyof typeof streamPartsByCode)) { + if (!validCodes.includes(prefix as keyof typeof dataStreamPartsByCode)) { throw new Error(`Failed to parse stream string. Invalid code ${prefix}.`); } - const code = prefix as keyof typeof streamPartsByCode; + const code = prefix as keyof typeof dataStreamPartsByCode; const textValue = line.slice(firstSeparatorIndex + 1); const jsonValue: JSONValue = JSON.parse(textValue); - return streamPartsByCode[code].parse(jsonValue); + return dataStreamPartsByCode[code].parse(jsonValue); }; /** @@ -558,11 +445,11 @@ and appends a new line. It ensures type-safety for the part type and value. */ -export function formatStreamPart( +export function formatDataStreamPart( type: T, - value: StreamPartValueType[T], -): StreamString { - const streamPart = streamParts.find(part => part.name === type); + value: DataStreamPartValueType[T], +): DataStreamString { + const streamPart = dataStreamParts.find(part => part.name === type); if (!streamPart) { throw new Error(`Invalid stream part type: ${type}`); diff --git a/packages/ui-utils/src/index.ts b/packages/ui-utils/src/index.ts index 2eecc054144c..89d068bfe97a 100644 --- a/packages/ui-utils/src/index.ts +++ b/packages/ui-utils/src/index.ts @@ -5,16 +5,25 @@ export { generateId } from '@ai-sdk/provider-utils'; // Export stream data utilities for custom stream implementations, // both on the client and server side. // NOTE: this is experimental / internal and may change without notice +export { + formatAssistantStreamPart, + parseAssistantStreamPart, +} from './assistant-stream-parts'; +export type { + AssistantStreamPart, + AssistantStreamString, +} from './assistant-stream-parts'; export { callChatApi } from './call-chat-api'; export { callCompletionApi } from './call-completion-api'; +export { formatDataStreamPart, parseDataStreamPart } from './data-stream-parts'; +export type { DataStreamPart, DataStreamString } from './data-stream-parts'; export { getTextFromDataUrl } from './data-url'; export type { DeepPartial } from './deep-partial'; export { isDeepEqualData } from './is-deep-equal-data'; export { parsePartialJson } from './parse-partial-json'; export { processDataProtocolResponse } from './process-data-protocol-response'; -export { processTextStream } from './process-text-stream'; +export { processAssistantStream } from './process-assistant-stream'; export { processDataStream } from './process-data-stream'; +export { processTextStream } from './process-text-stream'; export { asSchema, jsonSchema, zodSchema } from './schema'; export type { Schema } from './schema'; -export { formatStreamPart, parseStreamPart } from './stream-parts'; -export type { StreamPart, StreamString } from './stream-parts'; diff --git a/packages/ui-utils/src/process-assistant-stream.test.ts b/packages/ui-utils/src/process-assistant-stream.test.ts new file mode 100644 index 000000000000..5bc297b2f34f --- /dev/null +++ b/packages/ui-utils/src/process-assistant-stream.test.ts @@ -0,0 +1,234 @@ +import { describe, expect, it, vi } from 'vitest'; +import { AssistantStreamPartType } from './assistant-stream-parts'; +import { processAssistantStream } from './process-assistant-stream'; + +function createReadableStream( + chunks: Uint8Array[], +): ReadableStream { + return new ReadableStream({ + start(controller) { + chunks.forEach(chunk => controller.enqueue(chunk)); + controller.close(); + }, + }); +} + +function encodeText(text: string): Uint8Array { + return new TextEncoder().encode(text); +} + +describe('processDataStream', () => { + // Basic Functionality Tests + it('should process a simple text stream part', async () => { + const chunks = [encodeText('0:"Hello"\n')]; + const stream = createReadableStream(chunks); + const receivedParts: AssistantStreamPartType[] = []; + + await processAssistantStream({ + stream, + onStreamPart: async part => { + receivedParts.push(part); + }, + }); + + expect(receivedParts).toHaveLength(1); + expect(receivedParts[0]).toEqual({ + type: 'text', + value: 'Hello', + }); + }); + + it('should handle multiple stream parts in sequence', async () => { + const chunks = [encodeText('0:"Hello"\n0:"123"\n3:"error"\n')]; + const stream = createReadableStream(chunks); + const receivedParts: AssistantStreamPartType[] = []; + + await processAssistantStream({ + stream, + onStreamPart: async part => { + receivedParts.push(part); + }, + }); + + expect(receivedParts).toHaveLength(3); + expect(receivedParts[0]).toEqual({ type: 'text', value: 'Hello' }); + expect(receivedParts[1]).toEqual({ type: 'text', value: '123' }); + expect(receivedParts[2]).toEqual({ type: 'error', value: 'error' }); + }); + + // Edge Environment Specific Tests + it('should handle chunks that split JSON values', async () => { + const chunks = [encodeText('0:"Hel'), encodeText('lo"\n')]; + const stream = createReadableStream(chunks); + const receivedParts: AssistantStreamPartType[] = []; + + await processAssistantStream({ + stream, + onStreamPart: async part => { + receivedParts.push(part); + }, + }); + + expect(receivedParts).toHaveLength(1); + expect(receivedParts[0]).toEqual({ type: 'text', value: 'Hello' }); + }); + + it('should handle chunks that split at newlines', async () => { + const chunks = [encodeText('0:"Hello"\n'), encodeText('0:"World"\n')]; + const stream = createReadableStream(chunks); + const receivedParts: AssistantStreamPartType[] = []; + + await processAssistantStream({ + stream, + onStreamPart: async part => { + receivedParts.push(part); + }, + }); + + expect(receivedParts).toHaveLength(2); + expect(receivedParts[0]).toEqual({ type: 'text', value: 'Hello' }); + expect(receivedParts[1]).toEqual({ type: 'text', value: 'World' }); + }); + + it('should handle chunks that split unicode characters', async () => { + const emoji = '👋'; + const encoded = encodeText(`0:"Hello ${emoji}"\n`); + const splitPoint = encoded.length - 3; // Split in the middle of emoji bytes + + const chunks = [encoded.slice(0, splitPoint), encoded.slice(splitPoint)]; + const stream = createReadableStream(chunks); + const receivedParts: AssistantStreamPartType[] = []; + + await processAssistantStream({ + stream, + onStreamPart: async part => { + receivedParts.push(part); + }, + }); + + expect(receivedParts).toHaveLength(1); + expect(receivedParts[0]).toEqual({ type: 'text', value: `Hello ${emoji}` }); + }); + + // Error Cases + it('should throw on malformed JSON', async () => { + const chunks = [encodeText('0:{malformed]]\n')]; + const stream = createReadableStream(chunks); + + await expect( + processAssistantStream({ + stream, + onStreamPart: async () => {}, + }), + ).rejects.toThrow(); + }); + + it('should throw on invalid stream part codes', async () => { + const chunks = [encodeText('x:"invalid"\n')]; + const stream = createReadableStream(chunks); + + await expect( + processAssistantStream({ + stream, + onStreamPart: async () => {}, + }), + ).rejects.toThrow('Invalid code'); + }); + + // Edge Cases + it('should handle empty chunks', async () => { + const chunks = [ + new Uint8Array(0), + encodeText('0:"Hello"\n'), + new Uint8Array(0), + encodeText('0:"World"\n'), + ]; + const stream = createReadableStream(chunks); + const receivedParts: AssistantStreamPartType[] = []; + + await processAssistantStream({ + stream, + onStreamPart: async part => { + receivedParts.push(part); + }, + }); + + expect(receivedParts).toHaveLength(2); + expect(receivedParts[0]).toEqual({ type: 'text', value: 'Hello' }); + expect(receivedParts[1]).toEqual({ type: 'text', value: 'World' }); + }); + + it('should handle very large messages', async () => { + const largeString = 'x'.repeat(1024 * 1024); // 1MB string + const chunks = [encodeText(`0:"${largeString}"\n`)]; + const stream = createReadableStream(chunks); + const receivedParts: AssistantStreamPartType[] = []; + + await processAssistantStream({ + stream, + onStreamPart: async part => { + receivedParts.push(part); + }, + }); + + expect(receivedParts).toHaveLength(1); + expect(receivedParts[0]).toEqual({ type: 'text', value: largeString }); + }); + + it('should handle multiple newlines', async () => { + const chunks = [encodeText('0:"Hello"\n\n0:"World"\n')]; + const stream = createReadableStream(chunks); + const receivedParts: AssistantStreamPartType[] = []; + + await processAssistantStream({ + stream, + onStreamPart: async part => { + receivedParts.push(part); + }, + }); + + expect(receivedParts).toHaveLength(2); + expect(receivedParts[0]).toEqual({ type: 'text', value: 'Hello' }); + expect(receivedParts[1]).toEqual({ type: 'text', value: 'World' }); + }); + + // Cleanup and Resource Management + it('should properly release reader resources', async () => { + const mockRelease = vi.fn(); + const stream = new ReadableStream({ + start(controller) { + controller.enqueue(encodeText('0:"Hello"\n')); + controller.close(); + }, + cancel: mockRelease, + }); + + await processAssistantStream({ + stream, + onStreamPart: async () => {}, + }); + + // The reader should be automatically released when the stream is done + expect(mockRelease).not.toHaveBeenCalled(); + }); + + // Concurrency Tests + it('should handle rapid stream processing', async () => { + const parts = Array.from({ length: 100 }, (_, i) => `0:"Message ${i}"\n`); + const chunks = parts.map(encodeText); + const stream = createReadableStream(chunks); + const receivedParts: AssistantStreamPartType[] = []; + + await processAssistantStream({ + stream, + onStreamPart: async part => { + receivedParts.push(part); + }, + }); + + expect(receivedParts).toHaveLength(100); + receivedParts.forEach((part, i) => { + expect(part).toEqual({ type: 'text', value: `Message ${i}` }); + }); + }); +}); diff --git a/packages/ui-utils/src/process-assistant-stream.ts b/packages/ui-utils/src/process-assistant-stream.ts new file mode 100644 index 000000000000..f826cae48bbb --- /dev/null +++ b/packages/ui-utils/src/process-assistant-stream.ts @@ -0,0 +1,66 @@ +import { + AssistantStreamPartType, + parseAssistantStreamPart, +} from './assistant-stream-parts'; + +const NEWLINE = '\n'.charCodeAt(0); + +// concatenates all the chunks into a single Uint8Array +function concatChunks(chunks: Uint8Array[], totalLength: number) { + const concatenatedChunks = new Uint8Array(totalLength); + + let offset = 0; + for (const chunk of chunks) { + concatenatedChunks.set(chunk, offset); + offset += chunk.length; + } + chunks.length = 0; + + return concatenatedChunks; +} + +export async function processAssistantStream({ + stream, + onStreamPart, +}: { + stream: ReadableStream; + onStreamPart: (streamPart: AssistantStreamPartType) => Promise | void; +}): Promise { + // implementation note: this slightly more complex algorithm is required + // to pass the tests in the edge environment. + + const reader = stream.getReader(); + const decoder = new TextDecoder(); + const chunks: Uint8Array[] = []; + let totalLength = 0; + + while (true) { + const { value } = await reader.read(); + + if (value) { + chunks.push(value); + totalLength += value.length; + if (value[value.length - 1] !== NEWLINE) { + // if the last character is not a newline, we have not read the whole JSON value + continue; + } + } + + if (chunks.length === 0) { + break; // we have reached the end of the stream + } + + const concatenatedChunks = concatChunks(chunks, totalLength); + totalLength = 0; + + const streamParts = decoder + .decode(concatenatedChunks, { stream: true }) + .split('\n') + .filter(line => line !== '') // splitting leaves an empty string at the end + .map(parseAssistantStreamPart); + + for (const streamPart of streamParts) { + await onStreamPart(streamPart); + } + } +} diff --git a/packages/ui-utils/src/process-data-procotol-response.test.ts b/packages/ui-utils/src/process-data-procotol-response.test.ts index 7f6bbcab8c0e..e041025fc6e9 100644 --- a/packages/ui-utils/src/process-data-procotol-response.test.ts +++ b/packages/ui-utils/src/process-data-procotol-response.test.ts @@ -1,9 +1,9 @@ +import { LanguageModelV1FinishReason } from '@ai-sdk/provider'; import { describe, expect, it, vi } from 'vitest'; +import { formatDataStreamPart } from './data-stream-parts'; import { processDataProtocolResponse } from './process-data-protocol-response'; -import { formatStreamPart } from './stream-parts'; import { createDataProtocolStream } from './test/create-data-protocol-stream'; import { JSONValue, Message } from './types'; -import { LanguageModelV1FinishReason } from '@ai-sdk/provider'; let updateCalls: Array<{ newMessages: Message[]; @@ -46,14 +46,14 @@ describe('scenario: simple text response', () => { beforeEach(async () => { const stream = createDataProtocolStream([ - formatStreamPart('text', 'Hello, '), - formatStreamPart('text', 'world!'), - formatStreamPart('finish_step', { + formatDataStreamPart('text', 'Hello, '), + formatDataStreamPart('text', 'world!'), + formatDataStreamPart('finish_step', { finishReason: 'stop', usage: { completionTokens: 5, promptTokens: 10 }, isContinued: false, }), - formatStreamPart('finish_message', { + formatDataStreamPart('finish_message', { finishReason: 'stop', usage: { completionTokens: 5, promptTokens: 10 }, }), @@ -131,27 +131,27 @@ describe('scenario: server-side tool roundtrip', () => { beforeEach(async () => { const stream = createDataProtocolStream([ - formatStreamPart('tool_call', { + formatDataStreamPart('tool_call', { toolCallId: 'tool-call-id', toolName: 'tool-name', args: { city: 'London' }, }), - formatStreamPart('tool_result', { + formatDataStreamPart('tool_result', { toolCallId: 'tool-call-id', result: { weather: 'sunny' }, }), - formatStreamPart('finish_step', { + formatDataStreamPart('finish_step', { finishReason: 'tool-calls', usage: { completionTokens: 5, promptTokens: 10 }, isContinued: false, }), - formatStreamPart('text', 'The weather in London is sunny.'), - formatStreamPart('finish_step', { + formatDataStreamPart('text', 'The weather in London is sunny.'), + formatDataStreamPart('finish_step', { finishReason: 'stop', usage: { completionTokens: 2, promptTokens: 4 }, isContinued: false, }), - formatStreamPart('finish_message', { + formatDataStreamPart('finish_message', { finishReason: 'stop', usage: { completionTokens: 7, promptTokens: 14 }, }), @@ -285,19 +285,19 @@ describe('scenario: server-side continue roundtrip', () => { beforeEach(async () => { const stream = createDataProtocolStream([ - formatStreamPart('text', 'The weather in London '), - formatStreamPart('finish_step', { + formatDataStreamPart('text', 'The weather in London '), + formatDataStreamPart('finish_step', { finishReason: 'length', usage: { completionTokens: 5, promptTokens: 10 }, isContinued: true, }), - formatStreamPart('text', 'is sunny.'), - formatStreamPart('finish_step', { + formatDataStreamPart('text', 'is sunny.'), + formatDataStreamPart('finish_step', { finishReason: 'stop', usage: { completionTokens: 2, promptTokens: 4 }, isContinued: false, }), - formatStreamPart('finish_message', { + formatDataStreamPart('finish_message', { finishReason: 'stop', usage: { completionTokens: 7, promptTokens: 14 }, }), @@ -375,18 +375,18 @@ describe('scenario: delayed message annotations in onFinish', () => { beforeEach(async () => { const stream = createDataProtocolStream([ - formatStreamPart('text', 'text'), - formatStreamPart('finish_step', { + formatDataStreamPart('text', 'text'), + formatDataStreamPart('finish_step', { finishReason: 'stop', usage: { completionTokens: 5, promptTokens: 10 }, isContinued: false, }), - formatStreamPart('finish_message', { + formatDataStreamPart('finish_message', { finishReason: 'stop', usage: { completionTokens: 5, promptTokens: 10 }, }), // delayed message annotations: - formatStreamPart('message_annotations', [ + formatDataStreamPart('message_annotations', [ { example: 'annotation', }, @@ -471,16 +471,16 @@ describe('scenario: message annotations in onChunk', () => { beforeEach(async () => { const stream = createDataProtocolStream([ - formatStreamPart('message_annotations', ['annotation1']), - formatStreamPart('text', 't1'), - formatStreamPart('message_annotations', ['annotation2']), - formatStreamPart('text', 't2'), - formatStreamPart('finish_step', { + formatDataStreamPart('message_annotations', ['annotation1']), + formatDataStreamPart('text', 't1'), + formatDataStreamPart('message_annotations', ['annotation2']), + formatDataStreamPart('text', 't2'), + formatDataStreamPart('finish_step', { finishReason: 'stop', usage: { completionTokens: 5, promptTokens: 10 }, isContinued: false, }), - formatStreamPart('finish_message', { + formatDataStreamPart('finish_message', { finishReason: 'stop', usage: { completionTokens: 5, promptTokens: 10 }, }), diff --git a/packages/ui-utils/src/process-data-stream.test.ts b/packages/ui-utils/src/process-data-stream.test.ts index d323e72f3da8..988a1a5eb628 100644 --- a/packages/ui-utils/src/process-data-stream.test.ts +++ b/packages/ui-utils/src/process-data-stream.test.ts @@ -1,6 +1,6 @@ -import { describe, it, expect, vi } from 'vitest'; +import { describe, expect, it, vi } from 'vitest'; +import { DataStreamPartType } from './data-stream-parts'; import { processDataStream } from './process-data-stream'; -import { StreamPartType } from './stream-parts'; function createReadableStream( chunks: Uint8Array[], @@ -22,7 +22,7 @@ describe('processDataStream', () => { it('should process a simple text stream part', async () => { const chunks = [encodeText('0:"Hello"\n')]; const stream = createReadableStream(chunks); - const receivedParts: StreamPartType[] = []; + const receivedParts: DataStreamPartType[] = []; await processDataStream({ stream, @@ -41,7 +41,7 @@ describe('processDataStream', () => { it('should handle multiple stream parts in sequence', async () => { const chunks = [encodeText('0:"Hello"\n2:[1,2,3]\n3:"error"\n')]; const stream = createReadableStream(chunks); - const receivedParts: StreamPartType[] = []; + const receivedParts: DataStreamPartType[] = []; await processDataStream({ stream, @@ -60,7 +60,7 @@ describe('processDataStream', () => { it('should handle chunks that split JSON values', async () => { const chunks = [encodeText('0:"Hel'), encodeText('lo"\n')]; const stream = createReadableStream(chunks); - const receivedParts: StreamPartType[] = []; + const receivedParts: DataStreamPartType[] = []; await processDataStream({ stream, @@ -76,7 +76,7 @@ describe('processDataStream', () => { it('should handle chunks that split at newlines', async () => { const chunks = [encodeText('0:"Hello"\n'), encodeText('0:"World"\n')]; const stream = createReadableStream(chunks); - const receivedParts: StreamPartType[] = []; + const receivedParts: DataStreamPartType[] = []; await processDataStream({ stream, @@ -97,7 +97,7 @@ describe('processDataStream', () => { const chunks = [encoded.slice(0, splitPoint), encoded.slice(splitPoint)]; const stream = createReadableStream(chunks); - const receivedParts: StreamPartType[] = []; + const receivedParts: DataStreamPartType[] = []; await processDataStream({ stream, @@ -144,7 +144,7 @@ describe('processDataStream', () => { encodeText('0:"World"\n'), ]; const stream = createReadableStream(chunks); - const receivedParts: StreamPartType[] = []; + const receivedParts: DataStreamPartType[] = []; await processDataStream({ stream, @@ -162,7 +162,7 @@ describe('processDataStream', () => { const largeString = 'x'.repeat(1024 * 1024); // 1MB string const chunks = [encodeText(`0:"${largeString}"\n`)]; const stream = createReadableStream(chunks); - const receivedParts: StreamPartType[] = []; + const receivedParts: DataStreamPartType[] = []; await processDataStream({ stream, @@ -178,7 +178,7 @@ describe('processDataStream', () => { it('should handle multiple newlines', async () => { const chunks = [encodeText('0:"Hello"\n\n0:"World"\n')]; const stream = createReadableStream(chunks); - const receivedParts: StreamPartType[] = []; + const receivedParts: DataStreamPartType[] = []; await processDataStream({ stream, @@ -192,31 +192,6 @@ describe('processDataStream', () => { expect(receivedParts[1]).toEqual({ type: 'text', value: 'World' }); }); - // Complex Stream Part Types - it('should correctly parse assistant message stream parts', async () => { - const message = { - id: '123', - role: 'assistant', - content: [{ type: 'text', text: { value: 'Hello' } }], - }; - const chunks = [encodeText(`4:${JSON.stringify(message)}\n`)]; - const stream = createReadableStream(chunks); - const receivedParts: StreamPartType[] = []; - - await processDataStream({ - stream, - onStreamPart: async part => { - receivedParts.push(part); - }, - }); - - expect(receivedParts).toHaveLength(1); - expect(receivedParts[0]).toEqual({ - type: 'assistant_message', - value: message, - }); - }); - // Cleanup and Resource Management it('should properly release reader resources', async () => { const mockRelease = vi.fn(); @@ -242,7 +217,7 @@ describe('processDataStream', () => { const parts = Array.from({ length: 100 }, (_, i) => `0:"Message ${i}"\n`); const chunks = parts.map(encodeText); const stream = createReadableStream(chunks); - const receivedParts: StreamPartType[] = []; + const receivedParts: DataStreamPartType[] = []; await processDataStream({ stream, diff --git a/packages/ui-utils/src/process-data-stream.ts b/packages/ui-utils/src/process-data-stream.ts index a6b0a1ac536f..d86d4a2c09da 100644 --- a/packages/ui-utils/src/process-data-stream.ts +++ b/packages/ui-utils/src/process-data-stream.ts @@ -1,4 +1,4 @@ -import { parseStreamPart, StreamPartType } from './stream-parts'; +import { parseDataStreamPart, DataStreamPartType } from './data-stream-parts'; const NEWLINE = '\n'.charCodeAt(0); @@ -21,7 +21,7 @@ export async function processDataStream({ onStreamPart, }: { stream: ReadableStream; - onStreamPart: (streamPart: StreamPartType) => Promise | void; + onStreamPart: (streamPart: DataStreamPartType) => Promise | void; }): Promise { // implementation note: this slightly more complex algorithm is required // to pass the tests in the edge environment. @@ -54,7 +54,7 @@ export async function processDataStream({ .decode(concatenatedChunks, { stream: true }) .split('\n') .filter(line => line !== '') // splitting leaves an empty string at the end - .map(parseStreamPart); + .map(parseDataStreamPart); for (const streamPart of streamParts) { await onStreamPart(streamPart); diff --git a/packages/ui-utils/src/test/create-data-protocol-stream.ts b/packages/ui-utils/src/test/create-data-protocol-stream.ts index dac4cb514c6d..54ea4ce7d06a 100644 --- a/packages/ui-utils/src/test/create-data-protocol-stream.ts +++ b/packages/ui-utils/src/test/create-data-protocol-stream.ts @@ -1,8 +1,8 @@ import { convertArrayToReadableStream } from '@ai-sdk/provider-utils/test'; -import { StreamString } from '../stream-parts'; +import { DataStreamString } from '../data-stream-parts'; export function createDataProtocolStream( - dataPartTexts: StreamString[], + dataPartTexts: DataStreamString[], ): ReadableStream { const encoder = new TextEncoder(); return convertArrayToReadableStream( diff --git a/packages/vue/src/use-assistant.ts b/packages/vue/src/use-assistant.ts index 384ba47c9849..622db3222730 100644 --- a/packages/vue/src/use-assistant.ts +++ b/packages/vue/src/use-assistant.ts @@ -3,15 +3,15 @@ */ import { isAbortError } from '@ai-sdk/provider-utils'; -import { generateId, processDataStream } from '@ai-sdk/ui-utils'; import type { AssistantStatus, CreateMessage, Message, UseAssistantOptions, } from '@ai-sdk/ui-utils'; -import { computed, readonly, ref } from 'vue'; +import { generateId, processAssistantStream } from '@ai-sdk/ui-utils'; import type { ComputedRef, Ref } from 'vue'; +import { computed, readonly, ref } from 'vue'; export type UseAssistantHelpers = { /** @@ -180,7 +180,7 @@ export function useAssistant({ throw new Error('The response body is empty'); } - await processDataStream({ + await processAssistantStream({ stream: response.body, onStreamPart: async ({ type, value }) => { switch (type) { diff --git a/packages/vue/src/use-assistant.ui.test.tsx b/packages/vue/src/use-assistant.ui.test.tsx index e5905c5728ee..e78fc6fba65a 100644 --- a/packages/vue/src/use-assistant.ui.test.tsx +++ b/packages/vue/src/use-assistant.ui.test.tsx @@ -1,4 +1,4 @@ -import { formatStreamPart } from '@ai-sdk/ui-utils'; +import { formatAssistantStreamPart } from '@ai-sdk/ui-utils'; import { mockFetchDataStream, mockFetchDataStreamWithGenerator, @@ -27,11 +27,11 @@ describe('stream data stream', () => { url: 'https://example.com/api/assistant', chunks: [ // Format the stream part - formatStreamPart('assistant_control_data', { + formatAssistantStreamPart('assistant_control_data', { threadId: 't0', messageId: 'm0', }), - formatStreamPart('assistant_message', { + formatAssistantStreamPart('assistant_message', { id: 'm0', role: 'assistant', content: [{ type: 'text', text: { value: '' } }], @@ -81,14 +81,14 @@ describe('stream data stream', () => { const encoder = new TextEncoder(); yield encoder.encode( - formatStreamPart('assistant_control_data', { + formatAssistantStreamPart('assistant_control_data', { threadId: 't0', messageId: 'm1', }), ); yield encoder.encode( - formatStreamPart('assistant_message', { + formatAssistantStreamPart('assistant_message', { id: 'm1', role: 'assistant', content: [{ type: 'text', text: { value: '' } }], @@ -139,11 +139,11 @@ describe('Thread management', () => { const { requestBody } = mockFetchDataStream({ url: 'https://example.com/api/assistant', chunks: [ - formatStreamPart('assistant_control_data', { + formatAssistantStreamPart('assistant_control_data', { threadId: 't0', messageId: 'm0', }), - formatStreamPart('assistant_message', { + formatAssistantStreamPart('assistant_message', { id: 'm0', role: 'assistant', content: [{ type: 'text', text: { value: '' } }], @@ -185,11 +185,11 @@ describe('Thread management', () => { const { requestBody } = mockFetchDataStream({ url: 'https://example.com/api/assistant', chunks: [ - formatStreamPart('assistant_control_data', { + formatAssistantStreamPart('assistant_control_data', { threadId: 't1', messageId: 'm0', }), - formatStreamPart('assistant_message', { + formatAssistantStreamPart('assistant_message', { id: 'm0', role: 'assistant', content: [{ type: 'text', text: { value: '' } }], @@ -232,11 +232,11 @@ describe('Thread management', () => { const { requestBody } = mockFetchDataStream({ url: 'https://example.com/api/assistant', chunks: [ - formatStreamPart('assistant_control_data', { + formatAssistantStreamPart('assistant_control_data', { threadId: 't3', messageId: 'm0', }), - formatStreamPart('assistant_message', { + formatAssistantStreamPart('assistant_message', { id: 'm0', role: 'assistant', content: [{ type: 'text', text: { value: '' } }], diff --git a/packages/vue/src/use-chat.ui.test.tsx b/packages/vue/src/use-chat.ui.test.tsx index a09a0936951b..cbae03f2a1ad 100644 --- a/packages/vue/src/use-chat.ui.test.tsx +++ b/packages/vue/src/use-chat.ui.test.tsx @@ -1,5 +1,5 @@ import { withTestServer } from '@ai-sdk/provider-utils/test'; -import { formatStreamPart } from '@ai-sdk/ui-utils'; +import { formatDataStreamPart } from '@ai-sdk/ui-utils'; import { mockFetchDataStream, mockFetchDataStreamWithGenerator, @@ -152,11 +152,11 @@ describe('data protocol stream', () => { url: '/api/chat', type: 'stream-values', content: [ - formatStreamPart('text', 'Hello'), - formatStreamPart('text', ','), - formatStreamPart('text', ' world'), - formatStreamPart('text', '.'), - formatStreamPart('finish_message', { + formatDataStreamPart('text', 'Hello'), + formatDataStreamPart('text', ','), + formatDataStreamPart('text', ' world'), + formatDataStreamPart('text', '.'), + formatDataStreamPart('finish_message', { finishReason: 'stop', usage: { completionTokens: 1, promptTokens: 3 }, }), @@ -337,7 +337,7 @@ describe('form actions', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: ['Hello', ',', ' world', '.'].map(token => - formatStreamPart('text', token), + formatDataStreamPart('text', token), ), }); @@ -356,7 +356,7 @@ describe('form actions', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: ['How', ' can', ' I', ' help', ' you', '?'].map(token => - formatStreamPart('text', token), + formatDataStreamPart('text', token), ), }); @@ -381,7 +381,7 @@ describe('form actions (with options)', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: ['Hello', ',', ' world', '.'].map(token => - formatStreamPart('text', token), + formatDataStreamPart('text', token), ), }); @@ -400,7 +400,7 @@ describe('form actions (with options)', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: ['How', ' can', ' I', ' help', ' you', '?'].map(token => - formatStreamPart('text', token), + formatDataStreamPart('text', token), ), }); @@ -415,7 +415,7 @@ describe('form actions (with options)', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: ['The', ' sky', ' is', ' blue.'].map(token => - formatStreamPart('text', token), + formatDataStreamPart('text', token), ), }); @@ -505,7 +505,7 @@ describe('onToolCall', () => { mockFetchDataStream({ url: 'https://example.com/api/chat', chunks: [ - formatStreamPart('tool_call', { + formatDataStreamPart('tool_call', { toolCallId: 'tool-call-0', toolName: 'client-tool', args: { testArg: 'test-value' }, @@ -543,7 +543,7 @@ describe('tool invocations', () => { await userEvent.keyboard('{Enter}'); streamController.enqueue( - formatStreamPart('tool_call_streaming_start', { + formatDataStreamPart('tool_call_streaming_start', { toolCallId: 'tool-call-0', toolName: 'test-tool', }), @@ -558,7 +558,7 @@ describe('tool invocations', () => { }); streamController.enqueue( - formatStreamPart('tool_call_delta', { + formatDataStreamPart('tool_call_delta', { toolCallId: 'tool-call-0', argsTextDelta: '{"testArg":"t', }), @@ -571,7 +571,7 @@ describe('tool invocations', () => { }); streamController.enqueue( - formatStreamPart('tool_call_delta', { + formatDataStreamPart('tool_call_delta', { toolCallId: 'tool-call-0', argsTextDelta: 'est-value"}}', }), @@ -584,7 +584,7 @@ describe('tool invocations', () => { }); streamController.enqueue( - formatStreamPart('tool_call', { + formatDataStreamPart('tool_call', { toolCallId: 'tool-call-0', toolName: 'test-tool', args: { testArg: 'test-value' }, @@ -598,7 +598,7 @@ describe('tool invocations', () => { }); streamController.enqueue( - formatStreamPart('tool_result', { + formatDataStreamPart('tool_result', { toolCallId: 'tool-call-0', result: 'test-result', }), @@ -624,7 +624,7 @@ describe('tool invocations', () => { await userEvent.keyboard('{Enter}'); streamController.enqueue( - formatStreamPart('tool_call', { + formatDataStreamPart('tool_call', { toolCallId: 'tool-call-0', toolName: 'test-tool', args: { testArg: 'test-value' }, @@ -638,7 +638,7 @@ describe('tool invocations', () => { }); streamController.enqueue( - formatStreamPart('tool_result', { + formatDataStreamPart('tool_result', { toolCallId: 'tool-call-0', result: 'test-result', }),