Skip to content

Commit

Permalink
feat(frontend): regenerate with validation error (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
634750802 authored Sep 9, 2024
1 parent c145a8f commit 721eea5
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 32 deletions.
87 changes: 86 additions & 1 deletion frontend/app/src/components/chat/chat-controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { chat, type Chat, type ChatMessage, type PostChatParams } from '@/api/ch
import { ChatMessageController, type OngoingState } from '@/components/chat/chat-message-controller';
import { AppChatStreamState, chatDataPartSchema, type ChatMessageAnnotation, fixChatInitialData } from '@/components/chat/chat-stream-state';
import { getErrorMessage } from '@/lib/errors';
import { trigger } from '@/lib/react';
import { type JSONValue, type StreamPart } from 'ai';
import EventEmitter from 'eventemitter3';

Expand All @@ -21,6 +22,12 @@ export interface ChatControllerEventsMap {
'post-initialized': [];
'post-finished': [];
'post-error': [error: unknown];

/**
* Experimental
*/
'ui:input-mount': [HTMLTextAreaElement | HTMLInputElement];
'ui:input-unmount': [HTMLTextAreaElement | HTMLInputElement];
}

export class ChatController extends EventEmitter<ChatControllerEventsMap> {
Expand All @@ -32,6 +39,8 @@ export class ChatController extends EventEmitter<ChatControllerEventsMap> {
private _postError: unknown = undefined;
private _postInitialized: boolean = false;

private _inputElement: HTMLTextAreaElement | HTMLInputElement | null = null;

get postState () {
return {
params: this._postParams,
Expand All @@ -40,7 +49,12 @@ export class ChatController extends EventEmitter<ChatControllerEventsMap> {
};
}

constructor (chat: Chat | undefined = undefined, messages: ChatMessage[] | undefined = [], initialPost: Omit<PostChatParams, 'chat_id'> | undefined = undefined) {
constructor (
chat: Chat | undefined = undefined,
messages: ChatMessage[] | undefined = [],
initialPost: Omit<PostChatParams, 'chat_id'> | undefined = undefined,
inputElement: HTMLInputElement | HTMLTextAreaElement | null,
) {
super();
if (chat) {
this.chat = chat;
Expand All @@ -51,6 +65,77 @@ export class ChatController extends EventEmitter<ChatControllerEventsMap> {
if (initialPost) {
this.post(initialPost);
}
this._inputElement = inputElement;
if (inputElement) {
this.emit('ui:input-mount', inputElement);
}
}

get inputElement () {
return this._inputElement;
}

set inputElement (value: HTMLInputElement | HTMLTextAreaElement | null) {
if (this._inputElement) {
if (value) {
if (value !== this._inputElement) {
const old = this._inputElement;
this._inputElement = null;
this.emit('ui:input-unmount', old);

this._inputElement = value;
this.emit('ui:input-mount', value);
}
} else {
const old = this._inputElement;
this._inputElement = null;
this.emit('ui:input-unmount', old);
}
} else {
if (value) {
this._inputElement = value;
this.emit('ui:input-mount', value);
}
}
}

private get _enabledInputElement () {
if (!this._inputElement) {
console.warn('Input element is not exists.');
return;
}
if (this._inputElement.disabled) {
console.warn('Input element is disabled currently.');
return;
}

return this._inputElement;
}

get inputEnabled () {
if (!this._inputElement) {
return false;
}

return !this._inputElement.disabled;
}

get input (): string {
return this._inputElement?.value ?? '';
}

set input (value: string) {
const inputElement = this._enabledInputElement;
if (inputElement) {
trigger(inputElement as HTMLTextAreaElement, HTMLTextAreaElement, value);
}
}

focusInput () {
const inputElement = this._enabledInputElement;
if (inputElement) {
inputElement.focus();
}
}

get messages (): ChatMessageController[] {
Expand Down
19 changes: 14 additions & 5 deletions frontend/app/src/components/chat/chat-hooks.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,32 @@ export interface ChatMessageGroup {
assistant: ChatMessageController | undefined;
}

export function useChatController (id: string | undefined, initialChat: Chat | undefined, initialMessages: ChatMessage[] | undefined) {
const { chats, newChat } = useChats();
export function useChatController (
id: string | undefined,
initialChat: Chat | undefined,
initialMessages: ChatMessage[] | undefined,
inputElement: HTMLInputElement | HTMLTextAreaElement | null = null,
) {
const { chats } = useChats();

// Create essential chat controller
const [controller, setController] = useState(() => {
const [controller] = useState(() => {
if (id) {
let controller = chats.get(id);
if (!controller) {
controller = new ChatController(initialChat, initialMessages);
controller = new ChatController(initialChat, initialMessages, undefined, inputElement);
chats.set(id, controller);
}
return controller;
} else {
return new ChatController(undefined, undefined, undefined);
return new ChatController(undefined, undefined, undefined, inputElement);
}
});

useEffect(() => {
controller.inputElement = inputElement;
}, [controller, inputElement]);

return controller;
}

Expand Down
9 changes: 5 additions & 4 deletions frontend/app/src/components/chat/conversation.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ import { MessageInput } from '@/components/chat/message-input';
import { SecuritySettingContext, withReCaptcha } from '@/components/security-setting-provider';
import { useSize } from '@/components/use-size';
import { cn } from '@/lib/utils';
import { type ChangeEvent, type FormEvent, type ReactNode, useContext, useEffect, useState } from 'react';

import { type ChangeEvent, type FormEvent, type ReactNode, useContext, useRef, useState } from 'react';

export interface ConversationProps {
chatId?: string;
Expand All @@ -26,7 +25,9 @@ export interface ConversationProps {
}

export function Conversation ({ open, chat, chatId, history, placeholder, preventMutateBrowserHistory = false, preventShiftMessageInput = false, className }: ConversationProps) {
const controller = useChatController(chatId, chat, history);
const [inputElement, setInputElement] = useState<HTMLTextAreaElement | null>(null);

const controller = useChatController(chatId, chat, history, inputElement);
const postState = useChatPostState(controller);
const groups = useChatMessageGroups(useChatMessageControllers(controller));

Expand Down Expand Up @@ -70,7 +71,7 @@ export function Conversation ({ open, chat, chatId, history, placeholder, preven
<div className="h-24"></div>
</div>
{size && open && <form className={cn('block h-max p-4 fixed bottom-0', preventShiftMessageInput && 'absolute pb-0')} onSubmit={submitWithReCaptcha} style={{ left: preventShiftMessageInput ? 0 : size.x, width: size.width }}>
<MessageInput className="w-full transition-all" disabled={disabled} inputProps={{ value: input, onChange: handleInputChange, disabled }} />
<MessageInput inputRef={setInputElement} className="w-full transition-all" disabled={disabled} inputProps={{ value: input, onChange: handleInputChange, disabled }} />
</form>}
</ChatControllerProvider>
);
Expand Down
4 changes: 2 additions & 2 deletions frontend/app/src/components/chat/message-input.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@
import { cn } from '@/lib/utils';
import isHotkey from 'is-hotkey';
import { ArrowRightIcon } from 'lucide-react';
import { type ChangeEvent, type RefObject, useCallback, useRef, useState } from 'react';
import { type ChangeEvent, type Ref, type RefObject, useCallback, useRef, useState } from 'react';
import TextareaAutosize, { type TextareaAutosizeProps } from 'react-textarea-autosize';
import useSWR from 'swr';

export interface MessageInputProps {
className?: string,
disabled?: boolean,
inputRef?: RefObject<HTMLTextAreaElement>,
inputRef?: Ref<HTMLTextAreaElement>,
inputProps?: TextareaAutosizeProps,
engine?: string,
onEngineChange?: (name: string) => void,
Expand Down
22 changes: 11 additions & 11 deletions frontend/app/src/components/chat/message-operations.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { Button } from '@/components/ui/button';
import { TooltipProvider } from '@/components/ui/tooltip';
import { cn } from '@/lib/utils';
import copy from 'copy-to-clipboard';
import { ClipboardCheckIcon, ClipboardIcon, MessageSquareHeartIcon, MessageSquarePlusIcon, RefreshCwIcon } from 'lucide-react';
import { ClipboardCheckIcon, ClipboardIcon, MessageSquareHeartIcon, MessageSquarePlusIcon } from 'lucide-react';
import { useState } from 'react';

export function MessageOperations ({ message }: { message: ChatMessageController }) {
Expand All @@ -22,16 +22,16 @@ export function MessageOperations ({ message }: { message: ChatMessageController
return (
<TooltipProvider>
<div className="flex items-center gap-2">
<Button
size="sm"
className="gap-1 text-xs px-2 py-1 h-max"
variant="ghost"
onClick={() => controller.regenerate(message.id)}
disabled
>
<RefreshCwIcon size="1em" />
Regenerate
</Button>
{/*<Button*/}
{/* size="sm"*/}
{/* className="gap-1 text-xs px-2 py-1 h-max"*/}
{/* variant="ghost"*/}
{/* onClick={() => controller.regenerate(message.id)}*/}
{/* disabled*/}
{/*>*/}
{/* <RefreshCwIcon size="1em" />*/}
{/* Regenerate*/}
{/*</Button>*/}

<MessageFeedback
initial={feedbackData}
Expand Down
2 changes: 1 addition & 1 deletion frontend/app/src/components/chat/use-ask.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export function useAsk (onFinish?: () => void) {
toastError('Failed to chat', getErrorMessage(error));
};

const controller = newChat(undefined, undefined, { content: message, chat_engine: engineRef.current, headers: options?.headers });
const controller = newChat(undefined, undefined, { content: message, chat_engine: engineRef.current, headers: options?.headers }, null);

controller.once('created', chat => {
controller.off('post-error', handleInitialError);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { getVerify, isFinalVerifyState, isVisibleVerifyState, type MessageVerifyResponse, verify, VerifyStatus } from '#experimental/chat-verify-service/api';
import { useAuth } from '@/components/auth/AuthProvider';
import { useChatMessageField, useChatMessageStreamState } from '@/components/chat/chat-hooks';
import { useChatMessageField, useChatMessageStreamState, useCurrentChatController } from '@/components/chat/chat-hooks';
import type { ChatMessageController } from '@/components/chat/chat-message-controller';
import { isNotFinished } from '@/components/chat/utils';
import { Button } from '@/components/ui/button';
Expand All @@ -11,7 +10,7 @@ import { cn } from '@/lib/utils';
import { AnimatePresence, motion } from 'framer-motion';
import Highlight from 'highlight.js/lib/core';
import sql from 'highlight.js/lib/languages/sql';
import { CheckCircle2Icon, CheckIcon, ChevronDownIcon, CircleMinus, Loader2Icon, TriangleAlertIcon, XIcon } from 'lucide-react';
import { CheckCircle2Icon, CheckIcon, ChevronDownIcon, CircleMinus, Loader2Icon, RefreshCwIcon, TriangleAlertIcon, XIcon } from 'lucide-react';
import { type ReactElement, useEffect, useMemo, useState } from 'react';
import { format } from 'sql-formatter';
import useSWR from 'swr';
Expand All @@ -21,6 +20,7 @@ Highlight.registerLanguage('sql', sql);

export function MessageVerify ({ user, assistant }: { user: ChatMessageController | undefined, assistant: ChatMessageController | undefined }) {
const [open, setOpen] = useState(false);
const controller = useCurrentChatController();
const messageState = useChatMessageStreamState(assistant);
const question = useChatMessageField(user, 'content');
const answer = useChatMessageField(assistant, 'content');
Expand All @@ -29,15 +29,13 @@ export function MessageVerify ({ user, assistant }: { user: ChatMessageControlle

const externalRequestId = `${chat_id}_${message_id}`;

const me = useAuth();
const [verifyId, setVerifyId] = useState<string>();
const [verifying, setVerifying] = useState(false);
const [verifyError, setVerifyError] = useState<unknown>();

const serviceUrl = useExperimentalFeatures().message_verify_service;
const isSuperuser = !!me.me?.is_superuser;

const shouldPoll = serviceUrl && !!verifyId && !!assistant; // Remove isSuperuser check
const shouldPoll = serviceUrl && !!verifyId && !!assistant;
const { data: result, isLoading: isLoadingResult, error: pollError } = useSWR(
shouldPoll && `experimental.chat-message.${assistant.id}.verify`, () => getVerify(serviceUrl, verifyId!),
{
Expand Down Expand Up @@ -115,8 +113,24 @@ export function MessageVerify ({ user, assistant }: { user: ChatMessageControlle
</motion.div>}
</AnimatePresence>
</CollapsibleContent>
<div className="my-2 px-4 text-xs text-muted-foreground">
Powered by <a className="underline font-bold" href="https://www.pingcap.com/tidb-serverless/" target="_blank">TiDB Serverless</a>
<div className="my-2 px-4 flex items-center flex-wrap justify-between">
<div className="text-xs text-muted-foreground">
Powered by <a className="underline font-bold" href="https://www.pingcap.com/tidb-serverless/" target="_blank">TiDB Serverless</a>
</div>
{result?.status === VerifyStatus.FAILED && controller.inputEnabled && (
<Button
size="sm"
className="gap-1 text-xs px-2 py-1 h-max"
variant="ghost"
onClick={() => {
controller.input = composeRegenerateMessage(result);
controller.focusInput();
}}
>
<RefreshCwIcon size="1em" />
Regenerate with validation messages
</Button>
)}
</div>
</Collapsible>
);
Expand Down Expand Up @@ -222,3 +236,14 @@ function MessageVerifyRun ({ run }: { run: MessageVerifyResponse.Run }) {
</div>
);
}

function composeRegenerateMessage (result: MessageVerifyResponse) {
return `Below are the results of my verification of the SQL examples mentioned in the above answer on TiDB Serverless. I hope to use this to verify the correctness of the answer:
${result.runs.map(run => (`Explain: ${run.explanation}
SQL: ${run.sql}
SQL Result: ${(run.sql_error_code || run.sql_error_message) ? `${run.sql_error_code ?? '?????'} ${run.sql_error_message}` : JSON.stringify(run.results)}
Validation Result: ${run.success ? 'Success' : 'Failed'}`)).join('\n\n')}
`;

}
7 changes: 7 additions & 0 deletions frontend/app/src/lib/react.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
export function trigger<T extends typeof HTMLTextAreaElement | typeof HTMLInputElement> (inputElement: InstanceType<T>, Element: T, value: string) {
// https://stackoverflow.com/questions/23892547/what-is-the-best-way-to-trigger-change-or-input-event-in-react-js
const set = Object.getOwnPropertyDescriptor(Element.prototype, 'value')!.set!;
set.call(inputElement, value);
const event = new Event('input', { bubbles: true });
inputElement.dispatchEvent(event);
}

0 comments on commit 721eea5

Please sign in to comment.