Skip to content

Commit

Permalink
Name the forward pass thread in the trainer loop (#895)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #895

Internal
# Context
With the sched_ext effort we are trying to build custom Linux schedulers that provide a small performance boost to AI training and improve the resource isolation on the trainer hosts. The latter is necessary to avoid cases when noisy neighbor processes, like data loaders, slow down the GPU training.

More details in this note: https://fb.workplace.com/notes/1118655556176038

By naming the forward pass thread we can use its name and assign it a higher priority at the linux scheduler level. The backward pass is named inside the Pytorch implementation but the forward pass needs to be named at the application level.

We did the same thing in PyPer, APS, MVAI which are the largest trainer frameworks for reco models, consuming 70%+ of fleet level GPU hours for recommender systems.

# This Diff
Adds core lines
```
if torch.multiprocessing._get_thread_name() != "trainer_main":
    torch.multiprocessing._set_thread_name("trainer_main")
```
to train/eval/predict scripts. We can check the preexisting name to avoid renaming the same thread.

Reviewed By: diego-urgell

Differential Revision: D61924982

fbshipit-source-id: cad51567361d6cc33d2f7d662401178360ad605c
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Sep 9, 2024
1 parent b5b0b03 commit 665dd50
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 1 deletion.
40 changes: 39 additions & 1 deletion tests/framework/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import math
import unittest
from typing import Tuple
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import torch
from torch import nn
Expand All @@ -20,9 +20,12 @@
from torchtnt.framework.state import ActivePhase, State
from torchtnt.framework.unit import EvalUnit, TrainUnit, TTrainUnit
from torchtnt.utils.timer import Timer
from torchtnt.utils.version import is_torch_version_geq


class FitTest(unittest.TestCase):
TORCH_VERSION_GEQ_2_5_0: bool = is_torch_version_geq("2.5.0")

def test_fit_evaluate_every_n_epochs(self) -> None:
"""
Test fit entry point with evaluate_every_n_epochs=1
Expand Down Expand Up @@ -347,6 +350,41 @@ def test_error_message(self) -> None:
log.output,
)

@unittest.skipUnless(TORCH_VERSION_GEQ_2_5_0, "test requires PyTorch 2.5.0+")
@patch(
"torch.multiprocessing._get_thread_name", side_effect=["foo", "trainer_main"]
)
@patch("torch.multiprocessing._set_thread_name")
def test_fit_set_thread_name(
self, mock_set_thread_name: MagicMock, mock_get_thread_name: MagicMock
) -> None:
"""
Test fit entry point with evaluate_every_n_epochs=1
"""
input_dim = 2
train_dataset_len = 10
eval_dataset_len = 10
batch_size = 1

my_unit = DummyFitUnit(input_dim=input_dim)

train_dataloader = generate_random_dataloader(
train_dataset_len, input_dim, batch_size
)
eval_dataloader = generate_random_dataloader(
eval_dataset_len, input_dim, batch_size
)

fit(
my_unit,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_epochs=1,
evaluate_every_n_epochs=1,
)
self.assertEqual(mock_get_thread_name.call_count, 2)
mock_set_thread_name.assert_called_once()


class UnitWithError(TrainUnit[int], EvalUnit[int]):
def train_step(self, state: State, data: int) -> None:
Expand Down
16 changes: 16 additions & 0 deletions torchtnt/framework/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchtnt.framework.unit import TEvalData, TEvalUnit
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.timer import get_timer_summary, TimerProtocol
from torchtnt.utils.version import is_torch_version_geq

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -162,6 +163,21 @@ def _evaluate_impl(

# clear step_output to avoid retaining extra memory
eval_state._step_output = None

if (
eval_unit.eval_progress.num_steps_completed_in_epoch
- prev_steps_in_epoch
== 5
):
# Set the trainer thread name to improve debuggability. We do it after
# 5 iterations to make sure that all the processes or thread pools
# spawned / forked from the current process have already been created
# and the trainer_main characterizes only the CPU thread that runs the
# forward pass and schedules GPU work.
if is_torch_version_geq("2.5.0"):
if torch.multiprocessing._get_thread_name() != "trainer_main":
torch.multiprocessing._set_thread_name("trainer_main")

except StopIteration:
stop_iteration_reached = True
break
Expand Down
16 changes: 16 additions & 0 deletions torchtnt/framework/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchtnt.framework.unit import TPredictData, TPredictUnit
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.timer import get_timer_summary, TimerProtocol
from torchtnt.utils.version import is_torch_version_geq

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -170,6 +171,21 @@ def _predict_impl(

# clear step_output to avoid retaining extra memory
predict_state._step_output = None

if (
predict_unit.predict_progress.num_steps_completed_in_epoch
- prev_steps_in_epoch
== 5
):
# Set the trainer thread name to improve debuggability. We do it after
# 5 iterations to make sure that all the processes or thread pools
# spawned / forked from the current process have already been created
# and the trainer_main characterizes only the CPU thread that runs the
# forward pass and schedules GPU work.
if is_torch_version_geq("2.5.0"):
if torch.multiprocessing._get_thread_name() != "trainer_main":
torch.multiprocessing._set_thread_name("trainer_main")

except StopIteration:
stop_iteration_reached = True
break
Expand Down
15 changes: 15 additions & 0 deletions torchtnt/framework/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torchtnt.framework.unit import TTrainData, TTrainUnit
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.timer import get_timer_summary, TimerProtocol
from torchtnt.utils.version import is_torch_version_geq

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -221,6 +222,20 @@ def _train_epoch_impl(
# clear step_output to avoid retaining extra memory
train_state._step_output = None

if (
train_unit.train_progress.num_steps_completed_in_epoch
- prev_steps_in_epoch
== 5
):
# Set the trainer thread name to improve debuggability. We do it after
# 5 iterations to make sure that all the processes or thread pools
# spawned / forked from the current process have already been created
# and the trainer_main characterizes only the CPU thread that runs the
# forward pass and schedules GPU work.
if is_torch_version_geq("2.5.0"):
if torch.multiprocessing._get_thread_name() != "trainer_main":
torch.multiprocessing._set_thread_name("trainer_main")

if (
evaluate_every_n_steps
and train_unit.train_progress.num_steps_completed
Expand Down

0 comments on commit 665dd50

Please sign in to comment.