Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix initialization of ref_llm for full param dpo training with zero-3 #778

Merged
merged 9 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions xtuner/model/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,26 @@
from .sft import SupervisedFinetune


def disable_grad(model):
# freeze parameters
parameter_names = [n for n, _ in model.named_parameters()]
for param_name in parameter_names:
param = model.get_parameter(param_name)
param.requires_grad = False
return model.eval()


def create_reference_model(model):
if is_deepspeed_zero3_enabled():
raise ValueError('DeepSpeed ZeRO-3 is enabled and is not compatible '
'with `create_reference_model()`. Please instantiate '
'your reference model directly with '
'`AutoCausalLM.from_pretrained()`.')
ref_model = deepcopy(model)
ref_model = disable_grad(ref_model)
return ref_model


class DPO(SupervisedFinetune):
"""A general class of DPO and its variants."""

Expand All @@ -27,32 +47,15 @@ def __init__(self,
label_smoothing=0.0,
**kwargs):
super().__init__(llm, **kwargs)
self.ref_llm = ref_llm
self.loss_type = loss_type
self.label_smoothing = label_smoothing
self.beta = beta

if not self.use_lora:
self.ref_llm = self.create_reference_model(ref_llm, **kwargs)

def create_reference_model(self, ref_llm=None, **kwargs):
ref_model = None
if ref_llm is None:
if is_deepspeed_zero3_enabled():
raise ValueError(
'DeepSpeed ZeRO-3 is enabled and is not compatible '
'with `deepcopy(self.llm)`. Please instantiate '
'your reference model by modifying key `model.ref_llm` '
'in your config with `AutoCausalLM.from_pretrained()`.')
ref_model = deepcopy(self.llm)
if ref_llm is not None:
ref_llm = self._build_llm_from_cfg(ref_llm, kwargs.get("use_varlen_attn"), kwargs.get("max_position_embeddings"))
self.ref_llm = disable_grad(ref_llm)
else:
ref_model = SupervisedFinetune(ref_llm, **kwargs).llm
# freeze parameters
parameter_names = [n for n, _ in ref_model.named_parameters()]
for param_name in parameter_names:
param = ref_model.get_parameter(param_name)
param.requires_grad = False
return ref_model.eval()
self.ref_llm = None if self.use_lora else create_reference_model(self.llm)

def _gather_masked_logits(self, logits, labels, mask):
logits = torch.gather(
Expand Down
21 changes: 15 additions & 6 deletions xtuner/model/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,15 @@ def __init__(self,
tokenizer=None,
max_position_embeddings=None):
super().__init__()
with LoadWoInit():
if isinstance(llm, dict):
llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
self.llm = self._build_from_cfg_or_module(llm)

self.llm = self._build_llm_from_cfg(llm, use_varlen_attn, max_position_embeddings)

if tokenizer is not None:
if isinstance(tokenizer, dict):
tokenizer = BUILDER.build(tokenizer)
smart_tokenizer_and_embedding_resize(tokenizer, self.llm)

self.llm.config.use_cache = False
dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn)

if use_activation_checkpointing:
# For backward compatibility
if hasattr(self.llm, 'enable_input_require_grads'):
Expand Down Expand Up @@ -119,6 +115,19 @@ def __init__(self,
# the sequence.
self.use_varlen_attn = use_varlen_attn


def _build_llm_from_cfg(self, llm_cfg, use_varlen_attn, max_position_embeddings):
# For forward
with LoadWoInit():
if isinstance(llm_cfg, dict):
llm = self._dispatch_lm_model_cfg(llm_cfg, max_position_embeddings)
llm = self._build_from_cfg_or_module(llm)

llm.config.use_cache = False
dispatch_modules(llm, use_varlen_attn=use_varlen_attn)
return llm


def gradient_checkpointing_enable(self):
self.activation_checkpointing_enable()

Expand Down
Loading