-
Notifications
You must be signed in to change notification settings - Fork 5
/
setup.py
104 lines (91 loc) · 2.99 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import os.path as osp
from setuptools import setup, find_packages
from textwrap import dedent
import torch
from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
CUDA_AVAILABLE = torch.cuda.is_available() and CUDA_HOME is not None
DO_CPU = True
DO_CUDA = CUDA_AVAILABLE
if os.getenv('FORCE_ONLY_CPU', '0') == '1':
print('FORCE_ONLY_CPU: Only compiling CPU extensions')
DO_CPU = True
DO_CUDA = False
elif os.getenv('FORCE_ONLY_CUDA', '0') == '1':
print('FORCE_ONLY_CUDA: Only compiling CUDA extensions')
DO_CPU = False
DO_CUDA = True
elif os.getenv('FORCE_CUDA', '0') == '1':
print('FORCE_CUDA: Forcing compilation of CUDA extensions')
if not CUDA_AVAILABLE: print(f'{CUDA_AVAILABLE=}, high chance of failure')
DO_CUDA = True
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
# Define extensions
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'extensions')
cpu_kwargs = dict(
include_dirs=[extensions_dir],
extra_compile_args={'cxx': ['-O2']},
extra_link_args=['-s']
)
extensions_cpu = [
CppExtension('select_knn_cpu', ['extensions/select_knn_cpu.cpp'], **cpu_kwargs)
]
cuda_kwargs = dict(
include_dirs=[extensions_dir],
extra_compile_args={'cxx': ['-O2'], 'nvcc': ['--expt-relaxed-constexpr', '-O2']},
extra_link_args=['-s']
)
extensions_cuda = [
CUDAExtension(
'select_knn_cuda',
['extensions/select_knn_cuda.cpp', 'extensions/select_knn_cuda_kernel.cu'],
**cuda_kwargs
)
]
extensions = []
if DO_CPU: extensions.extend(extensions_cpu)
if DO_CUDA: extensions.extend(extensions_cuda)
# Print extensions
def repr_ext(ext):
"""
Debug print for an extension
"""
return dedent(f"""\
{ext.name}
sources: {', '.join(ext.sources)}
extra_compile_args: {ext.extra_compile_args}
extra_link_args: {ext.extra_link_args}
""")
print('\n---------------------\nExtensions:')
for ext in extensions: print(repr_ext(ext))
print('---------------------')
# Setup call
tests_require = ['pytest', 'pytest-cov', 'scipy']
setup(
name='torch_cmspepr',
version='1.0.0',
author='Lindsey Gray <[email protected]>, Jan Kieseler <[email protected]>, Thomas Klijnsma <[email protected]>',
author_email='[email protected]',
url='',
description=('PyTorch Extension Library for HGCAL Specific knn optimizations'),
keywords=[
'pytorch',
'knn',
'geometric-deep-learning',
'graph-neural-networks',
'cluster-algorithms',
],
license='MIT',
python_requires='>=3.6',
install_requires=[],
setup_requires=['pytest-runner'],
tests_require=tests_require,
extras_require={'test': tests_require},
ext_modules=extensions if not BUILD_DOCS else [],
cmdclass={
'build_ext':
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
},
packages=find_packages(),
)