Skip to content

Commit

Permalink
Support IterableDataset (#1596)
Browse files Browse the repository at this point in the history
  • Loading branch information
hjh0119 authored Aug 9, 2024
1 parent aca5a7c commit 5abab36
Show file tree
Hide file tree
Showing 16 changed files with 723 additions and 360 deletions.
1 change: 1 addition & 0 deletions docs/source/LLM/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- LLAVA模型: `https://github.com/haotian-liu/LLaVA.git`
- `--sft_type`: 表示微调的方式, 默认是`'lora'`. 你可以选择的值包括: 'lora', 'full', 'longlora', 'adalora', 'ia3', 'llamapro', 'adapter', 'vera', 'boft', 'fourierft'. 如果你要使用qlora, 你需设置`--sft_type lora --quantization_bit 4`.
- `--packing`: pack数据集到`max-length`, 默认值`False`.
- `--streaming`: 是否使用流式数据处理, 默认值`False`.
- `--freeze_parameters`: 当sft_type指定为'full'时, 将模型最底部的参数进行freeze. 指定范围为0. ~ 1., 默认为`0.`. 该参数提供了lora与全参数微调的折中方案.
- `--additional_trainable_parameters`: 作为freeze_parameters的补充, 只有在sft_type指定为'full'才允许被使用, 默认为`[]`. 例如你如果想训练50%的参数的情况下想额外训练embedding层, 你可以设置`--freeze_parameters 0.5 --additional_trainable_parameters transformer.wte`, 所有以`transformer.wte`开头的parameters都会被激活. 你也可以设置`--freeze_parameters 1 --additional_trainable_parameters xxx`来自定义可以训练的层.
- `--tuner_backend`: 表示lora, qlora的后端支持, 默认是`'peft'`. 你可以选择的值包括: 'swift', 'peft', 'unsloth'.
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/LLM/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- LLAVA model: `https://github.com/haotian-liu/LLaVA.git`
- `--sft_type`: Fine-tuning method, default is `'lora'`. Options include: 'lora', 'full', 'longlora', 'adalora', 'ia3', 'llamapro', 'adapter', 'vera', 'boft', 'fourierft'. If using qlora, you need to set `--sft_type lora --quantization_bit 4`.
- `--packing`: pack the dataset length to `max-length`, default `False`.
- `--streaming`: Whether to use iterable dataset, Default `False`.
- `--freeze_parameters`: When sft_type is set to 'full', freeze the bottommost parameters of the model. Range is 0. ~ 1., default is `0.`. This provides a compromise between lora and full fine-tuning.
- `--additional_trainable_parameters`: In addition to freeze_parameters, only allowed when sft_type is 'full', default is `[]`. For example, if you want to train embedding layer in addition to 50% of parameters, you can set `--freeze_parameters 0.5 --additional_trainable_parameters transformer.wte`, all parameters starting with `transformer.wte` will be activated. You can also set `--freeze_parameters 1 --additional_trainable_parameters xxx` to customize the trainable layers.
- `--tuner_backend`: Backend support for lora, qlora, default is `'peft'`. Options include: 'swift', 'peft', 'unsloth'.
Expand Down
18 changes: 12 additions & 6 deletions swift/llm/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
logger.info(f'args: {args}')
seed_everything(args.seed)
training_args = args.training_args
streaming = args.streaming
if is_torch_npu_available():
print(f'device_count: {torch.npu.device_count()}')
else:
Expand Down Expand Up @@ -169,6 +170,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
if val_dataset is None:
training_args.evaluation_strategy = IntervalStrategy.NO
training_args.do_eval = False
training_args.eval_strategy = IntervalStrategy.NO

template_kwargs = {}
template_info = TEMPLATE_MAPPING[args.template_type]
Expand All @@ -183,10 +185,10 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:

template: Template = get_template(
args.template_type, tokenizer, args.system, args.max_length, args.truncation_strategy, model=model)
if not template.support_multi_round and 'history' in train_dataset[0]:
if not template.support_multi_round and 'history' in next(iter(train_dataset)):
logger.info(
'The current template does not support multi-turn dialogue. The chatml template is used by default. \
You can also use the --model_type parameter to specify the template.')
You can also use the --model_type parameter to specify the template.')
template: Template = get_template(
'chatml', tokenizer, args.system, args.max_length, args.truncation_strategy, model=model)
args.system = template.default_system
Expand All @@ -206,6 +208,8 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
trainer_kwargs['is_vision'] = args.is_vision
model.config.model_type += '_' # add suffix to avoid checks in hfDPOTrainer

trainer_kwargs['streaming'] = streaming

trainer = trainer_cls(
model=model,
train_dataset=train_dataset,
Expand All @@ -227,7 +231,8 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
last_model_checkpoint = getattr(trainer.state, 'last_model_checkpoint', None)
logger.info(f'last_model_checkpoint: {last_model_checkpoint}')
logger.info(f'best_model_checkpoint: {trainer.state.best_model_checkpoint}')
train_time = get_time_info(trainer.state.log_history, len(train_dataset))
if not streaming:
train_time = get_time_info(trainer.state.log_history, len(train_dataset))
# Visualization
if is_master():
if 'tensorboard' in args.training_args.report_to:
Expand All @@ -239,15 +244,16 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
trainer.push_to_hub()
run_info = {
'memory': trainer.perf['memory'],
'train_time': train_time,
'last_model_checkpoint': last_model_checkpoint,
'best_model_checkpoint': trainer.state.best_model_checkpoint,
'best_metric': trainer.state.best_metric,
'global_step': trainer.state.global_step,
'log_history': trainer.state.log_history,
'model_info': model_info,
'dataset_info': trainer.dataset_info,
'model_info': model_info
}
if not streaming:
run_info.update({'train_time': train_time})
run_info.update({'dataset_info': trainer.dataset_info})
if is_master():
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
append_to_jsonl(jsonl_path, run_info)
Expand Down
35 changes: 24 additions & 11 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def _get_train_val_dataset(args: SftArguments) -> Tuple[HfDataset, Optional[HfDa
args.dataset_seed,
check_dataset_strategy=args.check_dataset_strategy,
model_name=args.model_name,
model_author=args.model_author)
model_author=args.model_author,
streaming=args.streaming,
streaming_val_size=args.streaming_val_size,
streaming_buffer_size=args.streaming_buffer_size)
if len(args.val_dataset) > 0:
# Loading val dataset
_, val_dataset = get_dataset(
Expand All @@ -45,7 +48,10 @@ def _get_train_val_dataset(args: SftArguments) -> Tuple[HfDataset, Optional[HfDa
args.dataset_seed,
check_dataset_strategy=args.check_dataset_strategy,
model_name=args.model_name,
model_author=args.model_author)
model_author=args.model_author,
streaming=args.streaming,
streaming_val_size=args.streaming_val_size,
streaming_buffer_size=args.streaming_buffer_size)

train_dataset, val_dataset = args._handle_dataset_compat(train_dataset, val_dataset)
logger.info(f'train_dataset: {train_dataset}')
Expand Down Expand Up @@ -111,6 +117,7 @@ def llm_sft_megatron(args: SftArguments) -> Dict[str, Any]:
def llm_sft(args: SftArguments) -> Dict[str, Any]:
logger.info(f'args: {args}')
is_generation = TEMPLATE_MAPPING[args.template_type].get('is_generation', False)
streaming = args.streaming
if is_generation and type(args) is SftArguments:
logger.warning(f"Please check if args.template_type: '{args.template_type}' is correct. "
'Currently, SFT is in progress, but the template is used for PT.')
Expand Down Expand Up @@ -267,7 +274,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
fsdp_flatten_parameters=False)

train_dataset, val_dataset = _get_train_val_dataset(args)
training_args.train_dataset_sample = train_dataset.shape[0] if train_dataset is not None else 0 # torchacc
if use_torchacc():
training_args.train_dataset_sample = train_dataset.shape[0] if train_dataset is not None else 0
template_kwargs = {}
template_kwargs['use_loss_scale'] = args.use_loss_scale
if args.loss_scale_config_path is not None:
Expand All @@ -288,6 +296,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
args.truncation_strategy,
model=model,
**template_kwargs)
if streaming:
template.encode = partial(template.encode, streaming=streaming)
args.system = template.default_system
logger.info(f'system: {args.system}')
logger.info(f'args.lazy_tokenize: {args.lazy_tokenize}')
Expand All @@ -307,10 +317,11 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
dataset_info['val_dataset'] = stat_dataset(val_dataset)
elif not args.lazy_tokenize:
dataset_info = {}
logger.info(f'Using num_proc: {args.preprocess_num_proc}')
train_dataset = dataset_map(train_dataset, template.encode, args.preprocess_num_proc)
if not streaming:
logger.info(f'Using num_proc: {args.preprocess_num_proc}')
train_dataset = dataset_map(train_dataset, template.encode, args.preprocess_num_proc, streaming=streaming)
if val_dataset is not None:
val_dataset = dataset_map(val_dataset, template.encode, args.preprocess_num_proc)
val_dataset = dataset_map(val_dataset, template.encode, args.preprocess_num_proc, streaming=streaming)
if args.test_oom_error:
train_dataset = sort_by_max_length(train_dataset, 20000)
# Data analysis
Expand All @@ -321,11 +332,11 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
raise AttributeError('Failed to access dataset attributes,train_dataset is None. This might be because:\n'
'(1) The dataset contains None for input or labels;\n'
"(2) The 'max_length' setting is too short causing data truncation.")
td0, tkwargs0 = train_dataset.data[0]
td0, tkwargs0 = train_dataset.data[0] if not streaming else (next(iter(train_dataset)), {})
print_example(td0, tokenizer, tkwargs0)
dataset_info['train_dataset'] = stat_dataset(train_dataset)
dataset_info['train_dataset'] = stat_dataset(train_dataset) if not streaming else None
if val_dataset is not None:
dataset_info['val_dataset'] = stat_dataset(val_dataset)
dataset_info['val_dataset'] = stat_dataset(val_dataset) if not streaming else None
else:
dataset_info = None
td0, tkwargs0 = template.encode(train_dataset[0])
Expand Down Expand Up @@ -395,7 +406,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
last_model_checkpoint = getattr(trainer.state, 'last_model_checkpoint', None)
logger.info(f'last_model_checkpoint: {last_model_checkpoint}')
logger.info(f'best_model_checkpoint: {trainer.state.best_model_checkpoint}')
train_time = get_time_info(trainer.state.log_history, len(train_dataset))
if not streaming:
train_time = get_time_info(trainer.state.log_history, len(train_dataset))
# Visualization
if is_master() and not use_torchacc():
if 'tensorboard' in args.training_args.report_to:
Expand All @@ -407,7 +419,6 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
trainer.push_to_hub()
run_info = {
'memory': trainer.perf['memory'],
'train_time': train_time,
'last_model_checkpoint': last_model_checkpoint,
'best_model_checkpoint': trainer.state.best_model_checkpoint,
'best_metric': trainer.state.best_metric,
Expand All @@ -416,6 +427,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
'model_info': model_info,
'dataset_info': dataset_info,
}
if not streaming:
run_info.update({'train_time': train_time})
for key in ['gen_time', 'gen_len']:
if trainer.perf[key] != 0:
run_info[key] = trainer.perf[key]
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
get_model_list_client_async, inference_client, inference_client_async)
from .dataset import (DATASET_MAPPING, DatasetName, HfDataset, get_dataset, get_dataset_from_repo,
load_dataset_from_local, load_ms_dataset, register_dataset, register_dataset_info,
register_local_dataset, sample_dataset)
register_local_dataset, sample_dataset, standard_keys)
from .media import MediaCache, MediaTag
from .model import (MODEL_MAPPING, GetModelTokenizerFunction, LoRATM, ModelType, get_additional_saved_files,
get_default_lora_target_modules, get_default_template_type, get_model_tokenizer,
Expand Down
67 changes: 61 additions & 6 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.distributed as dist
import transformers
from datasets import Dataset as HfDataset
from datasets import IterableDataset as HfIterableDataset
from datasets import concatenate_datasets
from packaging import version
from torch import dtype as Dtype
Expand All @@ -34,6 +35,7 @@
from .utils import is_lmdeploy_available, is_quant_model, is_vllm_available

logger = get_logger()
DATASET_TYPE = Union[HfDataset, HfIterableDataset]


def is_adapter(sft_type: str) -> bool:
Expand Down Expand Up @@ -374,11 +376,14 @@ def _register_self_cognition(self: Union['SftArguments', 'InferArguments']) -> N
'Representing the model name and model author in Chinese and English.')
setattr(self, k, v)

