Skip to content

Commit

Permalink
feat(answer): error handling (#4146)
Browse files Browse the repository at this point in the history
## Summary

Updates to the Answer Api client to make sure the error handling is
proper. Some little extras needed for it to work proper that will make
the subsequent review a bit easier.

## Why

The Answer Api is almost just a proxy for the stream coming from reveal.
And for some technical reasons, the call to the answer api will return
an error message but will not return downright error status codes like
400s when the stream fails.
That said, we can rely on the stream end message in case the stream is
not working properly, and since its better to give the user or the devs
all the information we can, this increment makes sure that the error
will appear in the logs and in the state in case we want to leverage the
error message in the front end component.

Co-authored-by: Danny Gauthier <[email protected]>
  • Loading branch information
dmgauthier and Danny Gauthier authored Jul 9, 2024
1 parent bd21dc3 commit cd679cc
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 162 deletions.
249 changes: 101 additions & 148 deletions packages/headless/src/api/knowledge/stream-answer-api.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import {
ArrayValue,
BooleanValue,
RecordValue,
Schema,
StringValue,
} from '@coveo/bueno';
import {
EventSourceMessage,
fetchEventSource,
} from '@microsoft/fetch-event-source';
import {selectFieldsToIncludeInCitation} from '../../features/generated-answer/generated-answer-selectors';
import {
GeneratedAnswerStyle,
GeneratedContentFormat,
Expand All @@ -30,31 +24,32 @@ import {answerSlice} from './answer-slice';
type StateNeededByAnswerAPI = ConfigurationSection &
GeneratedAnswerSection &
SearchSection &
DebugSection & {answer: ReturnType<typeof answerApi.reducer>};

interface ErrorPayload {
message?: string;
code?: number;
}

class FatalError extends Error {
constructor(public payload: ErrorPayload) {
super(payload.message);
}
}
DebugSection & {
answer: ReturnType<typeof answerApi.reducer>;
query?: QueryState;
searchHub?: string;
pipeline?: string;
};

interface GeneratedAnswerStream {
answerStyle: GeneratedAnswerStyle | undefined;
contentFormat: GeneratedContentFormat | undefined;
answer: string | undefined;
citations: GeneratedAnswerCitation[] | undefined;
generated: boolean;
answerStyle?: GeneratedAnswerStyle;
contentFormat?: GeneratedContentFormat;
answer?: string;
citations?: GeneratedAnswerCitation[];
generated?: boolean;
isStreaming: boolean;
isLoading: boolean;
error?: {message: string; code: number};
}
interface HeaderMessage {
answerStyle: GeneratedAnswerStyle;
contentFormat: GeneratedContentFormat;

interface StreamPayload
extends Pick<
GeneratedAnswerStream,
'answerStyle' | 'contentFormat' | 'citations'
> {
textDelta?: string;
padding?: string;
answerGenerated?: boolean;
}

type PayloadType =
Expand All @@ -63,55 +58,10 @@ type PayloadType =
| 'genqa.citationsType'
| 'genqa.endOfStreamType';

const headerMessageSchema = new Schema<HeaderMessage>({
answerStyle: new StringValue(),
contentFormat: new StringValue(),
});

const messageSchema = new Schema({
textDelta: new StringValue(),
});

const citationsSchema = new Schema({
citation: new ArrayValue(
new RecordValue({
values: {
clickUri: new StringValue(),
id: new StringValue(),
permanentid: new StringValue(),
text: new StringValue(),
title: new StringValue(),
uri: new StringValue(),
},
})
),
});

const validateHeaderMessage = (headerMessage: HeaderMessage) => {
headerMessageSchema.validate(headerMessage);
};

const validateMessage = (message: {textDelta: string}) => {
messageSchema.validate(message);
};

const validateCitationsMessage = (citations: {
citation: GeneratedAnswerCitation[];
}) => {
citationsSchema.validate(citations);
};

const validateEndOfStream = (endOfStream: {answerGenerated: boolean}) => {
new Schema({
answerGenerated: new BooleanValue(),
}).validate(endOfStream);
};

const handleHeaderMessage = (
draft: GeneratedAnswerStream,
payload: HeaderMessage
payload: Pick<GeneratedAnswerStream, 'answerStyle' | 'contentFormat'>
) => {
validateHeaderMessage(payload);
const {answerStyle, contentFormat} = payload;
draft.answerStyle = answerStyle;
draft.contentFormat = contentFormat;
Expand All @@ -121,84 +71,89 @@ const handleHeaderMessage = (

const handleMessage = (
draft: GeneratedAnswerStream,
payload: {textDelta: string}
payload: Pick<StreamPayload, 'textDelta'>
) => {
validateMessage(payload);
if (draft.answer === undefined) {
draft.answer = payload.textDelta;
} else {
} else if (typeof payload.textDelta === 'string') {
draft.answer = draft.answer.concat(payload.textDelta);
}
};

const handleCitations = (
draft: GeneratedAnswerStream,
payload: {citation: GeneratedAnswerCitation[]}
payload: Pick<StreamPayload, 'citations'>
) => {
validateCitationsMessage(payload);
draft.citations = payload.citation;
draft.citations = payload.citations;
};

const handleEndOfStream = (
draft: GeneratedAnswerStream,
payload: {answerGenerated: boolean}
payload: Pick<StreamPayload, 'answerGenerated'>
) => {
validateEndOfStream(payload);
draft.generated = payload.answerGenerated;
draft.isStreaming = false;
};

interface MessageType {
payloadType: PayloadType;
payload: string;
finishReason?: string;
errorMessage?: string;
code?: number;
}

const handleError = (
draft: GeneratedAnswerStream,
message: Required<MessageType>
) => {
draft.error = {
message: message.errorMessage,
code: message.code!,
};
draft.isStreaming = false;
draft.isLoading = false;
// Throwing an error here breaks the client and prevents the error from reaching the state.
console.error(`${message.errorMessage} - code ${message.code}`);
};

const updateCacheWithEvent = (
event: EventSourceMessage,
draft: GeneratedAnswerStream
) => {
const message: {payloadType: PayloadType; payload: string} = JSON.parse(
event.data
);
const parsedPayload = JSON.parse(message.payload);
const message: Required<MessageType> = JSON.parse(event.data);
if (message.finishReason === 'ERROR' && message.errorMessage) {
handleError(draft, message);
}

const parsedPayload: StreamPayload = message.payload.length
? JSON.parse(message.payload)
: {};

switch (message.payloadType) {
case 'genqa.headerMessageType':
handleHeaderMessage(draft, parsedPayload);
if (parsedPayload.answerStyle && parsedPayload.contentFormat) {
handleHeaderMessage(draft, parsedPayload);
}
break;
case 'genqa.messageType':
handleMessage(draft, parsedPayload);
if (parsedPayload.textDelta) {
handleMessage(draft, parsedPayload);
}
break;
case 'genqa.citationsType':
handleCitations(draft, parsedPayload);
if (parsedPayload.citations) {
handleCitations(draft, parsedPayload);
}
break;
case 'genqa.endOfStreamType':
handleEndOfStream(draft, parsedPayload);
if (parsedPayload.answerGenerated) {
handleEndOfStream(draft, parsedPayload);
}
break;
}
};

const onOpenStream = async (response: Response) => {
if (
response.ok &&
response.headers.get('content-type')?.includes('text/event-stream')
) {
return;
}

const isClientSideError =
response.status >= 400 && response.status < 500 && response.status !== 429;

if (isClientSideError) {
throw new FatalError({
message: 'Error opening stream',
code: response.status,
});
} else {
throw new Error();
}
};

const onError = (err: Error) => {
if (err instanceof FatalError) {
throw err;
}
};

export const answerApi = answerSlice.injectEndpoints({
overrideExisting: true,
endpoints: (builder) => ({
Expand All @@ -209,6 +164,7 @@ export const answerApi = answerSlice.injectEndpoints({
contentFormat: undefined,
answer: undefined,
citations: undefined,
error: undefined,
generated: false,
isStreaming: true,
isLoading: true,
Expand All @@ -224,11 +180,11 @@ export const answerApi = answerSlice.injectEndpoints({
* It cannot use the inferred state used by Redux, thus the casting.
* https://redux-toolkit.js.org/rtk-query/usage-with-typescript#typing-dispatch-and-getstate
*/
const {configuration} = getState() as unknown as StateNeededByAnswerAPI;
const {platformUrl, organizationId, accessToken, knowledge} =
configuration;
const {configuration, generatedAnswer} =
getState() as unknown as StateNeededByAnswerAPI;
const {platformUrl, organizationId, accessToken} = configuration;
await fetchEventSource(
`${platformUrl}/rest/organizations/${organizationId}/answer/v1/configs/${knowledge.answerConfigurationId}/generate`,
`${platformUrl}/rest/organizations/${organizationId}/answer/v1/configs/${generatedAnswer.answerConfigurationId}/generate`,
{
method: 'POST',
body: JSON.stringify(args),
Expand All @@ -239,49 +195,46 @@ export const answerApi = answerSlice.injectEndpoints({
'Accept-Encoding': '*',
},
fetch,
onopen: onOpenStream,
onmessage: (event) => {
updateCachedData((draft) => {
updateCacheWithEvent(event, draft);
});
},
onerror: onError,
onerror: (error) => {
throw error;
},
}
);
},
}),
}),
});

export const fetchAnswer = (
state: StateNeededByAnswerAPI & {
knowledge: ReturnType<typeof answerApi.reducer>;
query?: QueryState;
searchHub?: string;
pipeline?: string;
}
) => {
const query = selectQuery(state)?.q;
const constructAnswerQueryParams = (state: StateNeededByAnswerAPI) => {
const q = selectQuery(state)?.q;
const searchHub = selectSearchHub(state);
const pipeline = selectPipeline(state);

return answerApi.endpoints.getAnswer.initiate({
q: query,
const citationsFieldToInclude = selectFieldsToIncludeInCitation(state) ?? [];

return {
q,
pipelineRuleParameters: {
mlGenerativeQuestionAnswering: {
responseFormat: {
answerStyle: state.generatedAnswer.responseFormat.answerStyle,
},
citationsFieldToInclude,
},
},
...(searchHub?.length && {searchHub}),
...(pipeline?.length && {pipeline}),
});
};
};

export const selectAnswer = (
state: StateNeededByAnswerAPI & {
knowledge: ReturnType<typeof answerApi.reducer>;
query?: QueryState;
searchHub?: string;
pipeline?: string;
}
) =>
answerApi.endpoints.getAnswer.select({
q: selectQuery(state)?.q,
...(selectSearchHub(state)?.length && {searchHub: selectSearchHub(state)}),
...(selectPipeline(state)?.length && {pipeline: selectPipeline(state)}),
})(state);
export const fetchAnswer = (state: StateNeededByAnswerAPI) =>
answerApi.endpoints.getAnswer.initiate(constructAnswerQueryParams(state));

export const selectAnswer = (state: StateNeededByAnswerAPI) =>
answerApi.endpoints.getAnswer.select(constructAnswerQueryParams(state))(
state
);
3 changes: 2 additions & 1 deletion packages/headless/src/app/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
} from '@reduxjs/toolkit';
import {Logger} from 'pino';
import {getRelayInstanceFromState} from '../api/analytics/analytics-relay-client';
import {answerApi} from '../api/knowledge/stream-answer-api';
import {
disableAnalytics,
enableAnalytics,
Expand Down Expand Up @@ -403,7 +404,7 @@ function createMiddleware<Reducers extends ReducersMapObject>(
renewTokenMiddleware,
logActionErrorMiddleware(logger),
analyticsMiddleware,
].concat(options.middlewares || []);
].concat(answerApi.middleware, options.middlewares || []);
}

function shouldWarnAboutOrganizationEndpoints(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ export interface GeneratedAnswerPropsInitialState {
}

export interface GeneratedAnswerProps extends GeneratedAnswerPropsInitialState {
/**
* The answer configuration ID used to leverage coveo answer management capabilities.
*/
answerConfigurationId?: string;
/**
* A list of indexed fields to include in the citations returned with the generated answer.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ export const updateResponseFormat = createAction(
})
);

export const updateAnswerConfigurationId = createAction(
'knowledge/updateAnswerConfigurationId',
(payload: string) => validatePayload(payload, stringValue)
);

export const registerFieldsToIncludeInCitations = createAction(
'generatedAnswer/registerFieldsToIncludeInCitations',
(payload: string[]) => validatePayload<string[]>(payload, nonEmptyStringArray)
Expand Down
Loading

0 comments on commit cd679cc

Please sign in to comment.