Skip to content

Commit

Permalink
Merge pull request #513 from argonne-lcf/libtorch
Browse files Browse the repository at this point in the history
Updated LibTorch docs for Polaris and Aurora
  • Loading branch information
felker authored Nov 1, 2024
2 parents f6cc295 + a174961 commit cbe001f
Show file tree
Hide file tree
Showing 3 changed files with 322 additions and 122 deletions.
294 changes: 172 additions & 122 deletions docs/aurora/data-science/frameworks/libtorch.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,103 @@ During compilation, Intel optimizations will be activated automatically once the
## Environment Setup

To use LibTorch on Aurora, load the ML frameworks module
```bash
module load frameworks/2024.2.1_u1
```
module use /soft/modulefiles
module load frameworks/2023.12.15.001
```
which will also load the consistent oneAPI SDK and `cmake`.
which will also load the consistent oneAPI SDK (version 2024.2) and `cmake`.


## Torch and IPEX libraries

With the ML frameworks module loaded as shown above, run
```
```bash
python -c 'import torch; print(torch.__path__[0])'
python -c 'import torch;print(torch.utils.cmake_prefix_path)'
```
to find the path to the Torch libraries, include files, and CMake files.

For the path to the IPEX dynamic library, run
```
```bash
python -c 'import torch; print(torch.__path__[0].replace("torch","intel_extension_for_pytorch"))'
```


## Linking LibTorch and IPEX Libraries

## Model Inferencing Using the Torch API
This example shows how to perform inference on the ResNet50 model using only the LibTorch API.
First, get a jit-traced version of the model running `resnet50_trace.py` below.
When using the CMake build system, LibTorch and IPEX libraries can be linked to an example C++ application using the following `CMakeLists.txt` file
```bash
cmake_minimum_required(VERSION 3.5 FATAL_ERROR)
cmake_policy(SET CMP0074 NEW)
project(project-name)

find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS} -Wl,--no-as-needed")
set(TORCH_LIBS ${TORCH_LIBRARIES})

find_library(IPEX_LIB intel-ext-pt-gpu PATHS ${INTEL_EXTENSION_FOR_PYTORCH_PATH}/lib NO_DEFAULT_PATH REQUIRED)
set(TORCH_LIBS ${TORCH_LIBS} ${IPEX_LIB})
include_directories(SYSTEM ${INTEL_EXTENSION_FOR_PYTORCH_PATH}/include)

add_executable(exe main.cpp)
target_link_libraries(exe ${TORCH_LIBS})

set_property(TARGET exe PROPERTY CXX_STANDARD 17)
```

and configuring the build with
```
cmake \
-DCMAKE_PREFIX_PATH=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` \
-DINTEL_EXTENSION_FOR_PYTORCH_PATH=`python -c 'import torch; print(torch.__path__[0].replace("torch","intel_extension_for_pytorch"))'` \
./
make
```


## Device Introspection

Similarly to PyTorch, LibTorch provides API to perform instrospection on the devices available on the system.
The simple code below shows how to check if XPU devices are available, how many are present, and how to loop through them to discover some properties.

```c++
#include <torch/torch.h>
#include <c10/xpu/XPUFunctions.h>

int main(int argc, const char* argv[])
{
torch::DeviceType device;
int num_devices = 0;
if (torch::xpu::is_available()) {
std::cout << "XPU devices detected" << std::endl;
device = torch::kXPU;

num_devices = torch::xpu::device_count();
std::cout << "Number of XPU devices: " << num_devices << std::endl;

for (int i = 0; i < num_devices; ++i) {
c10::xpu::set_device(i);
std::cout << "Device " << i << ":" << std::endl;

c10::xpu::DeviceProp device_prop{};
c10::xpu::get_device_properties(&device_prop, i);
std::cout << " Name: " << device_prop.name << std::endl;
std::cout << " Total memory: " << device_prop.global_mem_size / (1024 * 1024) << " MB" << std::endl;
}
} else {
device = torch::kCPU;
std::cout << "No XPU devices detected, setting device to CPU" << std::endl;
}

return 0;
}
```
## Model Inferencing Using the Torch API
This example shows how to perform inference with the ResNet50 model using LibTorch.
First, get a jit-traced version of the model executing `python resnet50_trace.py` (shown below) on a compute node.
```python
import torch
import torchvision
import intel_extension_for_pytorch as ipex
Expand All @@ -58,81 +128,53 @@ print(f"Inference time: {toc-tic}")
torch.jit.save(model_jit, f"resnet50_jit.pt")
```

