Skip to content

Commit

Permalink
pyre annotations - remove from framework utils (#607)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #607

 Removing pyre-fixme annotations from utils.py and test_utils.py

Reviewed By: ananthsub

Differential Revision: D50900231

fbshipit-source-id: 88ef63a14090d14d4b23dd7214cb885e16744d8e
  • Loading branch information
galrotem authored and facebook-github-bot committed Nov 2, 2023
1 parent f81d53b commit 8fdaa10
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 61 deletions.
75 changes: 28 additions & 47 deletions tests/framework/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@

import time
import unittest
from typing import Iterator
from typing import cast, Dict, Iterator
from unittest.mock import MagicMock, patch

import torch
from torch import nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataset
from torchtnt.framework.state import State
from torchtnt.framework.utils import (
_construct_tracked_optimizers_and_schedulers,
_find_optimizers_for_module,
Expand All @@ -37,8 +39,8 @@


class UtilsTest(unittest.TestCase):
# pyre-fixme[4]: Attribute must be annotated.
cuda_available = torch.cuda.is_available()
cuda_available: bool = torch.cuda.is_available()
distributed_available: bool = torch.distributed.is_available()

def test_maybe_set_distributed_sampler_epoch(self) -> None:
mp_dict = spawn_multi_process(
Expand All @@ -54,10 +56,6 @@ def _test_maybe_set_distributed_sampler_epoch() -> bool:
"""
Test _maybe_set_distributed_sampler_epoch util function
"""
# pyre-fixme[6]: For 1st argument expected `Iterable[typing.Any]` but got
# `None`.
_maybe_set_distributed_sampler_epoch(None, 10)

random_dataset = generate_random_dataset(10, 3)
dummy_dataloader_with_distributed_sampler = DataLoader(
random_dataset, sampler=DistributedSampler(random_dataset)
Expand All @@ -66,9 +64,12 @@ def _test_maybe_set_distributed_sampler_epoch() -> bool:
_maybe_set_distributed_sampler_epoch(
dummy_dataloader_with_distributed_sampler, 20
)
# pyre-fixme[16]: Item `Sampler` of `Union[Sampler[typing.Any],
# Iterable[typing.Any]]` has no attribute `epoch`.
return dummy_dataloader_with_distributed_sampler.sampler.epoch == 20

sampler = cast(
DistributedSampler[object],
dummy_dataloader_with_distributed_sampler.sampler,
)
return sampler.epoch == 20

def test_set_module_training_mode(self) -> None:
"""
Expand All @@ -77,11 +78,12 @@ def test_set_module_training_mode(self) -> None:
module = nn.Linear(1, 1)
loss_fn = nn.CrossEntropyLoss()

tracked_modules = {"module": module, "loss_fn": loss_fn}
tracked_modules: Dict[str, torch.nn.Module] = {
"module": module,
"loss_fn": loss_fn,
}

# set module training mode to False
# pyre-fixme[6]: For 1st argument expected `Dict[str, Module]` but got
# `Dict[str, Union[Linear, CrossEntropyLoss]]`.
prior_module_train_states = _set_module_training_mode(tracked_modules, False)

self.assertFalse(module.training)
Expand All @@ -91,8 +93,6 @@ def test_set_module_training_mode(self) -> None:
self.assertTrue(prior_module_train_states["loss_fn"])

# set back to True
# pyre-fixme[6]: For 1st argument expected `Dict[str, Module]` but got
# `Dict[str, Union[Linear, CrossEntropyLoss]]`.
prior_module_train_states = _set_module_training_mode(tracked_modules, True)

self.assertTrue(module.training)
Expand All @@ -108,43 +108,37 @@ def test_reset_module_training_mode(self) -> None:
module = nn.Linear(1, 1)
loss_fn = nn.CrossEntropyLoss()

tracked_modules = {"module": module, "loss_fn": loss_fn}
tracked_modules: Dict[str, torch.nn.Module] = {
"module": module,
"loss_fn": loss_fn,
}

# set module training mode to False
# pyre-fixme[6]: For 1st argument expected `Dict[str, Module]` but got
# `Dict[str, Union[Linear, CrossEntropyLoss]]`.
prior_module_train_states = _set_module_training_mode(tracked_modules, False)

self.assertFalse(module.training)
self.assertFalse(loss_fn.training)

# set back to True using reset
# pyre-fixme[6]: For 1st argument expected `Dict[str, Module]` but got
# `Dict[str, Union[Linear, CrossEntropyLoss]]`.
_reset_module_training_mode(tracked_modules, prior_module_train_states)

self.assertTrue(module.training)
self.assertTrue(loss_fn.training)

def test_step_func_requires_iterator(self) -> None:
class Foo:
def bar(self) -> None:
pass
def bar(self, state: State, data: object) -> object:
return data

def baz(self, data: Iterator[int], b: int, c: str) -> int:
return b
def baz(self, state: State, data: Iterator[torch.Tensor]) -> object:
pass

def dummy(a: int, b: str, data: Iterator[str]) -> None:
pass

foo = Foo()

# pyre-fixme[6]: For 1st argument expected `(State, object) -> object` but
# got `BoundMethod[typing.Callable(Foo.bar)[[Named(self, Foo)], None], Foo]`.
self.assertFalse(_step_requires_iterator(foo.bar))
# pyre-fixme[6]: For 1st argument expected `(State, object) -> object` but
# got `BoundMethod[typing.Callable(Foo.baz)[[Named(self, Foo), Named(data,
# Iterator[int]), Named(b, int), Named(c, str)], int], Foo]`.
self.assertTrue(_step_requires_iterator(foo.baz))
self.assertTrue(_step_requires_iterator(dummy))

Expand Down Expand Up @@ -183,8 +177,7 @@ def test_is_epoch_done(self) -> None:
self.assertFalse(_is_epoch_done(p, max_steps_per_epoch=None, max_steps=None))

@patch("torchtnt.framework.utils.record_function")
# pyre-fixme[2]: Parameter must be annotated.
def test_get_timing_context(self, mock_record_function) -> None:
def test_get_timing_context(self, mock_record_function: MagicMock) -> None:
state = MagicMock()
state.timer = None

Expand All @@ -206,22 +199,16 @@ def test_find_optimizers_for_module(self) -> None:
optim1 = torch.optim.Adam(module1.parameters())
optim2 = torch.optim.Adagrad(module2.parameters())

opts = {"optim1": optim1, "optim2": optim2}
# pyre-fixme[6]: For 2nd argument expected `Dict[str, Optimizer]` but got
# `Dict[str, Union[Adagrad, Adam]]`.
opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2}
optimizers = _find_optimizers_for_module(module1, opts)
optim_name, _ = optimizers[0]
self.assertEqual(optim_name, "optim1")
# pyre-fixme[6]: For 2nd argument expected `Dict[str, Optimizer]` but got
# `Dict[str, Union[Adagrad, Adam]]`.
optimizers = _find_optimizers_for_module(module2, opts)
optim_name, _ = optimizers[0]
self.assertEqual(optim_name, "optim2")

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`.
@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
condition=distributed_available, reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
Expand All @@ -237,24 +224,18 @@ def _find_optimizers_for_FSDP_module() -> None:
optim1 = torch.optim.Adam(module1.parameters())
optim2 = torch.optim.Adagrad(module2.parameters())

opts = {"optim1": optim1, "optim2": optim2}
# pyre-fixme[6]: For 2nd argument expected `Dict[str, Optimizer]` but got
# `Dict[str, Union[Adagrad, Adam]]`.
opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2}
optim_list = _find_optimizers_for_module(module1, opts)
optim_name, _ = optim_list[0]

tc = unittest.TestCase()
tc.assertEqual(optim_name, "optim1")
# pyre-fixme[6]: For 2nd argument expected `Dict[str, Optimizer]` but got
# `Dict[str, Union[Adagrad, Adam]]`.
optim_list = _find_optimizers_for_module(module2, opts)
optim_name, _ = optim_list[0]
tc.assertEqual(optim_name, "optim2")

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`.
@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
condition=distributed_available, reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
Expand Down
38 changes: 24 additions & 14 deletions torchtnt/framework/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,21 @@
# LICENSE file in the root directory of this source tree.

import collections
import contextlib
import inspect
import logging
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from contextlib import contextmanager, nullcontext
from typing import (
Callable,
ContextManager,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Union,
)

import torch
import torch.nn as nn
Expand All @@ -23,6 +33,7 @@
from torchtnt.utils.progress import Progress

_logger: logging.Logger = logging.getLogger(__name__)
T = TypeVar("T")


# Helper functions common across the loops
Expand All @@ -44,8 +55,7 @@ def _is_epoch_done(


def _maybe_set_distributed_sampler_epoch(
# pyre-ignore: Missing parameter annotation [2]
dataloader: Iterable[Any],
dataloader: Iterable[object],
current_epoch: int,
) -> None:
"""Set epoch of distributed sampler in dataloader, if applicable.
Expand Down Expand Up @@ -82,8 +92,9 @@ def _reset_module_training_mode(


@contextmanager
# pyre-fixme[3]: Return type must be annotated.
def get_timing_context(state: State, event_name: str):
def get_timing_context(
state: State, event_name: str
) -> Generator[Tuple[ContextManager, ContextManager], None, None]:
"""
Returns a context manager that records an event to a :class:`~torchtnt.utils.timer.Timer` and to PyTorch Profiler.
Expand All @@ -92,9 +103,7 @@ def get_timing_context(state: State, event_name: str):
event_name: string identifier to use for timing
"""
timer_context = (
state.timer.time(event_name)
if state.timer is not None
else contextlib.nullcontext()
state.timer.time(event_name) if state.timer is not None else nullcontext()
)
profiler_context = record_function(event_name)
with timer_context, profiler_context:
Expand All @@ -105,7 +114,7 @@ def log_api_usage(entry_point: str) -> None:
torch._C._log_api_usage_once(f"torchtnt.framework.{entry_point}")


def _step_requires_iterator(step_func: Callable[[State, object], object]) -> bool:
def _step_requires_iterator(step_func: Callable[[State, T], object]) -> bool:
"""
Helper function to evaluate whether the loops should pass the data iterator to the `_step`
functions, or whether the loop should call `next(data_iter)` and pass a single batch to process.
Expand Down Expand Up @@ -160,18 +169,19 @@ def _construct_tracked_optimizers_and_schedulers(
Combines tracked optimizers and schedulers. Handles optimizers working on FSDP modules, wrapping them in FSDPOptimizerWrapper.
"""
# construct custom tracked optimizers with FSDP optimizers
tracked_optimizers_and_schedulers = _construct_tracked_optimizers(unit)
tracked_optimizers_and_schedulers: Dict[
str, Union[torch.optim.Optimizer, FSDPOptimizerWrapper, TLRScheduler]
] = {}
tracked_optimizers_and_schedulers.update(_construct_tracked_optimizers(unit))

# add schedulers
for lr_scheduler_attrib_name, lr_scheduler in unit.tracked_lr_schedulers().items():
if lr_scheduler_attrib_name in tracked_optimizers_and_schedulers:
_logger.warning(
f'Key collision "{lr_scheduler_attrib_name}" detected between LR Scheduler and optimizer attribute names. Please ensure there are no identical attribute names, as they will override each other.'
)
# pyre-ignore: Incompatible parameter type [6]: In call `dict.__setitem__`, for 2nd positional argument, expected `Optimizer` but got `str`.
tracked_optimizers_and_schedulers[lr_scheduler_attrib_name] = lr_scheduler

# pyre-ignore: Incompatible return type [7]
return tracked_optimizers_and_schedulers


Expand Down

0 comments on commit 8fdaa10

Please sign in to comment.