Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gguf: Add type definitions for split.* metadata + sanity check #679

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions packages/gguf/src/gguf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,9 @@ export async function ggufAllShards(
*/
fetch?: typeof fetch;
additionalFetchHeaders?: Record<string, string>;
allowLocalFile?: boolean;
}
): Promise<{ shards: GGUFParseOutput[]; parameterCount: number }> {
): Promise<GGUFParseOutput & { parameterCount: number }> {
const ggufShardFileInfo = parseGgufShardFilename(url);
if (ggufShardFileInfo) {
const total = parseInt(ggufShardFileInfo.total);
Expand All @@ -414,15 +415,48 @@ export async function ggufAllShards(

const PARALLEL_DOWNLOADS = 20;
const shards = await promisesQueue(
urls.map((shardUrl) => () => gguf(shardUrl, { ...params, computeParametersCount: true })),
urls.map((shardUrl) => async () => {
const output = await gguf(shardUrl, { ...params, computeParametersCount: true });
return output;
}),
PARALLEL_DOWNLOADS
);

// Sanity check split.count parameter
const output: GGUFParseOutput<{ strict: false }> = shards[0];
const splitCount = output.metadata["split.count"];
if (splitCount !== shards.length) {
throw new Error(`Expect to "split.count" to be ${shards.length}, but got ${splitCount}`);
}

// Sanity check split.no parameter
for (let i = 0; i < shards.length; i++) {
const shard = shards[i];
if (!shard.metadata["split.count"]) {
continue;
}
const splitNo = shard.metadata["split.no"];
if (splitNo !== i) {
throw new Error(`Expect to "split.no" to be ${i}, but got ${splitNo}`);
} else if (i > 0) {
// skip first shard (already added)
output.tensorInfos = [...output.tensorInfos, ...shard.tensorInfos];
}
}

// Sanity check split.tensors.count parameter
const splitTensorsCount = output.metadata["split.tensors.count"];
if (splitTensorsCount !== output.tensorInfos.length) {
throw new Error(
`Expect to "split.tensors.count" to be ${output.tensorInfos.length}, but got ${splitTensorsCount}`
);
}

return {
shards,
...output,
parameterCount: shards.map(({ parameterCount }) => parameterCount).reduce((acc, val) => acc + val, 0),
};
} else {
const { metadata, tensorInfos, parameterCount } = await gguf(url, { ...params, computeParametersCount: true });
return { shards: [{ metadata, tensorInfos }], parameterCount };
return await gguf(url, { ...params, computeParametersCount: true });
}
}
7 changes: 7 additions & 0 deletions packages/gguf/src/types.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,12 @@ describe("gguf-types", () => {
// @ts-expect-error llama does not have ssm.* keys
model["mamba.ssm.conv_kernel"] = 0;
}

if (model["split.count"]) {
model["split.no"] = 123;
} else {
// @ts-expect-error not a split (shard) model
model["split.no"] = 123;
}
});
});
16 changes: 15 additions & 1 deletion packages/gguf/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ interface NoTokenizer {
"tokenizer.ggml.model"?: undefined;
}

/// Splits

interface Splits {
// Index of the current split (couting from 0)
"split.no": number;
// Total number of splits (couting from 1)
"split.count": number;
// Total number of tensors from all splits
"split.tensors.count": number;
}
interface NoSplits {
"split.count"?: undefined;
}

/// Models outside of llama.cpp: "rwkv" and "whisper"

export type RWKV = GGUFGeneralInfo<"rwkv"> &
Expand Down Expand Up @@ -126,7 +140,7 @@ export type GGUFMetadata<Options extends GGUFMetadataOptions = { strict: true }>
} & GGUFModelKV &
(Options extends { strict: true } ? unknown : Record<string, MetadataValue>);

export type GGUFModelKV = (NoModelMetadata | ModelMetadata) & (NoTokenizer | Tokenizer);
export type GGUFModelKV = (NoModelMetadata | ModelMetadata) & (NoTokenizer | Tokenizer) & (Splits | NoSplits);

export interface GGUFTensorInfo {
name: string;
Expand Down
Loading