v0.6.0
VBE
TorchRec now natively supports VBE (variable batched embeddings) within the EmbeddingBagCollection
module. This allows variable batch size per feature, unlocking sparse input data deduplication, which can greatly speed up embedding lookup and all-to-all time. To enable, simply initialize KeyedJaggedTensor
with stride_per_key_per_rank
and inverse_indices
fields, which specify batch size per feature and inverse indices to reindex the embedding output respectively.
Embedding offloading
Embedding offloading is UVM caching (i.e. storing embedding tables on host memory with cache on HBM memory) plus prefetching and optimal sizing of cache. Embedding offloading would allow running a larger model with fewer GPUs, while maintaining competitive performance. To use, one needs to use the prefetching pipeline (PrefetchTrainPipelineSparseDist) and pass in per table cache load factor and the prefetch_pipeline flag through constraints in the planner.
Trec.shard/shard_modules
These APIs replace embedding submodules with its sharded variant. The shard API applies to an individual embedding module while the shard_modules API replaces all embedding modules and won’t touch other non-embedding submodules.
Embedding sharding follows similar behavior to the prior TorchRec DistributedModuleParallel behavior, except the ShardedModules have been made composable, meaning the modules are backed by TableBatchedEmbeddingSlices which are views into the underlying TBE (including .grad). This means that fused parameters are now returned with named_parameters(), including in DistributedModuleParallel.