forked from NervanaSystems/ngraph-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
140 lines (120 loc) · 4.62 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext as _build_ext
import sys
import sysconfig
import os
import re
def get_version(public_build_version):
local_build_version = os.environ.get('LOCAL_VERSION', '')
if not local_build_version:
return public_build_version
# Local version matching PEP440
local_ver_regexp = re.compile(r'(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*)', re.IGNORECASE)
if not local_ver_regexp.match(local_build_version):
raise RuntimeError("Invalid local version format:{}".format(local_build_version))
return "{}+{}".format(public_build_version, local_build_version)
class build_ext(_build_ext):
"""
Class to build Extensions without platform suffixes
ex: mkldnn_engine.cpython-35m-x86_64-linux-gnu.so => mkldnn_engine.so
"""
def get_ext_filename(self, ext_name):
_filename = _build_ext.get_ext_filename(self, ext_name)
return self.get_ext_filename_without_suffix(_filename)
def get_ext_filename_without_suffix(self, _filename):
name, ext = os.path.splitext(_filename)
ext_suffix = sysconfig.get_config_var('EXT_SUFFIX')
if ext_suffix == ext or ext_suffix == None:
return _filename
ext_suffix = ext_suffix.replace(ext, '')
idx = name.find(ext_suffix)
if idx == -1:
return _filename
else:
return name[:idx] + ext
ext_modules = []
if "MKLDNN_ROOT" in os.environ:
MKLDNNROOT=os.environ['MKLDNN_ROOT']
if sys.platform == 'darwin':
extra_link_args = ["-Wl,-rpath,%s/lib"%(MKLDNNROOT)]
else:
extra_link_args = ["-shared", "-Wl,-rpath,%s/lib"%(MKLDNNROOT)]
ext_modules.append(Extension('mkldnn_engine',
include_dirs = ['%s/include'%(MKLDNNROOT)],
extra_compile_args = ["-std=gnu99"],
extra_link_args = extra_link_args,
library_dirs = ['%s/lib'%(MKLDNNROOT)],
libraries = ['mkldnn'],
sources = ['ngraph/transformers/cpu/convolution.c', \
'ngraph/transformers/cpu/elementwise.c', \
'ngraph/transformers/cpu/innerproduct.c', \
'ngraph/transformers/cpu/mkldnn_engine.c',\
'ngraph/transformers/cpu/relu.c', \
'ngraph/transformers/cpu/pooling.c', \
'ngraph/transformers/cpu/batchnorm.c']))
"""
List requirements here as loosely as possible but include known limitations.
For example if we know that cffi <1.0 does not work, we list the min version here,
but for other package where no known limitations exist we do not impose a restriction.
This impacts external users who install ngraph via pip, and may install ngraph inside
an environment where an existing version of these required packages exists and should
not be upgraded/downgraded by our install unless absolutely necessary.
"""
requirements = [
"numpy",
"h5py",
"appdirs",
"six",
"tensorflow",
"scipy",
"protobuf",
"requests",
"frozendict",
"cached-property",
"orderedset",
"tqdm",
"enum34",
"future",
"configargparse",
"cachetools",
"decorator",
"pynvrtc",
"monotonic",
"pillow",
"jupyter",
"nbconvert",
"nbformat",
"setuptools",
"cffi>=1.0",
"parsel",
]
setup(
name="ngraph",
version=get_version(public_build_version="0.4.0"),
packages=find_packages(exclude=["tests"]),
install_requires=requirements,
author='Nervana Systems',
author_email='[email protected]',
url='http://www.nervanasys.com',
license='License :: Apache 2.0',
cmdclass={
'build_ext': build_ext,
},
ext_modules=ext_modules,
package_data={'ngraph': ['logging.json']},
)