diff --git a/tests/framework/test_fit.py b/tests/framework/test_fit.py index 7f38abb861..f72d7fa82f 100644 --- a/tests/framework/test_fit.py +++ b/tests/framework/test_fit.py @@ -10,7 +10,7 @@ import math import unittest from typing import Tuple -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import torch from torch import nn @@ -20,9 +20,12 @@ from torchtnt.framework.state import ActivePhase, State from torchtnt.framework.unit import EvalUnit, TrainUnit, TTrainUnit from torchtnt.utils.timer import Timer +from torchtnt.utils.version import is_torch_version_geq class FitTest(unittest.TestCase): + TORCH_VERSION_GEQ_2_5_0: bool = is_torch_version_geq("2.5.0") + def test_fit_evaluate_every_n_epochs(self) -> None: """ Test fit entry point with evaluate_every_n_epochs=1 @@ -347,6 +350,41 @@ def test_error_message(self) -> None: log.output, ) + @unittest.skipUnless(TORCH_VERSION_GEQ_2_5_0, "test requires PyTorch 2.5.0+") + @patch( + "torch.multiprocessing._get_thread_name", side_effect=["foo", "trainer_main"] + ) + @patch("torch.multiprocessing._set_thread_name") + def test_fit_set_thread_name( + self, mock_set_thread_name: MagicMock, mock_get_thread_name: MagicMock + ) -> None: + """ + Test fit entry point with evaluate_every_n_epochs=1 + """ + input_dim = 2 + train_dataset_len = 10 + eval_dataset_len = 10 + batch_size = 1 + + my_unit = DummyFitUnit(input_dim=input_dim) + + train_dataloader = generate_random_dataloader( + train_dataset_len, input_dim, batch_size + ) + eval_dataloader = generate_random_dataloader( + eval_dataset_len, input_dim, batch_size + ) + + fit( + my_unit, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + max_epochs=1, + evaluate_every_n_epochs=1, + ) + self.assertEqual(mock_get_thread_name.call_count, 2) + mock_set_thread_name.assert_called_once() + class UnitWithError(TrainUnit[int], EvalUnit[int]): def train_step(self, state: State, data: int) -> None: diff --git a/torchtnt/framework/evaluate.py b/torchtnt/framework/evaluate.py index 50c263b7c9..8c61794219 100644 --- a/torchtnt/framework/evaluate.py +++ b/torchtnt/framework/evaluate.py @@ -24,6 +24,7 @@ from torchtnt.framework.unit import TEvalData, TEvalUnit from torchtnt.framework.utils import get_timing_context from torchtnt.utils.timer import get_timer_summary, TimerProtocol +from torchtnt.utils.version import is_torch_version_geq logger: logging.Logger = logging.getLogger(__name__) @@ -162,6 +163,21 @@ def _evaluate_impl( # clear step_output to avoid retaining extra memory eval_state._step_output = None + + if ( + eval_unit.eval_progress.num_steps_completed_in_epoch + - prev_steps_in_epoch + == 5 + ): + # Set the trainer thread name to improve debuggability. We do it after + # 5 iterations to make sure that all the processes or thread pools + # spawned / forked from the current process have already been created + # and the trainer_main characterizes only the CPU thread that runs the + # forward pass and schedules GPU work. + if is_torch_version_geq("2.5.0"): + if torch.multiprocessing._get_thread_name() != "trainer_main": + torch.multiprocessing._set_thread_name("trainer_main") + except StopIteration: stop_iteration_reached = True break diff --git a/torchtnt/framework/predict.py b/torchtnt/framework/predict.py index 499e5cdf4c..bf414ec72b 100644 --- a/torchtnt/framework/predict.py +++ b/torchtnt/framework/predict.py @@ -24,6 +24,7 @@ from torchtnt.framework.unit import TPredictData, TPredictUnit from torchtnt.framework.utils import get_timing_context from torchtnt.utils.timer import get_timer_summary, TimerProtocol +from torchtnt.utils.version import is_torch_version_geq logger: logging.Logger = logging.getLogger(__name__) @@ -170,6 +171,21 @@ def _predict_impl( # clear step_output to avoid retaining extra memory predict_state._step_output = None + + if ( + predict_unit.predict_progress.num_steps_completed_in_epoch + - prev_steps_in_epoch + == 5 + ): + # Set the trainer thread name to improve debuggability. We do it after + # 5 iterations to make sure that all the processes or thread pools + # spawned / forked from the current process have already been created + # and the trainer_main characterizes only the CPU thread that runs the + # forward pass and schedules GPU work. + if is_torch_version_geq("2.5.0"): + if torch.multiprocessing._get_thread_name() != "trainer_main": + torch.multiprocessing._set_thread_name("trainer_main") + except StopIteration: stop_iteration_reached = True break diff --git a/torchtnt/framework/train.py b/torchtnt/framework/train.py index e78320a9d7..5d7299c779 100644 --- a/torchtnt/framework/train.py +++ b/torchtnt/framework/train.py @@ -27,6 +27,7 @@ from torchtnt.framework.unit import TTrainData, TTrainUnit from torchtnt.framework.utils import get_timing_context from torchtnt.utils.timer import get_timer_summary, TimerProtocol +from torchtnt.utils.version import is_torch_version_geq logger: logging.Logger = logging.getLogger(__name__) @@ -221,6 +222,20 @@ def _train_epoch_impl( # clear step_output to avoid retaining extra memory train_state._step_output = None + if ( + train_unit.train_progress.num_steps_completed_in_epoch + - prev_steps_in_epoch + == 5 + ): + # Set the trainer thread name to improve debuggability. We do it after + # 5 iterations to make sure that all the processes or thread pools + # spawned / forked from the current process have already been created + # and the trainer_main characterizes only the CPU thread that runs the + # forward pass and schedules GPU work. + if is_torch_version_geq("2.5.0"): + if torch.multiprocessing._get_thread_name() != "trainer_main": + torch.multiprocessing._set_thread_name("trainer_main") + if ( evaluate_every_n_steps and train_unit.train_progress.num_steps_completed