Skip to content

Commit

Permalink
remove path sync in TensorBoardLogger
Browse files Browse the repository at this point in the history
Reviewed By: galrotem

Differential Revision: D56645859

fbshipit-source-id: b14b03c2440876b50e835d4df4f8ff48658451e2
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Apr 26, 2024
1 parent 0159a07 commit 82a6b62
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 42 deletions.
24 changes: 0 additions & 24 deletions tests/utils/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,13 @@

from __future__ import annotations

import os
import tempfile
import unittest
from unittest.mock import Mock, patch

import torch.distributed.launcher as launcher
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from torch import distributed as dist

from torchtnt.utils.loggers.tensorboard import TensorBoardLogger
from torchtnt.utils.test_utils import get_pet_launch_config, skip_if_not_distributed


class TensorBoardLoggerTest(unittest.TestCase):
Expand Down Expand Up @@ -74,26 +70,6 @@ def test_log_rank_zero(self: TensorBoardLoggerTest) -> None:
logger = TensorBoardLogger(path=log_dir)
self.assertEqual(logger.writer, None)

@staticmethod
def _test_distributed() -> None:
dist.init_process_group("gloo")
rank = dist.get_rank()
with tempfile.TemporaryDirectory() as log_dir:
test_path = "correct"
invalid_path = "invalid"
if rank == 0:
logger = TensorBoardLogger(os.path.join(log_dir, test_path))
else:
logger = TensorBoardLogger(os.path.join(log_dir, invalid_path))

assert test_path in logger.path
assert invalid_path not in logger.path

@skip_if_not_distributed
def test_multiple_workers(self: TensorBoardLoggerTest) -> None:
config = get_pet_launch_config(2)
launcher.elastic_launch(config, entrypoint=self._test_distributed)()

def test_add_scalars_call_is_correctly_passed_to_summary_writer(
self: TensorBoardLoggerTest,
) -> None:
Expand Down
19 changes: 1 addition & 18 deletions torchtnt/utils/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ class TensorBoardLogger(MetricLogger):

def __init__(self: TensorBoardLogger, path: str, *args: Any, **kwargs: Any) -> None:
self._writer: Optional[SummaryWriter] = None

self._path: str = path
self._rank: int = get_global_rank()
self._sync_path_to_workers(path)

if self._rank == 0:
logger.info(
Expand All @@ -69,22 +68,6 @@ def __init__(self: TensorBoardLogger, path: str, *args: Any, **kwargs: Any) -> N

atexit.register(self.close)

def _sync_path_to_workers(self: TensorBoardLogger, path: str) -> None:
if not (dist.is_available() and dist.is_initialized()):
self._path: str = path
return

pg = PGWrapper(dist.group.WORLD)
path_container: List[str] = [path] if self._rank == 0 else [""]
pg.broadcast_object_list(path_container, 0)
updated_path = path_container[0]
if updated_path != path:
# because the logger only logs on rank 0, if users pass in a different path
# the logger will output the wrong `path` property, so we update it to match
# the correct path.
logger.info(f"Updating TensorBoard path to match rank 0: {updated_path}")
self._path: str = updated_path

@property
def writer(self: TensorBoardLogger) -> Optional[SummaryWriter]:
return self._writer
Expand Down

0 comments on commit 82a6b62

Please sign in to comment.