Skip to content

Commit

Permalink
replace all pyre-ignore with pyre-fixme (#689)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #689

Consolidate all pyre annotations to pyre-fixme

Reviewed By: anshulverma, JKSenthil

Differential Revision: D53064137

fbshipit-source-id: 9b531899725807dcbc3345415837cb052ecd5c44
  • Loading branch information
galrotem authored and facebook-github-bot committed Jan 25, 2024
1 parent 3f2ecbf commit 67bcc82
Show file tree
Hide file tree
Showing 14 changed files with 51 additions and 51 deletions.
2 changes: 1 addition & 1 deletion tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def test_best_checkpoint_no_top_k(self) -> None:
save_every_n_epochs=1,
best_checkpoint_config=BestCheckpointConfig(
monitored_metric="train_loss",
# pyre-ignore: Incompatible parameter type [6]
# pyre-fixme: Incompatible parameter type [6]
mode=mode,
),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_input_checks(self) -> None:
with self.assertRaisesRegex(
ValueError, "Unknown averaging method: foo. Only ema and swa are supported."
):
# pyre-ignore On purpose to test run time exception
# pyre-fixme On purpose to test run time exception
AveragedModel(model, averaging_method="foo")

def test_lit_ema(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/callbacks/base_csv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_step_output_rows(
self,
state: State,
unit: TPredictUnit,
# pyre-ignore: Missing parameter annotation [2]
# pyre-fixme: Missing parameter annotation [2]
step_output: Any,
) -> Union[List[str], List[List[str]]]:
...
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/callbacks/module_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
process_fn: Callable[
[List[ModuleSummaryObj]], None
] = _log_module_summary_tables,
# pyre-ignore
# pyre-fixme
module_inputs: Optional[
MutableMapping[str, Tuple[Tuple[Any, ...], Dict[str, Any]]]
] = None,
Expand Down
4 changes: 2 additions & 2 deletions torchtnt/framework/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _train_epoch_impl(
):
_evaluate_impl(
state,
# pyre-ignore: Incompatible parameter type [6]
# pyre-fixme: Incompatible parameter type [6]
train_unit,
callback_handler,
)
Expand Down Expand Up @@ -257,7 +257,7 @@ def _train_epoch_impl(
):
_evaluate_impl(
state,
# pyre-ignore: Incompatible parameter type [6]
# pyre-fixme: Incompatible parameter type [6]
train_unit,
callback_handler,
)
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/utils/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def __init__(
name: torch.IntTensor([idx])
for idx, name in enumerate(self._iterator_names)
}
# pyre-ignore[4]: missing attribute annotation
# pyre-fixme[4]: missing attribute annotation
self._process_group = dist.new_group(backend="gloo", ranks=None)

self._iterators_finished: List[str] = []
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/utils/data/profile_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def profile_dataloader(
with timer.time("copy_data_to_device"), record_function(
"copy_data_to_device"
):
# pyre-ignore [6]: device is checked as not None before calling this
# pyre-fixme [6]: device is checked as not None before calling this
data = copy_data_to_device(data, device)

steps_completed += 1
Expand Down
4 changes: 2 additions & 2 deletions torchtnt/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,14 @@ def collect_system_stats(device: torch.device) -> Dict[str, Any]:
system_stats: Dict[str, Any] = {}
cpu_stats = get_psutil_cpu_stats()

# pyre-ignore
# pyre-fixme
system_stats.update(cpu_stats)

if torch.cuda.is_available():
try:
gpu_stats = get_nvidia_smi_gpu_stats(device)

# pyre-ignore
# pyre-fixme
system_stats.update(gpu_stats)
system_stats.update(torch.cuda.memory_stats())
except FileNotFoundError:
Expand Down
40 changes: 20 additions & 20 deletions torchtnt/utils/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
aten: torch._ops._OpNamespace = torch.ops.aten


# pyre-ignore [2] we don't care the type in outputs
# pyre-fixme [2] we don't care the type in outputs
def _matmul_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number:
"""
Count flops for matmul.
Expand All @@ -35,7 +35,7 @@ def _matmul_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number
return flop


# pyre-ignore [2] we don't care the type in outputs
# pyre-fixme [2] we don't care the type in outputs
def _addmm_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number:
"""
Count flops for fully connected layers.
Expand All @@ -53,7 +53,7 @@ def _addmm_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number:
return flops


# pyre-ignore [2] we don't care the type in outputs
# pyre-fixme [2] we don't care the type in outputs
def _bmm_flop_jit(inputs: Tuple[torch.Tensor], outputs: Tuple[Any]) -> Number:
"""
Count flops for the bmm operation.
Expand Down Expand Up @@ -98,7 +98,7 @@ def _conv_flop_count(


def _conv_flop_jit(
inputs: Tuple[Any], # pyre-ignore [2] the inputs can be union of Tensor/bool/Tuple
inputs: Tuple[Any], # pyre-fixme [2] the inputs can be union of Tensor/bool/Tuple
outputs: Tuple[torch.Tensor],
) -> Number:
"""
Expand All @@ -118,7 +118,7 @@ def _transpose_shape(shape: torch.Size) -> List[int]:
return [shape[1], shape[0]] + list(shape[2:])


# pyre-ignore [2] the inputs can be union of Tensor/bool/Tuple & we don't care about outputs
# pyre-fixme [2] the inputs can be union of Tensor/bool/Tuple & we don't care about outputs
def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number:
grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]]
output_mask = inputs[-1]
Expand All @@ -127,7 +127,7 @@ def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number:

if output_mask[0]:
grad_input_shape = outputs[0].shape
# pyre-ignore [58] this is actually sum of Number and Number
# pyre-fixme [58] this is actually sum of Number and Number
flop_count = flop_count + _conv_flop_count(
grad_out_shape, w_shape, grad_input_shape, not fwd_transposed
)
Expand All @@ -143,7 +143,7 @@ def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number:
return flop_count


# pyre-ignore [5]
# pyre-fixme [5]
flop_mapping: Dict[Callable[..., Any], Callable[[Tuple[Any], Tuple[Any]], Number]] = {
aten.mm: _matmul_flop_jit,
aten.matmul: _matmul_flop_jit,
Expand All @@ -163,7 +163,7 @@ def _conv_backward_flop_jit(inputs: Tuple[Any], outputs: Tuple[Any]) -> Number:
}


# pyre-ignore [2, 3] it can be Tuple of anything.
# pyre-fixme [2, 3] it can be Tuple of anything.
def _normalize_tuple(x: Any) -> Tuple[Any]:
if not isinstance(x, tuple):
return (x,)
Expand Down Expand Up @@ -213,33 +213,33 @@ def __init__(self, module: torch.nn.Module) -> None:
)
self._parents: List[str] = [""]

