From ac599aa47c49bde2d557a6f7317347c940b29b17 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Fri, 2 Aug 2024 18:00:42 +0800 Subject: [PATCH 01/12] add dalle3 model --- app/client/platforms/openai.ts | 90 ++++++++++++++++++++++++---------- app/components/chat.tsx | 34 +++++++++++++ app/constant.ts | 7 ++- app/store/chat.ts | 5 ++ app/utils.ts | 4 ++ 5 files changed, 113 insertions(+), 27 deletions(-) diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 680125fe6c4..28de30051ea 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -33,6 +33,7 @@ import { getMessageTextContent, getMessageImages, isVisionModel, + isDalle3 as _isDalle3, } from "@/app/utils"; export interface OpenAIListModelResponse { @@ -58,6 +59,13 @@ export interface RequestPayload { max_tokens?: number; } +export interface DalleRequestPayload { + model: string; + prompt: string; + n: number; + size: "1024x1024" | "1792x1024" | "1024x1792"; +} + export class ChatGPTApi implements LLMApi { private disableListModels = true; @@ -101,19 +109,25 @@ export class ChatGPTApi implements LLMApi { } extractMessage(res: any) { + if (res.error) { + return "```\n" + JSON.stringify(res, null, 4) + "\n```"; + } + // dalle3 model return url, just return + if (res.data) { + const url = res.data?.at(0)?.url ?? ""; + return [ + { + type: "image_url", + image_url: { + url, + }, + }, + ]; + } return res.choices?.at(0)?.message?.content ?? ""; } async chat(options: ChatOptions) { - const visionModel = isVisionModel(options.config.model); - const messages: ChatOptions["messages"] = []; - for (const v of options.messages) { - const content = visionModel - ? await preProcessImageContent(v.content) - : getMessageTextContent(v); - messages.push({ role: v.role, content }); - } - const modelConfig = { ...useAppConfig.getState().modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig, @@ -123,26 +137,48 @@ export class ChatGPTApi implements LLMApi { }, }; - const requestPayload: RequestPayload = { - messages, - stream: options.config.stream, - model: modelConfig.model, - temperature: modelConfig.temperature, - presence_penalty: modelConfig.presence_penalty, - frequency_penalty: modelConfig.frequency_penalty, - top_p: modelConfig.top_p, - // max_tokens: Math.max(modelConfig.max_tokens, 1024), - // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore. - }; + let requestPayload: RequestPayload | DalleRequestPayload; + + const isDalle3 = _isDalle3(options.config.model); + if (isDalle3) { + const prompt = getMessageTextContent(options.messages.slice(-1)?.pop()); + requestPayload = { + model: options.config.model, + prompt, + n: 1, + size: options.config?.size ?? "1024x1024", + }; + } else { + const visionModel = isVisionModel(options.config.model); + const messages: ChatOptions["messages"] = []; + for (const v of options.messages) { + const content = visionModel + ? await preProcessImageContent(v.content) + : getMessageTextContent(v); + messages.push({ role: v.role, content }); + } - // add max_tokens to vision model - if (visionModel && modelConfig.model.includes("preview")) { - requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000); + requestPayload = { + messages, + stream: options.config.stream, + model: modelConfig.model, + temperature: modelConfig.temperature, + presence_penalty: modelConfig.presence_penalty, + frequency_penalty: modelConfig.frequency_penalty, + top_p: modelConfig.top_p, + // max_tokens: Math.max(modelConfig.max_tokens, 1024), + // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore. + }; + + // add max_tokens to vision model + if (visionModel && modelConfig.model.includes("preview")) { + requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000); + } } console.log("[Request] openai payload: ", requestPayload); - const shouldStream = !!options.config.stream; + const shouldStream = !isDalle3 && !!options.config.stream; const controller = new AbortController(); options.onController?.(controller); @@ -168,13 +204,15 @@ export class ChatGPTApi implements LLMApi { model?.provider?.providerName === ServiceProvider.Azure, ); chatPath = this.path( - Azure.ChatPath( + (isDalle3 ? Azure.ImagePath : Azure.ChatPath)( (model?.displayName ?? model?.name) as string, useCustomConfig ? useAccessStore.getState().azureApiVersion : "", ), ); } else { - chatPath = this.path(OpenaiPath.ChatPath); + chatPath = this.path( + isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath, + ); } const chatPayload = { method: "POST", diff --git a/app/components/chat.tsx b/app/components/chat.tsx index bb4b611ad79..b95e85d45df 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -37,6 +37,7 @@ import AutoIcon from "../icons/auto.svg"; import BottomIcon from "../icons/bottom.svg"; import StopIcon from "../icons/pause.svg"; import RobotIcon from "../icons/robot.svg"; +import SizeIcon from "../icons/size.svg"; import PluginIcon from "../icons/plugin.svg"; import { @@ -60,6 +61,7 @@ import { getMessageTextContent, getMessageImages, isVisionModel, + isDalle3, } from "../utils"; import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; @@ -481,6 +483,11 @@ export function ChatActions(props: { const [showPluginSelector, setShowPluginSelector] = useState(false); const [showUploadImage, setShowUploadImage] = useState(false); + const [showSizeSelector, setShowSizeSelector] = useState(false); + const dalle3Sizes = ["1024x1024", "1792x1024", "1024x1792"]; + const currentSize = + chatStore.currentSession().mask.modelConfig?.size || "1024x1024"; + useEffect(() => { const show = isVisionModel(currentModel); setShowUploadImage(show); @@ -624,6 +631,33 @@ export function ChatActions(props: { /> )} + {isDalle3(currentModel) && ( + setShowSizeSelector(true)} + text={currentSize} + icon={} + /> + )} + + {showSizeSelector && ( + ({ + title: m, + value: m, + }))} + onClose={() => setShowSizeSelector(false)} + onSelection={(s) => { + if (s.length === 0) return; + const size = s[0]; + chatStore.updateCurrentSession((session) => { + session.mask.modelConfig.size = size; + }); + showToast(size); + }} + /> + )} + setShowPluginSelector(true)} text={Locale.Plugin.Name} diff --git a/app/constant.ts b/app/constant.ts index 5251b5b4fc9..b777872c8e0 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -146,6 +146,7 @@ export const Anthropic = { export const OpenaiPath = { ChatPath: "v1/chat/completions", + ImagePath: "v1/images/generations", UsagePath: "dashboard/billing/usage", SubsPath: "dashboard/billing/subscription", ListModelPath: "v1/models", @@ -154,7 +155,10 @@ export const OpenaiPath = { export const Azure = { ChatPath: (deployName: string, apiVersion: string) => `deployments/${deployName}/chat/completions?api-version=${apiVersion}`, - ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}", + // https://.openai.azure.com/openai/deployments//images/generations?api-version= + ImagePath: (deployName: string, apiVersion: string) => + `deployments/${deployName}/images/generations?api-version=${apiVersion}`, + ExampleEndpoint: "https://{resource-url}/openai", }; export const Google = { @@ -256,6 +260,7 @@ const openaiModels = [ "gpt-4-vision-preview", "gpt-4-turbo-2024-04-09", "gpt-4-1106-preview", + "dall-e-3", ]; const googleModels = [ diff --git a/app/store/chat.ts b/app/store/chat.ts index 5892ef0c8c6..7b47f3ec629 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -26,6 +26,7 @@ import { nanoid } from "nanoid"; import { createPersistStore } from "../utils/store"; import { collectModelsWithDefaultModel } from "../utils/model"; import { useAccessStore } from "./access"; +import { isDalle3 } from "../utils"; export type ChatMessage = RequestMessage & { date: string; @@ -541,6 +542,10 @@ export const useChatStore = createPersistStore( const config = useAppConfig.getState(); const session = get().currentSession(); const modelConfig = session.mask.modelConfig; + // skip summarize when using dalle3? + if (isDalle3(modelConfig.model)) { + return; + } const api: ClientApi = getClientApi(modelConfig.providerName); diff --git a/app/utils.ts b/app/utils.ts index 2f2c8ae95ab..a3c329b8239 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -265,3 +265,7 @@ export function isVisionModel(model: string) { visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo ); } + +export function isDalle3(model: string) { + return "dall-e-3" === model; +} From 1c24ca58c784775fb0d2cf9daa07949d329bd36a Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Fri, 2 Aug 2024 18:03:19 +0800 Subject: [PATCH 02/12] add dalle3 model --- app/icons/size.svg | 1 + 1 file changed, 1 insertion(+) create mode 100644 app/icons/size.svg diff --git a/app/icons/size.svg b/app/icons/size.svg new file mode 100644 index 00000000000..3da4fadfec6 --- /dev/null +++ b/app/icons/size.svg @@ -0,0 +1 @@ + From 46cb48023e6b2ffa52a44775b58a83a97dcffac2 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Fri, 2 Aug 2024 18:50:48 +0800 Subject: [PATCH 03/12] fix typescript error --- app/client/api.ts | 3 ++- app/client/platforms/openai.ts | 7 +++++-- app/components/chat.tsx | 5 +++-- app/store/config.ts | 2 ++ app/typing.ts | 2 ++ 5 files changed, 14 insertions(+), 5 deletions(-) diff --git a/app/client/api.ts b/app/client/api.ts index f10e4761887..88157e79cc7 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -6,7 +6,7 @@ import { ServiceProvider, } from "../constant"; import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store"; -import { ChatGPTApi } from "./platforms/openai"; +import { ChatGPTApi, DalleRequestPayload } from "./platforms/openai"; import { GeminiProApi } from "./platforms/google"; import { ClaudeApi } from "./platforms/anthropic"; import { ErnieApi } from "./platforms/baidu"; @@ -42,6 +42,7 @@ export interface LLMConfig { stream?: boolean; presence_penalty?: number; frequency_penalty?: number; + size?: DalleRequestPayload["size"]; } export interface ChatOptions { diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 28de30051ea..54309e29f7e 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -13,6 +13,7 @@ import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { collectModelsWithDefaultModel } from "@/app/utils/model"; import { preProcessImageContent } from "@/app/utils/chat"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; +import { DalleSize } from "@/app/typing"; import { ChatOptions, @@ -63,7 +64,7 @@ export interface DalleRequestPayload { model: string; prompt: string; n: number; - size: "1024x1024" | "1792x1024" | "1024x1792"; + size: DalleSize; } export class ChatGPTApi implements LLMApi { @@ -141,7 +142,9 @@ export class ChatGPTApi implements LLMApi { const isDalle3 = _isDalle3(options.config.model); if (isDalle3) { - const prompt = getMessageTextContent(options.messages.slice(-1)?.pop()); + const prompt = getMessageTextContent( + options.messages.slice(-1)?.pop() as any, + ); requestPayload = { model: options.config.model, prompt, diff --git a/app/components/chat.tsx b/app/components/chat.tsx index b95e85d45df..67ea80c4a85 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -69,6 +69,7 @@ import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; import dynamic from "next/dynamic"; import { ChatControllerPool } from "../client/controller"; +import { DalleSize } from "../typing"; import { Prompt, usePromptStore } from "../store/prompt"; import Locale from "../locales"; @@ -484,9 +485,9 @@ export function ChatActions(props: { const [showUploadImage, setShowUploadImage] = useState(false); const [showSizeSelector, setShowSizeSelector] = useState(false); - const dalle3Sizes = ["1024x1024", "1792x1024", "1024x1792"]; + const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"]; const currentSize = - chatStore.currentSession().mask.modelConfig?.size || "1024x1024"; + chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024"; useEffect(() => { const show = isVisionModel(currentModel); diff --git a/app/store/config.ts b/app/store/config.ts index 1eaafe12b1d..705a9d87c40 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -1,4 +1,5 @@ import { LLMModel } from "../client/api"; +import { DalleSize } from "../typing"; import { getClientConfig } from "../config/client"; import { DEFAULT_INPUT_TEMPLATE, @@ -60,6 +61,7 @@ export const DEFAULT_CONFIG = { compressMessageLengthThreshold: 1000, enableInjectSystemPrompts: true, template: config?.template ?? DEFAULT_INPUT_TEMPLATE, + size: "1024x1024" as DalleSize, }, }; diff --git a/app/typing.ts b/app/typing.ts index b09722ab902..86320358157 100644 --- a/app/typing.ts +++ b/app/typing.ts @@ -7,3 +7,5 @@ export interface RequestMessage { role: MessageRole; content: string; } + +export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792"; From 8c83fe23a1661d37644626e8d71130d96ce413f9 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Fri, 2 Aug 2024 20:58:21 +0800 Subject: [PATCH 04/12] using b64_json for dall-e-3 --- app/client/platforms/openai.ts | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 54309e29f7e..ee9a70913bd 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -11,7 +11,11 @@ import { } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { collectModelsWithDefaultModel } from "@/app/utils/model"; -import { preProcessImageContent } from "@/app/utils/chat"; +import { + preProcessImageContent, + uploadImage, + base64Image2Blob, +} from "@/app/utils/chat"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; import { DalleSize } from "@/app/typing"; @@ -63,6 +67,7 @@ export interface RequestPayload { export interface DalleRequestPayload { model: string; prompt: string; + response_format: "url" | "b64_json"; n: number; size: DalleSize; } @@ -109,13 +114,18 @@ export class ChatGPTApi implements LLMApi { return cloudflareAIGatewayUrl([baseUrl, path].join("/")); } - extractMessage(res: any) { + async extractMessage(res: any) { if (res.error) { return "```\n" + JSON.stringify(res, null, 4) + "\n```"; } - // dalle3 model return url, just return + // dalle3 model return url, using url create image message if (res.data) { - const url = res.data?.at(0)?.url ?? ""; + let url = res.data?.at(0)?.url ?? ""; + const b64_json = res.data?.at(0)?.b64_json ?? ""; + if (!url && b64_json) { + // uploadImage + url = await uploadImage(base64Image2Blob(b64_json, "image/png")); + } return [ { type: "image_url", @@ -148,6 +158,8 @@ export class ChatGPTApi implements LLMApi { requestPayload = { model: options.config.model, prompt, + // URLs are only valid for 60 minutes after the image has been generated. + response_format: "b64_json", // using b64_json, and save image in CacheStorage n: 1, size: options.config?.size ?? "1024x1024", }; @@ -227,7 +239,7 @@ export class ChatGPTApi implements LLMApi { // make a fetch request const requestTimeoutId = setTimeout( () => controller.abort(), - REQUEST_TIMEOUT_MS, + isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow. ); if (shouldStream) { @@ -358,7 +370,7 @@ export class ChatGPTApi implements LLMApi { clearTimeout(requestTimeoutId); const resJson = await res.json(); - const message = this.extractMessage(resJson); + const message = await this.extractMessage(resJson); options.onFinish(message); } } catch (e) { From 4a8e85c28a293c765ce73af6afb34aaa4840290e Mon Sep 17 00:00:00 2001 From: Dogtiti <499960698@qq.com> Date: Fri, 2 Aug 2024 22:16:08 +0800 Subject: [PATCH 05/12] fix: empty response --- app/client/platforms/openai.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index ee9a70913bd..8b03d1397e6 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -135,7 +135,7 @@ export class ChatGPTApi implements LLMApi { }, ]; } - return res.choices?.at(0)?.message?.content ?? ""; + return res.choices?.at(0)?.message?.content ?? res; } async chat(options: ChatOptions) { From 8a4b8a84d67bb7431c5ce88046d94963dceebad7 Mon Sep 17 00:00:00 2001 From: frostime Date: Sat, 3 Aug 2024 17:16:05 +0800 Subject: [PATCH 06/12] =?UTF-8?q?=E2=9C=A8=20feat:=20=E8=B0=83=E6=95=B4?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=88=97=E8=A1=A8=EF=BC=8C=E5=B0=86=E8=87=AA?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=E6=A8=A1=E5=9E=8B=E6=94=BE=E5=9C=A8=E5=89=8D?= =?UTF-8?q?=E9=9D=A2=E6=98=BE=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/utils/model.ts | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/app/utils/model.ts b/app/utils/model.ts index 4de0eb8d96a..6b1485e32ad 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -22,15 +22,6 @@ export function collectModelTable( } > = {}; - // default models - models.forEach((m) => { - // using @ as fullName - modelTable[`${m.name}@${m?.provider?.id}`] = { - ...m, - displayName: m.name, // 'provider' is copied over if it exists - }; - }); - // server custom models customModels .split(",") @@ -89,6 +80,15 @@ export function collectModelTable( } }); + // default models + models.forEach((m) => { + // using @ as fullName + modelTable[`${m.name}@${m?.provider?.id}`] = { + ...m, + displayName: m.name, // 'provider' is copied over if it exists + }; + }); + return modelTable; } @@ -99,13 +99,16 @@ export function collectModelTableWithDefaultModel( ) { let modelTable = collectModelTable(models, customModels); if (defaultModel && defaultModel !== "") { - if (defaultModel.includes('@')) { + if (defaultModel.includes("@")) { if (defaultModel in modelTable) { modelTable[defaultModel].isDefault = true; } } else { for (const key of Object.keys(modelTable)) { - if (modelTable[key].available && key.split('@').shift() == defaultModel) { + if ( + modelTable[key].available && + key.split("@").shift() == defaultModel + ) { modelTable[key].isDefault = true; break; } From 1610675c8f956345b799be92fc1dbf4ba81c18f2 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Mon, 5 Aug 2024 11:36:35 +0800 Subject: [PATCH 07/12] remove hash.js --- app/utils/hmac.ts | 246 +++++++++++++++++++++++++++++++++++++++++++ app/utils/tencent.ts | 18 +--- package.json | 1 - yarn.lock | 15 +-- 4 files changed, 251 insertions(+), 29 deletions(-) create mode 100644 app/utils/hmac.ts diff --git a/app/utils/hmac.ts b/app/utils/hmac.ts new file mode 100644 index 00000000000..96292dac357 --- /dev/null +++ b/app/utils/hmac.ts @@ -0,0 +1,246 @@ +// From https://gist.github.com/guillermodlpa/f6d955f838e9b10d1ef95b8e259b2c58 +// From https://gist.github.com/stevendesu/2d52f7b5e1f1184af3b667c0b5e054b8 + +// To ensure cross-browser support even without a proper SubtleCrypto +// impelmentation (or without access to the impelmentation, as is the case with +// Chrome loaded over HTTP instead of HTTPS), this library can create SHA-256 +// HMAC signatures using nothing but raw JavaScript + +/* eslint-disable no-magic-numbers, id-length, no-param-reassign, new-cap */ + +// By giving internal functions names that we can mangle, future calls to +// them are reduced to a single byte (minor space savings in minified file) +const uint8Array = Uint8Array; +const uint32Array = Uint32Array; +const pow = Math.pow; + +// Will be initialized below +// Using a Uint32Array instead of a simple array makes the minified code +// a bit bigger (we lose our `unshift()` hack), but comes with huge +// performance gains +const DEFAULT_STATE = new uint32Array(8); +const ROUND_CONSTANTS: number[] = []; + +// Reusable object for expanded message +// Using a Uint32Array instead of a simple array makes the minified code +// 7 bytes larger, but comes with huge performance gains +const M = new uint32Array(64); + +// After minification the code to compute the default state and round +// constants is smaller than the output. More importantly, this serves as a +// good educational aide for anyone wondering where the magic numbers come +// from. No magic numbers FTW! +function getFractionalBits(n: number) { + return ((n - (n | 0)) * pow(2, 32)) | 0; +} + +let n = 2; +let nPrime = 0; +while (nPrime < 64) { + // isPrime() was in-lined from its original function form to save + // a few bytes + let isPrime = true; + // Math.sqrt() was replaced with pow(n, 1/2) to save a few bytes + // var sqrtN = pow(n, 1 / 2); + // So technically to determine if a number is prime you only need to + // check numbers up to the square root. However this function only runs + // once and we're only computing the first 64 primes (up to 311), so on + // any modern CPU this whole function runs in a couple milliseconds. + // By going to n / 2 instead of sqrt(n) we net 8 byte savings and no + // scaling performance cost + for (let factor = 2; factor <= n / 2; factor++) { + if (n % factor === 0) { + isPrime = false; + } + } + if (isPrime) { + if (nPrime < 8) { + DEFAULT_STATE[nPrime] = getFractionalBits(pow(n, 1 / 2)); + } + ROUND_CONSTANTS[nPrime] = getFractionalBits(pow(n, 1 / 3)); + + nPrime++; + } + + n++; +} + +// For cross-platform support we need to ensure that all 32-bit words are +// in the same endianness. A UTF-8 TextEncoder will return BigEndian data, +// so upon reading or writing to our ArrayBuffer we'll only swap the bytes +// if our system is LittleEndian (which is about 99% of CPUs) +const LittleEndian = !!new uint8Array(new uint32Array([1]).buffer)[0]; + +function convertEndian(word: number) { + if (LittleEndian) { + return ( + // byte 1 -> byte 4 + (word >>> 24) | + // byte 2 -> byte 3 + (((word >>> 16) & 0xff) << 8) | + // byte 3 -> byte 2 + ((word & 0xff00) << 8) | + // byte 4 -> byte 1 + (word << 24) + ); + } else { + return word; + } +} + +function rightRotate(word: number, bits: number) { + return (word >>> bits) | (word << (32 - bits)); +} + +function sha256(data: Uint8Array) { + // Copy default state + const STATE = DEFAULT_STATE.slice(); + + // Caching this reduces occurrences of ".length" in minified JavaScript + // 3 more byte savings! :D + const legth = data.length; + + // Pad data + const bitLength = legth * 8; + const newBitLength = 512 - ((bitLength + 64) % 512) - 1 + bitLength + 65; + + // "bytes" and "words" are stored BigEndian + const bytes = new uint8Array(newBitLength / 8); + const words = new uint32Array(bytes.buffer); + + bytes.set(data, 0); + // Append a 1 + bytes[legth] = 0b10000000; + // Store length in BigEndian + words[words.length - 1] = convertEndian(bitLength); + + // Loop iterator (avoid two instances of "var") -- saves 2 bytes + let round; + + // Process blocks (512 bits / 64 bytes / 16 words at a time) + for (let block = 0; block < newBitLength / 32; block += 16) { + const workingState = STATE.slice(); + + // Rounds + for (round = 0; round < 64; round++) { + let MRound; + // Expand message + if (round < 16) { + // Convert to platform Endianness for later math + MRound = convertEndian(words[block + round]); + } else { + const gamma0x = M[round - 15]; + const gamma1x = M[round - 2]; + MRound = + M[round - 7] + + M[round - 16] + + (rightRotate(gamma0x, 7) ^ + rightRotate(gamma0x, 18) ^ + (gamma0x >>> 3)) + + (rightRotate(gamma1x, 17) ^ + rightRotate(gamma1x, 19) ^ + (gamma1x >>> 10)); + } + + // M array matches platform endianness + M[round] = MRound |= 0; + + // Computation + const t1 = + (rightRotate(workingState[4], 6) ^ + rightRotate(workingState[4], 11) ^ + rightRotate(workingState[4], 25)) + + ((workingState[4] & workingState[5]) ^ + (~workingState[4] & workingState[6])) + + workingState[7] + + MRound + + ROUND_CONSTANTS[round]; + const t2 = + (rightRotate(workingState[0], 2) ^ + rightRotate(workingState[0], 13) ^ + rightRotate(workingState[0], 22)) + + ((workingState[0] & workingState[1]) ^ + (workingState[2] & (workingState[0] ^ workingState[1]))); + for (let i = 7; i > 0; i--) { + workingState[i] = workingState[i - 1]; + } + workingState[0] = (t1 + t2) | 0; + workingState[4] = (workingState[4] + t1) | 0; + } + + // Update state + for (round = 0; round < 8; round++) { + STATE[round] = (STATE[round] + workingState[round]) | 0; + } + } + + // Finally the state needs to be converted to BigEndian for output + // And we want to return a Uint8Array, not a Uint32Array + return new uint8Array( + new uint32Array( + STATE.map(function (val) { + return convertEndian(val); + }), + ).buffer, + ); +} + +function hmac(key: Uint8Array, data: ArrayLike) { + if (key.length > 64) key = sha256(key); + + if (key.length < 64) { + const tmp = new Uint8Array(64); + tmp.set(key, 0); + key = tmp; + } + + // Generate inner and outer keys + const innerKey = new Uint8Array(64); + const outerKey = new Uint8Array(64); + for (let i = 0; i < 64; i++) { + innerKey[i] = 0x36 ^ key[i]; + outerKey[i] = 0x5c ^ key[i]; + } + + // Append the innerKey + const msg = new Uint8Array(data.length + 64); + msg.set(innerKey, 0); + msg.set(data, 64); + + // Has the previous message and append the outerKey + const result = new Uint8Array(64 + 32); + result.set(outerKey, 0); + result.set(sha256(msg), 64); + + // Hash the previous message + return sha256(result); +} + +// Convert a string to a Uint8Array, SHA-256 it, and convert back to string +const encoder = new TextEncoder(); + +export function sign( + inputKey: string | Uint8Array, + inputData: string | Uint8Array, +) { + const key = + typeof inputKey === "string" ? encoder.encode(inputKey) : inputKey; + const data = + typeof inputData === "string" ? encoder.encode(inputData) : inputData; + return hmac(key, data); +} + +export function hex(bin: Uint8Array) { + return bin.reduce((acc, val) => { + const hexVal = "00" + val.toString(16); + return acc + hexVal.substring(hexVal.length - 2); + }, ""); +} + +export function hash(str: string) { + return hex(sha256(encoder.encode(str))); +} + +export function hashWithSecret(str: string, secret: string) { + return hex(sign(secret, str)).toString(); +} diff --git a/app/utils/tencent.ts b/app/utils/tencent.ts index f0cdd21ee17..92772703cf8 100644 --- a/app/utils/tencent.ts +++ b/app/utils/tencent.ts @@ -1,19 +1,9 @@ -import hash from "hash.js"; +import { sign, hash as getHash, hex } from "./hmac"; // 使用 SHA-256 和 secret 进行 HMAC 加密 -function sha256(message: any, secret = "", encoding?: string) { - return hash - .hmac(hash.sha256 as any, secret) - .update(message) - .digest(encoding as any); -} - -// 使用 SHA-256 进行哈希 -function getHash(message: any, encoding = "hex") { - return hash - .sha256() - .update(message) - .digest(encoding as any); +function sha256(message: any, secret: any, encoding?: string) { + const result = sign(secret, message); + return encoding == "hex" ? hex(result).toString() : result; } function getDate(timestamp: number) { diff --git a/package.json b/package.json index 001b28eac06..eb0a5ef6735 100644 --- a/package.json +++ b/package.json @@ -26,7 +26,6 @@ "@vercel/speed-insights": "^1.0.2", "emoji-picker-react": "^4.9.2", "fuse.js": "^7.0.0", - "hash.js": "^1.1.7", "heic2any": "^0.0.4", "html-to-image": "^1.11.11", "lodash-es": "^4.17.21", diff --git a/yarn.lock b/yarn.lock index 09bf322964d..793c845d722 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3799,14 +3799,6 @@ has@^1.0.3: dependencies: function-bind "^1.1.1" -hash.js@^1.1.7: - version "1.1.7" - resolved "https://registry.npmjs.org/hash.js/-/hash.js-1.1.7.tgz#0babca538e8d4ee4a0f8988d68866537a003cf42" - integrity sha512-taOaskGt4z4SOANNseOviYDvjEJinIkRgmp7LbKP2YTTmVxWBl87s/uzK9r+44BclBSp2X7K1hqeNfz9JbBeXA== - dependencies: - inherits "^2.0.3" - minimalistic-assert "^1.0.1" - hast-util-from-dom@^4.0.0: version "4.2.0" resolved "https://registry.yarnpkg.com/hast-util-from-dom/-/hast-util-from-dom-4.2.0.tgz#25836ddecc3cc0849d32749c2a7aec03e94b59a7" @@ -3970,7 +3962,7 @@ inflight@^1.0.4: once "^1.3.0" wrappy "1" -inherits@2, inherits@^2.0.3: +inherits@2: version "2.0.4" resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.4.tgz#0fa2c64f932917c3433a0ded55363aae37416b7c" integrity sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ== @@ -4962,11 +4954,6 @@ mimic-fn@^4.0.0: resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-4.0.0.tgz#60a90550d5cb0b239cca65d893b1a53b29871ecc" integrity sha512-vqiC06CuhBTUdZH+RYl8sFrL096vA45Ok5ISO6sE/Mr1jRbGH4Csnhi8f3wKVl7x8mO4Au7Ir9D3Oyv1VYMFJw== -minimalistic-assert@^1.0.1: - version "1.0.1" - resolved "https://registry.npmjs.org/minimalistic-assert/-/minimalistic-assert-1.0.1.tgz#2e194de044626d4a10e7f7fbc00ce73e83e4d5c7" - integrity sha512-UtJcAD4yEaGtjPezWuO9wC4nwUnVH/8/Im3yEHQP4b67cXlD/Qr9hdITCU1xDbSEXg2XKNaP8jsReV7vQd00/A== - minimatch@^3.0.4, minimatch@^3.0.5, minimatch@^3.1.1, minimatch@^3.1.2: version "3.1.2" resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.1.2.tgz#19cd194bfd3e428f049a70817c038d89ab4be35b" From 4a95dcb6e96aa020ca2db1ea3c3175eb7d3fce84 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Mon, 5 Aug 2024 12:45:25 +0800 Subject: [PATCH 08/12] hotfix get wrong llm --- app/client/api.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/app/client/api.ts b/app/client/api.ts index f10e4761887..abff459c5f8 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -118,6 +118,7 @@ export class ClientApi { break; case ModelProvider.Qwen: this.llm = new QwenApi(); + break; case ModelProvider.Hunyuan: this.llm = new HunyuanApi(); break; From 141ce2c99ae61e12dae21e34eba644bede70d310 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Mon, 5 Aug 2024 12:59:27 +0800 Subject: [PATCH 09/12] reduce cloudflare functions build size --- app/api/[provider]/[...path]/route.ts | 64 +++++++++++++++++++ .../[...path]/route.ts => alibaba.ts} | 26 +------- .../[...path]/route.ts => anthropic.ts} | 28 +------- .../{azure/[...path]/route.ts => azure.ts} | 30 +-------- .../{baidu/[...path]/route.ts => baidu.ts} | 26 +------- .../[...path]/route.ts => bytedance.ts} | 26 +------- .../{google/[...path]/route.ts => google.ts} | 6 +- .../[...path]/route.ts => moonshot.ts} | 26 +------- .../{openai/[...path]/route.ts => openai.ts} | 30 +-------- .../[...path]/route.ts => stability.ts} | 7 +- 10 files changed, 80 insertions(+), 189 deletions(-) create mode 100644 app/api/[provider]/[...path]/route.ts rename app/api/{alibaba/[...path]/route.ts => alibaba.ts} (91%) rename app/api/{anthropic/[...path]/route.ts => anthropic.ts} (92%) rename app/api/{azure/[...path]/route.ts => azure.ts} (66%) rename app/api/{baidu/[...path]/route.ts => baidu.ts} (91%) rename app/api/{bytedance/[...path]/route.ts => bytedance.ts} (90%) rename app/api/{google/[...path]/route.ts => google.ts} (96%) rename app/api/{moonshot/[...path]/route.ts => moonshot.ts} (91%) rename app/api/{openai/[...path]/route.ts => openai.ts} (82%) rename app/api/{stability/[...path]/route.ts => stability.ts} (95%) diff --git a/app/api/[provider]/[...path]/route.ts b/app/api/[provider]/[...path]/route.ts new file mode 100644 index 00000000000..6d028ac364d --- /dev/null +++ b/app/api/[provider]/[...path]/route.ts @@ -0,0 +1,64 @@ +import { ApiPath } from "@/app/constant"; +import { NextRequest, NextResponse } from "next/server"; +import { handle as openaiHandler } from "../../openai"; +import { handle as azureHandler } from "../../azure"; +import { handle as googleHandler } from "../../google"; +import { handle as anthropicHandler } from "../../anthropic"; +import { handle as baiduHandler } from "../../baidu"; +import { handle as bytedanceHandler } from "../../bytedance"; +import { handle as alibabaHandler } from "../../alibaba"; +import { handle as moonshotHandler } from "../../moonshot"; +import { handle as stabilityHandler } from "../../stability"; + +async function handle( + req: NextRequest, + { params }: { params: { provider: string; path: string[] } }, +) { + const apiPath = `/api/${params.provider}`; + console.log(`[${params.provider} Route] params `, params); + switch (apiPath) { + case ApiPath.Azure: + return azureHandler(req, { params }); + case ApiPath.Google: + return googleHandler(req, { params }); + case ApiPath.Anthropic: + return anthropicHandler(req, { params }); + case ApiPath.Baidu: + return baiduHandler(req, { params }); + case ApiPath.ByteDance: + return bytedanceHandler(req, { params }); + case ApiPath.Alibaba: + return alibabaHandler(req, { params }); + // case ApiPath.Tencent: using "/api/tencent" + case ApiPath.Moonshot: + return moonshotHandler(req, { params }); + case ApiPath.Stability: + return stabilityHandler(req, { params }); + default: + return openaiHandler(req, { params }); + } +} + +export const GET = handle; +export const POST = handle; + +export const runtime = "edge"; +export const preferredRegion = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; diff --git a/app/api/alibaba/[...path]/route.ts b/app/api/alibaba.ts similarity index 91% rename from app/api/alibaba/[...path]/route.ts rename to app/api/alibaba.ts index c97ce593473..675d9f301aa 100644 --- a/app/api/alibaba/[...path]/route.ts +++ b/app/api/alibaba.ts @@ -14,7 +14,7 @@ import type { RequestPayload } from "@/app/client/platforms/openai"; const serverConfig = getServerSideConfig(); -async function handle( +export async function handle( req: NextRequest, { params }: { params: { path: string[] } }, ) { @@ -40,30 +40,6 @@ async function handle( } } -export const GET = handle; -export const POST = handle; - -export const runtime = "edge"; -export const preferredRegion = [ - "arn1", - "bom1", - "cdg1", - "cle1", - "cpt1", - "dub1", - "fra1", - "gru1", - "hnd1", - "iad1", - "icn1", - "kix1", - "lhr1", - "pdx1", - "sfo1", - "sin1", - "syd1", -]; - async function request(req: NextRequest) { const controller = new AbortController(); diff --git a/app/api/anthropic/[...path]/route.ts b/app/api/anthropic.ts similarity index 92% rename from app/api/anthropic/[...path]/route.ts rename to app/api/anthropic.ts index 20f8d52e062..3d49f4c88c4 100644 --- a/app/api/anthropic/[...path]/route.ts +++ b/app/api/anthropic.ts @@ -9,13 +9,13 @@ import { } from "@/app/constant"; import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; -import { auth } from "../../auth"; +import { auth } from "./auth"; import { isModelAvailableInServer } from "@/app/utils/model"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; const ALLOWD_PATH = new Set([Anthropic.ChatPath, Anthropic.ChatPath1]); -async function handle( +export async function handle( req: NextRequest, { params }: { params: { path: string[] } }, ) { @@ -56,30 +56,6 @@ async function handle( } } -export const GET = handle; -export const POST = handle; - -export const runtime = "edge"; -export const preferredRegion = [ - "arn1", - "bom1", - "cdg1", - "cle1", - "cpt1", - "dub1", - "fra1", - "gru1", - "hnd1", - "iad1", - "icn1", - "kix1", - "lhr1", - "pdx1", - "sfo1", - "sin1", - "syd1", -]; - const serverConfig = getServerSideConfig(); async function request(req: NextRequest) { diff --git a/app/api/azure/[...path]/route.ts b/app/api/azure.ts similarity index 66% rename from app/api/azure/[...path]/route.ts rename to app/api/azure.ts index 4a17de0c8ab..e2cb0c7e66b 100644 --- a/app/api/azure/[...path]/route.ts +++ b/app/api/azure.ts @@ -2,10 +2,10 @@ import { getServerSideConfig } from "@/app/config/server"; import { ModelProvider } from "@/app/constant"; import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; -import { auth } from "../../auth"; -import { requestOpenai } from "../../common"; +import { auth } from "./auth"; +import { requestOpenai } from "./common"; -async function handle( +export async function handle( req: NextRequest, { params }: { params: { path: string[] } }, ) { @@ -31,27 +31,3 @@ async function handle( return NextResponse.json(prettyObject(e)); } } - -export const GET = handle; -export const POST = handle; - -export const runtime = "edge"; -export const preferredRegion = [ - "arn1", - "bom1", - "cdg1", - "cle1", - "cpt1", - "dub1", - "fra1", - "gru1", - "hnd1", - "iad1", - "icn1", - "kix1", - "lhr1", - "pdx1", - "sfo1", - "sin1", - "syd1", -]; diff --git a/app/api/baidu/[...path]/route.ts b/app/api/baidu.ts similarity index 91% rename from app/api/baidu/[...path]/route.ts rename to app/api/baidu.ts index 94c9963c7e9..f4315d186da 100644 --- a/app/api/baidu/[...path]/route.ts +++ b/app/api/baidu.ts @@ -14,7 +14,7 @@ import { getAccessToken } from "@/app/utils/baidu"; const serverConfig = getServerSideConfig(); -async function handle( +export async function handle( req: NextRequest, { params }: { params: { path: string[] } }, ) { @@ -52,30 +52,6 @@ async function handle( } } -export const GET = handle; -export const POST = handle; - -export const runtime = "edge"; -export const preferredRegion = [ - "arn1", - "bom1", - "cdg1", - "cle1", - "cpt1", - "dub1", - "fra1", - "gru1", - "hnd1", - "iad1", - "icn1", - "kix1", - "lhr1", - "pdx1", - "sfo1", - "sin1", - "syd1", -]; - async function request(req: NextRequest) { const controller = new AbortController(); diff --git a/app/api/bytedance/[...path]/route.ts b/app/api/bytedance.ts similarity index 90% rename from app/api/bytedance/[...path]/route.ts rename to app/api/bytedance.ts index 336c837f037..cb65b106109 100644 --- a/app/api/bytedance/[...path]/route.ts +++ b/app/api/bytedance.ts @@ -12,7 +12,7 @@ import { isModelAvailableInServer } from "@/app/utils/model"; const serverConfig = getServerSideConfig(); -async function handle( +export async function handle( req: NextRequest, { params }: { params: { path: string[] } }, ) { @@ -38,30 +38,6 @@ async function handle( } } -export const GET = handle; -export const POST = handle; - -export const runtime = "edge"; -export const preferredRegion = [ - "arn1", - "bom1", - "cdg1", - "cle1", - "cpt1", - "dub1", - "fra1", - "gru1", - "hnd1", - "iad1", - "icn1", - "kix1", - "lhr1", - "pdx1", - "sfo1", - "sin1", - "syd1", -]; - async function request(req: NextRequest) { const controller = new AbortController(); diff --git a/app/api/google/[...path]/route.ts b/app/api/google.ts similarity index 96% rename from app/api/google/[...path]/route.ts rename to app/api/google.ts index 83a7ce794c1..98fe469bfb7 100644 --- a/app/api/google/[...path]/route.ts +++ b/app/api/google.ts @@ -1,5 +1,5 @@ import { NextRequest, NextResponse } from "next/server"; -import { auth } from "../../auth"; +import { auth } from "./auth"; import { getServerSideConfig } from "@/app/config/server"; import { ApiPath, @@ -11,9 +11,9 @@ import { prettyObject } from "@/app/utils/format"; const serverConfig = getServerSideConfig(); -async function handle( +export async function handle( req: NextRequest, - { params }: { params: { path: string[] } }, + { params }: { params: { provider: string; path: string[] } }, ) { console.log("[Google Route] params ", params); diff --git a/app/api/moonshot/[...path]/route.ts b/app/api/moonshot.ts similarity index 91% rename from app/api/moonshot/[...path]/route.ts rename to app/api/moonshot.ts index 14bc0a40d92..247dd618321 100644 --- a/app/api/moonshot/[...path]/route.ts +++ b/app/api/moonshot.ts @@ -14,7 +14,7 @@ import type { RequestPayload } from "@/app/client/platforms/openai"; const serverConfig = getServerSideConfig(); -async function handle( +export async function handle( req: NextRequest, { params }: { params: { path: string[] } }, ) { @@ -40,30 +40,6 @@ async function handle( } } -export const GET = handle; -export const POST = handle; - -export const runtime = "edge"; -export const preferredRegion = [ - "arn1", - "bom1", - "cdg1", - "cle1", - "cpt1", - "dub1", - "fra1", - "gru1", - "hnd1", - "iad1", - "icn1", - "kix1", - "lhr1", - "pdx1", - "sfo1", - "sin1", - "syd1", -]; - async function request(req: NextRequest) { const controller = new AbortController(); diff --git a/app/api/openai/[...path]/route.ts b/app/api/openai.ts similarity index 82% rename from app/api/openai/[...path]/route.ts rename to app/api/openai.ts index 77059c151fc..6d11d679215 100644 --- a/app/api/openai/[...path]/route.ts +++ b/app/api/openai.ts @@ -3,8 +3,8 @@ import { getServerSideConfig } from "@/app/config/server"; import { ModelProvider, OpenaiPath } from "@/app/constant"; import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; -import { auth } from "../../auth"; -import { requestOpenai } from "../../common"; +import { auth } from "./auth"; +import { requestOpenai } from "./common"; const ALLOWD_PATH = new Set(Object.values(OpenaiPath)); @@ -20,7 +20,7 @@ function getModels(remoteModelRes: OpenAIListModelResponse) { return remoteModelRes; } -async function handle( +export async function handle( req: NextRequest, { params }: { params: { path: string[] } }, ) { @@ -70,27 +70,3 @@ async function handle( return NextResponse.json(prettyObject(e)); } } - -export const GET = handle; -export const POST = handle; - -export const runtime = "edge"; -export const preferredRegion = [ - "arn1", - "bom1", - "cdg1", - "cle1", - "cpt1", - "dub1", - "fra1", - "gru1", - "hnd1", - "iad1", - "icn1", - "kix1", - "lhr1", - "pdx1", - "sfo1", - "sin1", - "syd1", -]; diff --git a/app/api/stability/[...path]/route.ts b/app/api/stability.ts similarity index 95% rename from app/api/stability/[...path]/route.ts rename to app/api/stability.ts index 4b2bcc30527..2646ace858e 100644 --- a/app/api/stability/[...path]/route.ts +++ b/app/api/stability.ts @@ -3,7 +3,7 @@ import { getServerSideConfig } from "@/app/config/server"; import { ModelProvider, STABILITY_BASE_URL } from "@/app/constant"; import { auth } from "@/app/api/auth"; -async function handle( +export async function handle( req: NextRequest, { params }: { params: { path: string[] } }, ) { @@ -97,8 +97,3 @@ async function handle( clearTimeout(timeoutId); } } - -export const GET = handle; -export const POST = handle; - -export const runtime = "edge"; From b023a00445682fcb336fe231ffe7c667632c0d15 Mon Sep 17 00:00:00 2001 From: frostime Date: Mon, 5 Aug 2024 16:37:22 +0800 Subject: [PATCH 10/12] =?UTF-8?q?=F0=9F=94=A8=20refactor(model):=20?= =?UTF-8?q?=E6=9B=B4=E6=94=B9=E5=8E=9F=E5=85=88=E7=9A=84=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E6=96=B9=E6=B3=95=EF=BC=8C=E5=9C=A8=20collect=20table=20?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E5=90=8E=E9=9D=A2=E5=A2=9E=E5=8A=A0=E9=A2=9D?= =?UTF-8?q?=E5=A4=96=E7=9A=84=20sort=20=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/utils/model.ts | 50 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/app/utils/model.ts b/app/utils/model.ts index 6b1485e32ad..b117b5eb64a 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -7,6 +7,29 @@ const customProvider = (providerName: string) => ({ providerType: "custom", }); +const sortModelTable = ( + models: ReturnType, + rule: "custom-first" | "default-first", +) => + models.sort((a, b) => { + if (a.provider === undefined && b.provider === undefined) { + return 0; + } + + let aIsCustom = a.provider?.providerType === "custom"; + let bIsCustom = b.provider?.providerType === "custom"; + + if (aIsCustom === bIsCustom) { + return 0; + } + + if (aIsCustom) { + return rule === "custom-first" ? -1 : 1; + } else { + return rule === "custom-first" ? 1 : -1; + } + }); + export function collectModelTable( models: readonly LLMModel[], customModels: string, @@ -22,6 +45,15 @@ export function collectModelTable( } > = {}; + // default models + models.forEach((m) => { + // using @ as fullName + modelTable[`${m.name}@${m?.provider?.id}`] = { + ...m, + displayName: m.name, // 'provider' is copied over if it exists + }; + }); + // server custom models customModels .split(",") @@ -80,15 +112,6 @@ export function collectModelTable( } }); - // default models - models.forEach((m) => { - // using @ as fullName - modelTable[`${m.name}@${m?.provider?.id}`] = { - ...m, - displayName: m.name, // 'provider' is copied over if it exists - }; - }); - return modelTable; } @@ -126,7 +149,9 @@ export function collectModels( customModels: string, ) { const modelTable = collectModelTable(models, customModels); - const allModels = Object.values(modelTable); + let allModels = Object.values(modelTable); + + allModels = sortModelTable(allModels, "custom-first"); return allModels; } @@ -141,7 +166,10 @@ export function collectModelsWithDefaultModel( customModels, defaultModel, ); - const allModels = Object.values(modelTable); + let allModels = Object.values(modelTable); + + allModels = sortModelTable(allModels, "custom-first"); + return allModels; } From 150fc84b9b55fe07da2fefa73b2cbee255d9de14 Mon Sep 17 00:00:00 2001 From: frostime Date: Mon, 5 Aug 2024 19:43:32 +0800 Subject: [PATCH 11/12] =?UTF-8?q?=E2=9C=A8=20feat(model):=20=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=20sorted=20=E5=AD=97=E6=AE=B5=EF=BC=8C=E5=B9=B6?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E8=AF=A5=E5=AD=97=E6=AE=B5=E5=AF=B9=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=88=97=E8=A1=A8=E8=BF=9B=E8=A1=8C=E6=8E=92=E5=BA=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 在 Model 和 Provider 类型中增加 sorted 字段(api.ts) 2. 默认模型在初始化的时候,自动设置默认 sorted 字段,从 1000 开始自增长(constant.ts) 3. 自定义模型更新的时候,自动分配 sorted 字段(model.ts) --- app/client/api.ts | 2 ++ app/constant.ts | 19 ++++++++++++++++++ app/utils/model.ts | 49 +++++++++++++++++++++++++++------------------- 3 files changed, 50 insertions(+), 20 deletions(-) diff --git a/app/client/api.ts b/app/client/api.ts index f10e4761887..b13e0f8a4c0 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -64,12 +64,14 @@ export interface LLMModel { displayName?: string; available: boolean; provider: LLMModelProvider; + sorted: number; } export interface LLMModelProvider { id: string; providerName: string; providerType: string; + sorted: number; } export abstract class LLMApi { diff --git a/app/constant.ts b/app/constant.ts index 5251b5b4fc9..8ca17c4b359 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -320,86 +320,105 @@ const tencentModels = [ const moonshotModes = ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]; +let seq = 1000; // 内置的模型序号生成器从1000开始 export const DEFAULT_MODELS = [ ...openaiModels.map((name) => ({ name, available: true, + sorted: seq++, // Global sequence sort(index) provider: { id: "openai", providerName: "OpenAI", providerType: "openai", + sorted: 1, // 这里是固定的,确保顺序与之前内置的版本一致 }, })), ...openaiModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "azure", providerName: "Azure", providerType: "azure", + sorted: 2, }, })), ...googleModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "google", providerName: "Google", providerType: "google", + sorted: 3, }, })), ...anthropicModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "anthropic", providerName: "Anthropic", providerType: "anthropic", + sorted: 4, }, })), ...baiduModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "baidu", providerName: "Baidu", providerType: "baidu", + sorted: 5, }, })), ...bytedanceModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "bytedance", providerName: "ByteDance", providerType: "bytedance", + sorted: 6, }, })), ...alibabaModes.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "alibaba", providerName: "Alibaba", providerType: "alibaba", + sorted: 7, }, })), ...tencentModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "tencent", providerName: "Tencent", providerType: "tencent", + sorted: 8, }, })), ...moonshotModes.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "moonshot", providerName: "Moonshot", providerType: "moonshot", + sorted: 9, }, })), ] as const; diff --git a/app/utils/model.ts b/app/utils/model.ts index b117b5eb64a..0b62b53be09 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -1,32 +1,39 @@ import { DEFAULT_MODELS } from "../constant"; import { LLMModel } from "../client/api"; +const CustomSeq = { + val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts + cache: new Map(), + next: (id: string) => { + if (CustomSeq.cache.has(id)) { + return CustomSeq.cache.get(id) as number; + } else { + let seq = CustomSeq.val++; + CustomSeq.cache.set(id, seq); + return seq; + } + }, +}; + const customProvider = (providerName: string) => ({ id: providerName.toLowerCase(), providerName: providerName, providerType: "custom", + sorted: CustomSeq.next(providerName), }); -const sortModelTable = ( - models: ReturnType, - rule: "custom-first" | "default-first", -) => +/** + * Sorts an array of models based on specified rules. + * + * First, sorted by provider; if the same, sorted by model + */ +const sortModelTable = (models: ReturnType) => models.sort((a, b) => { - if (a.provider === undefined && b.provider === undefined) { - return 0; - } - - let aIsCustom = a.provider?.providerType === "custom"; - let bIsCustom = b.provider?.providerType === "custom"; - - if (aIsCustom === bIsCustom) { - return 0; - } - - if (aIsCustom) { - return rule === "custom-first" ? -1 : 1; + if (a.provider && b.provider) { + let cmp = a.provider.sorted - b.provider.sorted; + return cmp === 0 ? a.sorted - b.sorted : cmp; } else { - return rule === "custom-first" ? 1 : -1; + return a.sorted - b.sorted; } }); @@ -40,6 +47,7 @@ export function collectModelTable( available: boolean; name: string; displayName: string; + sorted: number; provider?: LLMModel["provider"]; // Marked as optional isDefault?: boolean; } @@ -107,6 +115,7 @@ export function collectModelTable( displayName: displayName || customModelName, available, provider, // Use optional chaining + sorted: CustomSeq.next(`${customModelName}@${provider?.id}`), }; } } @@ -151,7 +160,7 @@ export function collectModels( const modelTable = collectModelTable(models, customModels); let allModels = Object.values(modelTable); - allModels = sortModelTable(allModels, "custom-first"); + allModels = sortModelTable(allModels); return allModels; } @@ -168,7 +177,7 @@ export function collectModelsWithDefaultModel( ); let allModels = Object.values(modelTable); - allModels = sortModelTable(allModels, "custom-first"); + allModels = sortModelTable(allModels); return allModels; } From 3486954e073665b4bcaa4d41096b1341e4c497ff Mon Sep 17 00:00:00 2001 From: frostime Date: Mon, 5 Aug 2024 20:26:48 +0800 Subject: [PATCH 12/12] =?UTF-8?q?=F0=9F=90=9B=20fix(openai):=20=E4=B8=8A?= =?UTF-8?q?=E6=AC=A1=20commit=20=E5=90=8E=20openai.ts=20=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E4=B8=AD=E5=87=BA=E7=8E=B0=E7=B1=BB=E5=9E=8B=E4=B8=8D=E5=8C=B9?= =?UTF-8?q?=E9=85=8D=E7=9A=84=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/client/platforms/openai.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 680125fe6c4..d95aebe87b2 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -411,13 +411,17 @@ export class ChatGPTApi implements LLMApi { return []; } + //由于目前 OpenAI 的 disableListModels 默认为 true,所以当前实际不会运行到这场 + let seq = 1000; //同 Constant.ts 中的排序保持一致 return chatModels.map((m) => ({ name: m.id, available: true, + sorted: seq++, provider: { id: "openai", providerName: "OpenAI", providerType: "openai", + sorted: 1, }, })); }