Then, use the source code in `inference-example.cpp`
```
Then, build `inference-example.cpp` (shown below)
```c++
#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>

int main(int argc, const char* argv[]) {
torch::jit::script::Module model;
try {
model = torch::jit::load(argv[1]);
std::cout << "Loaded the model\n";
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
// Upload model to GPU
model.to(torch::Device(torch::kXPU));
std::cout << "Model offloaded to GPU\n\n";
auto options = torch::TensorOptions()
torch::jit::script::Module model;
try {
model = torch::jit::load(argv[1]);
std::cout << "Loaded the model\n";
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}

model.to(torch::Device(torch::kXPU));
std::cout << "Model offloaded to GPU\n\n";

auto options = torch::TensorOptions()
.dtype(torch::kFloat32)
.device(torch::kXPU);
torch::Tensor input_tensor = torch::rand({1,3,224,224}, options);
assert(input_tensor.dtype() == torch::kFloat32);
assert(input_tensor.device().type() == torch::kXPU);
std::cout << "Created the input tesor on GPU\n";
torch::Tensor input_tensor = torch::rand({1,3,224,224}, options);
assert(input_tensor.dtype() == torch::kFloat32);
assert(input_tensor.device().type() == torch::kXPU);
std::cout << "Created the input tesor on GPU\n";

torch::Tensor output = model.forward({input_tensor}).toTensor();
std::cout << "Performed inference\n\n";
torch::Tensor output = model.forward({input_tensor}).toTensor();
std::cout << "Performed inference\n\n";

std::cout << "Predicted tensor is : \n";
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/10) << '\n';
std::cout << "Slice of predicted tensor is : \n";
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/10) << '\n';

