Skip to content

Commit

Permalink
Merge pull request #9254 from weseek/feat/155690-implement-openai-thr…
Browse files Browse the repository at this point in the history
…ead-model

feat: Implement OpenAI thread-relation model
  • Loading branch information
yuki-takei authored Oct 21, 2024
2 parents e0ffd83 + cbfd89c commit 74cdd33
Show file tree
Hide file tree
Showing 11 changed files with 279 additions and 37 deletions.
3 changes: 2 additions & 1 deletion apps/app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@
"@testing-library/user-event": "^14.5.2",
"@types/express": "^4.17.21",
"@types/jest": "^29.5.2",
"@types/node-cron": "^3.0.11",
"@types/react-input-autosize": "^2.2.4",
"@types/react-scroll": "^1.8.4",
"@types/react-stickynode": "^4.0.3",
Expand Down Expand Up @@ -275,8 +276,8 @@
"react-hotkeys": "^2.0.0",
"react-input-autosize": "^3.0.0",
"react-toastify": "^9.1.3",
"remark-github-admonitions-to-directives": "^2.0.0",
"rehype-rewrite": "^4.0.2",
"remark-github-admonitions-to-directives": "^2.0.0",
"replacestream": "^4.0.3",
"sass": "^1.53.0",
"simple-load-script": "^1.0.2",
Expand Down
57 changes: 57 additions & 0 deletions apps/app/src/features/openai/server/models/thread-relation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import type mongoose from 'mongoose';
import { type Model, type Document, Schema } from 'mongoose';

import { getOrCreateModel } from '~/server/util/mongoose-utils';

const DAYS_UNTIL_EXPIRATION = 30;

const generateExpirationDate = (): Date => {
const currentDate = new Date();
const expirationDate = new Date(currentDate.setDate(currentDate.getDate() + DAYS_UNTIL_EXPIRATION));
return expirationDate;
};

interface ThreadRelation {
userId: mongoose.Types.ObjectId;
threadId: string;
expiredAt: Date;
}

interface ThreadRelationDocument extends ThreadRelation, Document {
updateThreadExpiration(): Promise<void>;
}

interface ThreadRelationModel extends Model<ThreadRelationDocument> {
getExpiredThreadRelations(limit?: number): Promise<ThreadRelationDocument[] | undefined>;
}

const schema = new Schema<ThreadRelationDocument, ThreadRelationModel>({
userId: {
type: Schema.Types.ObjectId,
ref: 'User',
required: true,
},
threadId: {
type: String,
required: true,
unique: true,
},
expiredAt: {
type: Date,
default: generateExpirationDate,
required: true,
},
});

schema.statics.getExpiredThreadRelations = async function(limit?: number): Promise<ThreadRelationDocument[] | undefined> {
const currentDate = new Date();
const expiredThreadRelations = await this.find({ expiredAt: { $lte: currentDate } }).limit(limit ?? 100).exec();
return expiredThreadRelations;
};

schema.methods.updateThreadExpiration = async function(): Promise<void> {
this.expiredAt = generateExpirationDate();
await this.save();
};

export default getOrCreateModel<ThreadRelationDocument, ThreadRelationModel>('ThreadRelation', schema);
29 changes: 7 additions & 22 deletions apps/app/src/features/openai/server/routes/thread.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import type { IUserHasId } from '@growi/core/dist/interfaces';
import type { Request, RequestHandler } from 'express';
import type { ValidationChain } from 'express-validator';
import { body } from 'express-validator';
import { filterXSS } from 'xss';

import type Crowi from '~/server/crowi';
import { apiV3FormValidator } from '~/server/middlewares/apiv3-form-validator';
import type { ApiV3Response } from '~/server/routes/apiv3/interfaces/apiv3-response';
import loggerFactory from '~/utils/logger';

import { openaiClient } from '../services';
import { getOpenaiService } from '../services/openai';

import { certifyAiService } from './middlewares/certify-ai-service';

const logger = loggerFactory('growi:routes:apiv3:openai:thread');

type CreateThreadReq = Request<undefined, ApiV3Response, {
userMessage: string,
threadId?: string,
}>
type CreateThreadReq = Request<undefined, ApiV3Response, { threadId?: string }> & { user: IUserHasId };

type CreateThreadFactory = (crowi: Crowi) => RequestHandler[];