# pyre-ignore
# pyre-fixme
def __exit__(self, exc_type, exc_val, exc_tb):
for hook_handle in self._all_hooks:
hook_handle.remove()
super().__exit__(exc_type, exc_val, exc_tb)

def __torch_dispatch__(
self,
func: Callable[..., Any], # pyre-ignore [2] func can be any func
types: Tuple[Any], # pyre-ignore [2]
args=(), # pyre-ignore [2]
kwargs=None, # pyre-ignore [2]
func: Callable[..., Any], # pyre-fixme [2] func can be any func
types: Tuple[Any], # pyre-fixme [2]
args=(), # pyre-fixme [2]
kwargs=None, # pyre-fixme [2]
) -> PyTree:
rs = func(*args, **kwargs)
outs = _normalize_tuple(rs)

if func in flop_mapping:
flop_count = flop_mapping[func](args, outs)
for par in self._parents:
# pyre-ignore [58]
# pyre-fixme [58]
self.flop_counts[par][func.__name__] += flop_count
else:
logging.debug(f"{func} is not yet supported in FLOPs calculation.")

return rs

# pyre-ignore [3]
# pyre-fixme [3]
def _create_backwards_push(self, name: str) -> Callable[..., Any]:
class PushState(torch.autograd.Function):
@staticmethod
Expand All @@ -262,7 +262,7 @@ def backward(ctx, *grad_outs):
# using a function parameter.
return PushState.apply

