diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index c841acb8eb..a326d1a7dd 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -660,7 +660,7 @@ def compute(self): for k in self.pass_at_k: results[f'pass@{k}'] = sum([self.estimator(n, c.item(), k) for c in self.correct]) / dataset_size - if len(results) == 0: # backwards compatibility - return results[0] + if len(results) == 1: # backwards compatibility + return list(results.values())[0] return results