diff --git a/xtuner/dataset/collate_fns/preference_collate_fn.py b/xtuner/dataset/collate_fns/preference_collate_fn.py index ca21613bb..4b6a7f5c3 100644 --- a/xtuner/dataset/collate_fns/preference_collate_fn.py +++ b/xtuner/dataset/collate_fns/preference_collate_fn.py @@ -58,14 +58,14 @@ def preference_collate_fn(instances: Sequence[Dict], labels = torch.stack(labels) if use_varlen_attn: - attention_mask = torch.ones_like(input_ids).bool() + attention_mask = None position_ids = torch.stack(position_ids, dim=0) else: # Some tokenizers have the same eos token and pad token, so input_ids # cannot be masked directly based on the pad token id. attention_mask = torch.zeros_like(input_ids).bool() - for i in ori_length: - attention_mask[:i] = True + for i, length in enumerate(ori_length): + attention_mask[i, :length] = True bs, seq_len = input_ids.shape position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1) @@ -74,11 +74,12 @@ def preference_collate_fn(instances: Sequence[Dict], input_ids = pad_for_sequence_parallel(input_ids, pad_index) labels = pad_for_sequence_parallel(labels, IGNORE_INDEX) position_ids = pad_for_sequence_parallel(position_ids, 0) - # We use attention_mask to distinguish `input_ids` from - # (sequence parallel) pad tokens in `get_var_len_atten_logps` method of - # class `DPO` and `ORPO` - attention_mask = pad_for_sequence_parallel(attention_mask, 0) + if attention_mask is not None: + attention_mask = pad_for_sequence_parallel(attention_mask, 0) if use_varlen_attn: + # We use attention_mask to distinguish `input_ids` from + # (sequence parallel) pad tokens in `get_var_len_atten_logps` + # method of class `DPO` and `ORPO` (cumulative_len, attention_mask ) = pad_cumulative_len_for_sequence_parallel(cumulative_len)