Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

💥 Simpler credentials passing around #918

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions packages/hub/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,21 @@ Learn how to find free models using the hub package in this [interactive tutoria

```ts
import { createRepo, uploadFiles, uploadFilesWithProgress, deleteFile, deleteRepo, listFiles, whoAmI } from "@huggingface/hub";
import type { RepoDesignation, Credentials } from "@huggingface/hub";
import type { RepoDesignation } from "@huggingface/hub";

const repo: RepoDesignation = { type: "model", name: "myname/some-model" };
const credentials: Credentials = { accessToken: "hf_..." };

const {name: username} = await whoAmI({credentials});
const {name: username} = await whoAmI({accessToken: "hf_..."});

for await (const model of listModels({search: {owner: username}, credentials})) {
for await (const model of listModels({search: {owner: username}, accessToken: "hf_..."})) {
console.log("My model:", model);
}

await createRepo({ repo, credentials, license: "mit" });
await createRepo({ repo, accessToken: "hf_...", license: "mit" });

await uploadFiles({
repo,
credentials,
accessToken: "hf_...",
files: [
// path + blob content
{
Expand All @@ -70,23 +69,23 @@ await uploadFiles({

for await (const progressEvent of await uploadFilesWithProgress({
repo,
credentials,
accessToken: "hf_...",
files: [
...
],
})) {
console.log(progressEvent);
}

await deleteFile({repo, credentials, path: "myfile.bin"});
await deleteFile({repo, accessToken: "hf_...", path: "myfile.bin"});

await (await downloadFile({ repo, path: "README.md" })).text();

for await (const fileInfo of listFiles({repo})) {
console.log(fileInfo);
}

await deleteRepo({ repo, credentials });
await deleteRepo({ repo, accessToken: "hf_..." });
```

## OAuth Login
Expand Down
16 changes: 4 additions & 12 deletions packages/hub/src/lib/commit.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ describe("commit", () => {
};

await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
repo,
license: "mit",
Expand All @@ -50,9 +48,7 @@ describe("commit", () => {
await commit({
repo,
title: "Some commit",
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
operations: [
{
Expand Down Expand Up @@ -135,9 +131,7 @@ size ${lfsContent.length}
};

await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
repo,
hubUrl: TEST_HUB_URL,
});
Expand All @@ -163,9 +157,7 @@ size ${lfsContent.length}
);
await commit({
repo,
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
title: "upload model",
operations,
Expand Down
16 changes: 8 additions & 8 deletions packages/hub/src/lib/commit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import type {
ApiPreuploadRequest,
ApiPreuploadResponse,
} from "../types/api/api-commit";
import type { Credentials, RepoDesignation } from "../types/public";
import type { CredentialsParams, RepoDesignation } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { chunk } from "../utils/chunk";
import { promisesQueue } from "../utils/promisesQueue";
Expand Down Expand Up @@ -54,12 +54,11 @@ type CommitBlob = Omit<CommitFile, "content"> & { content: Blob };
export type CommitOperation = CommitDeletedEntry | CommitFile /* | CommitRenameFile */;
type CommitBlobOperation = Exclude<CommitOperation, CommitFile> | CommitBlob;

export interface CommitParams {
export type CommitParams = {
title: string;
description?: string;
repo: RepoDesignation;
operations: CommitOperation[];
credentials?: Credentials;
/** @default "main" */
branch?: string;
/**
Expand All @@ -82,7 +81,8 @@ export interface CommitParams {
*/
fetch?: typeof fetch;
abortSignal?: AbortSignal;
}
// Credentials are optional due to custom fetch functions or cookie auth
} & Partial<CredentialsParams>;

export interface CommitOutput {
pullRequestUrl?: string;
Expand Down Expand Up @@ -121,7 +121,7 @@ export type CommitProgressEvent =
* Can be exposed later to offer fine-tuned progress info
*/
export async function* commitIter(params: CommitParams): AsyncGenerator<CommitProgressEvent, CommitOutput> {
checkCredentials(params.credentials);
const accessToken = checkCredentials(params);
const repoId = toRepoId(params.repo);
yield { event: "phase", phase: "preuploading" };

Expand Down Expand Up @@ -189,7 +189,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
{
method: "POST",
headers: {
...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }),
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
"Content-Type": "application/json",
},
body: JSON.stringify(payload),
Expand Down Expand Up @@ -263,7 +263,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
{
method: "POST",
headers: {
...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }),
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
Accept: "application/vnd.git-lfs+json",
"Content-Type": "application/vnd.git-lfs+json",
},
Expand Down Expand Up @@ -468,7 +468,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
{
method: "POST",
headers: {
...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }),
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
"Content-Type": "application/x-ndjson",
},
body: [
Expand Down
27 changes: 14 additions & 13 deletions packages/hub/src/lib/count-commits.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { Credentials, RepoDesignation } from "../types/public";
import type { CredentialsParams, RepoDesignation } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { toRepoId } from "../utils/toRepoId";

export async function countCommits(params: {
credentials?: Credentials;
repo: RepoDesignation;
/**
* Revision to list commits from. Defaults to the default branch.
*/
revision?: string;
hubUrl?: string;
fetch?: typeof fetch;
}): Promise<number> {
checkCredentials(params.credentials);
export async function countCommits(
params: {
repo: RepoDesignation;
/**
* Revision to list commits from. Defaults to the default branch.
*/
revision?: string;
hubUrl?: string;
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<number> {
const accessToken = checkCredentials(params);
const repoId = toRepoId(params.repo);

// Could upgrade to 1000 commits per page
Expand All @@ -23,7 +24,7 @@ export async function countCommits(params: {
}?limit=1`;

const res: Response = await (params.fetch ?? fetch)(url, {
headers: params.credentials ? { Authorization: `Bearer ${params.credentials.accessToken}` } : {},
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
});

if (!res.ok) {
Expand Down
12 changes: 3 additions & 9 deletions packages/hub/src/lib/create-repo.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ describe("createRepo", () => {
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;

const result = await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
repo: {
name: repoName,
type: "model",
Expand Down Expand Up @@ -62,9 +60,7 @@ describe("createRepo", () => {
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;

const result = await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
repo: repoName,
files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
Expand All @@ -88,9 +84,7 @@ describe("createRepo", () => {
const repoName = `datasets/${TEST_USER}/TEST-${insecureRandomString()}`;

const result = await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
repo: repoName,
files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
Expand Down
41 changes: 21 additions & 20 deletions packages/hub/src/lib/create-repo.ts
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { ApiCreateRepoPayload } from "../types/api/api-create-repo";
import type { Credentials, RepoDesignation, SpaceSdk } from "../types/public";
import type { CredentialsParams, RepoDesignation, SpaceSdk } from "../types/public";
import { base64FromBytes } from "../utils/base64FromBytes";
import { checkCredentials } from "../utils/checkCredentials";
import { toRepoId } from "../utils/toRepoId";

export async function createRepo(params: {
repo: RepoDesignation;
credentials: Credentials;
private?: boolean;
license?: string;
/**
* Only a few lightweight files are supported at repo creation
*/
files?: Array<{ content: ArrayBuffer | Blob; path: string }>;
/** @required for when {@link repo.type} === "space" */
sdk?: SpaceSdk;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): Promise<{ repoUrl: string }> {
checkCredentials(params.credentials);
export async function createRepo(
params: {
repo: RepoDesignation;
private?: boolean;
license?: string;
/**
* Only a few lightweight files are supported at repo creation
*/
files?: Array<{ content: ArrayBuffer | Blob; path: string }>;
/** @required for when {@link repo.type} === "space" */
sdk?: SpaceSdk;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & CredentialsParams
): Promise<{ repoUrl: string }> {
const accessToken = checkCredentials(params);
const repoId = toRepoId(params.repo);
const [namespace, repoName] = repoId.name.split("/");

Expand Down Expand Up @@ -61,7 +62,7 @@ export async function createRepo(params: {
: undefined,
} satisfies ApiCreateRepoPayload),
headers: {
Authorization: `Bearer ${params.credentials.accessToken}`,
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
});
Expand Down
9 changes: 3 additions & 6 deletions packages/hub/src/lib/delete-file.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@ describe("deleteFile", () => {
it("should delete a file", async () => {
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
const repo = { type: "model", name: repoName } satisfies RepoId;
const credentials = {
accessToken: TEST_ACCESS_TOKEN,
};

try {
const result = await createRepo({
credentials,
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
repo,
files: [
Expand All @@ -39,7 +36,7 @@ describe("deleteFile", () => {

assert.strictEqual(await content?.text(), "file1");

await deleteFile({ path: "file1", repo, credentials, hubUrl: TEST_HUB_URL });
await deleteFile({ path: "file1", repo, accessToken: TEST_ACCESS_TOKEN, hubUrl: TEST_HUB_URL });

content = await downloadFile({
repo,
Expand All @@ -59,7 +56,7 @@ describe("deleteFile", () => {
} finally {
await deleteRepo({
repo,
credentials,
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
});
}
Expand Down
29 changes: 15 additions & 14 deletions packages/hub/src/lib/delete-file.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import type { Credentials } from "../types/public";
import type { CredentialsParams } from "../types/public";
import type { CommitOutput, CommitParams } from "./commit";
import { commit } from "./commit";

export function deleteFile(params: {
credentials: Credentials;
repo: CommitParams["repo"];
path: string;
commitTitle?: CommitParams["title"];
commitDescription?: CommitParams["description"];
hubUrl?: CommitParams["hubUrl"];
fetch?: CommitParams["fetch"];
branch?: CommitParams["branch"];
isPullRequest?: CommitParams["isPullRequest"];
parentCommit?: CommitParams["parentCommit"];
}): Promise<CommitOutput> {
export function deleteFile(
params: {
repo: CommitParams["repo"];
path: string;
commitTitle?: CommitParams["title"];
commitDescription?: CommitParams["description"];
hubUrl?: CommitParams["hubUrl"];
fetch?: CommitParams["fetch"];
branch?: CommitParams["branch"];
isPullRequest?: CommitParams["isPullRequest"];
parentCommit?: CommitParams["parentCommit"];
} & CredentialsParams
): Promise<CommitOutput> {
return commit({
credentials: params.credentials,
...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }),
repo: params.repo,
operations: [
{
Expand Down
Loading
Loading