Skip to content

Commit

Permalink
add doc in data.py
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Jul 21, 2023
1 parent 7720b36 commit 437329a
Showing 1 changed file with 53 additions and 27 deletions.
80 changes: 53 additions & 27 deletions optimum/gptq/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,35 @@


import random
from typing import Any, Dict, List, Union
from typing import Any, Dict, List

import numpy as np
import torch
from datasets import load_dataset


def prepare_dataset(
examples: List[Dict[str, Union[List[int], torch.LongTensor]]], pad_token_id: int = None, batch_size: int = 1
):
"""
Set of utilities for loading most used datasets (original dataset from GPTQ paper) and be able to easily use them during quantization
"""


def prepare_dataset(examples: List[Dict[str, torch.LongTensor]], batch_size: int = 1, pad_token_id: int = None):
"""
Prepare the dataset by making sure that we have the right format and `batch_size`
Args:
examples (`List[Dict[str, torch.LongTensor]]`):
List of data to prepare
batch_size (`int`, *optional*, defaults to `1`):
Batch size of the data
pad_token_id (`int`, *optional*, defaults to `None`):
Pad token id of the model
Returns:
`_type_`: _description_
"""
new_examples = []
for example in examples:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
if isinstance(input_ids, List):
input_ids = [input_ids]
if isinstance(attention_mask, List):
attention_mask = [attention_mask]
new_examples.append(
{"input_ids": torch.LongTensor(input_ids), "attention_mask": torch.LongTensor(attention_mask)}
)
Expand All @@ -47,8 +58,24 @@ def prepare_dataset(


def collate_data(
blocks: List[Dict[str, List[torch.LongTensor]]], pad_token_id: int = None, contain_labels: bool = False
blocks: List[Dict[str, torch.LongTensor]],
contain_labels: bool = False,
pad_token_id: int = None,
) -> Dict[str, torch.LongTensor]:
"""
Collate data in `blocks`
Args:
blocks (`List[Dict[str, torch.LongTensor]]`):
List of tensors that we need to batch together
pad_token_id (`int`, *optional*, defaults to `None`):
Pad token id of the model
contain_labels (`bool`, *optional*, defaults to `False`):
Set True to also process the labels
Returns:
`Dict[str, torch.LongTensor]`: Batched data
"""

def pad_block(block, pads):
return torch.cat((pads.to(block.device), block), dim=-1).long()

Expand Down Expand Up @@ -83,7 +110,7 @@ def pad_block(block, pads):
return data


def get_wikitext2(tokenizer, seqlen, nsamples, split="train"):
def get_wikitext2(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
if split == "train":
data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
elif split == "validation":
Expand All @@ -101,7 +128,7 @@ def get_wikitext2(tokenizer, seqlen, nsamples, split="train"):
return dataset


def get_c4(tokenizer, seqlen, nsamples, split="train"):
def get_c4(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
if split == "train":
data = load_dataset(
"allenai/c4", "allenai--c4", data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, split="train"
Expand Down Expand Up @@ -129,7 +156,7 @@ def get_c4(tokenizer, seqlen, nsamples, split="train"):
return dataset


def get_c4_new(tokenizer, seqlen, nsamples, split="train"):
def get_c4_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
if split == "train":
data = load_dataset(
"allenai/c4", "allenai--c4", data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, split="train"
Expand Down Expand Up @@ -157,7 +184,7 @@ def get_c4_new(tokenizer, seqlen, nsamples, split="train"):
return dataset


def get_ptb(tokenizer, seqlen, nsamples, split="train"):
def get_ptb(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
if split == "train":
data = load_dataset("ptb_text_only", "penn_treebank", split="train")
elif split == "validation":
Expand All @@ -176,7 +203,7 @@ def get_ptb(tokenizer, seqlen, nsamples, split="train"):
return dataset


def get_ptb_new(tokenizer, seqlen, nsamples, split="train"):
def get_ptb_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
if split == "train":
data = load_dataset("ptb_text_only", "penn_treebank", split="train")
elif split == "validation":
Expand All @@ -198,7 +225,7 @@ def get_dataset(
dataset_name: str, tokenizer: Any, nsamples: int = 128, seqlen: int = 2048, seed: int = 0, split: str = "train"
):
"""
Get the dataset from the original paper on GTPQ
Get the dataset from the original paper of GTPQ
Args:
dataset_name (`str`):
Expand All @@ -219,17 +246,16 @@ def get_dataset(
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
get_dataset_map = {
"wikitext2": get_wikitext2,
"c4": get_c4,
"c4-new": get_c4_new,
"ptb": get_ptb,
"ptb-new": get_ptb_new,
}
if split not in ["train", "test"]:
raise ValueError(f"The split need to be 'train' or 'validation' but found {split}")
if dataset_name == "wikitext2":
return get_wikitext2(tokenizer=tokenizer, nsamples=nsamples, seqlen=seqlen)
elif dataset_name == "c4":
return get_c4(tokenizer=tokenizer, nsamples=nsamples, seqlen=seqlen)
elif dataset_name == "c4-new":
return get_c4_new(tokenizer=tokenizer, nsamples=nsamples, seqlen=seqlen)
elif dataset_name == "ptb":
return get_ptb(tokenizer=tokenizer, nsamples=nsamples, seqlen=seqlen)
elif dataset_name == "ptb-new":
return get_ptb_new(tokenizer=tokenizer, nsamples=nsamples, seqlen=seqlen)
else:
raise ValueError(f"Expected a value in ['wikitext2','c4','ptb','c4-new','ptb-new'] but found {dataset_name}")
if dataset_name not in get_dataset_map:
raise ValueError(f"Expected a value in {list(get_dataset_map.keys())} but found {dataset_name}")
get_dataset_fn = get_dataset_map[dataset_name]
return get_dataset_fn(tokenizer=tokenizer, nsamples=nsamples, seqlen=seqlen)

0 comments on commit 437329a

Please sign in to comment.