Skip to content

Commit

Permalink
Fix device_id in mmc4 pass (#250)
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-awadalla authored Aug 28, 2023
1 parent 9bebf9f commit a05dcba
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions open_flamingo/train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def train_one_epoch(
with autocast():
loss_mmc4 = model(
vision_x=images,
lang_x=input_ids,
attention_mask=attention_mask,
lang_x=input_ids.to(device_id),
attention_mask=attention_mask.to(device_id),
labels=labels,
)[0]

Expand Down

0 comments on commit a05dcba

Please sign in to comment.