Skip to content

Commit

Permalink
upgrade pyre version in fbcode/torchrec - batch 1 (#2516)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2516

Differential Revision: D64846836
  • Loading branch information
Pyre Bot Jr. authored and facebook-github-bot committed Oct 24, 2024
1 parent 9669707 commit c41cfa7
Show file tree
Hide file tree
Showing 17 changed files with 70 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .pyre_configuration
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
}
],
"strict": true,
"version": "0.0.101703592829"
"version": "0.0.101729681899"
}
22 changes: 22 additions & 0 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,13 +674,24 @@ def __init__(
self.table_name_to_count: Dict[str, int] = {}
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}

# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
# `ShardedEmbeddingTable`.
for idx, config in enumerate(self._config.embedding_tables):
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
self._local_rows.append(config.local_rows)
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `get_weight_init_min`.
self._weight_init_mins.append(config.get_weight_init_min())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `get_weight_init_max`.
self._weight_init_maxs.append(config.get_weight_init_max())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `num_embeddings`.
self._num_embeddings.append(config.num_embeddings)
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
self._local_cols.append(config.local_cols)
self._feature_table_map.extend([idx] * config.num_features())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
if config.name not in self.table_name_to_count:
self.table_name_to_count[config.name] = 0
self.table_name_to_count[config.name] += 1
Expand Down Expand Up @@ -1080,13 +1091,24 @@ def __init__(
self.table_name_to_count: Dict[str, int] = {}
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}

# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
# `ShardedEmbeddingTable`.
for idx, config in enumerate(self._config.embedding_tables):
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
self._local_rows.append(config.local_rows)
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `get_weight_init_min`.
self._weight_init_mins.append(config.get_weight_init_min())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `get_weight_init_max`.
self._weight_init_maxs.append(config.get_weight_init_max())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `num_embeddings`.
self._num_embeddings.append(config.num_embeddings)
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
self._local_cols.append(config.local_cols)
self._feature_table_map.extend([idx] * config.num_features())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
if config.name not in self.table_name_to_count:
self.table_name_to_count[config.name] = 0
self.table_name_to_count[config.name] += 1
Expand Down
9 changes: 9 additions & 0 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def transform_module(
compile_mode: CompileMode,
world_size: int,
batch_size: int,
# pyre-fixme[24]: Generic type `ContextManager` expects 1 type parameter.
ctx: ContextManager,
benchmark_unsharded_module: bool = False,
) -> torch.nn.Module:
Expand Down Expand Up @@ -616,7 +617,11 @@ def trace_handler(prof) -> None:
if device_type == "cuda":
with torch.profiler.profile(
activities=[
# pyre-fixme[16]: Module `profiler` has no attribute
# `ProfilerActivity`.
torch.profiler.ProfilerActivity.CPU,
# pyre-fixme[16]: Module `profiler` has no attribute
# `ProfilerActivity`.
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
Expand Down Expand Up @@ -745,7 +750,11 @@ def trace_handler(prof) -> None:
a = a * torch.rand(16384, 16384, device="cuda")
with torch.profiler.profile(
activities=[
# pyre-fixme[16]: Module `profiler` has no attribute
# `ProfilerActivity`.
torch.profiler.ProfilerActivity.CPU,
# pyre-fixme[16]: Module `profiler` has no attribute
# `ProfilerActivity`.
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/keyed_jagged_tensor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,9 @@ def _update_local(
) -> None:
raise NotImplementedError("Inference does not support update")

# pyre-fixme[7]: Expected `KeyedJaggedTensor` but got implicit return value of
# `None`.
def _update_preproc(self, values: KeyedJaggedTensor) -> KeyedJaggedTensor:
# pyre-fixme[7]: Expected `KeyedJaggedTensor` but got implicit return value
# of `None`.
pass


Expand Down
10 changes: 5 additions & 5 deletions torchrec/distributed/object_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,16 @@ def input_dist(
*input,
# pyre-ignore[2]
**kwargs,
# pyre-fixme[7]: Expected `Awaitable[Awaitable[Tensor]]` but got implicit return
# value of `None`.
) -> Awaitable[Awaitable[torch.Tensor]]:
# pyre-fixme[7]: Expected `Awaitable[Awaitable[Tensor]]` but got implicit
# return value of `None`.
pass

# pyre-fixme[7]: Expected `DistOut` but got implicit return value of `None`.
def compute(self, ctx: ShrdCtx, dist_input: torch.Tensor) -> DistOut:
# pyre-fixme[7]: Expected `DistOut` but got implicit return value of `None`.
pass

# pyre-fixme[7]: Expected `LazyAwaitable[Out]` but got implicit return value of
# `None`.
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
# pyre-fixme[7]: Expected `LazyAwaitable[Variable[Out]]` but got implicit
# return value of `None`.
pass
4 changes: 4 additions & 0 deletions torchrec/distributed/planner/tests/test_partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,8 @@ def test_different_sharding_plan(self) -> None:
for shard in sharding_option.shards:
if shard.storage and shard.rank is not None:
greedy_perf_hbm_uses[
# pyre-fixme[6]: For 1st argument expected `SupportsIndex`
# but got `Optional[int]`.
shard.rank
] += shard.storage.hbm # pyre-ignore[16]

Expand All @@ -796,6 +798,8 @@ def test_different_sharding_plan(self) -> None:
for sharding_option in sharding_options:
for shard in sharding_option.shards:
if shard.storage and shard.rank:
# pyre-fixme[6]: For 1st argument expected `SupportsIndex` but
# got `Optional[int]`.
memory_balanced_hbm_uses[shard.rank] += shard.storage.hbm

self.assertTrue(max(memory_balanced_hbm_uses) < max(greedy_perf_hbm_uses))
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/shards_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
aten = torch.ops.aten # pyre-ignore[5]


# pyre-fixme[13]: Attribute `_local_shards` is never initialized.
# pyre-fixme[13]: Attribute `_storage_meta` is never initialized.
class LocalShardsWrapper(torch.Tensor):
"""
A wrapper class to hold local shards of a DTensor.
Expand All @@ -37,7 +35,9 @@ class LocalShardsWrapper(torch.Tensor):
"""

__slots__ = ["_local_shards", "_storage_meta"]
# pyre-fixme[13]: Attribute `_local_shards` is never initialized.
_local_shards: List[torch.Tensor]
# pyre-fixme[13]: Attribute `_storage_meta` is never initialized.
_storage_meta: TensorStorageMetadata

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/tensor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,8 @@ def _update_local(
deduped_ids, dedup_permutation = deterministic_dedup(ids)
shard.update(deduped_ids, values[dedup_permutation])

# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
def _update_preproc(self, values: torch.Tensor) -> torch.Tensor:
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
pass

def update(self, ids: torch.Tensor, values: torch.Tensor) -> None:
Expand Down
4 changes: 4 additions & 0 deletions torchrec/distributed/tests/test_awaitable.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ def _wait_impl(self) -> torch.Tensor:
class AwaitableTests(unittest.TestCase):
def test_callback(self) -> None:
awaitable = AwaitableInstance()
# pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got
# `(ret: Any) -> int`.
awaitable.callbacks.append(lambda ret: 2 * ret)
self.assertTrue(
torch.allclose(awaitable.wait(), torch.FloatTensor([2.0, 4.0, 6.0]))
)

def test_callback_chained(self) -> None:
awaitable = AwaitableInstance()
# pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got
# `(ret: Any) -> int`.
awaitable.callbacks.append(lambda ret: 2 * ret)
awaitable.callbacks.append(lambda ret: ret**2)
self.assertTrue(
Expand Down
13 changes: 6 additions & 7 deletions torchrec/distributed/tests/test_embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def __init__(self) -> None:
torch.nn.Module(),
]

# pyre-fixme[7]: Expected `EmbeddingBagCollectionContext` but got implicit
# return value of `None`.
def create_context(self) -> ShrdCtx:
# pyre-fixme[7]: Expected `EmbeddingBagCollectionContext` but got implicit
# return value of `None`.
pass

def input_dist(
Expand All @@ -41,19 +41,18 @@ def input_dist(
*input,
# pyre-ignore[2]
**kwargs,
) -> Awaitable[Awaitable[CompIn]]:
# pyre-fixme[7]: Expected `Awaitable[Awaitable[KJTList]]` but got implicit
# return value of `None`.
) -> Awaitable[Awaitable[CompIn]]:
pass

# pyre-fixme[7]: Expected `List[Tensor]` but got implicit return value of `None`.
def compute(self, ctx: ShrdCtx, dist_input: CompIn) -> DistOut:
# pyre-fixme[7]: Expected `List[Tensor]` but got implicit return value of
# `None`.
pass

# pyre-fixme[7]: Expected `LazyAwaitable[Dict[str, Tensor]]` but got implicit
# return value of `None`.
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
# pyre-fixme[7]: Expected `LazyAwaitable[Dict[str, Tensor]]` but got
# implicit return value of `None`.
pass


Expand Down
2 changes: 0 additions & 2 deletions torchrec/distributed/tests/test_lazy_awaitable.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

tempFile = None
with tempfile.NamedTemporaryFile(delete=False) as f:
# pyre-fixme[6]: For 2nd argument expected `SupportsWrite[bytes]` but
# got `_TemporaryFileWrapper[bytes]`.
pickle.dump(gm, f)
tempFile = f

Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,6 +1613,7 @@ def __init__(

def get_compiled_autograd_ctx(
self,
# pyre-fixme[24]: Generic type `ContextManager` expects 1 type parameter.
) -> ContextManager:
# this allows for pipelining
# to avoid doing a sum on None
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
# other metaclasses (i.e. AwaitableMeta) for customized
# behaviors, as Generic is non-trival metaclass in
# python 3.6 and below
# pyre-fixme[21]: Could not find name `GenericMeta` in `typing` (stubbed).
from typing import GenericMeta
except ImportError:
# In python 3.7+, GenericMeta doesn't exist as it's no
Expand Down Expand Up @@ -975,6 +974,9 @@ def __init__(
torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}")
self._qcomm_codecs_registry = qcomm_codecs_registry

# pyre-fixme[56]: Pyre doesn't yet support decorators with ParamSpec applied to
# generic functions. Consider using a context manager instead of a decorator, if
# possible.
@abc.abstractclassmethod
# pyre-ignore [3]
def shard(
Expand Down
6 changes: 5 additions & 1 deletion torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,11 @@ def maybe_reset_parameters(m: nn.Module) -> None:


def maybe_annotate_embedding_event(
event: EmbeddingEvent, module_fqn: Optional[str], sharding_type: Optional[str]
event: EmbeddingEvent,
module_fqn: Optional[str],
sharding_type: Optional[str],
# pyre-fixme[24]: Generic type `AbstractContextManager` expects 2 type parameters,
# received 1.
) -> AbstractContextManager[None]:
if module_fqn and sharding_type:
annotation = f"[{event.value}]_[{module_fqn}]_[{sharding_type}]"
Expand Down
2 changes: 2 additions & 0 deletions torchrec/linter/module_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def print_error_message(
"""
lint_item = {
"path": python_path,
# pyre-fixme[16]: `AST` has no attribute `lineno`.
"line": node.lineno,
# pyre-fixme[16]: `AST` has no attribute `col_offset`.
"char": node.col_offset + 1,
"severity": severity,
"name": name,
Expand Down
1 change: 1 addition & 0 deletions torchrec/metrics/tests/test_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def _test_adjust_compute_interval(
)
mock_time.time = MagicMock(return_value=0.0)

# pyre-fixme[53]: Captured variable `batch` is not annotated.
def _train(metric_module: RecMetricModule) -> float:
for _ in range(metric_module.compute_interval_steps):
metric_module.update(batch)
Expand Down
2 changes: 2 additions & 0 deletions torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def convert_list_of_modules_to_modulelist(
# `Iterable[torch.nn.Module]`.
len(modules)
== sizes[0]
# pyre-fixme[6]: For 1st argument expected `pyre_extensions.PyreReadOnly[Sized]`
# but got `Iterable[Module]`.
), f"the counts of modules ({len(modules)}) do not match with the required counts {sizes}"
if len(sizes) == 1:
return torch.nn.ModuleList(modules)
Expand Down

0 comments on commit c41cfa7

Please sign in to comment.