Skip to content

Commit

Permalink
add unit_utils (#638)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #638

# Context
Having `torchtnt/framework/utils` is a bit misleading since one might think these are external utils, while at the moment it's more of a combination of internal and external (get_timing_context), so for the sake of discoverability, it makes sense to split `framework/utils.py`

# This diff
Extract loop related utils from `framework/utils.py` to `framework/_unit_utils.py.py`

Reviewed By: JKSenthil

Differential Revision: D51593728

fbshipit-source-id: 389ebec0fbaa133c1fa04345b5902f823f289679
  • Loading branch information
galrotem authored and facebook-github-bot committed Nov 29, 2023
1 parent ba46e08 commit affb136
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 114 deletions.
84 changes: 84 additions & 0 deletions tests/framework/test_unit_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Dict, Iterator

import torch

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import Optimizer
from torchtnt.framework._unit_utils import (
_find_optimizers_for_module,
_step_requires_iterator,
)
from torchtnt.framework.state import State
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import spawn_multi_process


class UnitUtilsTest(unittest.TestCase):
cuda_available: bool = torch.cuda.is_available()
distributed_available: bool = torch.distributed.is_available()

def test_step_func_requires_iterator(self) -> None:
class Foo:
def bar(self, state: State, data: object) -> object:
return data

def baz(self, state: State, data: Iterator[torch.Tensor]) -> object:
pass

def dummy(a: int, b: str, data: Iterator[str]) -> None:
pass

foo = Foo()

self.assertFalse(_step_requires_iterator(foo.bar))
self.assertTrue(_step_requires_iterator(foo.baz))
self.assertTrue(_step_requires_iterator(dummy))

def test_find_optimizers_for_module(self) -> None:
module1 = torch.nn.Linear(10, 10)
module2 = torch.nn.Linear(10, 10)
optim1 = torch.optim.Adam(module1.parameters())
optim2 = torch.optim.Adagrad(module2.parameters())

opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2}
optimizers = _find_optimizers_for_module(module1, opts)
optim_name, _ = optimizers[0]
self.assertEqual(optim_name, "optim1")
optimizers = _find_optimizers_for_module(module2, opts)
optim_name, _ = optimizers[0]
self.assertEqual(optim_name, "optim2")

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_find_optimizers_for_FSDP_module(self) -> None:
spawn_multi_process(2, "nccl", self._find_optimizers_for_FSDP_module)

@staticmethod
def _find_optimizers_for_FSDP_module() -> None:
device = init_from_env()
module1 = FSDP(torch.nn.Linear(10, 10).to(device))
module2 = torch.nn.Linear(10, 10)
optim1 = torch.optim.Adam(module1.parameters())
optim2 = torch.optim.Adagrad(module2.parameters())

opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2}
optim_list = _find_optimizers_for_module(module1, opts)
optim_name, _ = optim_list[0]

tc = unittest.TestCase()
tc.assertEqual(optim_name, "optim1")
optim_list = _find_optimizers_for_module(module2, opts)
optim_name, _ = optim_list[0]
tc.assertEqual(optim_name, "optim2")
74 changes: 1 addition & 73 deletions tests/framework/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,14 @@

import time
import unittest
from typing import Dict, Iterator
from unittest.mock import MagicMock, patch

import torch
from torchtnt.framework.utils import get_timing_context

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import Optimizer
from torchtnt.framework.state import State
from torchtnt.framework.utils import (
_find_optimizers_for_module,
_step_requires_iterator,
get_timing_context,
)
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import spawn_multi_process
from torchtnt.utils.timer import Timer


class UtilsTest(unittest.TestCase):
cuda_available: bool = torch.cuda.is_available()
distributed_available: bool = torch.distributed.is_available()

def test_step_func_requires_iterator(self) -> None:
class Foo:
def bar(self, state: State, data: object) -> object:
return data

def baz(self, state: State, data: Iterator[torch.Tensor]) -> object:
pass

def dummy(a: int, b: str, data: Iterator[str]) -> None:
pass

foo = Foo()

self.assertFalse(_step_requires_iterator(foo.bar))
self.assertTrue(_step_requires_iterator(foo.baz))
self.assertTrue(_step_requires_iterator(dummy))

@patch("torchtnt.framework.utils.record_function")
def test_get_timing_context(self, mock_record_function: MagicMock) -> None:
state = MagicMock()
Expand All @@ -62,44 +31,3 @@ def test_get_timing_context(self, mock_record_function: MagicMock) -> None:
time.sleep(1)
self.assertTrue("b" in state.timer.recorded_durations.keys())
mock_record_function.assert_called_with("b")

def test_find_optimizers_for_module(self) -> None:
module1 = torch.nn.Linear(10, 10)
module2 = torch.nn.Linear(10, 10)
optim1 = torch.optim.Adam(module1.parameters())
optim2 = torch.optim.Adagrad(module2.parameters())

opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2}
optimizers = _find_optimizers_for_module(module1, opts)
optim_name, _ = optimizers[0]
self.assertEqual(optim_name, "optim1")
optimizers = _find_optimizers_for_module(module2, opts)
optim_name, _ = optimizers[0]
self.assertEqual(optim_name, "optim2")

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_find_optimizers_for_FSDP_module(self) -> None:
spawn_multi_process(2, "nccl", self._find_optimizers_for_FSDP_module)

@staticmethod
def _find_optimizers_for_FSDP_module() -> None:
device = init_from_env()
module1 = FSDP(torch.nn.Linear(10, 10).to(device))
module2 = torch.nn.Linear(10, 10)
optim1 = torch.optim.Adam(module1.parameters())
optim2 = torch.optim.Adagrad(module2.parameters())

opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2}
optim_list = _find_optimizers_for_module(module1, opts)
optim_name, _ = optim_list[0]

tc = unittest.TestCase()
tc.assertEqual(optim_name, "optim1")
optim_list = _find_optimizers_for_module(module2, opts)
optim_name, _ = optim_list[0]
tc.assertEqual(optim_name, "optim2")
52 changes: 52 additions & 0 deletions torchtnt/framework/_unit_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import collections
import inspect
import logging
from typing import Callable, Dict, List, Tuple, TypeVar

import torch
import typing_extensions
from torchtnt.framework.state import State

_logger: logging.Logger = logging.getLogger(__name__)
T = TypeVar("T")


def _step_requires_iterator(step_func: Callable[[State, T], object]) -> bool:
"""
Helper function to evaluate whether the get_next_X_batch method should pass the data iterator to the `X_step`
functions, or whether get_next_X_batch should call `next(data_iter)` and pass a single batch to the step method.
This is closely tied to the Unit's corresponding step function signature.
"""
argspec = inspect.getfullargspec(step_func)
annotations = argspec.annotations
if "data" not in annotations:
_logger.warning(
f"Expected step function to have an annotated argument named ``data``. Found {annotations}."
)
return False
annotated_type = annotations["data"]
return typing_extensions.get_origin(annotated_type) is collections.abc.Iterator


def _find_optimizers_for_module(
module: torch.nn.Module, optimizers: Dict[str, torch.optim.Optimizer]
) -> List[Tuple[str, torch.optim.Optimizer]]:
"""
Given a module, returns a list of optimizers that are associated with it.
"""
optimizer_list = []
module_params = [param.data_ptr() for param in module.parameters()]
for optim_name, optimizer in optimizers.items():
optimizer_params = [
param.data_ptr() for param in optimizer.param_groups[0]["params"]
]
if all(module_param in optimizer_params for module_param in module_params):
optimizer_list.append((optim_name, optimizer))
return optimizer_list
2 changes: 1 addition & 1 deletion torchtnt/framework/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

import torch
from torchtnt.framework._loop_utils import _step_requires_iterator
from torchtnt.framework._unit_utils import _find_optimizers_for_module

from torchtnt.framework.state import State
from torchtnt.framework.utils import _find_optimizers_for_module
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import _is_fsdp_module, FSDPOptimizerWrapper
from torchtnt.utils.progress import Progress
Expand Down
41 changes: 1 addition & 40 deletions torchtnt/framework/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import collections
import inspect
import logging
from contextlib import contextmanager, nullcontext
from typing import Callable, ContextManager, Dict, Generator, List, Tuple, TypeVar
from typing import ContextManager, Generator, Tuple, TypeVar

import torch
import typing_extensions
from torch.profiler import record_function
from torchtnt.framework.state import State

Expand All @@ -36,38 +32,3 @@ def get_timing_context(
profiler_context = record_function(event_name)
with timer_context, profiler_context:
yield (timer_context, profiler_context)


def _step_requires_iterator(step_func: Callable[[State, T], object]) -> bool:
"""
Helper function to evaluate whether the loops should pass the data iterator to the `_step`
functions, or whether the loop should call `next(data_iter)` and pass a single batch to process.
This is closely tied to the Unit's corresponding step function signature.
"""
argspec = inspect.getfullargspec(step_func)
annotations = argspec.annotations
if "data" not in annotations:
_logger.warning(
f"Expected step function to have an annotated argument named ``data``. Found {annotations}."
)
return False
annotated_type = annotations["data"]
return typing_extensions.get_origin(annotated_type) is collections.abc.Iterator


def _find_optimizers_for_module(
module: torch.nn.Module, optimizers: Dict[str, torch.optim.Optimizer]
) -> List[Tuple[str, torch.optim.Optimizer]]:
"""
Given a module, returns a list of optimizers that are associated with it.
"""
optimizer_list = []
module_params = [param.data_ptr() for param in module.parameters()]
for optim_name, optimizer in optimizers.items():
optimizer_params = [
param.data_ptr() for param in optimizer.param_groups[0]["params"]
]
if all(module_param in optimizer_params for module_param in module_params):
optimizer_list.append((optim_name, optimizer))
return optimizer_list

0 comments on commit affb136

Please sign in to comment.