Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Jul 19, 2024
1 parent ba7afc7 commit a12239b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
7 changes: 5 additions & 2 deletions xtuner/model/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ def __init__(self,
self.beta = beta

if ref_llm is not None:
ref_llm = self._build_llm_from_cfg(ref_llm, kwargs.get("use_varlen_attn"), kwargs.get("max_position_embeddings"))
ref_llm = self.build_llm_from_cfg(
ref_llm, kwargs.get('use_varlen_attn', False),
kwargs.get('max_position_embeddings', None))
self.ref_llm = disable_grad(ref_llm)
else:
self.ref_llm = None if self.use_lora else create_reference_model(self.llm)
self.ref_llm = None if self.use_lora else create_reference_model(
self.llm)

def _gather_masked_logits(self, logits, labels, mask):
logits = torch.gather(
Expand Down
11 changes: 6 additions & 5 deletions xtuner/model/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def __init__(self,
max_position_embeddings=None):
super().__init__()

self.llm = self._build_llm_from_cfg(llm, use_varlen_attn, max_position_embeddings)
self.llm = self.build_llm_from_cfg(llm, use_varlen_attn,
max_position_embeddings)

if tokenizer is not None:
if isinstance(tokenizer, dict):
Expand Down Expand Up @@ -115,19 +116,19 @@ def __init__(self,
# the sequence.
self.use_varlen_attn = use_varlen_attn


def _build_llm_from_cfg(self, llm_cfg, use_varlen_attn, max_position_embeddings):
def build_llm_from_cfg(self, llm_cfg, use_varlen_attn,
max_position_embeddings):
# For forward
with LoadWoInit():
if isinstance(llm_cfg, dict):
llm = self._dispatch_lm_model_cfg(llm_cfg, max_position_embeddings)
llm = self._dispatch_lm_model_cfg(llm_cfg,
max_position_embeddings)
llm = self._build_from_cfg_or_module(llm)

llm.config.use_cache = False
dispatch_modules(llm, use_varlen_attn=use_varlen_attn)
return llm


def gradient_checkpointing_enable(self):
self.activation_checkpointing_enable()

Expand Down

0 comments on commit a12239b

Please sign in to comment.