diff --git a/composer/models/huggingface.py b/composer/models/huggingface.py index eeb85a415e..5f3eef3176 100644 --- a/composer/models/huggingface.py +++ b/composer/models/huggingface.py @@ -96,13 +96,21 @@ def forward(self, batch): return output def loss(self, outputs, batch): - return outputs['loss'] + if self.config.use_return_dict: + return outputs['loss'] + else: + # loss is at index 0 in the output tuple + return outputs[0] def eval_forward(self, batch, outputs: Optional[Any] = None): output = outputs if outputs else self.forward(batch) if self.use_logits: self.labels = batch.pop('labels') - output = output['logits'] + if self.config.use_return_dict: + output = output['logits'] + else: + # logits are at index 1 in the output tuple + output = output[1] # if we are in the single class case, then remove the classes dimension if output.shape[1] == 1: