Skip to content

Commit

Permalink
fixing multi-gpu sync
Browse files Browse the repository at this point in the history
  • Loading branch information
pic-o committed Feb 6, 2024
1 parent 0654172 commit e02907a
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,21 +915,22 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool = False, is_valid
total_mask_sum = torch.sum(seq_mask)
avg_mask_sum = ( total_mask_sum / B )

# Do a quick return, if there is no tokens of value to learn from due to full masking
if num_devices > 1 and total_mask_sum == 0:
return 0
# # Do a quick return, if there is no tokens of value to learn from due to full masking
# # DO NOT DO THIS : This causes multi node / multi GPU to go out of sync
# if num_devices <= 1 and total_mask_sum == 0:
# return 0

# Checkpoint steps
def checkpointed_step(idx, targets, mask, last_shift_states,
last_wkv_states):
# Skip if there is no tokens of value to learn from
if idx.shape[1] == 0:
# Prepare dummy loss
train_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_()
sample_loss = train_loss.clone().detach().requires_grad_(False)

# Return the checkpoint values
return sample_loss, train_loss, last_shift_states, last_wkv_states, 0
# # Skip if there is no tokens of value to learn from
# if idx.shape[1] == 0:
# # Prepare dummy loss
# train_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_()
# sample_loss = train_loss.clone().detach().requires_grad_(False)

# # Return the checkpoint values
# return sample_loss, train_loss, last_shift_states, last_wkv_states, 0

# Get the logits, and the new states
logits, new_shift_states, new_wkv_states = self(
Expand Down

0 comments on commit e02907a

Please sign in to comment.