diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index 1b45859c9e..34a47ef52d 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -250,7 +250,7 @@ def sync( ): # this is based off the gather_all_tensors utility function in torchmetrics, except it works with non-tensor objects # (in particular, lists of strings). Link here: https://github.com/Lightning-AI/torchmetrics/blob/99d6d9d6ac4eb1b3398241df558604e70521e6b0/src/torchmetrics/utilities/distributed.py#L97-L148 - if should_sync: + if torch.distributed.is_initialized() and should_sync: print(f"Syncing") group = process_group or self.process_group world_size = torch.distributed.get_world_size(group) # pyright: ignore [reportGeneralTypeIssues]