# pyre-ignore [3]
# pyre-fixme [3]
def _create_backwards_pop(self, name: str) -> Callable[..., Any]:
class PopState(torch.autograd.Function):
@staticmethod
Expand All @@ -286,9 +286,9 @@ def backward(ctx, *grad_outs):
# using a function parameter.
return PopState.apply

# pyre-ignore [3] Return a callable function
# pyre-fixme [3] Return a callable function
def _enter_module(self, name: str) -> Callable[..., Any]:
# pyre-ignore [2, 3]
# pyre-fixme [2, 3]
def f(module: torch.nn.Module, inputs: Tuple[Any]):
parents = self._parents
parents.append(name)
Expand All @@ -298,9 +298,9 @@ def f(module: torch.nn.Module, inputs: Tuple[Any]):

return f

# pyre-ignore [3] Return a callable function
# pyre-fixme [3] Return a callable function
def _exit_module(self, name: str) -> Callable[..., Any]:
# pyre-ignore [2, 3]
# pyre-fixme [2, 3]
def f(module: torch.nn.Module, inputs: Tuple[Any], outputs: Tuple[Any]):
parents = self._parents
assert parents[-1] == name
Expand Down
4 changes: 2 additions & 2 deletions torchtnt/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@


def _is_named_tuple(
# pyre-ignore: Missing parameter annotation [2]: Parameter `x` must have a type other than `Any`.
# pyre-fixme: Missing parameter annotation [2]: Parameter `x` must have a type other than `Any`.
x: Any,
) -> bool:
return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields")


def get_tensor_size_bytes_map(
# pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.
# pyre-fixme: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.
obj: Any,
) -> Dict[torch.Tensor, int]:
tensor_map = {}
Expand Down
24 changes: 12 additions & 12 deletions torchtnt/utils/module_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ def _clean_flops(flop: DefaultDict[str, DefaultDict[str, int]], N: int) -> None:

def _get_module_flops_and_activation_sizes(
module: torch.nn.Module,
# pyre-ignore
# pyre-fixme
module_args: Optional[Tuple[Any, ...]] = None,
# pyre-ignore
# pyre-fixme
module_kwargs: Optional[MutableMapping[str, Any]] = None,
) -> _ModuleSummaryData:
# a mapping from module name to activation size tuple (in_size, out_size)
Expand Down Expand Up @@ -309,9 +309,9 @@ def _has_tensor(item: Optional[PyTree]) -> bool:

