Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chat completion method #645

Merged
merged 31 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
332648d
implement chat completion
radames May 1, 2024
d63db89
missing import type
radames May 1, 2024
26dd3b1
fix chatCompletion input type
radames May 1, 2024
a90c0e7
🩹
coyotte508 May 2, 2024
a3928f4
✅ Update tests file
coyotte508 May 2, 2024
1d509c3
✅ Update tests
coyotte508 May 2, 2024
4e9ba18
🐛 One more test
coyotte508 May 2, 2024
c535893
✅ More tests
coyotte508 May 2, 2024
63d5151
✅ one more test
coyotte508 May 2, 2024
9d2f737
✅ Fix last test
coyotte508 May 2, 2024
0f881e4
Merge branch 'main' into chatCompletion
coyotte508 May 3, 2024
6a9ad56
remove skips
radames May 4, 2024
23637bf
recorded tapes.json
radames May 4, 2024
ca54d67
add chat chatCompletion hint to change url
radames May 4, 2024
7ff57f2
add chatCompletion test with modelid
radames May 4, 2024
91ec869
tests
radames May 4, 2024
8fd2621
test with error message
radames May 4, 2024
c9b95a5
test
radames May 4, 2024
5ab21ca
better error handling
radames May 5, 2024
074aa76
Merge branch 'main' into chatCompletion
radames May 8, 2024
32ad989
add chat completion example to inference README.md
radames May 8, 2024
3d8bfc6
fix
radames May 8, 2024
5e3a9d6
📝 Update README.md
coyotte508 May 8, 2024
0ca9ad0
return_full_text not compatible here
radames May 8, 2024
72cfa24
remove return_full_text
radames May 8, 2024
779b828
tests
radames May 8, 2024
b66fcf3
Update packages/inference/README.md
radames May 9, 2024
87ee635
fix chat completion example
radames May 9, 2024
b901f5b
♻️ Do not sent `options`
coyotte508 May 11, 2024
5f2b488
record test
radames May 11, 2024
6502858
Merge branch 'main' into chatCompletion
radames May 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions packages/inference/src/HfInference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ type TaskWithNoAccessToken = {
) => ReturnType<Task[key]>;
};

type TaskWithNoAccessTokenNoModel = {
type TaskWithNoAccessTokenNoEndpointUrl = {
[key in keyof Task]: (
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "model">,
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "endpointUrl">,
options?: Parameters<Task[key]>[1]
) => ReturnType<Task[key]>;
};
Expand Down Expand Up @@ -57,12 +57,12 @@ export class HfInferenceEndpoint {
enumerable: false,
value: (params: RequestArgs, options: Options) =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
fn({ ...params, accessToken, model: endpointUrl } as any, { ...defaultOptions, ...options }),
fn({ ...params, accessToken, endpointUrl } as any, { ...defaultOptions, ...options }),
});
}
}
}

export interface HfInference extends TaskWithNoAccessToken {}

export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoModel {}
export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoEndpointUrl {}
8 changes: 8 additions & 0 deletions packages/inference/src/lib/isEmpty.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export function isObjectEmpty(object: object): boolean {
for (const prop in object) {
if (Object.prototype.hasOwnProperty.call(object, prop)) {
return false;
}
}
return true;
}
17 changes: 12 additions & 5 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import type { InferenceTask, Options, RequestArgs } from "../types";
import { isObjectEmpty } from "../lib/isEmpty";
import { omit } from "../utils/omit";
import { HF_HUB_URL } from "./getDefaultTask";
import { isUrl } from "./isUrl";

