Skip to content

Commit

Permalink
add pack dataset docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
pppppM committed Apr 7, 2024
1 parent 68afab4 commit cf8e8af
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 11 deletions.
66 changes: 58 additions & 8 deletions xtuner/dataset/hybrid/_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@


class _PackDataset(torch.utils.data.Dataset):
"""The new dataset obtained by concatenating multiple raw data.
Args:
dataset (datasets.Dataset): The tokenized dataset.
max_length (int): The length of each data after concatenation.
Note:
The original dataset's type must be `datasets.Dataset`, others will be
very slow.
Note:
The data in the original dataset must have the `num_tokens` key,
recording the number of tokens for each piece of data.
"""

def __init__(self, dataset, max_length=2048):
super().__init__()
Expand All @@ -17,66 +31,102 @@ def __init__(self, dataset, max_length=2048):

self._ori_lens = dataset['num_tokens']

# The number of data items after packing
self._num_packed_samples = sum(self._ori_lens) // self.max_length

# Shuffle the order of the original dataset
# The packing will proceed according to the order after shuffle.
# For example,
# shfl_inds = [3, 1, 2, 0],
# self._ori_lens[3] + self._ori_lens[1] = max_length,
# self._ori_lens[2] + self._ori_lens[0] = max_length,
# Ultimately, dataset[3] and dataset[1] will be combined into a new
# data, and dataset[2] and dataset[0] will be combined into a new data.
inds = [i for i in range(len(self.dataset))]
random.shuffle(inds)
self.shfl_inds = inds

shfl_lens = [self._ori_lens[i] for i in inds]
# shuffled cumulative lengths
shfl_lens = [self._ori_lens[i] for i in inds]
shfl_acc_lens = list(itertools.accumulate(shfl_lens))

self._shfl_item_rngs_left = [0] + shfl_acc_lens[:-1]
self._shfl_item_rngs_right = shfl_acc_lens

def _pack_ids_and_labels_in_range(self, begin, end):
def _pack_ids_and_labels_in_range(self, begin: int, end: int):
"""Packs ids and labels in a given range using bisection method.
Args:
begin: Index indicating the beginning of the range.
end: Index indicating the end of the range.
Returns:
A tuple containing packed ids, labels, and cumulative lengths.
"""

# Use binary search to find dataset positions that fall within begin
# and end range
left = bisect.bisect(self._shfl_item_rngs_right, begin)
right = bisect.bisect(self._shfl_item_rngs_left, end)

trunc_ids = []
trunc_labels = []
cumulative_len = []
position_ids = []

for i in range(left, right):
cumulative_len.append(len(trunc_ids))

# Determine the real range we will cut in current original item
item_begin = self._shfl_item_rngs_left[i]
item_end = self._shfl_item_rngs_right[i]

# Calculate exact positions within current dataset item
inner_l = max(begin, item_begin) - item_begin
inner_r = min(end, item_end) - item_begin
position_ids.extend([i for i in range(inner_r - inner_l)])

# Get original data and labels
ori_idx = self.shfl_inds[i]
ori_input_ids = self.dataset[ori_idx]['input_ids']
ori_labels = self.dataset[ori_idx]['labels']

# Add original data and labels from calculated positions
# to trunc_ids and trunc_labels
trunc_ids.extend(ori_input_ids[inner_l:inner_r])
trunc_labels.extend(ori_labels[inner_l:inner_r])

# append the max_length at end of cumulative lengths list
cumulative_len.append(self.max_length)
return trunc_ids, trunc_labels, cumulative_len, position_ids

# return populated lists of truncated ids, labels and their cumulative
# lengths
return trunc_ids, trunc_labels, cumulative_len

def __len__(self):
return self._num_packed_samples

def __getitem__(self, item):
"""Returns a dict containing packed data in the given item.
Args:
item: An index to retrieve packed data.
Returns:
A dict including packed input_ids, labels, and cumulative_len.
"""
# The cumulative length from the start position of this data
begin = item * self.max_length
# The cumulative length from the end position of this data
end = (item + 1) * self.max_length

# Extract data within the range from the shuffled original dataset.
_res = self._pack_ids_and_labels_in_range(begin, end)
packed_ids, packed_labels, cumulative_len, position_ids = _res
packed_ids, packed_labels, cumulative_len = _res
assert self.max_length == len(packed_ids) == len(packed_labels)

packed = {
'input_ids': packed_ids,
'labels': packed_labels,
'num_tokens': self.max_length,
'cumulative_len': cumulative_len,
'position_ids': position_ids
}

return packed
8 changes: 5 additions & 3 deletions xtuner/dataset/hybrid/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,11 @@ def tokenize_dataset(self, dataset: List[dict]) -> List[dict]:
Returns:
List[dict]:
Each dict in the list must contain two keys: `input_ids` and
`labels`. Both are lists of int, and they should have equal
lengths.
Each dict in the list must contain three keys: `input_ids`,
`labels` and `num_tokens`.
`input_ids` and `labels` are lists of int, and they should
have equal lengths.
`num_tokens` is an integer,the length of `input_ids`.
"""

def openai_to_raw_training(item: dict) -> Dict:
Expand Down

0 comments on commit cf8e8af

Please sign in to comment.