diff --git a/tests/utils/loggers/test_tensorboard.py b/tests/utils/loggers/test_tensorboard.py index 00da45b781..709c437cf0 100644 --- a/tests/utils/loggers/test_tensorboard.py +++ b/tests/utils/loggers/test_tensorboard.py @@ -9,17 +9,13 @@ from __future__ import annotations -import os import tempfile import unittest from unittest.mock import Mock, patch -import torch.distributed.launcher as launcher from tensorboard.backend.event_processing.event_accumulator import EventAccumulator -from torch import distributed as dist from torchtnt.utils.loggers.tensorboard import TensorBoardLogger -from torchtnt.utils.test_utils import get_pet_launch_config, skip_if_not_distributed class TensorBoardLoggerTest(unittest.TestCase): @@ -74,26 +70,6 @@ def test_log_rank_zero(self: TensorBoardLoggerTest) -> None: logger = TensorBoardLogger(path=log_dir) self.assertEqual(logger.writer, None) - @staticmethod - def _test_distributed() -> None: - dist.init_process_group("gloo") - rank = dist.get_rank() - with tempfile.TemporaryDirectory() as log_dir: - test_path = "correct" - invalid_path = "invalid" - if rank == 0: - logger = TensorBoardLogger(os.path.join(log_dir, test_path)) - else: - logger = TensorBoardLogger(os.path.join(log_dir, invalid_path)) - - assert test_path in logger.path - assert invalid_path not in logger.path - - @skip_if_not_distributed - def test_multiple_workers(self: TensorBoardLoggerTest) -> None: - config = get_pet_launch_config(2) - launcher.elastic_launch(config, entrypoint=self._test_distributed)() - def test_add_scalars_call_is_correctly_passed_to_summary_writer( self: TensorBoardLoggerTest, ) -> None: diff --git a/torchtnt/utils/loggers/tensorboard.py b/torchtnt/utils/loggers/tensorboard.py index 95ceb137ca..5f360c3216 100644 --- a/torchtnt/utils/loggers/tensorboard.py +++ b/torchtnt/utils/loggers/tensorboard.py @@ -53,9 +53,8 @@ class TensorBoardLogger(MetricLogger): def __init__(self: TensorBoardLogger, path: str, *args: Any, **kwargs: Any) -> None: self._writer: Optional[SummaryWriter] = None - + self._path: str = path self._rank: int = get_global_rank() - self._sync_path_to_workers(path) if self._rank == 0: logger.info( @@ -69,22 +68,6 @@ def __init__(self: TensorBoardLogger, path: str, *args: Any, **kwargs: Any) -> N atexit.register(self.close) - def _sync_path_to_workers(self: TensorBoardLogger, path: str) -> None: - if not (dist.is_available() and dist.is_initialized()): - self._path: str = path - return - - pg = PGWrapper(dist.group.WORLD) - path_container: List[str] = [path] if self._rank == 0 else [""] - pg.broadcast_object_list(path_container, 0) - updated_path = path_container[0] - if updated_path != path: - # because the logger only logs on rank 0, if users pass in a different path - # the logger will output the wrong `path` property, so we update it to match - # the correct path. - logger.info(f"Updating TensorBoard path to match rank 0: {updated_path}") - self._path: str = updated_path - @property def writer(self: TensorBoardLogger) -> Optional[SummaryWriter]: return self._writer