Expand All @@ -24,8 +26,7 @@ export async function makeRequestOptions(
taskHint?: InferenceTask;
}
): Promise<{ url: string; info: RequestInit }> {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { accessToken, model: _model, ...otherArgs } = args;
const { accessToken, endpointUrl, ...otherArgs } = args;
let { model } = args;
const {
forceTask: task,
Expand Down Expand Up @@ -78,10 +79,16 @@ export async function makeRequestOptions(
}

const url = (() => {
if (endpointUrl && isUrl(model)) {
throw new TypeError("Both model and endpointUrl cannot be URLs");
}
if (isUrl(model)) {
console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
return model;
}

if (endpointUrl) {
return endpointUrl;
}
if (task) {
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
}
Expand All @@ -105,8 +112,8 @@ export async function makeRequestOptions(
body: binary
? args.data
: JSON.stringify({
...otherArgs,
options: options && otherOptions,
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs),
...(otherOptions && !isObjectEmpty(otherOptions) && { options: otherOptions }),
}),
...(credentials && { credentials }),
signal: options?.signal,
Expand Down
3 changes: 3 additions & 0 deletions packages/inference/src/tasks/custom/streamingRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ export async function* streamingRequest<T>(
onChunk(value);
for (const event of events) {
if (event.data.length > 0) {
if (event.data === "[DONE]") {
return;
}
const data = JSON.parse(event.data);
if (typeof data === "object" && data !== null && "error" in data) {
throw new Error(data.error);
Expand Down
2 changes: 2 additions & 0 deletions packages/inference/src/tasks/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ export * from "./nlp/textGenerationStream";
export * from "./nlp/tokenClassification";
export * from "./nlp/translation";
export * from "./nlp/zeroShotClassification";
export * from "./nlp/chatCompletion";
export * from "./nlp/chatCompletionStream";

// Multimodal tasks
export * from "./multimodal/documentQuestionAnswering";
Expand Down
31 changes: 31 additions & 0 deletions packages/inference/src/tasks/nlp/chatCompletion.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks";
radames marked this conversation as resolved.
Show resolved Hide resolved

/**
* Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream
*/

export async function chatCompletion(
args: BaseArgs & ChatCompletionInput,
options?: Options
): Promise<ChatCompletionOutput> {
const res = await request<ChatCompletionOutput>(args, {
...options,
taskHint: "text-generation",
});
const isValidOutput =
typeof res === "object" &&
Array.isArray(res?.choices) &&
typeof res?.created === "number" &&
typeof res?.id === "string" &&
typeof res?.model === "string" &&
typeof res?.system_fingerprint === "string" &&
typeof res?.usage === "object";

if (!isValidOutput) {
throw new InferenceOutputError("Expected ChatCompletionOutput");
}
return res;
}
16 changes: 16 additions & 0 deletions packages/inference/src/tasks/nlp/chatCompletionStream.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import type { BaseArgs, Options } from "../../types";
import { streamingRequest } from "../custom/streamingRequest";
import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks";

/**
* Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
*/
export async function* chatCompletionStream(
args: BaseArgs & ChatCompletionInput,
options?: Options
): AsyncGenerator<ChatCompletionStreamOutput> {
yield* streamingRequest<ChatCompletionStreamOutput>(args, {
...options,
taskHint: "text-generation",
});
}
1 change: 1 addition & 0 deletions packages/inference/src/tasks/nlp/textGenerationStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ export interface TextGenerationStreamDetails {
}

export interface TextGenerationStreamOutput {
index?: number;
/** Generated token, one at a time */
token: TextGenerationStreamToken;
/**
Expand Down
17 changes: 14 additions & 3 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { PipelineType } from "@huggingface/tasks";
import type { ChatCompletionInput } from "@huggingface/tasks";

export interface Options {
/**
Expand Down Expand Up @@ -32,7 +33,7 @@ export interface Options {
signal?: AbortSignal;

/**
* Credentials to use for the request. If this is a string, it will be passed straight on. If it's a boolean, true will be "include" and false will not send credentials at all (which defaults to "same-origin" inside browsers).
* (Default: "same-origin"). String | Boolean. Credentials to use for the request. If this is a string, it will be passed straight on. If it's a boolean, true will be "include" and false will not send credentials at all.
*/
includeCredentials?: string | boolean;
}
Expand All @@ -47,15 +48,25 @@ export interface BaseArgs {
*/
accessToken?: string;
/**
* The model to use. Can be a full URL for a dedicated inference endpoint.
* The model to use.
*
* If not specified, will call huggingface.co/api/tasks to get the default model for the task.
*
* /!\ Legacy behavior allows this to be an URL, but this is deprecated and will be removed in the future.
* Use the `endpointUrl` parameter instead.
*/
model?: string;

/**
* The URL of the endpoint to use. If not specified, will call huggingface.co/api/tasks to get the default endpoint for the task.
*
* If specified, will use this URL instead of the default one.
*/
endpointUrl?: string;
}

export type RequestArgs = BaseArgs &
({ data: Blob | ArrayBuffer } | { inputs: unknown }) & {
({ data: Blob | ArrayBuffer } | { inputs: unknown } | ChatCompletionInput) & {
parameters?: Record<string, unknown>;
accessToken?: string;
};
Loading
Loading