diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index 2d8cea3c9..5d4e30443 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -8,6 +8,7 @@ # pyre-strict import copy +import logging from collections import defaultdict, OrderedDict from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Type @@ -54,6 +55,8 @@ from torchrec.modules.utils import construct_jagged_tensors from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +logger: logging.Logger = logging.getLogger(__name__) + class ManagedCollisionCollectionAwaitable(LazyAwaitable[KeyedJaggedTensor]): def __init__( @@ -140,6 +143,7 @@ def __init__( qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: super().__init__() + self.need_preprocess: bool = module.need_preprocess self._device = device self._env = env self._table_name_to_parameter_sharding: Dict[str, ParameterSharding] = ( @@ -249,15 +253,33 @@ def _create_managed_collision_modules( ) -> None: self._mc_module_name_shard_metadata: DefaultDict[str, List[int]] = defaultdict() - self._feature_to_offset: Dict[str, int] = {} + # To map mch output indices from local to global. key: table_name self._table_to_offset: Dict[str, int] = {} + # the split sizes of tables belonging to each sharding. outer len is # shardings + self._sharding_per_table_feature_splits: List[List[int]] = [] + # the split sizes of features per sharding. len is # shardings + self._sharding_feature_splits: List[int] = [] + # the split sizes of features per table. len is # tables sum over all shardings + self._table_feature_splits: List[int] = [] + self._feature_names: List[str] = [] + + # table names of each sharding + self._sharding_tables: List[List[str]] = [] + self._sharding_features: List[List[str]] = [] + for sharding in self._embedding_shardings: assert isinstance(sharding, BaseRwEmbeddingSharding) + self._sharding_tables.append([]) + self._sharding_features.append([]) + self._sharding_per_table_feature_splits.append([]) grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( sharding._grouped_embedding_configs ) + self._sharding_feature_splits.append(len(sharding.feature_names())) + + num_sharding_features = 0 for group_config in grouped_embedding_configs: for table in group_config.embedding_tables: # pyre-ignore [16] @@ -271,6 +293,9 @@ def _create_managed_collision_modules( ] + [table.num_embeddings] mc_module = module._managed_collision_modules[table.name] + self._sharding_tables[-1].append(table.name) + self._sharding_features[-1].extend(table.feature_names) + self._feature_names.extend(table.feature_names) self._managed_collision_modules[table.name] = ( mc_module.rebuild_with_output_id_range( output_id_range=( @@ -315,21 +340,42 @@ def _create_managed_collision_modules( zch_size, zch_size_cumsum[-1], ) - for feature in table.feature_names: - self._feature_to_offset[feature] = new_min_output_id self._table_to_offset[table.name] = new_min_output_id + self._table_feature_splits.append(len(table.feature_names)) + self._sharding_per_table_feature_splits[-1].append( + self._table_feature_splits[-1] + ) + num_sharding_features += self._table_feature_splits[-1] + + assert num_sharding_features == len( + sharding.feature_names() + ), f"Shared feature is not supported. {num_sharding_features=}, {self._sharding_per_table_feature_splits[-1]=}" + + if self._sharding_features[-1] != sharding.feature_names(): + logger.warn( + "The order of tables of this sharding is altered due to grouping: " + f"{self._sharding_features[-1]=} vs {sharding.feature_names()=}" + ) + + logger.info(f"{self._table_feature_splits=}") + logger.info(f"{self._sharding_per_table_feature_splits=}") + logger.info(f"{self._feature_names=}") + logger.info(f"{self._table_to_offset=}") + logger.info(f"{self._sharding_tables=}") + logger.info(f"{self._sharding_features=}") + def _create_input_dists( self, input_feature_names: List[str], ) -> None: - feature_names: List[str] = [] - self._feature_splits: List[int] = [] - for sharding in self._embedding_shardings: + for sharding, sharding_features in zip( + self._embedding_shardings, self._sharding_features + ): assert isinstance(sharding, BaseRwEmbeddingSharding) feature_hash_sizes: List[int] = [ self._managed_collision_modules[self._feature_to_table[f]].input_size() - for f in sharding.feature_names() + for f in sharding_features ] input_dist = RwSparseFeaturesDist( @@ -344,11 +390,9 @@ def _create_input_dists( keep_original_indices=True, ) self._input_dists.append(input_dist) - feature_names.extend(sharding.feature_names()) - self._feature_splits.append(len(sharding.feature_names())) self._features_order: List[int] = [] - for f in feature_names: + for f in self._feature_names: self._features_order.append(input_feature_names.index(f)) self._features_order = ( [] @@ -386,30 +430,48 @@ def input_dist( self._has_uninitialized_input_dists = False with torch.no_grad(): - features_dict = features.to_dict() - output: Dict[str, JaggedTensor] = features_dict.copy() - for table, mc_module in self._managed_collision_modules.items(): - feature_list: List[str] = self._table_to_features[table] - mc_input: Dict[str, JaggedTensor] = {} - for feature in feature_list: - mc_input[feature] = features_dict[feature] - mc_input = mc_module.preprocess(mc_input) - output.update(mc_input) - - # NOTE shared features not currently supported - features = KeyedJaggedTensor.from_jt_dict(output) - if self._features_order: features = features.permute( self._features_order, self._features_order_tensor, ) - features_by_sharding = features.split( - self._feature_splits, - ) + + feature_splits: List[KeyedJaggedTensor] = [] + if self.need_preprocess: + # NOTE: No shared features allowed! + feature_splits = features.split(self._table_feature_splits) + else: + feature_splits = features.split(self._sharding_feature_splits) + + ti: int = 0 awaitables = [] - for input_dist, features in zip(self._input_dists, features_by_sharding): - awaitables.append(input_dist(features)) + for i, tables in enumerate(self._sharding_tables): + if self.need_preprocess: + output: Dict[str, JaggedTensor] = {} + for table in tables: + kjt: KeyedJaggedTensor = feature_splits[ti] + mc_module = self._managed_collision_modules[table] + # TODO: change to Dict[str, Tensor] + mc_input: Dict[str, JaggedTensor] = { + table: JaggedTensor( + values=kjt.values(), + lengths=kjt.lengths(), + ) + } + mc_input = mc_module.preprocess(mc_input) + output.update(mc_input) + ti += 1 + shard_kjt = KeyedJaggedTensor( + keys=self._sharding_features[i], + values=torch.cat([jt.values() for jt in output.values()]), + lengths=torch.cat([jt.lengths() for jt in output.values()]), + ) + else: + shard_kjt = feature_splits[i] + + input_dist = self._input_dists[i] + + awaitables.append(input_dist(shard_kjt)) ctx.sharding_contexts.append( SequenceShardingContext( features_before_input_dist=features, @@ -426,26 +488,31 @@ def input_dist( def _kjt_list_to_tensor_list( self, kjt_list: KJTList, - feature_to_offset: Dict[str, int], ) -> List[torch.Tensor]: remapped_ids_ret: List[torch.Tensor] = [] - # TODO: find a better solution - for kjt in kjt_list: - jt_dict = kjt.to_dict() - for feature, jt in jt_dict.items(): - offset = feature_to_offset[feature] - jt._values = jt.values().add(offset) - new_kjt = KeyedJaggedTensor.from_jt_dict(jt_dict) - remapped_ids_ret.append(new_kjt.values().view(-1, 1)) + # TODO: find a better solution, could be padding + for kjt, tables, splits in zip( + kjt_list, self._sharding_tables, self._sharding_per_table_feature_splits + ): + if len(splits) > 1: + feature_splits = kjt.split(splits) + vals: List[torch.Tensor] = [] + # assert len(feature_splits) == len(sharding.embedding_tables()) + for feature_split, table in zip(feature_splits, tables): + offset = self._table_to_offset[table] + vals.append(feature_split.values() + offset) + remapped_ids_ret.append(torch.cat(vals).view(-1, 1)) + else: + remapped_ids_ret.append(kjt.values() + self._table_to_offset[tables[0]]) return remapped_ids_ret def global_to_local_index( self, - feature_dict: Dict[str, JaggedTensor], + jt_dict: Dict[str, JaggedTensor], ) -> Dict[str, JaggedTensor]: - for feature, jt in feature_dict.items(): - jt._values = jt.values() - self._feature_to_offset[feature] - return feature_dict + for table, jt in jt_dict.items(): + jt._values = jt.values() - self._table_to_offset[table] + return jt_dict def compute( self, @@ -454,26 +521,59 @@ def compute( ) -> KJTList: remapped_kjts: List[KeyedJaggedTensor] = [] - for features, sharding_ctx in zip( + # per shard + for features, sharding_ctx, tables, splits, fns in zip( dist_input, ctx.sharding_contexts, + self._sharding_tables, + self._sharding_per_table_feature_splits, + self._sharding_features, ): sharding_ctx.lengths_after_input_dist = features.lengths().view( -1, features.stride() ) - features_dict = features.to_dict() - output: Dict[str, JaggedTensor] = features_dict.copy() - for table, mc_module in self._managed_collision_modules.items(): - feature_list: List[str] = self._table_to_features[table] - mc_input: Dict[str, JaggedTensor] = {} - for feature in feature_list: - mc_input[feature] = features_dict[feature] - mc_input = mc_module.profile(mc_input) - mc_input = mc_module.remap(mc_input) - mc_input = self.global_to_local_index(mc_input) - output.update(mc_input) - remapped_kjts.append(KeyedJaggedTensor.from_jt_dict(output)) + values: torch.Tensor + if len(splits) > 1: + # features per shard split by tables + feature_splits = features.split(splits) + output: Dict[str, JaggedTensor] = {} + for table, kjt in zip(tables, feature_splits): + # TODO: Dict[str, Tensor] + mc_input: Dict[str, JaggedTensor] = { + table: JaggedTensor( + values=kjt.values(), + lengths=kjt.lengths(), + ) + } + mcm = self._managed_collision_modules[table] + mc_input = mcm.profile(mc_input) + mc_input = mcm.remap(mc_input) + mc_input = self.global_to_local_index(mc_input) + output.update(mc_input) + values = torch.cat([jt.values() for jt in output.values()]) + else: + table: str = tables[0] + mc_input: Dict[str, JaggedTensor] = { + table: JaggedTensor( + values=features.values(), + lengths=features.lengths(), + ) + } + mcm = self._managed_collision_modules[table] + mc_input = mcm.profile(mc_input) + mc_input = mcm.remap(mc_input) + mc_input = self.global_to_local_index(mc_input) + values = mc_input[table].values() + + remapped_kjts.append( + KeyedJaggedTensor( + keys=fns, + values=values, + lengths=features.lengths(), + weights=features.weights_or_none(), + ) + ) return KJTList(remapped_kjts) def evict(self) -> Dict[str, Optional[torch.Tensor]]: @@ -505,8 +605,7 @@ def output_dist( ctx: ManagedCollisionCollectionContext, output: KJTList, ) -> LazyAwaitable[KeyedJaggedTensor]: - - global_remapped = self._kjt_list_to_tensor_list(output, self._feature_to_offset) + global_remapped = self._kjt_list_to_tensor_list(output) awaitables_per_sharding: List[Awaitable[torch.Tensor]] = [] features_before_all2all_per_sharding: List[KeyedJaggedTensor] = [] for odist, remapped_ids, sharding_ctx in zip( diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index f9fd66280..91d37edc0 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -36,6 +36,8 @@ from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.modules.embedding_tower import EmbeddingTower, EmbeddingTowerCollection from torchrec.modules.feature_processor import PositionWeightedProcessor +from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection from torchrec.modules.regroup import KTRegroupAsDict from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor from torchrec.streamable import Pipelineable @@ -1009,52 +1011,43 @@ def __init__( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], device: Optional[torch.device] = None, - max_feature_lengths_list: Optional[List[Dict[str, int]]] = None, + max_feature_lengths: Optional[Dict[str, int]] = None, ) -> None: super().__init__() if device is None: device = torch.device("cpu") self.fps: Optional[nn.ModuleList] = None - self.fp_ebc: Optional[EmbeddingBagCollection] = None - if max_feature_lengths_list is not None: - self.fps = nn.ModuleList( - [ - PositionWeightedProcessor( - max_feature_lengths=max_feature_lengths, - device=( - device - if device != torch.device("meta") - else torch.device("cpu") - ), - ) - for max_feature_lengths in max_feature_lengths_list - ] - ) - normal_id_list_tables = [] - fp_id_list_tables = [] - for table in tables: - # the key set of feature_processor is either subset or none in the feature_names - if set(table.feature_names).issubset( - set(max_feature_lengths_list[0].keys()) - ): - fp_id_list_tables.append(table) - else: - normal_id_list_tables.append(table) + self.fp_ebc: Optional[FeatureProcessedEmbeddingBagCollection] = None + + if max_feature_lengths is not None: + fp_tables_names = set(max_feature_lengths.keys()) + normal_tables_names = {table.name for table in tables} - fp_tables_names self.ebc: EmbeddingBagCollection = EmbeddingBagCollection( - tables=normal_id_list_tables, + tables=[table for table in tables if table.name in normal_tables_names], device=device, ) - self.fp_ebc: EmbeddingBagCollection = EmbeddingBagCollection( - tables=fp_id_list_tables, - device=device, - is_weighted=True, + + fp = PositionWeightedModuleCollection( + max_feature_lengths=max_feature_lengths, + device=( + device if device != torch.device("meta") else torch.device("cpu") + ), + ) + self.fp_ebc = FeatureProcessedEmbeddingBagCollection( + embedding_bag_collection=EmbeddingBagCollection( + tables=[table for table in tables if table.name in fp_tables_names], + device=device, + is_weighted=True, + ), + feature_processors=fp, ) else: self.ebc: EmbeddingBagCollection = EmbeddingBagCollection( tables=tables, device=device, ) + self.weighted_ebc: Optional[EmbeddingBagCollection] = ( EmbeddingBagCollection( tables=weighted_tables, @@ -1109,7 +1102,6 @@ def __init__( embedding_groups: Optional[Dict[str, List[str]]] = None, dense_device: Optional[torch.device] = None, sparse_device: Optional[torch.device] = None, - feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, ) -> None: super().__init__() if dense_device is None: @@ -1148,7 +1140,7 @@ def __init__( embedding_groups: Optional[Dict[str, List[str]]] = None, dense_device: Optional[torch.device] = None, sparse_device: Optional[torch.device] = None, - max_feature_lengths_list: Optional[List[Dict[str, int]]] = None, + max_feature_lengths: Optional[Dict[str, int]] = None, feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, over_arch_clazz: Type[nn.Module] = TestOverArch, preproc_module: Optional[nn.Module] = None, @@ -1167,7 +1159,7 @@ def __init__( tables, weighted_tables, sparse_device, - max_feature_lengths_list if max_feature_lengths_list is not None else None, + max_feature_lengths, ) embedding_names = ( diff --git a/torchrec/inference/modules.py b/torchrec/inference/modules.py index d1249f3ef..997aeb02f 100644 --- a/torchrec/inference/modules.py +++ b/torchrec/inference/modules.py @@ -420,14 +420,19 @@ def _quantize_fp_module( fp_module_fqn: str, activation_dtype: torch.dtype = torch.float, weight_dtype: torch.dtype = DEFAULT_QUANTIZATION_DTYPE, + per_fp_table_weight_dtype: Optional[Dict[str, torch.dtype]] = None, ) -> None: """ If FeatureProcessedEmbeddingBagCollection is found, quantize via direct module swap. """ - fp_module.qconfig = quant.QConfig( + + quant_prep_enable_register_tbes(model, [FeatureProcessedEmbeddingBagCollection]) + fp_module.qconfig = QuantConfig( activation=quant.PlaceholderObserver.with_args(dtype=activation_dtype), weight=quant.PlaceholderObserver.with_args(dtype=weight_dtype), + per_table_weight_dtype=per_fp_table_weight_dtype, ) + # ie. "root.submodule.feature_processed_mod" -> "root.submodule", "feature_processed_mod" fp_ebc_parent_fqn, fp_ebc_name = fp_module_fqn.rsplit(".", 1) fp_ebc_parent = getattr_recursive(model, fp_ebc_parent_fqn) @@ -447,7 +452,15 @@ def _quantize_fp_module( additional_mapping[type(m)] = quantization_mapping[typename] elif typename == FEATURE_PROCESSED_EBC_TYPE: # handle the fp ebc separately - _quantize_fp_module(model, m, n, weight_dtype=fp_weight_dtype) + _quantize_fp_module( + model, + m, + n, + weight_dtype=fp_weight_dtype, + # Pass in per_fp_table_weight_dtype if it is provided, perhaps + # fpebc parameters are also in here + per_fp_table_weight_dtype=per_table_weight_dtype, + ) quant_prep_enable_register_tbes(model, list(additional_mapping.keys())) quantize_embeddings( diff --git a/torchrec/inference/tests/test_inference.py b/torchrec/inference/tests/test_inference.py index 9def46b74..4d3da310c 100644 --- a/torchrec/inference/tests/test_inference.py +++ b/torchrec/inference/tests/test_inference.py @@ -12,6 +12,7 @@ from argparse import Namespace import torch +from fbgemm_gpu.split_embedding_configs import SparseType from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES from torchrec.distributed.global_settings import set_propogate_device from torchrec.distributed.test_utils.test_model import ( @@ -175,3 +176,42 @@ def test_set_pruning_data(self) -> None: spec[1], pruning_dict[spec[0]], ) + + def test_quantize_per_table_dtype(self) -> None: + max_feature_lengths = {} + + # First two tables as FPEBC + max_feature_lengths[self.tables[0].name] = 100 + max_feature_lengths[self.tables[1].name] = 100 + + model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + num_float_features=10, + dense_device=torch.device("cpu"), + sparse_device=torch.device("cpu"), + over_arch_clazz=TestOverArchRegroupModule, + max_feature_lengths=max_feature_lengths, + ) + + per_table_dtype = {} + + for table in self.tables + self.weighted_tables: + # quint4x2 different than int8, which is default + per_table_dtype[table.name] = torch.quint4x2 + + quantized_model = quantize_inference_model( + model, per_table_weight_dtype=per_table_dtype + ) + + num_tbes = 0 + # Check EBC configs and TBE for correct shapes + for module in quantized_model.modules(): + if module.__class__.__name__ == "IntNBitTableBatchedEmbeddingBagsCodegen": + num_tbes += 1 + for i, spec in enumerate(module.embedding_specs): + self.assertEqual(spec[3], SparseType.INT4) + + # 3 TBES (1 FPEBC, 2 EBCs (1 weighted, 1 unweighted)) + + self.assertEqual(num_tbes, 3) diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index a52b76871..391a52123 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -20,14 +20,6 @@ from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor -try: - from torchrec.sparse.jagged_tensor import ComputeJTDictToKJT -except ImportError: - # Dummy implementation, use try catch for torch package compatibility issue - torch._C._log_api_usage_once( - "ImportError ComputeJTDictToKJT, ignoring, but possible incompatiblity with torch.package" - ) - logger: Logger = getLogger(__name__) @@ -44,6 +36,46 @@ def apply_mc_method_to_jt_dict( return attr(features_dict) +@torch.fx.wrap +def _update( + base: Optional[Dict[str, JaggedTensor]], delta: Dict[str, JaggedTensor] +) -> Dict[str, JaggedTensor]: + if base is None: + base = delta + else: + base.update(delta) + return base + + +@torch.fx.wrap +def _cat_jagged_values(jd: Dict[str, JaggedTensor]) -> torch.Tensor: + return torch.cat([jt.values() for jt in jd.values()]) + + +@torch.fx.wrap +def _mcc_lazy_init( + features: KeyedJaggedTensor, + feature_names: List[str], + features_order: List[int], + created_feature_order: bool, +) -> Tuple[KeyedJaggedTensor, bool, List[int]]: # features_order + input_feature_names: List[str] = features.keys() + if not created_feature_order: + for f in feature_names: + features_order.append(input_feature_names.index(f)) + + if features_order == list(range(len(features_order))): + features_order = torch.jit.annotate(List[int], []) + created_feature_order = True + + if len(features_order) > 0: + features = features.permute( + features_order, + ) + + return (features, created_feature_order, features_order) + + @torch.no_grad() def dynamic_threshold_filter( id_counts: torch.Tensor, @@ -251,28 +283,30 @@ class ManagedCollisionCollection(nn.Module): """ _table_to_features: Dict[str, List[str]] + _features_order: List[int] def __init__( self, managed_collision_modules: Dict[str, ManagedCollisionModule], embedding_configs: List[BaseEmbeddingConfig], + need_preprocess: bool = True, ) -> None: super().__init__() self._managed_collision_modules = nn.ModuleDict(managed_collision_modules) self._embedding_configs = embedding_configs + self.need_preprocess = need_preprocess self._feature_to_table: Dict[str, str] = { feature: config.name for config in embedding_configs for feature in config.feature_names } - self._table_to_features = {} - - self._compute_jt_dict_to_kjt: torch.nn.Module = ComputeJTDictToKJT() - for feature, table in self._feature_to_table.items(): - if table not in self._table_to_features: - self._table_to_features[table] = [] + self._table_to_features: Dict[str, List[str]] = { + config.name: config.feature_names for config in embedding_configs + } - self._table_to_features[table].append(feature) + self._table_feature_splits: List[int] = [ + len(features) for features in self._table_to_features.values() + ] table_to_config = {config.name: config for config in embedding_configs} @@ -287,6 +321,23 @@ def __init__( f"max_output_id in managed collision module for {name} " f"must match {config.num_embeddings}" ) + self._feature_names: List[str] = [ + feature for config in embedding_configs for feature in config.feature_names + ] + self._created_feature_order = False + self._features_order = [] + + def _create_feature_order( + self, + input_feature_names: List[str], + device: torch.device, + ) -> None: + features_order: List[int] = [] + for f in self._feature_names: + features_order.append(input_feature_names.index(f)) + + if features_order != list(range(len(features_order))): + self._features_order = features_order def embedding_configs(self) -> List[BaseEmbeddingConfig]: return self._embedding_configs @@ -295,16 +346,41 @@ def forward( self, features: KeyedJaggedTensor, ) -> KeyedJaggedTensor: - features_dict = features.to_dict() - output: Dict[str, JaggedTensor] = features_dict.copy() - for table, mc_module in self._managed_collision_modules.items(): - feature_list: List[str] = self._table_to_features[table] - mc_input: Dict[str, JaggedTensor] = {} - for feature in feature_list: - mc_input[feature] = features_dict[feature] + ( + features, + self._created_feature_order, + self._features_order, + ) = _mcc_lazy_init( + features, + self._feature_names, + self._features_order, + self._created_feature_order, + ) + + feature_splits: List[KeyedJaggedTensor] = features.split( + self._table_feature_splits + ) + + output: Optional[Dict[str, JaggedTensor]] = None + for i, (table, mc_module) in enumerate(self._managed_collision_modules.items()): + kjt: KeyedJaggedTensor = feature_splits[i] + mc_input: Dict[str, JaggedTensor] = { + table: JaggedTensor( + values=kjt.values(), + lengths=kjt.lengths(), + ) + } mc_input = mc_module(mc_input) - output.update(mc_input) - return self._compute_jt_dict_to_kjt(output) + output = _update(output, mc_input) + + assert output is not None + values: torch.Tensor = _cat_jagged_values(output) + return KeyedJaggedTensor( + keys=features.keys(), + values=values, + lengths=features.lengths(), + weights=features.weights_or_none(), + ) def evict(self) -> Dict[str, Optional[torch.Tensor]]: evictions: Dict[str, Optional[torch.Tensor]] = {}