From 67bcc827b11a643c7f44657228dcabf09f98d430 Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Thu, 25 Jan 2024 02:39:58 -0800 Subject: [PATCH] replace all pyre-ignore with pyre-fixme (#689) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/689 Consolidate all pyre annotations to pyre-fixme Reviewed By: anshulverma, JKSenthil Differential Revision: D53064137 fbshipit-source-id: 9b531899725807dcbc3345415837cb052ecd5c44 --- .../callbacks/test_base_checkpointer.py | 2 +- tests/utils/test_swa.py | 2 +- .../framework/callbacks/base_csv_writer.py | 2 +- .../framework/callbacks/module_summary.py | 2 +- torchtnt/framework/train.py | 4 +- torchtnt/utils/data/iterators.py | 2 +- torchtnt/utils/data/profile_dataloader.py | 2 +- torchtnt/utils/device.py | 4 +- torchtnt/utils/flops.py | 40 +++++++++---------- torchtnt/utils/memory.py | 4 +- torchtnt/utils/module_summary.py | 24 +++++------ torchtnt/utils/prepare_module.py | 2 +- torchtnt/utils/swa.py | 10 ++--- torchtnt/utils/timer.py | 2 +- 14 files changed, 51 insertions(+), 51 deletions(-) diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index ea746ab7da..616495821f 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -664,7 +664,7 @@ def test_best_checkpoint_no_top_k(self) -> None: save_every_n_epochs=1, best_checkpoint_config=BestCheckpointConfig( monitored_metric="train_loss", - # pyre-ignore: Incompatible parameter type [6] + # pyre-fixme: Incompatible parameter type [6] mode=mode, ), ) diff --git a/tests/utils/test_swa.py b/tests/utils/test_swa.py index bae4875883..f3b6c4434a 100644 --- a/tests/utils/test_swa.py +++ b/tests/utils/test_swa.py @@ -191,7 +191,7 @@ def test_input_checks(self) -> None: with self.assertRaisesRegex( ValueError, "Unknown averaging method: foo. Only ema and swa are supported." ): - # pyre-ignore On purpose to test run time exception + # pyre-fixme On purpose to test run time exception AveragedModel(model, averaging_method="foo") def test_lit_ema(self) -> None: diff --git a/torchtnt/framework/callbacks/base_csv_writer.py b/torchtnt/framework/callbacks/base_csv_writer.py index 08116fd44c..610183b2ce 100644 --- a/torchtnt/framework/callbacks/base_csv_writer.py +++ b/torchtnt/framework/callbacks/base_csv_writer.py @@ -59,7 +59,7 @@ def get_step_output_rows( self, state: State, unit: TPredictUnit, - # pyre-ignore: Missing parameter annotation [2] + # pyre-fixme: Missing parameter annotation [2] step_output: Any, ) -> Union[List[str], List[List[str]]]: ... diff --git a/torchtnt/framework/callbacks/module_summary.py b/torchtnt/framework/callbacks/module_summary.py index 5c3d9aea0f..1d6374382b 100644 --- a/torchtnt/framework/callbacks/module_summary.py +++ b/torchtnt/framework/callbacks/module_summary.py @@ -43,7 +43,7 @@ def __init__( process_fn: Callable[ [List[ModuleSummaryObj]], None ] = _log_module_summary_tables, - # pyre-ignore + # pyre-fixme module_inputs: Optional[ MutableMapping[str, Tuple[Tuple[Any, ...], Dict[str, Any]]] ] = None, diff --git a/torchtnt/framework/train.py b/torchtnt/framework/train.py index 1fb186873d..8846e76916 100644 --- a/torchtnt/framework/train.py +++ b/torchtnt/framework/train.py @@ -223,7 +223,7 @@ def _train_epoch_impl( ): _evaluate_impl( state, - # pyre-ignore: Incompatible parameter type [6] + # pyre-fixme: Incompatible parameter type [6] train_unit, callback_handler, ) @@ -257,7 +257,7 @@ def _train_epoch_impl( ): _evaluate_impl( state, - # pyre-ignore: Incompatible parameter type [6] + # pyre-fixme: Incompatible parameter type [6] train_unit, callback_handler, ) diff --git a/torchtnt/utils/data/iterators.py b/torchtnt/utils/data/iterators.py index fcd17a8e09..a172871b9c 100644 --- a/torchtnt/utils/data/iterators.py +++ b/torchtnt/utils/data/iterators.py @@ -343,7 +343,7 @@ def __init__( name: torch.IntTensor([idx]) for idx, name in enumerate(self._iterator_names) } - # pyre-ignore[4]: missing attribute annotation + # pyre-fixme[4]: missing attribute annotation self._process_group = dist.new_group(backend="gloo", ranks=None) self._iterators_finished: List[str] = [] diff --git a/torchtnt/utils/data/profile_dataloader.py b/torchtnt/utils/data/profile_dataloader.py index 8caffeed83..e0c67d08e4 100644 --- a/torchtnt/utils/data/profile_dataloader.py +++ b/torchtnt/utils/data/profile_dataloader.py @@ -52,7 +52,7 @@ def profile_dataloader( with timer.time("copy_data_to_device"), record_function( "copy_data_to_device" ): - # pyre-ignore [6]: device is checked as not None before calling this + # pyre-fixme [6]: device is checked as not None before calling this data = copy_data_to_device(data, device) steps_completed += 1 diff --git a/torchtnt/utils/device.py b/torchtnt/utils/device.py index 04c567a25a..072fe4e2c3 100644 --- a/torchtnt/utils/device.py +++ b/torchtnt/utils/device.py @@ -313,14 +313,14 @@ def collect_system_stats(device: torch.device) -> Dict[str, Any]: system_stats: Dict[str, Any] = {} cpu_stats = get_psutil_cpu_stats() - # pyre-ignore + # pyre-fixme system_stats.update(cpu_stats) if torch.cuda.is_available(): try: gpu_stats = get_nvidia_smi_gpu_stats(device) - # pyre-ignore + # pyre-fixme system_stats.update(gpu_stats) system_stats.update(torch.cuda.memory_stats()) except FileNotFoundError: diff --git a/torchtnt/utils/flops.py b/torchtnt/utils/flops.py index 97df8a5a22..3b02f0e738 100644 --- a/torchtnt/utils/flops.py +++ b/torchtnt/utils/flops.py @@ -21,7 +21,7 @@ aten: torch._ops._OpNamespace = torch.ops.aten -# pyre-ignore [2] we don't care the type in outputs +# pyre-fixme [2] we don't care the type in outputs def _matmul_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number: """ Count flops for matmul. @@ -35,7 +35,7 @@ def _matmul_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number return flop -# pyre-ignore [2] we don't care the type in outputs +# pyre-fixme [2] we don't care the type in outputs def _addmm_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number: """ Count flops for fully connected layers. @@ -53,7 +53,7 @@ def _addmm_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number: return flops -# pyre-ignore [2] we don't care the type in outputs +# pyre-fixme [2] we don't care the type in outputs def _bmm_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number: """ Count flops for the bmm operation. @@ -98,7 +98,7 @@ def _conv_flop_count( def _conv_flop_jit( - inputs: Tuple[Any], # pyre-ignore [2] the inputs can be union of Tensor/bool/Tuple + inputs: Tuple[Any], # pyre-fixme [2] the inputs can be union of Tensor/bool/Tuple outputs: Tuple[torch.Tensor], ) -> Number: """ @@ -118,7 +118,7 @@ def _transpose_shape(shape: torch.Size) -> List[int]: return [shape[1], shape[0]] + list(shape[2:]) -# pyre-ignore [2] the inputs can be union of Tensor/bool/Tuple & we don't care about outputs +# pyre-fixme [2] the inputs can be union of Tensor/bool/Tuple & we don't care about outputs def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number: grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]] output_mask = inputs[-1] @@ -127,7 +127,7 @@ def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number: if output_mask[0]: grad_input_shape = outputs[0].shape - # pyre-ignore [58] this is actually sum of Number and Number + # pyre-fixme [58] this is actually sum of Number and Number flop_count = flop_count + _conv_flop_count( grad_out_shape, w_shape, grad_input_shape, not fwd_transposed ) @@ -143,7 +143,7 @@ def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number: return flop_count -# pyre-ignore [5] +# pyre-fixme [5] flop_mapping: Dict[Callable[..., Any], Callable[[Tuple[Any], Tuple[Any]], Number]] = { aten.mm: _matmul_flop_jit, aten.matmul: _matmul_flop_jit, @@ -163,7 +163,7 @@ def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number: } -# pyre-ignore [2, 3] it can be Tuple of anything. +# pyre-fixme [2, 3] it can be Tuple of anything. def _normalize_tuple(x: Any) -> Tuple[Any]: if not isinstance(x, tuple): return (x,) @@ -213,7 +213,7 @@ def __init__(self, module: torch.nn.Module) -> None: ) self._parents: List[str] = [""] - # pyre-ignore + # pyre-fixme def __exit__(self, exc_type, exc_val, exc_tb): for hook_handle in self._all_hooks: hook_handle.remove() @@ -221,10 +221,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): def __torch_dispatch__( self, - func: Callable[..., Any], # pyre-ignore [2] func can be any func - types: Tuple[Any], # pyre-ignore [2] - args=(), # pyre-ignore [2] - kwargs=None, # pyre-ignore [2] + func: Callable[..., Any], # pyre-fixme [2] func can be any func + types: Tuple[Any], # pyre-fixme [2] + args=(), # pyre-fixme [2] + kwargs=None, # pyre-fixme [2] ) -> PyTree: rs = func(*args, **kwargs) outs = _normalize_tuple(rs) @@ -232,14 +232,14 @@ def __torch_dispatch__( if func in flop_mapping: flop_count = flop_mapping[func](args, outs) for par in self._parents: - # pyre-ignore [58] + # pyre-fixme [58] self.flop_counts[par][func.__name__] += flop_count else: logging.debug(f"{func} is not yet supported in FLOPs calculation.") return rs - # pyre-ignore [3] + # pyre-fixme [3] def _create_backwards_push(self, name: str) -> Callable[..., Any]: class PushState(torch.autograd.Function): @staticmethod @@ -262,7 +262,7 @@ def backward(ctx, *grad_outs): # using a function parameter. return PushState.apply - # pyre-ignore [3] + # pyre-fixme [3] def _create_backwards_pop(self, name: str) -> Callable[..., Any]: class PopState(torch.autograd.Function): @staticmethod @@ -286,9 +286,9 @@ def backward(ctx, *grad_outs): # using a function parameter. return PopState.apply - # pyre-ignore [3] Return a callable function + # pyre-fixme [3] Return a callable function def _enter_module(self, name: str) -> Callable[..., Any]: - # pyre-ignore [2, 3] + # pyre-fixme [2, 3] def f(module: torch.nn.Module, inputs: Tuple[Any]): parents = self._parents parents.append(name) @@ -298,9 +298,9 @@ def f(module: torch.nn.Module, inputs: Tuple[Any]): return f - # pyre-ignore [3] Return a callable function + # pyre-fixme [3] Return a callable function def _exit_module(self, name: str) -> Callable[..., Any]: - # pyre-ignore [2, 3] + # pyre-fixme [2, 3] def f(module: torch.nn.Module, inputs: Tuple[Any], outputs: Tuple[Any]): parents = self._parents assert parents[-1] == name diff --git a/torchtnt/utils/memory.py b/torchtnt/utils/memory.py index a2119e94fe..ff85ae2bfa 100644 --- a/torchtnt/utils/memory.py +++ b/torchtnt/utils/memory.py @@ -20,14 +20,14 @@ def _is_named_tuple( - # pyre-ignore: Missing parameter annotation [2]: Parameter `x` must have a type other than `Any`. + # pyre-fixme: Missing parameter annotation [2]: Parameter `x` must have a type other than `Any`. x: Any, ) -> bool: return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields") def get_tensor_size_bytes_map( - # pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`. + # pyre-fixme: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`. obj: Any, ) -> Dict[torch.Tensor, int]: tensor_map = {} diff --git a/torchtnt/utils/module_summary.py b/torchtnt/utils/module_summary.py index fd5f4d34c8..0cad76595d 100644 --- a/torchtnt/utils/module_summary.py +++ b/torchtnt/utils/module_summary.py @@ -209,9 +209,9 @@ def _clean_flops(flop: DefaultDict[str, DefaultDict[str, int]], N: int) -> None: def _get_module_flops_and_activation_sizes( module: torch.nn.Module, - # pyre-ignore + # pyre-fixme module_args: Optional[Tuple[Any, ...]] = None, - # pyre-ignore + # pyre-fixme module_kwargs: Optional[MutableMapping[str, Any]] = None, ) -> _ModuleSummaryData: # a mapping from module name to activation size tuple (in_size, out_size) @@ -309,9 +309,9 @@ def _has_tensor(item: Optional[PyTree]) -> bool: def get_module_summary( module: torch.nn.Module, - # pyre-ignore + # pyre-fixme module_args: Optional[Tuple[Any, ...]] = None, - # pyre-ignore + # pyre-fixme module_kwargs: Optional[MutableMapping[str, Any]] = None, ) -> ModuleSummary: """ @@ -669,13 +669,13 @@ def _activation_size_hook( activation_sizes: Dict[ str, Tuple[Union[TUnknown, List[int]], Union[TUnknown, List[int]]] ], - # pyre-ignore: Invalid type parameters [24] + # pyre-fixme: Invalid type parameters [24] ) -> Callable[[str], Callable]: - # pyre-ignore: Missing parameter annotation [2] + # pyre-fixme: Missing parameter annotation [2] def intermediate_hook( module_name: str, ) -> Callable[[torch.nn.Module, Any, Any], None]: - # pyre-ignore + # pyre-fixme def hook(_: torch.nn.Module, inp: Any, out: Any) -> None: if len(inp) == 1: inp = inp[0] @@ -690,9 +690,9 @@ def hook(_: torch.nn.Module, inp: Any, out: Any) -> None: def _forward_time_pre_hook( timer_mapping: Dict[str, float] - # pyre-ignore: Invalid type parameters [24] + # pyre-fixme: Invalid type parameters [24] ) -> Callable[[str], Callable]: - # pyre-ignore: Missing parameter annotation [2] + # pyre-fixme: Missing parameter annotation [2] def intermediate_hook( module_name: str, ) -> Callable[[torch.nn.Module, Any], None]: @@ -707,9 +707,9 @@ def hook(_module: torch.nn.Module, _inp: Any) -> None: def _forward_time_hook( timer_mapping: Dict[str, float], elapsed_times: Dict[str, float], - # pyre-ignore: Invalid type parameters [24] + # pyre-fixme: Invalid type parameters [24] ) -> Callable[[str], Callable]: - # pyre-ignore: Missing parameter annotation [2] + # pyre-fixme: Missing parameter annotation [2] def intermediate_hook( module_name: str, ) -> Callable[[torch.nn.Module, Any, Any], None]: @@ -725,7 +725,7 @@ def hook(_module: torch.nn.Module, _inp: Any, _out: Any) -> None: def _register_hooks( module: torch.nn.Module, - # pyre-ignore: Invalid type parameters [24] + # pyre-fixme: Invalid type parameters [24] hooks: List[Tuple[Callable, _HookType]], ) -> List[RemovableHandle]: """ diff --git a/torchtnt/utils/prepare_module.py b/torchtnt/utils/prepare_module.py index 299450b683..58caffedf4 100644 --- a/torchtnt/utils/prepare_module.py +++ b/torchtnt/utils/prepare_module.py @@ -112,7 +112,7 @@ class TorchCompileParams: fullgraph: bool = False dynamic: bool = False - # pyre-ignore: Invalid type parameters [24] + # pyre-fixme: Invalid type parameters [24] backend: Union[str, Callable] = "inductor" mode: Union[str, None] = None options: Optional[Dict[str, Union[str, int, bool]]] = None diff --git a/torchtnt/utils/swa.py b/torchtnt/utils/swa.py index 857cdd08da..774fc461e6 100644 --- a/torchtnt/utils/swa.py +++ b/torchtnt/utils/swa.py @@ -9,7 +9,7 @@ import torch # TODO: torch/optim/swa_utils.pyi needs to be updated -# pyre-ignore Undefined import [21]: Could not find a name `get_ema_multi_avg_fn` defined in module `torch.optim.swa_utils`. +# pyre-fixme Undefined import [21]: Could not find a name `get_ema_multi_avg_fn` defined in module `torch.optim.swa_utils`. from torch.optim.swa_utils import ( AveragedModel as PyTorchAveragedModel, get_ema_multi_avg_fn, @@ -55,11 +55,11 @@ def __init__( raise ValueError(f"Decay must be between 0 and 1, got {ema_decay}") # TODO: torch/optim/swa_utils.pyi needs to be updated - # pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`. + # pyre-fixme Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`. multi_avg_fn = get_ema_multi_avg_fn(ema_decay) elif averaging_method == "swa": # TODO: torch/optim/swa_utils.pyi needs to be updated - # pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_swa_multi_avg_fn`. + # pyre-fixme Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_swa_multi_avg_fn`. multi_avg_fn = get_swa_multi_avg_fn() if use_lit: @@ -88,7 +88,7 @@ def __init__( # use default init implementation # TODO: torch/optim/swa_utils.pyi needs to be updated - # pyre-ignore Unexpected keyword [28] + # pyre-fixme Unexpected keyword [28] super().__init__( model, device=device, @@ -104,6 +104,6 @@ def update_parameters(self, model: torch.nn.Module) -> None: ) # TODO: torch/optim/swa_utils.pyi needs to be updated - # pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`. + # pyre-fixme Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`. self.multi_avg_fn = get_ema_multi_avg_fn(decay) super().update_parameters(model) diff --git a/torchtnt/utils/timer.py b/torchtnt/utils/timer.py index f84e994cb4..24b8c0fbcf 100644 --- a/torchtnt/utils/timer.py +++ b/torchtnt/utils/timer.py @@ -364,7 +364,7 @@ def _sync_durations( pg_wrapper.all_gather_object(outputs, recorded_durations) ret = defaultdict(list) for output in outputs: - # pyre-ignore [16]: `Optional` has no attribute `__getitem__`. + # pyre-fixme [16]: `Optional` has no attribute `__getitem__`. for k, v in output.items(): if k not in ret: ret[k] = []