Skip to content

Commit

Permalink
add ddr_cap config to shard_quant_model in TGIF (#2539)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2539

add ddr_cap config to shard_quant_model in TGIF
inference path, so that we can fully utilize the CPU memory

Reviewed By: ljyuva83

Differential Revision: D65451305

fbshipit-source-id: a77a5457283d7993d4b68b18bb7736c8cf4d7f64
  • Loading branch information
zhaojuanmao authored and facebook-github-bot committed Nov 5, 2024
1 parent 786bb1e commit 63d604a
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include "torchrec/inference/BatchingQueue.h"
#include "torchrec/inference/Observer.h"
#include "torchrec/inference/ResultSplit.h"
#include "torchrec/inference/include/torchrec/inference/Observer.h"
#include "torchrec/inference/include/torchrec/inference/Observer.h" // @manual

namespace torchrec {

Expand Down
2 changes: 1 addition & 1 deletion torchrec/inference/inference_legacy/src/GPUExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include <folly/stop_watch.h>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/profiler.h> // @manual

// remove this after we switch over to multipy externally for torchrec
#ifdef FBCODE_CAFFE2
Expand Down
2 changes: 2 additions & 0 deletions torchrec/inference/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def shard_quant_model(
sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None,
device_memory_size: Optional[int] = None,
constraints: Optional[Dict[str, ParameterConstraints]] = None,
ddr_cap: Optional[int] = None,
) -> Tuple[torch.nn.Module, ShardingPlan]:
"""
Shard a quantized TorchRec model, used for generating the most optimal model for inference and
Expand Down Expand Up @@ -557,6 +558,7 @@ def shard_quant_model(
compute_device=compute_device,
local_world_size=world_size,
hbm_cap=hbm_cap,
ddr_cap=ddr_cap,
)
batch_size = 1
model_plan = trec_dist.planner.EmbeddingShardingPlanner(
Expand Down

0 comments on commit 63d604a

Please sign in to comment.