From 665dd50fbcb04489b3d110e66c24d5be82f8c98b Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Mon, 9 Sep 2024 16:17:48 -0700 Subject: [PATCH] Name the forward pass thread in the trainer loop (#895) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/895 Internal # Context With the sched_ext effort we are trying to build custom Linux schedulers that provide a small performance boost to AI training and improve the resource isolation on the trainer hosts. The latter is necessary to avoid cases when noisy neighbor processes, like data loaders, slow down the GPU training. More details in this note: https://fb.workplace.com/notes/1118655556176038 By naming the forward pass thread we can use its name and assign it a higher priority at the linux scheduler level. The backward pass is named inside the Pytorch implementation but the forward pass needs to be named at the application level. We did the same thing in PyPer, APS, MVAI which are the largest trainer frameworks for reco models, consuming 70%+ of fleet level GPU hours for recommender systems. # This Diff Adds core lines ``` if torch.multiprocessing._get_thread_name() != "trainer_main": torch.multiprocessing._set_thread_name("trainer_main") ``` to train/eval/predict scripts. We can check the preexisting name to avoid renaming the same thread. Reviewed By: diego-urgell Differential Revision: D61924982 fbshipit-source-id: cad51567361d6cc33d2f7d662401178360ad605c --- tests/framework/test_fit.py | 40 +++++++++++++++++++++++++++++++++- torchtnt/framework/evaluate.py | 16 ++++++++++++++ torchtnt/framework/predict.py | 16 ++++++++++++++ torchtnt/framework/train.py | 15 +++++++++++++ 4 files changed, 86 insertions(+), 1 deletion(-) diff --git a/tests/framework/test_fit.py b/tests/framework/test_fit.py index 7f38abb861..f72d7fa82f 100644 --- a/tests/framework/test_fit.py +++ b/tests/framework/test_fit.py @@ -10,7 +10,7 @@ import math import unittest from typing import Tuple -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import torch from torch import nn @@ -20,9 +20,12 @@ from torchtnt.framework.state import ActivePhase, State from torchtnt.framework.unit import EvalUnit, TrainUnit, TTrainUnit from torchtnt.utils.timer import Timer +from torchtnt.utils.version import is_torch_version_geq class FitTest(unittest.TestCase): + TORCH_VERSION_GEQ_2_5_0: bool = is_torch_version_geq("2.5.0") + def test_fit_evaluate_every_n_epochs(self) -> None: """ Test fit entry point with evaluate_every_n_epochs=1 @@ -347,6 +350,41 @@ def test_error_message(self) -> None: log.output, ) + @unittest.skipUnless(TORCH_VERSION_GEQ_2_5_0, "test requires PyTorch 2.5.0+") + @patch( + "torch.multiprocessing._get_thread_name", side_effect=["foo", "trainer_main"] + ) + @patch("torch.multiprocessing._set_thread_name") + def test_fit_set_thread_name( + self, mock_set_thread_name: MagicMock, mock_get_thread_name: MagicMock + ) -> None: + """ + Test fit entry point with evaluate_every_n_epochs=1 + """ + input_dim = 2 + train_dataset_len = 10 + eval_dataset_len = 10 + batch_size = 1 + + my_unit = DummyFitUnit(input_dim=input_dim) + + train_dataloader = generate_random_dataloader( + train_dataset_len, input_dim, batch_size + ) + eval_dataloader = generate_random_dataloader( + eval_dataset_len, input_dim, batch_size + ) + + fit( + my_unit, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + max_epochs=1, + evaluate_every_n_epochs=1, + ) + self.assertEqual(mock_get_thread_name.call_count, 2) + mock_set_thread_name.assert_called_once() + class UnitWithError(TrainUnit[int], EvalUnit[int]): def train_step(self, state: State, data: int) -> None: diff --git a/torchtnt/framework/evaluate.py b/torchtnt/framework/evaluate.py index 50c263b7c9..8c61794219 100644 --- a/torchtnt/framework/evaluate.py +++ b/torchtnt/framework/evaluate.py @@ -24,6 +24,7 @@ from torchtnt.framework.unit import TEvalData, TEvalUnit from torchtnt.framework.utils import get_timing_context from torchtnt.utils.timer import get_timer_summary, TimerProtocol +from torchtnt.utils.version import is_torch_version_geq logger: logging.Logger = logging.getLogger(__name__) @@ -162,6 +163,21 @@ def _evaluate_impl( # clear step_output to avoid retaining extra memory eval_state._step_output = None + + if ( + eval_unit.eval_progress.num_steps_completed_in_epoch + - prev_steps_in_epoch + == 5 + ): + # Set the trainer thread name to improve debuggability. We do it after + # 5 iterations to make sure that all the processes or thread pools + # spawned / forked from the current process have already been created + # and the trainer_main characterizes only the CPU thread that runs the + # forward pass and schedules GPU work. + if is_torch_version_geq("2.5.0"): + if torch.multiprocessing._get_thread_name() != "trainer_main": + torch.multiprocessing._set_thread_name("trainer_main") + except StopIteration: stop_iteration_reached = True break diff --git a/torchtnt/framework/predict.py b/torchtnt/framework/predict.py index 499e5cdf4c..bf414ec72b 100644 --- a/torchtnt/framework/predict.py +++ b/torchtnt/framework/predict.py @@ -24,6 +24,7 @@ from torchtnt.framework.unit import TPredictData, TPredictUnit from torchtnt.framework.utils import get_timing_context from torchtnt.utils.timer import get_timer_summary, TimerProtocol +from torchtnt.utils.version import is_torch_version_geq logger: logging.Logger = logging.getLogger(__name__) @@ -170,6 +171,21 @@ def _predict_impl( # clear step_output to avoid retaining extra memory predict_state._step_output = None + + if ( + predict_unit.predict_progress.num_steps_completed_in_epoch + - prev_steps_in_epoch + == 5 + ): + # Set the trainer thread name to improve debuggability. We do it after + # 5 iterations to make sure that all the processes or thread pools + # spawned / forked from the current process have already been created + # and the trainer_main characterizes only the CPU thread that runs the + # forward pass and schedules GPU work. + if is_torch_version_geq("2.5.0"): + if torch.multiprocessing._get_thread_name() != "trainer_main": + torch.multiprocessing._set_thread_name("trainer_main") + except StopIteration: stop_iteration_reached = True break diff --git a/torchtnt/framework/train.py b/torchtnt/framework/train.py index e78320a9d7..5d7299c779 100644 --- a/torchtnt/framework/train.py +++ b/torchtnt/framework/train.py @@ -27,6 +27,7 @@ from torchtnt.framework.unit import TTrainData, TTrainUnit from torchtnt.framework.utils import get_timing_context from torchtnt.utils.timer import get_timer_summary, TimerProtocol +from torchtnt.utils.version import is_torch_version_geq logger: logging.Logger = logging.getLogger(__name__) @@ -221,6 +222,20 @@ def _train_epoch_impl( # clear step_output to avoid retaining extra memory train_state._step_output = None + if ( + train_unit.train_progress.num_steps_completed_in_epoch + - prev_steps_in_epoch + == 5 + ): + # Set the trainer thread name to improve debuggability. We do it after + # 5 iterations to make sure that all the processes or thread pools + # spawned / forked from the current process have already been created + # and the trainer_main characterizes only the CPU thread that runs the + # forward pass and schedules GPU work. + if is_torch_version_geq("2.5.0"): + if torch.multiprocessing._get_thread_name() != "trainer_main": + torch.multiprocessing._set_thread_name("trainer_main") + if ( evaluate_every_n_steps and train_unit.train_progress.num_steps_completed