Skip to content

Commit

Permalink
Update LLaST implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenX17 committed Jul 13, 2024
1 parent 56fe3b5 commit 10531db
Show file tree
Hide file tree
Showing 19 changed files with 1,108 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
rev: 5.0.4
hooks:
- id: flake8
args: ["--exclude=xtuner/model/transformers_models/*"]
args: ["--exclude=xtuner/model/transformers_models/*,xtuner/evaluation/metrics/sacrebleu.py"]
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
Expand Down
3 changes: 2 additions & 1 deletion xtuner/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
load_intern_repo_tokenized_dataset,
load_intern_repo_untokenized_dataset)
from .json_dataset import load_json_file
from .llast import LLaSTDataset
from .llava import LLaVADataset
from .modelscope import process_ms_dataset
from .moss_sft import MOSSSFTDataset
Expand All @@ -24,5 +25,5 @@
'load_intern_repo_tokenized_dataset',
'load_intern_repo_untokenized_dataset', 'build_packed_dataset',
'RefCOCOJsonDataset', 'RefCOCOJsonEvalDataset', 'InvRefCOCOJsonDataset',
'load_json_file'
'load_json_file', 'LLaSTDataset'
]
5 changes: 4 additions & 1 deletion xtuner/dataset/collate_fns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .default_collate_fn import default_collate_fn
from .llast_collate_fn import llast_audiomask_mel_collate_fn
from .mmlu_collate_fn import mmlu_collate_fn

__all__ = ['default_collate_fn', 'mmlu_collate_fn']
__all__ = [
'default_collate_fn', 'mmlu_collate_fn', 'llast_audiomask_mel_collate_fn'
]
60 changes: 60 additions & 0 deletions xtuner/dataset/collate_fns/llast_collate_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) LLaST. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence

import torch
from torch.nn.utils.rnn import pad_sequence

from xtuner.utils import (DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX,
LLAST_AUDIO_PADDING_TOKEN_INDEX)


def llast_audiomask_mel_collate_fn(
instances: Sequence[Dict],
pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
return_hf_format: bool = False) -> Dict[str, torch.Tensor]:
"""Add audio tokens and conduct padding operation."""
input_ids = []
labels = []
feats_lens = []
has_audio = any(inst.get('audio_tokens') is not None for inst in instances)

if has_audio:
audio_tokens = []
for example in instances:
input_ids.append(torch.tensor(example['input_ids']))
labels.append(torch.tensor(example['labels']))
if has_audio:
audio_tokens.append(example['audio_tokens'])
feats_lens.append(torch.tensor(example['audio_lens']))
if len(instances) > 1:
input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=pad_index)
labels = pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX)
# padding audio tokens
padded_audio_tokens = pad_sequence(
audio_tokens,
batch_first=True,
padding_value=LLAST_AUDIO_PADDING_TOKEN_INDEX)

else:
input_ids = torch.stack(input_ids)
labels = torch.stack(labels)
padded_audio_tokens = torch.stack(audio_tokens)

data_dict = {
'input_ids': input_ids,
'attention_mask': input_ids.ne(pad_index),
'labels': labels
}

if has_audio:
audio_lens = torch.stack(feats_lens)
data_dict['audio_tokens'] = padded_audio_tokens
data_dict['audio_lens'] = audio_lens

if return_hf_format:
return data_dict
else:
return {'data': data_dict, 'data_samples': instances}
19 changes: 15 additions & 4 deletions xtuner/dataset/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def add_template_to_dataset(dataset, template_map_fn, map_num_proc):


def tokenize_dataset(dataset, tokenizer, max_length, with_image_token,
input_ids_with_output, remove_unused_columns,
map_num_proc):
with_audio_token, input_ids_with_output,
remove_unused_columns, map_num_proc):
assert (tokenizer is not None) and (max_length is not None), \
f'({tokenizer}, {max_length})'
if isinstance(tokenizer, dict) or isinstance(
Expand All @@ -78,6 +78,7 @@ def tokenize_dataset(dataset, tokenizer, max_length, with_image_token,
tokenizer=tokenizer,
max_length=max_length,
with_image_token=with_image_token,
with_audio_token=with_audio_token,
input_ids_with_output=input_ids_with_output),
remove_columns=list(dataset.column_names)
if remove_unused_columns else None,
Expand Down Expand Up @@ -112,6 +113,7 @@ def process(dataset,
use_varlen_attn=False,
input_ids_with_output=True,
with_image_token=False,
with_audio_token=False,
map_num_proc=32):
"""Post-process the dataset loaded from the Hugging Face Hub, or a local
dataset.
Expand Down Expand Up @@ -153,6 +155,9 @@ def process(dataset,
with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to
IMAGE_TOKEN_INDEX. Typically set it to True during the training
of VLM.
with_audio_token: Whether to convert DEFAULT_AUDIO_TOKEN to
LLAST_AUDIO_TOKEN_INDEX. Typically set it to True during the
training of SLM.
map_num_proc: Max number of processes when mapping the dataset.
"""
if use_varlen_attn:
Expand Down Expand Up @@ -197,7 +202,8 @@ def process(dataset,

if do_dataset_tokenization:
dataset = tokenize_dataset(dataset, tokenizer, max_length,
with_image_token, input_ids_with_output,
with_image_token, with_audio_token,
input_ids_with_output,
remove_unused_columns, map_num_proc)

if input_ids_with_output:
Expand All @@ -213,7 +219,7 @@ def process(dataset,
shuffle_before_pack, map_num_proc)

# add 'length'
dataset = dataset.map(get_lengths, num_proc=map_num_proc)
dataset = dataset.map(get_lengths, num_proc=1)
setattr(dataset, 'length', dataset['length'])

return dataset
Expand All @@ -234,6 +240,7 @@ def process_hf_dataset(dataset,
use_varlen_attn=False,
input_ids_with_output=True,
with_image_token=False,
with_audio_token=False,
map_num_proc=32):
"""Post-process the dataset loaded from the Hugging Face Hub, or a local
dataset.
Expand Down Expand Up @@ -275,6 +282,9 @@ def process_hf_dataset(dataset,
with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to
IMAGE_TOKEN_INDEX. Typically set it to True during the training
of VLM.
with_audio_token: Whether to convert DEFAULT_AUDIO_TOKEN to
LLAST_AUDIO_TOKEN_INDEX. Typically set it to True during the
training of SLM.
map_num_proc: Max number of processes when mapping the dataset.
"""
kwargs = dict(
Expand All @@ -293,6 +303,7 @@ def process_hf_dataset(dataset,
use_varlen_attn=use_varlen_attn,
input_ids_with_output=input_ids_with_output,
with_image_token=with_image_token,
with_audio_token=with_audio_token,
map_num_proc=map_num_proc)
if not (dist.is_available() and dist.is_initialized()):
return process(**kwargs)
Expand Down
Loading

0 comments on commit 10531db

Please sign in to comment.