def get_module_summary(
module: torch.nn.Module,
# pyre-ignore
# pyre-fixme
module_args: Optional[Tuple[Any, ...]] = None,
# pyre-ignore
# pyre-fixme
module_kwargs: Optional[MutableMapping[str, Any]] = None,
) -> ModuleSummary:
"""
Expand Down Expand Up @@ -669,13 +669,13 @@ def _activation_size_hook(
activation_sizes: Dict[
str, Tuple[Union[TUnknown, List[int]], Union[TUnknown, List[int]]]
],
# pyre-ignore: Invalid type parameters [24]
# pyre-fixme: Invalid type parameters [24]
) -> Callable[[str], Callable]:
# pyre-ignore: Missing parameter annotation [2]
# pyre-fixme: Missing parameter annotation [2]
def intermediate_hook(
module_name: str,
) -> Callable[[torch.nn.Module, Any, Any], None]:
# pyre-ignore
# pyre-fixme
def hook(_: torch.nn.Module, inp: Any, out: Any) -> None:
if len(inp) == 1:
inp = inp[0]
Expand All @@ -690,9 +690,9 @@ def hook(_: torch.nn.Module, inp: Any, out: Any) -> None:

def _forward_time_pre_hook(
timer_mapping: Dict[str, float]
# pyre-ignore: Invalid type parameters [24]
# pyre-fixme: Invalid type parameters [24]
) -> Callable[[str], Callable]:
# pyre-ignore: Missing parameter annotation [2]
# pyre-fixme: Missing parameter annotation [2]
def intermediate_hook(
module_name: str,
) -> Callable[[torch.nn.Module, Any], None]:
Expand All @@ -707,9 +707,9 @@ def hook(_module: torch.nn.Module, _inp: Any) -> None:
def _forward_time_hook(
timer_mapping: Dict[str, float],
elapsed_times: Dict[str, float],
# pyre-ignore: Invalid type parameters [24]
# pyre-fixme: Invalid type parameters [24]
) -> Callable[[str], Callable]:
# pyre-ignore: Missing parameter annotation [2]
# pyre-fixme: Missing parameter annotation [2]
def intermediate_hook(
module_name: str,
) -> Callable[[torch.nn.Module, Any, Any], None]:
Expand All @@ -725,7 +725,7 @@ def hook(_module: torch.nn.Module, _inp: Any, _out: Any) -> None:

def _register_hooks(
module: torch.nn.Module,
# pyre-ignore: Invalid type parameters [24]
# pyre-fixme: Invalid type parameters [24]
hooks: List[Tuple[Callable, _HookType]],
) -> List[RemovableHandle]:
"""
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/utils/prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class TorchCompileParams:

fullgraph: bool = False
dynamic: bool = False
# pyre-ignore: Invalid type parameters [24]
# pyre-fixme: Invalid type parameters [24]
backend: Union[str, Callable] = "inductor"
mode: Union[str, None] = None
options: Optional[Dict[str, Union[str, int, bool]]] = None
Expand Down
10 changes: 5 additions & 5 deletions torchtnt/utils/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Undefined import [21]: Could not find a name `get_ema_multi_avg_fn` defined in module `torch.optim.swa_utils`.
# pyre-fixme Undefined import [21]: Could not find a name `get_ema_multi_avg_fn` defined in module `torch.optim.swa_utils`.
from torch.optim.swa_utils import (
AveragedModel as PyTorchAveragedModel,
get_ema_multi_avg_fn,
Expand Down Expand Up @@ -55,11 +55,11 @@ def __init__(
raise ValueError(f"Decay must be between 0 and 1, got {ema_decay}")

# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
# pyre-fixme Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
multi_avg_fn = get_ema_multi_avg_fn(ema_decay)
elif averaging_method == "swa":
# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_swa_multi_avg_fn`.
# pyre-fixme Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_swa_multi_avg_fn`.
multi_avg_fn = get_swa_multi_avg_fn()

if use_lit:
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(
# use default init implementation

# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Unexpected keyword [28]
# pyre-fixme Unexpected keyword [28]
super().__init__(
model,
device=device,
Expand All @@ -104,6 +104,6 @@ def update_parameters(self, model: torch.nn.Module) -> None:
)

# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
# pyre-fixme Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
self.multi_avg_fn = get_ema_multi_avg_fn(decay)
super().update_parameters(model)
2 changes: 1 addition & 1 deletion torchtnt/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def _sync_durations(
pg_wrapper.all_gather_object(outputs, recorded_durations)
ret = defaultdict(list)
for output in outputs:
# pyre-ignore [16]: `Optional` has no attribute `__getitem__`.
# pyre-fixme [16]: `Optional` has no attribute `__getitem__`.
for k, v in output.items():
if k not in ret:
ret[k] = []
Expand Down

0 comments on commit 67bcc82

Please sign in to comment.