Skip to content

Commit

Permalink
Add proper support for CPU-only Pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
ducksoup committed Aug 20, 2019
1 parent fa84237 commit 4043965
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 35 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ To install PyTorch, please refer to https://github.com/pytorch/pytorch#installat

To install the package containing the iABN layers:
```bash
pip install git+https://github.com/mapillary/[email protected].4
pip install git+https://github.com/mapillary/[email protected].5
```
Note that some parts of InPlace-ABN have native C++/CUDA implementations, meaning that the command above will need to
compile them.
Expand Down
16 changes: 16 additions & 0 deletions include/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,19 @@
AT_ERROR(#NAME, " not implemented for '", toString(x_type), "'"); \
} \
}()

#ifdef WITH_CUDA
#define CUDA_DISPATCH(REF_TENSOR, METHOD, ...) \
if ((REF_TENSOR).is_cuda()) { \
return METHOD ## _cuda(__VA_ARGS__); \
} else { \
return METHOD ## _cpu(__VA_ARGS__); \
}
#else
#define CUDA_DISPATCH(REF_TENSOR, METHOD, ...) \
if ((REF_TENSOR).is_cuda()) { \
AT_ERROR("CUDA support was not enabled at compile time"); \
} else { \
return METHOD ## _cpu(__VA_ARGS__); \
}
#endif
44 changes: 30 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from os import path, listdir

import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension


def find_sources(root_dir):
def find_sources(root_dir, with_cuda=True):
extensions = [".cpp", ".cu"] if with_cuda else [".cpp"]

sources = []
for file in listdir(root_dir):
_, ext = path.splitext(file)
if ext in [".cpp", ".cu"]:
if ext in extensions:
sources.append(path.join(root_dir, file))

return sources
Expand All @@ -19,6 +22,29 @@ def find_sources(root_dir):
with open(path.join(here, "README.md"), encoding="utf-8") as f:
long_description = f.read()

if torch.cuda.is_available():
ext_modules = [
CUDAExtension(
name="inplace_abn._backend",
sources=find_sources("src"),
extra_compile_args={
"cxx": ["-O3"],
"nvcc": []
},
include_dirs=["include/"],
define_macros=[("WITH_CUDA", 1)]
)
]
else:
ext_modules = [
CppExtension(
name="inplace_abn._backend",
sources=find_sources("src", False),
extra_compile_args=["-O3"],
include_dirs=["include/"]
)
]

setuptools.setup(
# Meta-data
name="inplace-abn",
Expand All @@ -45,16 +71,6 @@ def find_sources(root_dir):

# Package description
packages=["inplace_abn"],
ext_modules=[
CUDAExtension(
name="inplace_abn._backend",
sources=find_sources("src"),
extra_compile_args={
"cxx": ["-O3"],
"nvcc": []
},
include_dirs=["include/"],
)
],
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension}
)
29 changes: 9 additions & 20 deletions src/inplace_abn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "inplace_abn.h"
#include "checks.h"
#include "utils.h"
#include "dispatch.h"

/***********************************************************************************************************************
* Exposed methods
Expand All @@ -14,13 +15,10 @@
std::tuple<at::Tensor, at::Tensor, at::Tensor> statistics(const at::Tensor& x) {
AT_CHECK(x.ndimension() >= 2, "x should have at least 2 dimensions");

if (x.is_cuda()) {
return statistics_cuda(x);
} else {
return statistics_cpu(x);
}
CUDA_DISPATCH(x, statistics, x)
}

#ifdef WITH_CUDA
std::tuple<at::Tensor, at::Tensor, at::Tensor> reduce_statistics(
const at::Tensor& all_mean, const at::Tensor& all_var, const at::Tensor& all_count) {
// Inputs shouldn't be half
Expand All @@ -44,6 +42,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> reduce_statistics(

return reduce_statistics_cuda(all_mean, all_var, all_count);
}
#endif

void forward(at::Tensor& x, const at::Tensor& mean, const at::Tensor& var,
const c10::optional<at::Tensor>& weight, const c10::optional<at::Tensor>& bias,
Expand All @@ -62,11 +61,7 @@ void forward(at::Tensor& x, const at::Tensor& mean, const at::Tensor& var,
AT_CHECK((weight.has_value() && bias.has_value()) || (!weight.has_value() && !bias.has_value()),
"weight and bias must be equally present or not present");

if (x.is_cuda()) {
forward_cuda(x, mean, var, weight, bias, eps, activation, activation_param);
} else {
forward_cpu(x, mean, var, weight, bias, eps, activation, activation_param);
}
CUDA_DISPATCH(x, forward, x, mean, var, weight, bias, eps, activation, activation_param)
}

std::tuple<at::Tensor, at::Tensor> backward_reduce(
Expand All @@ -88,11 +83,7 @@ std::tuple<at::Tensor, at::Tensor> backward_reduce(
AT_CHECK((weight.has_value() && bias.has_value()) || (!weight.has_value() && !bias.has_value()),
"weight and bias must be equally present or not present");

if (y_act.is_cuda()) {
return backward_reduce_cuda(y_act, dy_act, weight, bias, eps, activation, activation_param);
} else {
return backward_reduce_cpu(y_act, dy_act, weight, bias, eps, activation, activation_param);
}
CUDA_DISPATCH(y_act, backward_reduce, y_act, dy_act, weight, bias, eps, activation, activation_param)
}

at::Tensor backward(const at::Tensor& xhat, const at::Tensor& dy, const at::Tensor& var, const at::Tensor& count,
Expand All @@ -111,11 +102,7 @@ at::Tensor backward(const at::Tensor& xhat, const at::Tensor& dy, const at::Tens
AT_CHECK(is_compatible_weight(xhat, weight.value()),
"weight is not compatible with xhat (wrong size or scalar type)");

if (xhat.is_cuda()) {
return backward_cuda(xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, eps);
} else {
return backward_cpu(xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, eps);
}
CUDA_DISPATCH(xhat, backward, xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, eps)
}

at::Tensor backward_test(const at::Tensor& dy_, const at::Tensor& var, const c10::optional<at::Tensor>& weight,
Expand Down Expand Up @@ -148,7 +135,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

// Forward methods
m.def("statistics", &statistics, "Compute iABN statistics, i.e. mean, biased variance and sample count");
#ifdef WITH_CUDA
m.def("reduce_statistics", &reduce_statistics, "Reduce statistics from multiple GPUs");
#endif
m.def("forward", &forward, "iABN forward pass. This is an in-place operation w.r.t. x");

// Backward methods
Expand Down

0 comments on commit 4043965

Please sign in to comment.