Expand All @@ -32,24 +30,11 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
return [
accessTokenParser, loginRequiredStrictly, certifyAiService, validator, apiV3FormValidator,
async(req: CreateThreadReq, res: ApiV3Response) => {
const openaiService = getOpenaiService();
if (openaiService == null) {
return res.apiv3Err('OpenaiService is not available', 503);
}

try {
const vectorStore = await openaiService.getOrCreateVectorStoreForPublicScope();
const threadId = req.body.threadId;
const thread = threadId == null
? await openaiClient.beta.threads.create({
tool_resources: {
file_search: {
vector_store_ids: [vectorStore.vectorStoreId],
},
},
})
: await openaiClient.beta.threads.retrieve(threadId);

const openaiService = getOpenaiService();
const filterdThreadId = req.body.threadId != null ? filterXSS(req.body.threadId) : undefined;
const vectorStore = await openaiService?.getOrCreateVectorStoreForPublicScope();
const thread = await openaiService?.getOrCreateThread(req.user._id, vectorStore?.vectorStoreId, filterdThreadId);
return res.apiv3({ thread });
}
catch (err) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ export class AzureOpenaiClientDelegator implements IOpenaiClientDelegator {
// TODO: initialize openaiVectorStoreId property
}

async createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread> {
return this.client.beta.threads.create({
tool_resources: {
file_search: {
vector_store_ids: [vectorStoreId],
},
},
});
}

async retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread> {
return this.client.beta.threads.retrieve(threadId);
}

async deleteThread(threadId: string): Promise<OpenAI.Beta.Threads.ThreadDeleted> {
return this.client.beta.threads.del(threadId);
}

async createVectorStore(scopeType:VectorStoreScopeType): Promise<OpenAI.Beta.VectorStores.VectorStore> {
return this.client.beta.vectorStores.create({ name: `growi-vector-store-{${scopeType}` });
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import type { Uploadable } from 'openai/uploads';
import type { VectorStoreScopeType } from '~/features/openai/server/models/vector-store';

export interface IOpenaiClientDelegator {
createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread>
retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread>
deleteThread(threadId: string): Promise<OpenAI.Beta.Threads.ThreadDeleted>
retrieveVectorStore(vectorStoreId: string): Promise<OpenAI.Beta.VectorStores.VectorStore>
createVectorStore(scopeType:VectorStoreScopeType): Promise<OpenAI.Beta.VectorStores.VectorStore>
uploadFile(file: Uploadable): Promise<OpenAI.Files.FileObject>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ export class OpenaiClientDelegator implements IOpenaiClientDelegator {
this.client = new OpenAI({ apiKey });
}

async createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread> {
return this.client.beta.threads.create({
tool_resources: {
file_search: {
vector_store_ids: [vectorStoreId],
},
},
});
}

async retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread> {
return this.client.beta.threads.retrieve(threadId);
}

async deleteThread(threadId: string): Promise<OpenAI.Beta.Threads.ThreadDeleted> {
return this.client.beta.threads.del(threadId);
}

async createVectorStore(scopeType:VectorStoreScopeType): Promise<OpenAI.Beta.VectorStores.VectorStore> {
return this.client.beta.vectorStores.create({ name: `growi-vector-store-${scopeType}` });
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import OpenAI from 'openai';

import loggerFactory from '~/utils/logger';

const logger = loggerFactory('growi:service:openai');

// Error Code Reference
// https://platform.openai.com/docs/guides/error-codes/api-errors

// Error Handling Reference
// https://github.com/openai/openai-node/tree/d08bf1a8fa779e6a9349d92ddf65530dd84e686d?tab=readme-ov-file#handling-errors

type ErrorHandler = {
notFoundError?: () => Promise<void>;
}

export const oepnaiApiErrorHandler = async(error: unknown, handler: ErrorHandler): Promise<void> => {
if (!(error instanceof OpenAI.APIError)) {
return;
}

logger.error(error);

if (error.status === 404 && handler.notFoundError != null) {
await handler.notFoundError();
return;
}

};
90 changes: 76 additions & 14 deletions apps/app/src/features/openai/server/services/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import mongoose from 'mongoose';
import type OpenAI from 'openai';
import { toFile } from 'openai';

import ThreadRelationModel from '~/features/openai/server/models/thread-relation';
import VectorStoreModel, { VectorStoreScopeType, type VectorStoreDocument } from '~/features/openai/server/models/vector-store';
import VectorStoreFileRelationModel, {
type VectorStoreFileRelation,
Expand All @@ -19,8 +20,8 @@ import loggerFactory from '~/utils/logger';

import { OpenaiServiceTypes } from '../../interfaces/ai';


import { getClient } from './client-delegator';
import { oepnaiApiErrorHandler } from './openai-api-error-handler';

const BATCH_SIZE = 100;

Expand All @@ -29,7 +30,9 @@ const logger = loggerFactory('growi:service:openai');
let isVectorStoreForPublicScopeExist = false;

export interface IOpenaiService {
getOrCreateThread(userId: string, vectorStoreId?: string, threadId?: string): Promise<OpenAI.Beta.Threads.Thread | undefined>;
getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
deleteExpiredThreads(limit: number): Promise<void>;
createVectorStoreFile(pages: PageDocument[]): Promise<void>;
deleteVectorStoreFile(pageId: Types.ObjectId): Promise<void>;
rebuildVectorStoreAll(): Promise<void>;
Expand All @@ -42,6 +45,60 @@ class OpenaiService implements IOpenaiService {
return getClient({ openaiServiceType });
}

public async getOrCreateThread(userId: string, vectorStoreId?: string, threadId?: string): Promise<OpenAI.Beta.Threads.Thread> {
if (vectorStoreId != null && threadId == null) {
try {
const thread = await this.client.createThread(vectorStoreId);
await ThreadRelationModel.create({ userId, threadId: thread.id });
return thread;
}
catch (err) {
throw new Error(err);
}
}

const threadRelation = await ThreadRelationModel.findOne({ threadId });
if (threadRelation == null) {
throw new Error('ThreadRelation document is not exists');
}

// Check if a thread entity exists
// If the thread entity does not exist, the thread-relation document is deleted
try {
const thread = await this.client.retrieveThread(threadRelation.threadId);

// Update expiration date if thread entity exists
await threadRelation.updateThreadExpiration();

return thread;
}
catch (err) {
await oepnaiApiErrorHandler(err, { notFoundError: async() => { await threadRelation.remove() } });
throw new Error(err);
}
}

public async deleteExpiredThreads(limit: number): Promise<void> {
const expiredThreadRelations = await ThreadRelationModel.getExpiredThreadRelations(limit);
if (expiredThreadRelations == null) {
return;
}

const deletedThreadIds: string[] = [];
for await (const expiredThreadRelation of expiredThreadRelations) {
try {
const deleteThreadResponse = await this.client.deleteThread(expiredThreadRelation.threadId);
logger.debug('Delete thread', deleteThreadResponse);
deletedThreadIds.push(expiredThreadRelation.threadId);
}
catch (err) {
logger.error(err);
}
}

await ThreadRelationModel.deleteMany({ threadId: { $in: deletedThreadIds } });
}

public async getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument> {
const vectorStoreDocument = await VectorStoreModel.findOne({ scorpeType: VectorStoreScopeType.PUBLIC });

Expand All @@ -50,11 +107,17 @@ class OpenaiService implements IOpenaiService {
}

if (vectorStoreDocument != null && !isVectorStoreForPublicScopeExist) {
const vectorStore = await this.client.retrieveVectorStore(vectorStoreDocument.vectorStoreId);
if (vectorStore != null) {
try {
// Check if vector store entity exists
// If the vector store entity does not exist, the vector store document is deleted
await this.client.retrieveVectorStore(vectorStoreDocument.vectorStoreId);
isVectorStoreForPublicScopeExist = true;
return vectorStoreDocument;
}
catch (err) {
await oepnaiApiErrorHandler(err, { notFoundError: async() => { await vectorStoreDocument.remove() } });
throw new Error(err);
}
}

const newVectorStore = await this.client.createVectorStore(VectorStoreScopeType.PUBLIC);
Expand All @@ -74,7 +137,7 @@ class OpenaiService implements IOpenaiService {
return uploadedFile;
}

async createVectorStoreFile(pages: Array<PageDocument>): Promise<void> {
async createVectorStoreFile(pages: Array<HydratedDocument<PageDocument>>): Promise<void> {
const vectorStoreFileRelationsMap: Map<string, VectorStoreFileRelation> = new Map();
const processUploadFile = async(page: PageDocument) => {
if (page._id != null && page.grant === PageGrant.GRANT_PUBLIC && page.revision != null) {
Expand Down Expand Up @@ -112,22 +175,22 @@ class OpenaiService implements IOpenaiService {
}

try {
// Save vector store file relation
await VectorStoreFileRelationModel.upsertVectorStoreFileRelations(vectorStoreFileRelations);

// Create vector store file
const vectorStore = await this.getOrCreateVectorStoreForPublicScope();
const createVectorStoreFileBatchResponse = await this.client.createVectorStoreFileBatch(vectorStore.vectorStoreId, uploadedFileIds);
logger.debug('Create vector store file', createVectorStoreFileBatchResponse);

// Save vector store file relation
await VectorStoreFileRelationModel.upsertVectorStoreFileRelations(vectorStoreFileRelations);
}
catch (err) {
logger.error(err);

// Delete all uploaded files if createVectorStoreFileBatch fails
uploadedFileIds.forEach(async(fileId) => {
const deleteFileResponse = await this.client.deleteFile(fileId);
logger.debug('Delete vector store file (Due to createVectorStoreFileBatch failure)', deleteFileResponse);
});
const pageIds = pages.map(page => page._id);
for await (const pageId of pageIds) {
await this.deleteVectorStoreFile(pageId);
}
}

}
Expand All @@ -140,9 +203,8 @@ class OpenaiService implements IOpenaiService {
}

const deletedFileIds: string[] = [];
for (const fileId of vectorStoreFileRelation.fileIds) {
for await (const fileId of vectorStoreFileRelation.fileIds) {
try {
// eslint-disable-next-line no-await-in-loop
const deleteFileResponse = await this.client.deleteFile(fileId);
logger.debug('Delete vector store file', deleteFileResponse);
deletedFileIds.push(fileId);
Expand Down Expand Up @@ -174,7 +236,7 @@ class OpenaiService implements IOpenaiService {
const createVectorStoreFile = this.createVectorStoreFile.bind(this);
const createVectorStoreFileStream = new Transform({
objectMode: true,
async transform(chunk: PageDocument[], encoding, callback) {
async transform(chunk: HydratedDocument<PageDocument>[], encoding, callback) {
await createVectorStoreFile(chunk);
this.push(chunk);
callback();
Expand Down
Loading

0 comments on commit 74cdd33

Please sign in to comment.