Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torchrec out of sync with github issue #2509

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pyre.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
steps:
- uses: conda-incubator/setup-miniconda@v2
with:
python-version: 3.9
python-version: 3.11
- name: Checkout Torchrec
uses: actions/checkout@v2
- name: Install dependencies
Expand Down
2 changes: 1 addition & 1 deletion .pyre_configuration
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
}
],
"strict": true,
"version": "0.0.101703592829"
"version": "0.0.101729595488"
}
22 changes: 22 additions & 0 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,13 +674,24 @@ def __init__(
self.table_name_to_count: Dict[str, int] = {}
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}

# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
# `ShardedEmbeddingTable`.
for idx, config in enumerate(self._config.embedding_tables):
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
self._local_rows.append(config.local_rows)
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `get_weight_init_min`.
self._weight_init_mins.append(config.get_weight_init_min())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `get_weight_init_max`.
self._weight_init_maxs.append(config.get_weight_init_max())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `num_embeddings`.
self._num_embeddings.append(config.num_embeddings)
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
self._local_cols.append(config.local_cols)
self._feature_table_map.extend([idx] * config.num_features())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
if config.name not in self.table_name_to_count:
self.table_name_to_count[config.name] = 0
self.table_name_to_count[config.name] += 1
Expand Down Expand Up @@ -1080,13 +1091,24 @@ def __init__(
self.table_name_to_count: Dict[str, int] = {}
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}

# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
# `ShardedEmbeddingTable`.
for idx, config in enumerate(self._config.embedding_tables):
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
self._local_rows.append(config.local_rows)
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `get_weight_init_min`.
self._weight_init_mins.append(config.get_weight_init_min())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `get_weight_init_max`.
self._weight_init_maxs.append(config.get_weight_init_max())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `num_embeddings`.
self._num_embeddings.append(config.num_embeddings)
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
self._local_cols.append(config.local_cols)
self._feature_table_map.extend([idx] * config.num_features())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
if config.name not in self.table_name_to_count:
self.table_name_to_count[config.name] = 0
self.table_name_to_count[config.name] += 1
Expand Down
9 changes: 9 additions & 0 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def transform_module(
compile_mode: CompileMode,
world_size: int,
batch_size: int,
# pyre-fixme[24]: Generic type `ContextManager` expects 1 type parameter.
ctx: ContextManager,
benchmark_unsharded_module: bool = False,
) -> torch.nn.Module:
Expand Down Expand Up @@ -616,7 +617,11 @@ def trace_handler(prof) -> None:
if device_type == "cuda":
with torch.profiler.profile(
activities=[
# pyre-fixme[16]: Module `profiler` has no attribute
# `ProfilerActivity`.
torch.profiler.ProfilerActivity.CPU,
# pyre-fixme[16]: Module `profiler` has no attribute
# `ProfilerActivity`.
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
Expand Down Expand Up @@ -745,7 +750,11 @@ def trace_handler(prof) -> None:
a = a * torch.rand(16384, 16384, device="cuda")
with torch.profiler.profile(
activities=[
# pyre-fixme[16]: Module `profiler` has no attribute
# `ProfilerActivity`.
torch.profiler.ProfilerActivity.CPU,
# pyre-fixme[16]: Module `profiler` has no attribute
# `ProfilerActivity`.
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/keyed_jagged_tensor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,9 @@ def _update_local(
) -> None:
raise NotImplementedError("Inference does not support update")

# pyre-fixme[7]: Expected `KeyedJaggedTensor` but got implicit return value of
# `None`.
def _update_preproc(self, values: KeyedJaggedTensor) -> KeyedJaggedTensor:
# pyre-fixme[7]: Expected `KeyedJaggedTensor` but got implicit return value
# of `None`.
pass


Expand Down
10 changes: 5 additions & 5 deletions torchrec/distributed/object_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,16 @@ def input_dist(
*input,
# pyre-ignore[2]
**kwargs,
# pyre-fixme[7]: Expected `Awaitable[Awaitable[Tensor]]` but got implicit return
# value of `None`.
) -> Awaitable[Awaitable[torch.Tensor]]:
# pyre-fixme[7]: Expected `Awaitable[Awaitable[Tensor]]` but got implicit
# return value of `None`.
pass

# pyre-fixme[7]: Expected `DistOut` but got implicit return value of `None`.
def compute(self, ctx: ShrdCtx, dist_input: torch.Tensor) -> DistOut:
# pyre-fixme[7]: Expected `DistOut` but got implicit return value of `None`.
pass

# pyre-fixme[7]: Expected `LazyAwaitable[Out]` but got implicit return value of
# `None`.
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
# pyre-fixme[7]: Expected `LazyAwaitable[Variable[Out]]` but got implicit
# return value of `None`.
pass
4 changes: 4 additions & 0 deletions torchrec/distributed/planner/tests/test_partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,8 @@ def test_different_sharding_plan(self) -> None:
for shard in sharding_option.shards:
if shard.storage and shard.rank is not None:
greedy_perf_hbm_uses[
# pyre-fixme[6]: For 1st argument expected `SupportsIndex`
# but got `Optional[int]`.
shard.rank
] += shard.storage.hbm # pyre-ignore[16]

Expand All @@ -796,6 +798,8 @@ def test_different_sharding_plan(self) -> None:
for sharding_option in sharding_options:
for shard in sharding_option.shards:
if shard.storage and shard.rank:
# pyre-fixme[6]: For 1st argument expected `SupportsIndex` but
# got `Optional[int]`.
memory_balanced_hbm_uses[shard.rank] += shard.storage.hbm

self.assertTrue(max(memory_balanced_hbm_uses) < max(greedy_perf_hbm_uses))
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/shards_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
aten = torch.ops.aten # pyre-ignore[5]


# pyre-fixme[13]: Attribute `_local_shards` is never initialized.
# pyre-fixme[13]: Attribute `_storage_meta` is never initialized.
class LocalShardsWrapper(torch.Tensor):
"""
A wrapper class to hold local shards of a DTensor.
Expand All @@ -37,7 +35,9 @@ class LocalShardsWrapper(torch.Tensor):
"""

__slots__ = ["_local_shards", "_storage_meta"]
# pyre-fixme[13]: Attribute `_local_shards` is never initialized.
_local_shards: List[torch.Tensor]
# pyre-fixme[13]: Attribute `_storage_meta` is never initialized.
_storage_meta: TensorStorageMetadata

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/tensor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,8 @@ def _update_local(
deduped_ids, dedup_permutation = deterministic_dedup(ids)
shard.update(deduped_ids, values[dedup_permutation])

# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
def _update_preproc(self, values: torch.Tensor) -> torch.Tensor:
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
pass

def update(self, ids: torch.Tensor, values: torch.Tensor) -> None:
Expand Down
4 changes: 4 additions & 0 deletions torchrec/distributed/tests/test_awaitable.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ def _wait_impl(self) -> torch.Tensor:
class AwaitableTests(unittest.TestCase):
def test_callback(self) -> None:
awaitable = AwaitableInstance()
# pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got
# `(ret: Any) -> int`.
awaitable.callbacks.append(lambda ret: 2 * ret)
self.assertTrue(
torch.allclose(awaitable.wait(), torch.FloatTensor([2.0, 4.0, 6.0]))
)

def test_callback_chained(self) -> None:
awaitable = AwaitableInstance()
# pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got
# `(ret: Any) -> int`.
awaitable.callbacks.append(lambda ret: 2 * ret)
awaitable.callbacks.append(lambda ret: ret**2)
self.assertTrue(
Expand Down
13 changes: 6 additions & 7 deletions torchrec/distributed/tests/test_embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def __init__(self) -> None:
torch.nn.Module(),
]

# pyre-fixme[7]: Expected `EmbeddingBagCollectionContext` but got implicit
# return value of `None`.
def create_context(self) -> ShrdCtx:
# pyre-fixme[7]: Expected `EmbeddingBagCollectionContext` but got implicit
# return value of `None`.
pass

def input_dist(
Expand All @@ -41,19 +41,18 @@ def input_dist(
*input,
# pyre-ignore[2]
**kwargs,
) -> Awaitable[Awaitable[CompIn]]:
# pyre-fixme[7]: Expected `Awaitable[Awaitable[KJTList]]` but got implicit
# return value of `None`.
) -> Awaitable[Awaitable[CompIn]]:
pass

# pyre-fixme[7]: Expected `List[Tensor]` but got implicit return value of `None`.
def compute(self, ctx: ShrdCtx, dist_input: CompIn) -> DistOut:
# pyre-fixme[7]: Expected `List[Tensor]` but got implicit return value of
# `None`.
pass

# pyre-fixme[7]: Expected `LazyAwaitable[Dict[str, Tensor]]` but got implicit
# return value of `None`.
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
# pyre-fixme[7]: Expected `LazyAwaitable[Dict[str, Tensor]]` but got
# implicit return value of `None`.
pass


Expand Down
2 changes: 0 additions & 2 deletions torchrec/distributed/tests/test_lazy_awaitable.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

tempFile = None
with tempfile.NamedTemporaryFile(delete=False) as f:
# pyre-fixme[6]: For 2nd argument expected `SupportsWrite[bytes]` but
# got `_TemporaryFileWrapper[bytes]`.
pickle.dump(gm, f)
tempFile = f

Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,7 @@ def __init__(

def get_compiled_autograd_ctx(
self,
# pyre-fixme[24]: Generic type `ContextManager` expects 1 type parameter.
) -> ContextManager:
# this allows for pipelining
# to avoid doing a sum on None
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
# other metaclasses (i.e. AwaitableMeta) for customized
# behaviors, as Generic is non-trival metaclass in
# python 3.6 and below
# pyre-fixme[21]: Could not find name `GenericMeta` in `typing` (stubbed).
from typing import GenericMeta
except ImportError:
# In python 3.7+, GenericMeta doesn't exist as it's no
Expand Down Expand Up @@ -975,6 +974,9 @@ def __init__(
torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}")
self._qcomm_codecs_registry = qcomm_codecs_registry

# pyre-fixme[56]: Pyre doesn't yet support decorators with ParamSpec applied to
# generic functions. Consider using a context manager instead of a decorator, if
# possible.
@abc.abstractclassmethod
# pyre-ignore [3]
def shard(
Expand Down
6 changes: 5 additions & 1 deletion torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,11 @@ def maybe_reset_parameters(m: nn.Module) -> None:


def maybe_annotate_embedding_event(
event: EmbeddingEvent, module_fqn: Optional[str], sharding_type: Optional[str]
event: EmbeddingEvent,
module_fqn: Optional[str],
sharding_type: Optional[str],
# pyre-fixme[24]: Generic type `AbstractContextManager` expects 2 type parameters,
# received 1.
) -> AbstractContextManager[None]:
if module_fqn and sharding_type:
annotation = f"[{event.value}]_[{module_fqn}]_[{sharding_type}]"
Expand Down
2 changes: 2 additions & 0 deletions torchrec/linter/module_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def print_error_message(
"""
lint_item = {
"path": python_path,
# pyre-fixme[16]: `AST` has no attribute `lineno`.
"line": node.lineno,
# pyre-fixme[16]: `AST` has no attribute `col_offset`.
"char": node.col_offset + 1,
"severity": severity,
"name": name,
Expand Down
1 change: 1 addition & 0 deletions torchrec/metrics/tests/test_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def _test_adjust_compute_interval(
)
mock_time.time = MagicMock(return_value=0.0)

# pyre-fixme[53]: Captured variable `batch` is not annotated.
def _train(metric_module: RecMetricModule) -> float:
for _ in range(metric_module.compute_interval_steps):
metric_module.update(batch)
Expand Down
2 changes: 2 additions & 0 deletions torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def convert_list_of_modules_to_modulelist(
# `Iterable[torch.nn.Module]`.
len(modules)
== sizes[0]
# pyre-fixme[6]: For 1st argument expected `pyre_extensions.PyreReadOnly[Sized]`
# but got `Iterable[Module]`.
), f"the counts of modules ({len(modules)}) do not match with the required counts {sizes}"
if len(sizes) == 1:
return torch.nn.ModuleList(modules)
Expand Down
Loading