def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_dataset: Optional[HfDataset],
val_dataset: Optional[HfDataset]) -> Tuple[Optional[HfDataset], Optional[HfDataset]]:
def _handle_dataset_compat(
self: Union['SftArguments', 'InferArguments'], train_dataset: Optional[DATASET_TYPE],
val_dataset: Optional[DATASET_TYPE]) -> Tuple[Optional[DATASET_TYPE], Optional[DATASET_TYPE]]:
# compatibility. (Deprecated)
streaming = getattr(self, 'streaming', False)
random_state = np.random.RandomState(self.dataset_seed)
val_dataset_sample = self.val_dataset_sample

if train_dataset is not None and self.train_dataset_sample >= 0:
train_dataset_sample = min(self.train_dataset_sample, train_dataset.shape[0])
if train_dataset.shape[0] > train_dataset_sample:
Expand All @@ -388,10 +393,13 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_
if val_dataset_sample is None:
val_dataset_sample = max(int(train_dataset_sample * self.dataset_test_ratio), 1)
if val_dataset is not None and val_dataset_sample is not None and val_dataset_sample >= 0:
if val_dataset.shape[0] > val_dataset_sample:
if not streaming and val_dataset.shape[0] > val_dataset_sample:
logger.info(f'val_dataset_sample: {val_dataset_sample}')
val_idxs = random_state.permutation(val_dataset_sample)
val_dataset = val_dataset.select(val_idxs)
elif streaming:
val_dataset = val_dataset.shuffle(
seed=self.dataset_seed, buffer_size=self.streaming_buffer_size).take(val_dataset_sample)

