Skip to content

Commit

Permalink
[Fix] fix initialization of ref_llm for full param dpo training with …
Browse files Browse the repository at this point in the history
…zero-3 (#778)

* Fix initialization of ref_llm

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update sft.py

* Update dpo.py

* Update dpo.py

* Update dpo.py
  • Loading branch information
xu-song authored Jul 19, 2024
1 parent 381d1c8 commit ba7afc7
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 27 deletions.
45 changes: 24 additions & 21 deletions xtuner/model/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,26 @@
from .sft import SupervisedFinetune


def disable_grad(model):
# freeze parameters
parameter_names = [n for n, _ in model.named_parameters()]
for param_name in parameter_names:
param = model.get_parameter(param_name)
param.requires_grad = False
return model.eval()


def create_reference_model(model):
if is_deepspeed_zero3_enabled():
raise ValueError('DeepSpeed ZeRO-3 is enabled and is not compatible '
'with `create_reference_model()`. Please instantiate '
'your reference model directly with '
'`AutoCausalLM.from_pretrained()`.')
ref_model = deepcopy(model)
ref_model = disable_grad(ref_model)
return ref_model


class DPO(SupervisedFinetune):
"""A general class of DPO and its variants."""

Expand All @@ -27,32 +47,15 @@ def __init__(self,
label_smoothing=0.0,
**kwargs):
super().__init__(llm, **kwargs)
self.ref_llm = ref_llm
self.loss_type = loss_type
self.label_smoothing = label_smoothing
self.beta = beta

if not self.use_lora:
self.ref_llm = self.create_reference_model(ref_llm, **kwargs)

def create_reference_model(self, ref_llm=None, **kwargs):
ref_model = None
if ref_llm is None:
if is_deepspeed_zero3_enabled():
raise ValueError(
'DeepSpeed ZeRO-3 is enabled and is not compatible '
'with `deepcopy(self.llm)`. Please instantiate '
'your reference model by modifying key `model.ref_llm` '
'in your config with `AutoCausalLM.from_pretrained()`.')
ref_model = deepcopy(self.llm)
if ref_llm is not None:
ref_llm = self._build_llm_from_cfg(ref_llm, kwargs.get("use_varlen_attn"), kwargs.get("max_position_embeddings"))
self.ref_llm = disable_grad(ref_llm)
else:
ref_model = SupervisedFinetune(ref_llm, **kwargs).llm
# freeze parameters
parameter_names = [n for n, _ in ref_model.named_parameters()]
for param_name in parameter_names:
param = ref_model.get_parameter(param_name)
param.requires_grad = False
return ref_model.eval()
self.ref_llm = None if self.use_lora else create_reference_model(self.llm)

def _gather_masked_logits(self, logits, labels, mask):
logits = torch.gather(
Expand Down
21 changes: 15 additions & 6 deletions xtuner/model/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,15 @@ def __init__(self,
tokenizer=None,
max_position_embeddings=None):
super().__init__()
with LoadWoInit():
if isinstance(llm, dict):
llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
self.llm = self._build_from_cfg_or_module(llm)

self.llm = self._build_llm_from_cfg(llm, use_varlen_attn, max_position_embeddings)

if tokenizer is not None:
if isinstance(tokenizer, dict):
tokenizer = BUILDER.build(tokenizer)
smart_tokenizer_and_embedding_resize(tokenizer, self.llm)

self.llm.config.use_cache = False
dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn)

if use_activation_checkpointing:
# For backward compatibility
if hasattr(self.llm, 'enable_input_require_grads'):
Expand Down Expand Up @@ -119,6 +115,19 @@ def __init__(self,
# the sequence.
self.use_varlen_attn = use_varlen_attn


def _build_llm_from_cfg(self, llm_cfg, use_varlen_attn, max_position_embeddings):
# For forward
with LoadWoInit():
if isinstance(llm_cfg, dict):
llm = self._dispatch_lm_model_cfg(llm_cfg, max_position_embeddings)
llm = self._build_from_cfg_or_module(llm)

llm.config.use_cache = False
dispatch_modules(llm, use_varlen_attn=use_varlen_attn)
return llm


def gradient_checkpointing_enable(self):
self.activation_checkpointing_enable()

Expand Down

0 comments on commit ba7afc7

Please sign in to comment.