Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Consolidate streaming logic #50

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ All public types, including error messages, are documented in [this file](/apps/
Example of streaming GPT-4 results to the console:

```ts
await ai.getCompletion(
await window.ai.getCompletion(
{
messages: [{ role: "user", content: "Who are you?" }]
},
Expand Down
2 changes: 1 addition & 1 deletion apps/extension/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "window-extension",
"displayName": "Window AI",
"version": "0.0.6",
"version": "0.0.8",
"private": true,
"description": "Use your own AI models on the web",
"author": "Alex Atallah <[email protected]>",
Expand Down
67 changes: 35 additions & 32 deletions apps/extension/src/background/ports/completion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import type { PlasmoMessaging } from "@plasmohq/messaging"

import type { PortRequest, PortResponse } from "~core/constants"
import { PortName } from "~core/constants"
import { configManager } from "~core/managers/config"
import { transactionManager } from "~core/managers/transaction"
import * as modelRouter from "~core/model-router"
import { err, isErr, isOk, ok } from "~core/utils/result-monad"
Expand Down Expand Up @@ -33,50 +32,54 @@ const handler: PlasmoMessaging.PortHandler<
// Save the incomplete txn
await transactionManager.save(txn)

const config = await configManager.getWithDefault(txn.model)

if (request.shouldStream && modelRouter.isStreamable(config)) {
const replies: string[] = []
const errors: string[] = []

const results = await modelRouter.stream(txn)

for await (const result of results) {
if (isOk(result)) {
const outputs = [getOutput(txn.input, result.data)]
res.send({ response: ok(outputs), id })
replies.push(result.data)
} else {
res.send({ response: result, id })
errors.push(result.error)
}
}

txn.outputs = replies.length
? [getOutput(txn.input, replies.join(""))]
: undefined
txn.error = errors.join("") || undefined
} else {
const result = await modelRouter.complete(txn)
const hasMultipleOutputs = txn.numOutputs && txn.numOutputs > 1
const replies: string[] = []
const errors: string[] = []
const results = await modelRouter.generate(txn)

for await (const result of results) {
if (isOk(result)) {
const outputs = result.data.map((d) => getOutput(txn.input, d))
const outputs = [getOutput(txn.input, result.data, true)]
res.send({ response: ok(outputs), id })
txn.outputs = outputs
replies.push(result.data)
} else {
res.send({ response: result, id })
txn.error = result.error
errors.push(result.error)
// TODO handle auth errors
// if (isAuthError(result.error)) {
// await requestAuth()
// }
}
}

// Collect the replies and errors onto the txn
txn.outputs = !replies.length
? undefined
: hasMultipleOutputs
? replies.map((r) => getOutput(txn.input, r))
: [getOutput(txn.input, replies.join(""))]
txn.error = errors.join("\n") || undefined

// Send the final output to the client, as non-partial
if (txn.outputs) {
res.send({ response: ok(txn.outputs), id })
}

// Update the completion with the reply
await transactionManager.save(txn)
}

function getOutput(input: Input, result: string): Output {
function getOutput(input: Input, result: string, isPartial?: boolean): Output {
return isMessagesInput(input)
? { message: { role: "assistant", content: result } }
: { text: result }
? { message: { role: "assistant", content: result }, isPartial }
: { text: result, isPartial }
}

// function isAuthError(error: string) {
// return (
// error.startsWith(ErrorCode.ModelRejectedRequest) &&
// error.split(": ")[1].startsWith("401")
// )
// }

export default handler
26 changes: 20 additions & 6 deletions apps/extension/src/background/ports/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import type { PlasmoMessaging } from "@plasmohq/messaging"

import type { PortRequest, PortResponse } from "~core/constants"
import { PortName } from "~core/constants"
import { configManager } from "~core/managers/config"
import { AuthType, configManager } from "~core/managers/config"
import { err, ok } from "~core/utils/result-monad"
import { log } from "~core/utils/utils"

Expand All @@ -18,14 +18,28 @@ const handler: PlasmoMessaging.PortHandler<
return res.send(err(ErrorCode.InvalidRequest))
}

const { id } = req.body

const currentModel = await configManager.getDefault()
const { id, request } = req.body
if (request) {
// TODO handle other model providers here by checking request.baseUrl
// TODO request the user's permission to add the model provider
const { metadata, shouldSetDefault } = request
const config =
(await configManager.forAuthAndModel(AuthType.External)) ||
configManager.init(AuthType.External)
const newConfig = {
...config,
authMetadata: metadata
}
await configManager.save(newConfig)
if (shouldSetDefault) {
await configManager.setDefault(newConfig)
}
}

// We're starting a request, so send the request to the extension UI
const config = await configManager.getDefault()
res.send({
id,
response: ok({ model: currentModel.id })
response: ok({ model: configManager.getCurrentModel(config) })
})
}

Expand Down
5 changes: 1 addition & 4 deletions apps/extension/src/background/ports/permission.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,7 @@ export async function requestPermission(
requestId: string
) {
const originData = request.transaction.origin
const origin = await originManager.getOrInit(
request.transaction.origin.id,
originData
)
const origin = await originManager.getOrInit(originData.id, originData)
if (origin.permissions === "allow") {
log("Permission granted by user settings: ", origin)
return ok(true)
Expand Down
30 changes: 24 additions & 6 deletions apps/extension/src/contents/inpage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,18 @@ export const windowAI: WindowAI<ModelID> = {
},
async getCompletion(input, options = {}) {
const { onStreamResult } = _validateOptions(options)
const shouldStream = !!onStreamResult
const shouldReturnMultiple = options.numOutputs && options.numOutputs > 1
const requestId = _relayRequest(PortName.Completion, {
transaction: transactionManager.init(input, _getOriginData(), options),
shouldStream
transaction: transactionManager.init(input, _getOriginData(), options)
})
return new Promise((resolve, reject) => {
_addResponseListener<CompletionResponse>(requestId, (res) => {
if (isOk(res)) {
resolve(shouldReturnMultiple ? res.data : res.data[0])
onStreamResult && onStreamResult(res.data[0], null)
if (!res.data[0].isPartial) {
resolve(shouldReturnMultiple ? res.data : res.data[0])
} else {
onStreamResult && onStreamResult(res.data[0], null)
}
} else {
reject(res.error)
onStreamResult && onStreamResult(null, res.error)
Expand All @@ -59,7 +60,7 @@ export const windowAI: WindowAI<ModelID> = {
},

async getCurrentModel() {
const requestId = _relayRequest(PortName.Model, {})
const requestId = _relayRequest(PortName.Model, undefined)
return new Promise((resolve, reject) => {
_addResponseListener<ModelResponse>(requestId, (res) => {
if (isOk(res)) {
Expand All @@ -86,6 +87,23 @@ export const windowAI: WindowAI<ModelID> = {
}
})
return requestId
},

BETA_updateModelProvider({ baseUrl, metadata, shouldSetDefault }) {
const requestId = _relayRequest(PortName.Model, {
baseUrl,
metadata,
shouldSetDefault
})
return new Promise((resolve, reject) => {
_addResponseListener<ModelResponse>(requestId, (res) => {
if (isOk(res)) {
resolve(res.data.model)
} else {
reject(res.error)
}
})
})
}
}

Expand Down
11 changes: 8 additions & 3 deletions apps/extension/src/core/components/pure/Dropdown.tsx
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import { ChevronUpDownIcon } from "@heroicons/react/24/solid"
import { useState } from "react"

export function Dropdown<T extends string>({
export function Dropdown<T>({
styled = false,
children,
choices,
getLabel,
onSelect
}: {
styled?: boolean
children: React.ReactNode
choices: T[]
getLabel?: (choice: T) => string
onSelect: (choice: T) => void
}) {
const [isOpen, setIsOpen] = useState(false)
const getLabelOrDefault =
getLabel ||
((choice) => (typeof choice === "string" ? choice : JSON.stringify(choice)))

return (
<div>
Expand Down Expand Up @@ -41,14 +46,14 @@ export function Dropdown<T extends string>({
<div className="py-1" role="none">
{choices.map((choice) => (
<button
key={choice}
key={getLabelOrDefault(choice)}
className="block text-left w-full px-4 py-2 pr-8 text-slate-700 hover:bg-indigo-100 hover:text-indigo-900 focus:outline-none focus:bg-indigo-100 focus:text-indigo-900"
role="menuitem"
onClick={() => {
setIsOpen(false)
onSelect(choice)
}}>
{choice}
{getLabelOrDefault(choice)}
</button>
))}
</div>
Expand Down
14 changes: 9 additions & 5 deletions apps/extension/src/core/constants.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import type { ErrorCode, Output, RequestID } from "window.ai"
import type {
ErrorCode,
ModelProviderOptions,
Output,
RequestID
} from "window.ai"

import type { EventRequest, EventResponse } from "~background/ports/events"
import type { ModelID } from "~public-interface"
Expand All @@ -19,7 +24,7 @@ export interface PortRequest {
id?: RequestID
request: { requesterId: RequestID; permitted?: boolean }
}
[PortName.Model]: { id: RequestID; request: ModelRequest }
[PortName.Model]: { id: RequestID; request?: ModelRequest }
[PortName.Events]: { id?: RequestID; request: EventRequest<unknown> }
}

Expand Down Expand Up @@ -51,12 +56,11 @@ export enum ContentMessageType {

export type CompletionRequest = {
transaction: Transaction
shouldStream?: boolean
}
export type CompletionResponse = Result<Output[], ErrorCode | string>

export type ModelRequest = {}
export type ModelResponse = Result<{ model: ModelID }, ErrorCode>
export type ModelRequest = ModelProviderOptions
export type ModelResponse = Result<{ model: ModelID | undefined }, ErrorCode>

export type { EventRequest, EventResponse }

Expand Down
16 changes: 13 additions & 3 deletions apps/extension/src/core/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { init as initCohere } from "./cohere"
import { init as initLocal } from "./local"
import { Model } from "./model"
import { init as initOpenAI } from "./openai"
import { init as initOpenRouter } from "./openrouter"
import { init as initTogether } from "./together"

// TODO configure basic in-memory lru cache
Expand All @@ -25,6 +26,16 @@ export const local = initLocal(
}
)

export const openrouter = initOpenRouter(
{
debug: shouldDebugModels
},
{
max_tokens: DEFAULT_MAX_TOKENS,
presence_penalty: 0 // Using negative numbers causes 500s from davinci
}
)

export const openai = initOpenAI(
{
debug: shouldDebugModels
Expand Down Expand Up @@ -57,10 +68,9 @@ export const cohere = initCohere(
}
)

export const modelCallers: { [K in ModelID]: Model } = {
export const modelAPICallers: { [K in ModelID]: Model } = {
[ModelID.GPT3]: openai,
[ModelID.GPT4]: openai,
[ModelID.Cohere]: cohere,
[ModelID.Together]: together,
[ModelID.Local]: local
[ModelID.Together]: together
}
Loading