Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
josejg committed Feb 13, 2024
1 parent e143dc6 commit e5758de
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,14 +582,13 @@ def estimator(self, n: int, c: int, k: int) -> float:
return 1.0 - float(np.prod(1.0 - k / np.arange(n - c + 1, n + 1)))

def _initialize_state(self, batch: dict[str, Any]):
dataset_size = batch['dataset_size']
device = batch['input_ids'].device
self.dataset_size = batch['dataset_size']
self.pass_at_k = batch['pass_at_k']
self.num_generations = batch['generations_per_sample']

# We need to defer the accumulator initialization because it depends on dataset size
self.add_state('correct', default=torch.zeros(dataset_size, device=device), dist_reduce_fx='sum')
self.add_state('total', default=torch.zeros(dataset_size, device=device), dist_reduce_fx='sum')
self.add_state('correct', default=torch.zeros(self.dataset_size), dist_reduce_fx='sum')
self.add_state('total', default=torch.zeros(self.dataset_size), dist_reduce_fx='sum')
dist.barrier()
self._initialized = True

Expand Down Expand Up @@ -655,13 +654,14 @@ def update(self, batch: Dict[str, Any], outputs: List[str], labels: List[str]):
def compute(self):
assert isinstance(self.correct, Tensor)
assert isinstance(self.total, Tensor)
assert (self.total == self.num_generations).all().item()
if not (self.total == self.num_generations).all().item():
raise ValueError(f"Some samples in the dataset have less than {self.num_generations} generations")

results = {}
dataset_size = len(self.correct)
n = self.num_generations

for k in self.pass_at_k:
results[f'pass@{k}'] = sum([self.estimator(n, c.item(), k) for c in self.correct]) / dataset_size
results[f'pass@{k}'] = sum([self.estimator(n, c.item(), k) for c in self.correct]) / self.dataset_size

if len(results) == 1: # backwards compatibility
return list(results.values())[0]
Expand Down

0 comments on commit e5758de

Please sign in to comment.