Skip to content

Commit

Permalink
Fix inference example and polish CMAKE paths and README
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulZhang12 committed Oct 23, 2024
1 parent 1a57ce1 commit 9572b4c
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 13 deletions.
5 changes: 2 additions & 3 deletions torchrec/inference/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ cmake_minimum_required(VERSION 3.8)

project(inference C CXX)

include(/home/paulzhan/grpc/examples/cpp/cmake/common.cmake)

include(${GRPC_COMMON_CMAKE_PATH})

# Proto file
get_filename_component(hw_proto "/home/paulzhan/torchrec/torchrec/inference/protos/predictor.proto" ABSOLUTE)
get_filename_component(hw_proto "${CMAKE_CURRENT_BINARY_DIR}/../protos/predictor.proto" ABSOLUTE)
get_filename_component(hw_proto_path "${hw_proto}" PATH)

# Generated sources
Expand Down
39 changes: 33 additions & 6 deletions torchrec/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,17 @@ C++ 17 is a requirement. GCC version has to be >= 9, with initial testing done o
<br>

### **1. Install Dependencies**
1. [GRPC for C++][https://grpc.io/docs/languages/cpp/quickstart/] needs to be installed, with the resulting installation directory being `$HOME/.local`
1. [GRPC for C++](https://grpc.io/docs/languages/cpp/quickstart/) needs to be installed, with the resulting installation directory being `$HOME/.local`
2. Ensure that **the protobuf compiler (protoc) binary being used is from the GRPC installation above**. The protoc binary will live in `$HOME/.local/bin`, which may not match with the system protoc binary, can check with `which protoc`.
3. Install PyTorch, FBGEMM, and TorchRec (ideally in a virtual environment):
3. Create a Python virtual environment (this example uses miniconda):
```
conda create -n inference python='3.11'
```
4. Instal grpc tooling for proto files
```
pip install grpcio-tools
```
4. Install PyTorch, FBGEMM, and TorchRec:
```
pip install torch --index-url https://download.pytorch.org/whl/cu121
pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/cu121
Expand All @@ -27,12 +35,27 @@ pip install torchrec --index-url https://download.pytorch.org/whl/cu121

Replace these variables with the relevant paths in your system. Check `CMakeLists.txt` and `server.cpp` to see how they're used throughout the build and runtime.

1. Set up FBGEMM library path:
```
# provide fbgemm_gpu_py.so to enable fbgemm_gpu c++ operators
find $HOME -name fbgemm_gpu_py.so
# Use path from correct virtual environment above and set environment variable $FBGEMM_LIB to it
export FBGEMM_LIB=""
# Use path from your current virtual environment above and set environment variable $FBGEMM_LIB to it.
# WARNING: Below is just an example path, your path will be different! Replace the example path.
export FBGEMM_LIB="/home/paulzhan/miniconda3/envs/inference/lib/python3.11/site-packages/fbgemm_gpu/fbgemm_gpu_py.so"""
```

2. Set GRPC github repository path:
In setting up [GRPC for C++](https://grpc.io/docs/languages/cpp/quickstart/), you will have cloned the GRPC github repository.
Let's set a variable for the our GRPC repository path to make things easier later on.

```
# Find the common.cmake file from grpc
find $HOME -name common.cmake
# Use the absolute path from result above, usually only 1 result but will be path that contains "/grpc/examples/cpp/cmake/common.cmake"
# WARNING: Below is just an example path, your path will be different! Replace the example path.
export GRPC_COMMON_CMAKE_PATH="/home/paulzhan/grpc/examples/cpp/cmake/common.cmake"
```

### **3. Generate TorchScripted DLRM model**
Expand Down Expand Up @@ -73,7 +96,7 @@ Start the server, loading in the model saved previously
./server /tmp/model.pt
```

**output**
**Output**

In the logs, you should see:

Expand All @@ -96,9 +119,13 @@ Server listening on 0.0.0.0:50051
+-----------------------------------------------------------------------------+
```

In another terminal instance, make a request to the server via the client:
In another terminal instance, make a request to the server via the client.

```
# Revisit the TorchRec inference folder
cd ~/torchrec/torchrec/inference/
# Run the client to make requests to server
python client.py
```

Expand Down
4 changes: 3 additions & 1 deletion torchrec/inference/dlrm_packager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import sys
from typing import List

import torch

from dlrm_predict import create_training_batch, DLRMModelConfig, DLRMPredictFactory
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES

Expand Down Expand Up @@ -102,7 +104,7 @@ def main(argv: List[str]) -> None:
sample_input=batch,
)

script_module = DLRMPredictFactory(model_config).create_predict_module(world_size=1)
script_module = DLRMPredictFactory(model_config).create_predict_module(device="cuda")

script_module.save(args.output_path)
print(f"Package is saved to {args.output_path}")
Expand Down
7 changes: 4 additions & 3 deletions torchrec/inference/dlrm_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,16 @@ def __init__(
dense_arch_layer_sizes: List[int],
over_arch_layer_sizes: List[int],
id_list_features_keys: List[str],
dense_device: Optional[torch.device] = None,
dense_device: Optional[str] = None,
) -> None:
module = DLRM(
embedding_bag_collection=embedding_bag_collection,
dense_in_features=dense_in_features,
dense_arch_layer_sizes=dense_arch_layer_sizes,
over_arch_layer_sizes=over_arch_layer_sizes,
dense_device=dense_device,
dense_device=torch.device(dense_device),
)

super().__init__(module, dense_device)

self.id_list_features_keys: List[str] = id_list_features_keys
Expand Down Expand Up @@ -154,7 +155,7 @@ class DLRMPredictFactory(PredictFactory):
def __init__(self, model_config: DLRMModelConfig) -> None:
self.model_config = model_config

def create_predict_module(self, world_size: int, device: str) -> torch.nn.Module:
def create_predict_module(self, device: str) -> torch.nn.Module:
logging.basicConfig(level=logging.INFO)
set_propogate_device(True)

Expand Down

0 comments on commit 9572b4c

Please sign in to comment.