return 0;
return 0;
}
```
and the `CMakeLists.txt` file

```
cmake_minimum_required(VERSION 3.5 FATAL_ERROR)
cmake_policy(SET CMP0074 NEW)
project(inference-example)
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS} -Wl,--no-as-needed")
add_executable(inference-example inference-example.cpp)
target_link_libraries(inference-example "${TORCH_LIBRARIES}" "${INTEL_EXTENSION_FOR_PYTORCH_PATH}/lib/libintel-ext-pt-gpu.so")
set_property(TARGET inference-example PROPERTY CXX_STANDARD 17)
```

to build the inference example.

Finally, execute the `doConfig.sh` script below
```
#!/bin/bash
cmake \
-DCMAKE_PREFIX_PATH=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` \
-DINTEL_EXTENSION_FOR_PYTORCH_PATH=`python -c 'import torch; print(torch.__path__[0].replace("torch","intel_extension_for_pytorch"))'` \
./
and execute it with `./inference-example ./resnet50_jit.pt`.
make
./inference-example ./resnet50_jit.pt
```
## LibTorch Interoperability with SYCL Pipelines
The LibTorch API can be integrated with data pilelines using SYCL to offload and operate on the input and output data on the Intel Max 1550 GPU.
The code below extends the above example of performing inference with the ResNet50 model by first generating the input data on the CPU, then offloading it to the GPU with SYCL, and finally passing the device pointer to LibTorch for inference.
The LibTorch API can be integrated with data pilelines using SYCL to operate on input and output data already offloaded to the Intel Max 1550 GPU.
The code below extends the above example of performing inference with the ResNet50 model by first generating the input data on the CPU, then offloading it to the GPU with SYCL, and finally passing the device pointer to LibTorch for inference using `torch::from_blob()`, which create a Torch tensor from a data pointer with zero-copy.
The source code for `inference-example.cpp` is modified as follows
```
```c++
#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>
Expand All @@ -143,78 +185,86 @@ const int N_BATCH = 1;
const int N_CHANNELS = 3;
const int N_PIXELS = 224;
const int INPUTS_SIZE = N_BATCH*N_CHANNELS*N_PIXELS*N_PIXELS;
const int OUTPUTS_SIZE = N_BATCH*N_CHANNELS;
int main(int argc, const char* argv[]) {
torch::jit::script::Module model;
try {
model = torch::jit::load(argv[1]);
std::cout << "Loaded the model\n";
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
// Upload model to GPU
model.to(torch::Device(torch::kXPU));
std::cout << "Model offloaded to GPU\n\n";
// Create the input data on the host
std::vector<float> inputs(INPUTS_SIZE);
srand(12345);
for (int i=0; i<INPUTS_SIZE; i++) {
inputs[i] = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
}
std::cout << "Generated input data on the host \n\n";
// Move input data to the device with SYCL
sycl::queue Q(sycl::gpu_selector_v);
std::cout << "SYCL running on "
torch::jit::script::Module model;
try {
model = torch::jit::load(argv[1]);
std::cout << "Loaded the model\n";
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
model.to(torch::Device(torch::kXPU));
std::cout << "Model offloaded to GPU\n\n";
// Create the input data on the host
std::vector<float> inputs(INPUTS_SIZE);
srand(12345);
for (int i=0; i<INPUTS_SIZE; i++) {
inputs[i] = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
}
std::cout << "Generated input data on the host \n\n";
// Move input data to the device with SYCL
sycl::queue Q(sycl::gpu_selector_v);
std::cout << "SYCL running on "
<< Q.get_device().get_info<sycl::info::device::name>()
<< "\n\n";
float *d_inputs = sycl::malloc_device<float>(INPUTS_SIZE, Q);
Q.memcpy((void *) d_inputs, (void *) inputs.data(), INPUTS_SIZE*sizeof(float));
Q.wait();
// Convert input array to Torch tensor
auto options = torch::TensorOptions()
float *d_inputs = sycl::malloc_device<float>(INPUTS_SIZE, Q);
Q.memcpy((void *) d_inputs, (void *) inputs.data(), INPUTS_SIZE*sizeof(float));
Q.wait();
// Pre-allocate the output array on device and fill with a number
double *d_outputs = sycl::malloc_device<double>(OUTPUTS_SIZE, Q);
Q.submit([&](sycl::handler &cgh) {
cgh.parallel_for(OUTPUTS_SIZE, [=](sycl::id<1> idx) {
d_outputs[idx] = 1.2345;
});
});
Q.wait();
std::cout << "Offloaded input data to the GPU \n\n";
// Convert input array to Torch tensor
auto options = torch::TensorOptions()
.dtype(torch::kFloat32)
.device(torch::kXPU);
torch::Tensor input_tensor = at::from_blob(d_inputs, {N_BATCH,N_CHANNELS,N_PIXELS,N_PIXELS},
nullptr, at::device(torch::kXPU).dtype(torch::kFloat32),
torch::kXPU)
.to(torch::kXPU);
assert(input_tensor.dtype() == torch::kFloat32);
assert(input_tensor.device().type() == torch::kXPU);
std::cout << "Created the input tesor on GPU\n";
// Perform inference
torch::Tensor output = model.forward({input_tensor}).toTensor();
std::cout << "Performed inference\n\n";
std::cout << "Predicted tensor is : \n";
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/10) << '\n';
return 0;
torch::Tensor input_tensor = torch::from_blob(
d_inputs,
{N_BATCH,N_CHANNELS,N_PIXELS,N_PIXELS},
options);
assert(input_tensor.dtype() == torch::kFloat32);
assert(input_tensor.device().type() == torch::kXPU);
std::cout << "Created the input Torch tesor on GPU\n\n";
// Perform inference
torch::NoGradGuard no_grad; // equivalent to "with torch.no_grad():" in PyTorch
torch::Tensor output = model.forward({input_tensor}).toTensor();
std::cout << "Performed inference\n\n";
// Copy the output Torch tensor to the SYCL pointer
auto output_tensor_ptr = output.contiguous().data_ptr();
Q.memcpy((void *) d_outputs, (void *) output_tensor_ptr, OUTPUTS_SIZE*sizeof(double));
Q.wait();
std::cout << "Copied output Torch tensor to SYCL pointer\n";
return 0;
}
```

and the CMake commands also change to include
```
#!/bin/bash
Note that an additional C++ flag is needed in this case, as shown below in the `cmake` command
```bash
cmake \
-DCMAKE_CXX_FLAGS="-std=c++17 -fsycl" \
-DCMAKE_PREFIX_PATH=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` \
-DINTEL_EXTENSION_FOR_PYTORCH_PATH=`python -c 'import torch; print(torch.__path__[0].replace("torch","intel_extension_for_pytorch"))'` \
./
make
./inference-example ./resnet50_jit.pt
```

## Known Issues

* The LibTorch introspection API that are available for CUDA devices, such as `torch::cuda::is_available()`, are still under development for Intel Max 1550 GPU.



Expand Down
Loading

0 comments on commit cbe001f

Please sign in to comment.