Skip to content

Commit

Permalink
Merge branch 'main' into refactor-llm
Browse files Browse the repository at this point in the history
  • Loading branch information
pppppM committed Mar 29, 2024
2 parents 5fffd8c + 7de581c commit 0f31481
Show file tree
Hide file tree
Showing 15 changed files with 457 additions and 401 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
🔍 Explore our models on
[![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🤗%20Huggingface)](https://huggingface.co/xtuner)
[![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🤖%20ModelScope)](https://www.modelscope.cn/organization/xtuner)
[![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🧰%20OpenXLab)](https://openxlab.org.cn/usercenter/xtuner)

English | [简体中文](README_zh-CN.md)

Expand Down
6 changes: 2 additions & 4 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ tiktoken
# limit pytorch version <= 2.1.2 as there may be some bugs in triton 2.2
torch<=2.1.2
torchvision<=0.16.2
# Minimum 4.34.0 to support added_tokens_decoder of tokenizer
# Exclude 4.34.1, 4.35.0, 4.35.1, 4.35.2 to avoid BC-break,
# see https://github.com/huggingface/transformers/pull/27020, https://github.com/huggingface/transformers/pull/27073
transformers>=4.34.0,!=4.34.1,!=4.35.0,!=4.35.1,!=4.35.2
# Minimum 4.36.0 to support `Cache` data structure used by KV Cache
transformers>=4.36.0
transformers_stream_generator
12 changes: 8 additions & 4 deletions xtuner/engine/hooks/evaluate_chat_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(self,
every_n_iters=None,
max_new_tokens=600,
stop_word=None,
stop_words=[]):
stop_words=[],
generation_kwargs={}):
self.evaluation_inputs = evaluation_inputs
if isinstance(self.evaluation_inputs, str):
self.evaluation_inputs = [self.evaluation_inputs]
Expand Down Expand Up @@ -69,8 +70,9 @@ def __init__(self,
if image_processor is not None:
self.image_processor = BUILDER.build(image_processor)
self.stop_criteria = StoppingCriteriaList()

# default generation config
self.gen_config = GenerationConfig(
default_generation_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.1,
Expand All @@ -79,8 +81,10 @@ def __init__(self,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None else
self.tokenizer.eos_token_id,
)
self.tokenizer.eos_token_id)
default_generation_kwargs.update(generation_kwargs)
self.gen_config = GenerationConfig(**default_generation_kwargs)

self.stop_criteria = StoppingCriteriaList()
for word in stop_words:
self.stop_criteria.append(
Expand Down
102 changes: 82 additions & 20 deletions xtuner/engine/hooks/throughput_hook.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Optional, Union

import torch
from mmengine import print_log
from mmengine.hooks import Hook
from mmengine.model.wrappers import is_model_wrapper
from torch.utils._pytree import tree_flatten

from xtuner.parallel.sequence import get_sequence_parallel_world_size

DATA_BATCH = Optional[Union[dict, tuple, list]]


Expand All @@ -20,12 +24,39 @@ def __init__(self,
hidden_size=None,
num_layers=None,
vocab_size=None,
mlp_ratio=None):
mlp_ratio=None,
is_casual=None):
self.use_activation_checkpointing = use_activation_checkpointing
self.hidden_size = hidden_size
self.num_layers = num_layers
self.vocab_size = vocab_size
self.mlp_ratio = mlp_ratio
self.is_casual = is_casual

@staticmethod
def _guess_is_casual_attn(model):
for module in model.modules():
if hasattr(module, 'is_causal'):
return module.is_causal
print_log(
'It\'s impossible to speculate whether casual attention was used, '
'and FLOPs will be calculated as `casual = True`.', 'current')
return True

@staticmethod
def _get_batch_size_and_sequence_len(data_batch):
data_list, _ = tree_flatten(data_batch)
for data in data_list:
if isinstance(data, torch.Tensor):
return data.size(0), data.size(1)
raise RuntimeError('No tensor found in the batch')

@staticmethod
def _guess_use_activation_checkpointing(model):
for module in model.modules():
if hasattr(module, 'gradient_checkpointing'):
return module.gradient_checkpointing
return False

def before_run(self, runner) -> None:
if is_model_wrapper(runner.model):
Expand All @@ -41,20 +72,18 @@ def before_run(self, runner) -> None:
self.mlp_ratio = self.mlp_ratio or (model.config.intermediate_size /
model.config.hidden_size)
self.mlp_ratio *= 1.5 # has gate_proj
return
self.is_casual = self.is_casual if self.is_casual is not None \
else self._guess_is_casual_attn(model)

def _get_batch_size_and_sequence_len(self, data_batch):
data_list, _ = tree_flatten(data_batch)
for data in data_list:
if isinstance(data, torch.Tensor):
return data.size(0), data.size(1)
raise RuntimeError('No tensor found in the batch')
use_varlen_attn = getattr(model, 'use_varlen_attn', False)
if use_varlen_attn:
print_log(
'Using variable-length Flash Attention causes an inflation'
' in the FLOPs calculation.',
'current',
level=logging.WARNING)

def _guess_use_activation_checkpointing(self, model):
for module in model.modules():
if hasattr(module, 'gradient_checkpointing'):
return module.gradient_checkpointing
return False
return

def after_train_iter(self,
runner,
Expand All @@ -66,17 +95,50 @@ def after_train_iter(self,

batch_size, sequence_len = self._get_batch_size_and_sequence_len(
data_batch)
sequence_parallel_size = get_sequence_parallel_world_size()

message_hub = runner.message_hub
iter_time = message_hub.get_scalar('train/time').current()

flops_per_iteration = (
(3 + int(self.use_activation_checkpointing)) *
((8 + self.mlp_ratio * 4) * batch_size * sequence_len *
self.hidden_size**2 +
4 * batch_size * sequence_len**2 * self.hidden_size)
) * self.num_layers + \
6 * batch_size * sequence_len * self.hidden_size * self.vocab_size
# We consider a language model with 𝑙 transformer layers,
# hidden size h, sequence length s, vocabulary size V, and
# training batch size B.
# A $A_{mxk}$ x $X_{kxn}$ matrix multiplication requires 2𝑚 ×𝑘 ×𝑛 FLOPs
# (factor of 2 needed to account for multiplies and adds).

# Attention Layer:
# qkv_proj + o_proj: 8B * s * h^2
# attn: 2B * s^2 * h (casual=False) and 2B * s^2 * h / 2 (casual=True)

# MLP Layer:
# up_proj + down_proj + gate_proj: 4B * s * h^2 * mlp_ratio
# (In Llama mlp_ratio = intermediate_size / hidden_size * 1.5
# (has gate_proj))

# The backward pass requires double the number of FLOPs since we
# need to calculate the gradients with respect to both input and
# weight tensors. In addition, we are using activation recomputation,
# which requires an additional forward pass before the backward pass.

# While sequence parallel will affect the FLOPs calculation in attn.
# Suppose the sequence length in one GPU is s and the sequence
# parallel world size is `sp_size`, which means the total
# sequence length in the attention calculation is
# `s * sp_size` and the number of attention heads decrease to
# `num_heads / sp_size`. Hence, the FLOPs in attn calculation is:
# 2B * (s * sp_size)^2 * (h / sp_size) (casual=False) and
# 2B * (s * sp_size)^2 * (h / sp_size) / 2 (casual=True)

flops_qkvo_proj = 8 * batch_size * sequence_len * self.hidden_size**2
flops_attn = 4 * batch_size * sequence_len**2 * self.hidden_size * \
sequence_parallel_size / (int(self.is_casual) + 1)
flops_mlp = 4 * self.mlp_ratio * batch_size * sequence_len * \
self.hidden_size**2
flops_wo_head = (3 + int(self.use_activation_checkpointing)) * (
flops_qkvo_proj + flops_attn + flops_mlp) * self.num_layers
flops_head = 3 * 2 * batch_size * sequence_len * self.hidden_size * \
self.vocab_size
flops_per_iteration = flops_wo_head + flops_head

avg_tflops_per_gpu = flops_per_iteration / 1e12 / (iter_time + 1e-12)
tokens_per_sec_per_gpu = batch_size * sequence_len / (
Expand Down
33 changes: 10 additions & 23 deletions xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Mapping, Optional, Sequence, Union
from typing import Optional, Union

import torch
import torch.distributed as dist
from mmengine import MessageHub
from mmengine.hooks import Hook
Expand All @@ -11,20 +10,6 @@

class VarlenAttnArgsToMessageHubHook(Hook):

args = ('cumulative_len', 'max_seqlen')

def cast_data(self, data):
if isinstance(data, Mapping):
return {key: self.cast_data(data[key]) for key in data}
elif isinstance(data, (str, bytes)) or data is None:
return data
elif isinstance(data, Sequence):
return type(data)(self.cast_data(sample) for sample in data) # type: ignore # noqa: E501 # yapf:disable
elif isinstance(data, torch.Tensor):
return data.cuda()
else:
return data

def before_train_iter(self,
runner,
batch_idx: int,
Expand All @@ -35,10 +20,13 @@ def before_train_iter(self,
assert 'data' in data_batch.keys()
data = data_batch['data']

for arg in self.args:
assert arg in data
message_hub.update_info(f'{arg}_rank_{rank}',
self.cast_data(data.pop(arg)))
cumulative_len = data.pop('cumulative_len')
assert len(cumulative_len) == 1
cumulative_len = cumulative_len[0].cuda()
message_hub.update_info(f'cumulative_len_rank_{rank}', cumulative_len)

max_seqlen = data.pop('max_seqlen')
message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen)

def after_train_iter(self,
runner,
Expand All @@ -47,6 +35,5 @@ def after_train_iter(self,
outputs: Optional[dict] = None) -> None:
rank = dist.get_rank()
message_hub = MessageHub.get_instance('varlen_attn_args')

for arg in self.args:
message_hub.update_info(f'{arg}_rank_{rank}', None)
message_hub.update_info(f'cumulative_len_rank_{rank}', None)
message_hub.update_info(f'max_seqlen_rank_{rank}', None)
66 changes: 65 additions & 1 deletion xtuner/model/llava.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from collections import OrderedDict

import torch
import torch.nn as nn
from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel
from peft import get_peft_model, prepare_model_for_kbit_training
from transformers import AutoConfig

from xtuner.registry import BUILDER
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
from .utils import (LoadWoInit, find_all_linear_names,
get_peft_model_state_dict, guess_load_checkpoint,
make_inputs_require_grad,
Expand All @@ -26,11 +30,15 @@ def __init__(self,
projector_depth=2,
llm_lora=None,
visual_encoder_lora=None,
use_activation_checkpointing=True):
use_activation_checkpointing=True,
max_position_embeddings=None):
super().__init__()
self.freeze_llm = freeze_llm
self.freeze_visual_encoder = freeze_visual_encoder
with LoadWoInit():
if isinstance(llm, dict):
llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)

self.llm = self._build_from_cfg_or_module(llm)
self.visual_encoder = self._build_from_cfg_or_module(
visual_encoder)
Expand Down Expand Up @@ -157,6 +165,62 @@ def state_dict(self, *args, **kwargs):
for k, v in state_dict.items() if 'projector.' in k})
return to_return

@staticmethod
def _prepare_for_long_context_training(cfg, llm_cfg,
max_position_embeddings):

orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
if orig_rope_scaling is None:
orig_rope_scaling = {'factor': 1}

orig_rope_scaling_factor = orig_rope_scaling[
'factor'] if 'factor' in orig_rope_scaling.keys() else 1
orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
if orig_ctx_len:
orig_ctx_len *= orig_rope_scaling_factor
if max_position_embeddings > orig_ctx_len:
scaling_factor = float(
math.ceil(max_position_embeddings / orig_ctx_len))
llm_cfg.rope_scaling = {
'type': 'linear',
'factor': scaling_factor
}

# hardcode for internlm2
llm_cfg.attn_implementation = 'flash_attention_2'
cfg.config = llm_cfg

return cfg, llm_cfg

@staticmethod
def _prepare_for_flash_attn(cfg, llm_cfg):
cls_name = type(llm_cfg).__name__
SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
'MixtralConfig', 'Qwen2Config',
'Starcoder2Config', 'Starcoder2Config')
SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
'MistralConfig', 'MixtralConfig', 'Qwen2Config',
'Starcoder2Config', 'Starcoder2Config')

if SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
cfg.torch_dtype = torch.bfloat16 \
if torch.cuda.is_bf16_supported() else torch.float16
cfg.attn_implementation = 'flash_attention_2'
elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
cfg.attn_implementation = 'sdpa'

return cfg, llm_cfg

def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
llm_cfg = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True)
cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
if max_position_embeddings is not None:
cfg, llm_cfg = self._prepare_for_long_context_training(
cfg, llm_cfg, max_position_embeddings)
return cfg

def _build_from_cfg_or_module(self, cfg_or_mod):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
Expand Down
10 changes: 6 additions & 4 deletions xtuner/model/modules/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .yi import yi_attn_forward

IS_LOW_VERSION_TRANSFORMERS = digit_version(
transformers.__version__) < digit_version('4.36')
transformers.__version__) < digit_version('4.38')
SUPPORT_FLASH1 = digit_version(torch.__version__) >= digit_version('2.0.0')
SUPPORT_FLASH2 = False

Expand Down Expand Up @@ -48,7 +48,7 @@ def dispatch_llama_attn_forward(model, use_varlen_attn):
if use_varlen_attn:
assert SUPPORT_FLASH2 and SUPPORT_TRITON, \
'flash_attn and triton is required if you want to use varlen_attn.'
elif not SUPPORT_FLASH:
elif not SUPPORT_FLASH2:
return

from .llama import (llama_attn_forward, llama_attn_forward_legacy,
Expand All @@ -57,8 +57,10 @@ def dispatch_llama_attn_forward(model, use_varlen_attn):

print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING)
for module in model.modules():
if type(module).__name__ in ('LlamaAttention', 'LlamaFlashAttention2',
'LlamaSdpaAttention'):
# Do not need to dispatch if
# type(module).__name__ == 'LlamaSdpaAttention', as flash_attn is
# required when using sequence parallel
if type(module).__name__ in ('LlamaAttention', 'LlamaFlashAttention2'):
if use_varlen_attn:
print_log('dispatch llama varlen attn forward', 'current')
if IS_LOW_VERSION_TRANSFORMERS:
Expand Down
Loading

0 comments on commit 0f31481

Please sign in to comment.