-
Notifications
You must be signed in to change notification settings - Fork 274
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #638 # Context Having `torchtnt/framework/utils` is a bit misleading since one might think these are external utils, while at the moment it's more of a combination of internal and external (get_timing_context), so for the sake of discoverability, it makes sense to split `framework/utils.py` # This diff Extract loop related utils from `framework/utils.py` to `framework/_unit_utils.py.py` Reviewed By: JKSenthil Differential Revision: D51593728 fbshipit-source-id: 389ebec0fbaa133c1fa04345b5902f823f289679
- Loading branch information
1 parent
ba46e08
commit affb136
Showing
5 changed files
with
139 additions
and
114 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
from typing import Dict, Iterator | ||
|
||
import torch | ||
|
||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
from torch.optim import Optimizer | ||
from torchtnt.framework._unit_utils import ( | ||
_find_optimizers_for_module, | ||
_step_requires_iterator, | ||
) | ||
from torchtnt.framework.state import State | ||
from torchtnt.utils.env import init_from_env | ||
from torchtnt.utils.test_utils import spawn_multi_process | ||
|
||
|
||
class UnitUtilsTest(unittest.TestCase): | ||
cuda_available: bool = torch.cuda.is_available() | ||
distributed_available: bool = torch.distributed.is_available() | ||
|
||
def test_step_func_requires_iterator(self) -> None: | ||
class Foo: | ||
def bar(self, state: State, data: object) -> object: | ||
return data | ||
|
||
def baz(self, state: State, data: Iterator[torch.Tensor]) -> object: | ||
pass | ||
|
||
def dummy(a: int, b: str, data: Iterator[str]) -> None: | ||
pass | ||
|
||
foo = Foo() | ||
|
||
self.assertFalse(_step_requires_iterator(foo.bar)) | ||
self.assertTrue(_step_requires_iterator(foo.baz)) | ||
self.assertTrue(_step_requires_iterator(dummy)) | ||
|
||
def test_find_optimizers_for_module(self) -> None: | ||
module1 = torch.nn.Linear(10, 10) | ||
module2 = torch.nn.Linear(10, 10) | ||
optim1 = torch.optim.Adam(module1.parameters()) | ||
optim2 = torch.optim.Adagrad(module2.parameters()) | ||
|
||
opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2} | ||
optimizers = _find_optimizers_for_module(module1, opts) | ||
optim_name, _ = optimizers[0] | ||
self.assertEqual(optim_name, "optim1") | ||
optimizers = _find_optimizers_for_module(module2, opts) | ||
optim_name, _ = optimizers[0] | ||
self.assertEqual(optim_name, "optim2") | ||
|
||
@unittest.skipUnless( | ||
condition=distributed_available, reason="Torch distributed is needed to run" | ||
) | ||
@unittest.skipUnless( | ||
condition=cuda_available, reason="This test needs a GPU host to run." | ||
) | ||
def test_find_optimizers_for_FSDP_module(self) -> None: | ||
spawn_multi_process(2, "nccl", self._find_optimizers_for_FSDP_module) | ||
|
||
@staticmethod | ||
def _find_optimizers_for_FSDP_module() -> None: | ||
device = init_from_env() | ||
module1 = FSDP(torch.nn.Linear(10, 10).to(device)) | ||
module2 = torch.nn.Linear(10, 10) | ||
optim1 = torch.optim.Adam(module1.parameters()) | ||
optim2 = torch.optim.Adagrad(module2.parameters()) | ||
|
||
opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2} | ||
optim_list = _find_optimizers_for_module(module1, opts) | ||
optim_name, _ = optim_list[0] | ||
|
||
tc = unittest.TestCase() | ||
tc.assertEqual(optim_name, "optim1") | ||
optim_list = _find_optimizers_for_module(module2, opts) | ||
optim_name, _ = optim_list[0] | ||
tc.assertEqual(optim_name, "optim2") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import collections | ||
import inspect | ||
import logging | ||
from typing import Callable, Dict, List, Tuple, TypeVar | ||
|
||
import torch | ||
import typing_extensions | ||
from torchtnt.framework.state import State | ||
|
||
_logger: logging.Logger = logging.getLogger(__name__) | ||
T = TypeVar("T") | ||
|
||
|
||
def _step_requires_iterator(step_func: Callable[[State, T], object]) -> bool: | ||
""" | ||
Helper function to evaluate whether the get_next_X_batch method should pass the data iterator to the `X_step` | ||
functions, or whether get_next_X_batch should call `next(data_iter)` and pass a single batch to the step method. | ||
This is closely tied to the Unit's corresponding step function signature. | ||
""" | ||
argspec = inspect.getfullargspec(step_func) | ||
annotations = argspec.annotations | ||
if "data" not in annotations: | ||
_logger.warning( | ||
f"Expected step function to have an annotated argument named ``data``. Found {annotations}." | ||
) | ||
return False | ||
annotated_type = annotations["data"] | ||
return typing_extensions.get_origin(annotated_type) is collections.abc.Iterator | ||
|
||
|
||
def _find_optimizers_for_module( | ||
module: torch.nn.Module, optimizers: Dict[str, torch.optim.Optimizer] | ||
) -> List[Tuple[str, torch.optim.Optimizer]]: | ||
""" | ||
Given a module, returns a list of optimizers that are associated with it. | ||
""" | ||
optimizer_list = [] | ||
module_params = [param.data_ptr() for param in module.parameters()] | ||
for optim_name, optimizer in optimizers.items(): | ||
optimizer_params = [ | ||
param.data_ptr() for param in optimizer.param_groups[0]["params"] | ||
] | ||
if all(module_param in optimizer_params for module_param in module_params): | ||
optimizer_list.append((optim_name, optimizer)) | ||
return optimizer_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters