Skip to content

Commit

Permalink
[Safetensors parser] Support optional file path (#563)
Browse files Browse the repository at this point in the history
### Before

Safetensors parser was in assumption that safetensors files are strictly
named `model.safetensors`, which is not the case always.

### This PR

Support optional file names so that other names like
`llama3.safetensors` would still work

---------

Co-authored-by: Eliott C. <[email protected]>
Co-authored-by: Julien Chaumond <[email protected]>
  • Loading branch information
3 people authored Mar 20, 2024
1 parent cf98cd7 commit 3bd9297
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
26 changes: 24 additions & 2 deletions packages/hub/src/lib/parse-safetensors-metadata.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { parseSafetensorsMetadata } from "./parse-safetensors-metadata";
import { sum } from "../utils/sum";

describe("parseSafetensorsMetadata", () => {
it("fetch info for single-file", async () => {
it("fetch info for single-file (with the default conventional filename)", async () => {
const parse = await parseSafetensorsMetadata({
repo: "bert-base-uncased",
computeParametersCount: true,
Expand All @@ -25,7 +25,7 @@ describe("parseSafetensorsMetadata", () => {
// total params = 110m
});

it("fetch info for sharded", async () => {
it("fetch info for sharded (with the default conventional filename)", async () => {
const parse = await parseSafetensorsMetadata({
repo: "bigscience/bloom",
computeParametersCount: true,
Expand Down Expand Up @@ -61,4 +61,26 @@ describe("parseSafetensorsMetadata", () => {
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 124_697_947);
// total params = 124m
});

it("fetch info for single-file with file path", async () => {
const parse = await parseSafetensorsMetadata({
repo: "CompVis/stable-diffusion-v1-4",
computeParametersCount: true,
path: "unet/diffusion_pytorch_model.safetensors",
});

assert(!parse.sharded);
assert.deepStrictEqual(parse.header.__metadata__, { format: "pt" });

// Example of one tensor (the header contains many tensors)

assert.deepStrictEqual(parse.header["up_blocks.3.resnets.0.norm2.bias"], {
dtype: "F32",
shape: [320],
data_offsets: [3_409_382_416, 3_409_383_696],
});

assert.deepStrictEqual(parse.parameterCount, { F32: 859_520_964 });
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 859_520_964);
});
});
25 changes: 19 additions & 6 deletions packages/hub/src/lib/parse-safetensors-metadata.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ import { downloadFile } from "./download-file";
import { fileExists } from "./file-exists";
import { promisesQueue } from "../utils/promisesQueue";

const SINGLE_FILE = "model.safetensors";
const INDEX_FILE = "model.safetensors.index.json";
export const SAFETENSORS_FILE = "model.safetensors";
export const SAFETENSORS_INDEX_FILE = "model.safetensors.index.json";
/// We advise model/library authors to use the filenames above for convention inside model repos,
/// but in some situations safetensors weights have different filenames.
export const RE_SAFETENSORS_FILE = /\.safetensors$/;
export const RE_SAFETENSORS_INDEX_FILE = /\.safetensors\.index\.json$/;
const PARALLEL_DOWNLOADS = 5;
const MAX_HEADER_LENGTH = 25_000_000;

Expand Down Expand Up @@ -154,6 +158,10 @@ async function parseShardedIndex(
export async function parseSafetensorsMetadata(params: {
/** Only models are supported */
repo: RepoDesignation;
/**
* Relative file path to safetensors file inside `repo`. Defaults to `SAFETENSORS_FILE` or `SAFETENSORS_INDEX_FILE` (whichever one exists).
*/
path?: string;
/**
* Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType
*
Expand All @@ -176,6 +184,7 @@ export async function parseSafetensorsMetadata(params: {
*
* @default false
*/
path?: string;
computeParametersCount?: boolean;
hubUrl?: string;
credentials?: Credentials;
Expand All @@ -187,6 +196,7 @@ export async function parseSafetensorsMetadata(params: {
}): Promise<SafetensorsParseFromRepo>;
export async function parseSafetensorsMetadata(params: {
repo: RepoDesignation;
path?: string;
computeParametersCount?: boolean;
hubUrl?: string;
credentials?: Credentials;
Expand All @@ -203,17 +213,20 @@ export async function parseSafetensorsMetadata(params: {
throw new TypeError("Only model repos should contain safetensors files.");
}

if (await fileExists({ ...params, path: SINGLE_FILE })) {
const header = await parseSingleFile(SINGLE_FILE, params);
if (RE_SAFETENSORS_FILE.test(params.path ?? "") || (await fileExists({ ...params, path: SAFETENSORS_FILE }))) {
const header = await parseSingleFile(params.path ?? SAFETENSORS_FILE, params);
return {
sharded: false,
header,
...(params.computeParametersCount && {
parameterCount: computeNumOfParamsByDtypeSingleFile(header),
}),
};
} else if (await fileExists({ ...params, path: INDEX_FILE })) {
const { index, headers } = await parseShardedIndex(INDEX_FILE, params);
} else if (
RE_SAFETENSORS_INDEX_FILE.test(params.path ?? "") ||
(await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE }))
) {
const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params);
return {
sharded: true,
index,
Expand Down

0 comments on commit 3bd9297

Please sign in to comment.