if (train_dataset is None or not hasattr(self, 'train_dataset_mix_ratio') or self.train_dataset_mix_ratio <= 0
or len(self.train_dataset_mix_ds) == 0):
Expand All @@ -401,7 +409,11 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_
logger.info(f'train_dataset_mix_ds: {self.train_dataset_mix_ds}')
logger.info(f'len(train_dataset): {len(train_dataset)}, mix_dataset_sample: {mix_dataset_sample}')
mixed_dataset = get_dataset(
self.train_dataset_mix_ds, 0.0, random_state, check_dataset_strategy=self.check_dataset_strategy)[0]
self.train_dataset_mix_ds,
0.0,
random_state,
check_dataset_strategy=self.check_dataset_strategy,
streaming=streaming)[0]
if len(mixed_dataset) < mix_dataset_sample:
logger.warn(f'The length of dataset used for mixin: {self.train_dataset_mix_ds} are '
'lesser than the ratio required by the `train_dataset_mix_ratio` '
Expand Down Expand Up @@ -590,7 +602,10 @@ class SftArguments(ArgumentsBase):
max_length: int = 2048 # -1: no limit
truncation_strategy: Literal['delete', 'truncation_left'] = 'delete'
check_dataset_strategy: Literal['none', 'discard', 'error', 'warning'] = 'none'

# streaming dataset
streaming: bool = False
streaming_val_size: int = 0
streaming_buffer_size: int = 16384
# Chinese name and English name
model_name: List[str] = field(default_factory=lambda: [None, None], metadata={'help': "e.g. ['小黄', 'Xiao Huang']"})
model_author: List[str] = field(
Expand Down Expand Up @@ -1025,7 +1040,8 @@ def __post_init__(self) -> None:
if self.gradient_accumulation_steps is None:
self.gradient_accumulation_steps = math.ceil(16 / self.batch_size / self.world_size)
template_info = TEMPLATE_MAPPING[self.template_type]
if self.lazy_tokenize is None:
self._handle_streaming_args()
if self.lazy_tokenize is None and not self.streaming:
self.lazy_tokenize = template_info.get('lazy_tokenize', False)
logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}')
if self.dataloader_num_workers is None:
Expand Down Expand Up @@ -1095,6 +1111,9 @@ def _init_training_args(self) -> None:
else:
kwargs['evaluation_strategy'] = self.evaluation_strategy

if 'accelerator_config' in parameters:
kwargs['accelerator_config'] = {'dispatch_batches': False}

training_args = Seq2SeqTrainingArguments(
output_dir=self.output_dir,
logging_dir=self.logging_dir,
Expand Down Expand Up @@ -1181,6 +1200,42 @@ def _handle_pai_compat(self) -> None:
self.add_output_dir_suffix = False
logger.info(f'Setting args.add_output_dir_suffix: {self.add_output_dir_suffix}')

def _handle_streaming_args(self) -> None:
if not self.streaming:
return
if self.max_steps == -1:
raise ValueError('Please specify `max_steps` in streaming mode.')

if self.packing:
self.packing = False
logger.warning('Packing is not supported for streaming dataset, set to False')

if self.test_oom_error:
self.test_oom_error = False
logger.warning('test_oom_error is not supported for streaming dataset, set to False')

if self.lazy_tokenize:
self.lazy_tokenize = False
logger.info('lazy_tokenize set to False in streaming dataset')

if self.train_dataset_mix_ratio > 0:
logger.warning('train_dataset_mix_ratio is not supported for streaming dataset, set to 0')
self.train_dataset_mix_ratio = 0

if self.dataset_test_ratio > 0:
logger.info('Set dataset_test_ratio to 0 in streaming mode.'
'You can manually set val_dataset and val_dataset_sample.'
'or set streaming_val_size instead to split from train dataset')
self.dataset_test_ratio = 0

if self.train_dataset_sample > 0:
logger.warning('train_dataset_sample is not supported for streaming dataset, set to -1')
self.train_dataset_sample = -1

if self.dataloader_num_workers is None or self.dataloader_num_workers > 0:
logger.info('Set dataloader_num_workers to 0 in streaming mode')
self.dataloader_num_workers = 0


@dataclass
class InferArguments(ArgumentsBase):
Expand Down
Loading

0 comments on commit 5abab36

Please sign in to comment.