From 63d604a3bb38297db37bcddb3d4ac4c54958e33f Mon Sep 17 00:00:00 2001 From: Yanli Zhao Date: Tue, 5 Nov 2024 10:37:25 -0800 Subject: [PATCH] add ddr_cap config to shard_quant_model in TGIF (#2539) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2539 add ddr_cap config to shard_quant_model in TGIF inference path, so that we can fully utilize the CPU memory Reviewed By: ljyuva83 Differential Revision: D65451305 fbshipit-source-id: a77a5457283d7993d4b68b18bb7736c8cf4d7f64 --- torchrec/inference/include/torchrec/inference/GPUExecutor.h | 2 +- torchrec/inference/inference_legacy/src/GPUExecutor.cpp | 2 +- torchrec/inference/modules.py | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torchrec/inference/include/torchrec/inference/GPUExecutor.h b/torchrec/inference/include/torchrec/inference/GPUExecutor.h index 00c93668b..d2d289670 100644 --- a/torchrec/inference/include/torchrec/inference/GPUExecutor.h +++ b/torchrec/inference/include/torchrec/inference/GPUExecutor.h @@ -32,7 +32,7 @@ #include "torchrec/inference/BatchingQueue.h" #include "torchrec/inference/Observer.h" #include "torchrec/inference/ResultSplit.h" -#include "torchrec/inference/include/torchrec/inference/Observer.h" +#include "torchrec/inference/include/torchrec/inference/Observer.h" // @manual namespace torchrec { diff --git a/torchrec/inference/inference_legacy/src/GPUExecutor.cpp b/torchrec/inference/inference_legacy/src/GPUExecutor.cpp index 38b00ad21..8178ed3f0 100644 --- a/torchrec/inference/inference_legacy/src/GPUExecutor.cpp +++ b/torchrec/inference/inference_legacy/src/GPUExecutor.cpp @@ -25,7 +25,7 @@ #include #include #include -#include +#include // @manual // remove this after we switch over to multipy externally for torchrec #ifdef FBCODE_CAFFE2 diff --git a/torchrec/inference/modules.py b/torchrec/inference/modules.py index 1dd1735bc..fb8c9c21d 100644 --- a/torchrec/inference/modules.py +++ b/torchrec/inference/modules.py @@ -488,6 +488,7 @@ def shard_quant_model( sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None, device_memory_size: Optional[int] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, + ddr_cap: Optional[int] = None, ) -> Tuple[torch.nn.Module, ShardingPlan]: """ Shard a quantized TorchRec model, used for generating the most optimal model for inference and @@ -557,6 +558,7 @@ def shard_quant_model( compute_device=compute_device, local_world_size=world_size, hbm_cap=hbm_cap, + ddr_cap=ddr_cap, ) batch_size = 1 model_plan = trec_dist.planner.EmbeddingShardingPlanner(