From 9dc1142beee0f88090b0d5320770041d446bffae Mon Sep 17 00:00:00 2001 From: pppppM Date: Sun, 7 Apr 2024 16:16:18 +0800 Subject: [PATCH] remove old collate fn --- xtuner/dataset/hybrid/__init__.py | 3 +- xtuner/dataset/hybrid/collate.py | 54 ------------------------------- xtuner/dataset/hybrid/dataset.py | 36 +++++---------------- xtuner/model/text/__init__.py | 1 + xtuner/model/text/finetune.py | 1 - 5 files changed, 10 insertions(+), 85 deletions(-) delete mode 100644 xtuner/dataset/hybrid/collate.py diff --git a/xtuner/dataset/hybrid/__init__.py b/xtuner/dataset/hybrid/__init__.py index 3e91de358..fcd3922df 100644 --- a/xtuner/dataset/hybrid/__init__.py +++ b/xtuner/dataset/hybrid/__init__.py @@ -1,9 +1,8 @@ -from .collate import text_collate_fn +# Copyright (c) OpenMMLab. All rights reserved. from .dataset import TextDataset from .mappings import map_protocol, map_sequential, openai_to_raw_training __all__ = [ - 'text_collate_fn', 'TextDataset', 'map_protocol', 'map_sequential', diff --git a/xtuner/dataset/hybrid/collate.py b/xtuner/dataset/hybrid/collate.py deleted file mode 100644 index 25d1fc773..000000000 --- a/xtuner/dataset/hybrid/collate.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Sequence - -import torch -from torch.nn.utils.rnn import pad_sequence - -from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX - - -def text_collate_fn(instances: Sequence[Dict], - pad_index: int = DEFAULT_PAD_TOKEN_INDEX, - return_hf_format: bool = False): - - input_ids = [] - labels = [] - cumulative_len = [] - position_ids = [] - - for i, data in enumerate(instances): - input_ids.append(torch.LongTensor(data['input_ids'])) - labels.append(torch.LongTensor(data['labels'])) - position_ids.append(torch.IntTensor(data['position_ids'])) - - if 'cumulative_len' in data: - cumulative_len.append(torch.IntTensor(data['cumulative_len'])) - - if len(instances) > 1: - input_ids = pad_sequence( - input_ids, batch_first=True, padding_value=pad_index) - labels = pad_sequence( - labels, batch_first=True, padding_value=IGNORE_INDEX) - position_ids = pad_sequence( - position_ids, batch_first=True, padding_value=0) - else: - input_ids = torch.stack(input_ids) - labels = torch.stack(labels) - position_ids = torch.stack(position_ids) - - if len(cumulative_len) == 0: - cumulative_len = None - - # breakpoint() - data_dict = { - 'input_ids': input_ids, - 'position_ids': position_ids, - 'attention_mask': input_ids.ne(pad_index), - 'labels': labels, - 'cumulative_len': cumulative_len, - } - - if return_hf_format: - return data_dict - else: - return {'data': data_dict, 'data_samples': None} diff --git a/xtuner/dataset/hybrid/dataset.py b/xtuner/dataset/hybrid/dataset.py index 39fa2c100..d394e85eb 100644 --- a/xtuner/dataset/hybrid/dataset.py +++ b/xtuner/dataset/hybrid/dataset.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import functools import json import os @@ -270,8 +271,8 @@ def load_dataset( """ if self.is_cached(cache_dir): print_log( - f'{cache_dir} is cached dataset that will be loaded ' - 'directly; `data_files` and `data_dir` will become' + f'{cache_dir} is a cached dataset that will be loaded ' + 'directly; `data_files` and `data_dir` will become ' 'invalid.', logger='current') @@ -359,7 +360,7 @@ def tokenize_dataset(self, dataset: List[dict]) -> List[dict]: `labels` and `num_tokens`. `input_ids` and `labels` are lists of int, and they should have equal lengths. - `num_tokens` is an integer,the length of `input_ids`. + `num_tokens` is an integer, the length of `input_ids`. """ def openai_to_raw_training(item: dict) -> Dict: @@ -574,48 +575,27 @@ def __getitem__(self, item: int) -> Dict[str, List]: stop_words=['<|im_end|>'], ) - from xtuner.dataset.hybrid.mappings import openai_to_raw_training - - data_dir = './llava_data/LLaVA-Instruct-150K/' - image_dir = './llava_data/llava_images/' - data_files = 'llava_v1_5_mix665k.json' - dataset = TextDataset( 'internlm/internlm2-chat-1_8b', chat_template, sample_ratio=1, max_length=32 * 1024, - data_dir=data_dir, - data_files=data_files, + data_dir='converted_alpaca', + cache_dir='cached_alpaca', pack_to_max_length=True, - mappings=[openai_to_raw_training], num_proc=4) print(dataset[0]) - dataset.cache('cached_llava') - dataset = TextDataset( - 'internlm/internlm2-chat-1_8b', - chat_template, - sample_ratio=1, - max_length=32 * 1024, - cache_dir='cached_llava', - pack_to_max_length=True, - mappings=[ - openai_to_raw_training, - ], - num_proc=4) - print(dataset[0]) - from mmengine.dataset import DefaultSampler from torch.utils.data import DataLoader - from xtuner.dataset.hybrid.collate import text_collate_fn + from xtuner.model import TextFinetune loader = DataLoader( dataset, 4, num_workers=0, - collate_fn=text_collate_fn, + collate_fn=TextFinetune.dataloader_collate_fn, sampler=DefaultSampler(dataset, shuffle=True)) for data in tqdm(loader): diff --git a/xtuner/model/text/__init__.py b/xtuner/model/text/__init__.py index 3c5a77f77..65375fac6 100644 --- a/xtuner/model/text/__init__.py +++ b/xtuner/model/text/__init__.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. from .finetune import TextFinetune __all__ = ['TextFinetune'] diff --git a/xtuner/model/text/finetune.py b/xtuner/model/text/finetune.py index 4e31fdc25..3845c3936 100644 --- a/xtuner/model/text/finetune.py +++ b/xtuner/model/text/finetune.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. - from collections import OrderedDict from typing import Dict, List, Optional, Union