Skip to content

Commit

Permalink
Fix zero3 compatibility issue for DPO (#781)
Browse files Browse the repository at this point in the history
* fix zero3

* reformat

* reformat

* reformat

* reformat
  • Loading branch information
Johnson-Wang authored Jun 20, 2024
1 parent 5f2bca4 commit 72b645f
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions xtuner/model/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,6 @@
from .sft import SupervisedFinetune


def create_reference_model(model):
if is_deepspeed_zero3_enabled():
raise ValueError('DeepSpeed ZeRO-3 is enabled and is not compatible '
'with `create_reference_model()`. Please instantiate '
'your reference model directly with '
'`AutoCausalLM.from_pretrained()`.')

parameter_names = [n for n, _ in model.named_parameters()]
ref_model = deepcopy(model)

# if no layers are shared, return copy of model
for param_name in parameter_names:
param = ref_model.get_parameter(param_name)
param.requires_grad = False
return ref_model.eval()


class DPO(SupervisedFinetune):
"""A general class of DPO and its variants."""

Expand All @@ -50,7 +33,26 @@ def __init__(self,
self.beta = beta

if not self.use_lora:
self.ref_llm = create_reference_model(self.llm)
self.ref_llm = self.create_reference_model(ref_llm, **kwargs)

def create_reference_model(self, ref_llm=None, **kwargs):
ref_model = None
if ref_llm is None:
if is_deepspeed_zero3_enabled():
raise ValueError(
'DeepSpeed ZeRO-3 is enabled and is not compatible '
'with `deepcopy(self.llm)`. Please instantiate '
'your reference model by modifying key `model.ref_llm` '
'in your config with `AutoCausalLM.from_pretrained()`.')
ref_model = deepcopy(self.llm)
else:
ref_model = SupervisedFinetune(ref_llm, **kwargs).llm
# freeze parameters
parameter_names = [n for n, _ in ref_model.named_parameters()]
for param_name in parameter_names:
param = ref_model.get_parameter(param_name)
param.requires_grad = False
return ref_model.eval()

def _gather_masked_logits(self, logits, labels, mask):
logits = torch.gather(
Expand Down

0 comments on commit 72b645f

Please sign in to comment.