1from __future__ import absolute_import
2from __future__ import division
3from __future__ import print_function
4from __future__ import unicode_literals
5
6from distutils.spawn import find_executable
7from distutils import sysconfig, log
8import setuptools
9import setuptools.command.build_py
10import setuptools.command.develop
11import setuptools.command.build_ext
12
13from collections import namedtuple
14from contextlib import contextmanager
15import glob
16import os
17import shlex
18import subprocess
19import sys
20import struct
21from textwrap import dedent
22import multiprocessing
23
24
25TOP_DIR = os.path.realpath(os.path.dirname(__file__))
26SRC_DIR = os.path.join(TOP_DIR, 'onnx')
27TP_DIR = os.path.join(TOP_DIR, 'third_party')
28CMAKE_BUILD_DIR = os.path.join(TOP_DIR, '.setuptools-cmake-build')
29
30WINDOWS = (os.name == 'nt')
31
32CMAKE = find_executable('cmake3') or find_executable('cmake')
33MAKE = find_executable('make')
34
35install_requires = []
36setup_requires = []
37tests_require = []
38extras_require = {}
39
40################################################################################
41# Global variables for controlling the build variant
42################################################################################
43
44ONNX_ML = not bool(os.getenv('ONNX_ML') == '0')
45ONNX_NAMESPACE = os.getenv('ONNX_NAMESPACE', 'onnx')
46ONNX_BUILD_TESTS = bool(os.getenv('ONNX_BUILD_TESTS') == '1')
47
48DEBUG = bool(os.getenv('DEBUG'))
49COVERAGE = bool(os.getenv('COVERAGE'))
50
51################################################################################
52# Version
53################################################################################
54
55try:
56    git_version = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
57                                          cwd=TOP_DIR).decode('ascii').strip()
58except (OSError, subprocess.CalledProcessError):
59    git_version = None
60
61with open(os.path.join(TOP_DIR, 'VERSION_NUMBER')) as version_file:
62    VersionInfo = namedtuple('VersionInfo', ['version', 'git_version'])(
63        version=version_file.read().strip(),
64        git_version=git_version
65    )
66
67################################################################################
68# Pre Check
69################################################################################
70
71assert CMAKE, 'Could not find "cmake" executable!'
72
73################################################################################
74# Utilities
75################################################################################
76
77
78@contextmanager
79def cd(path):
80    if not os.path.isabs(path):
81        raise RuntimeError('Can only cd to absolute path, got: {}'.format(path))
82    orig_path = os.getcwd()
83    os.chdir(path)
84    try:
85        yield
86    finally:
87        os.chdir(orig_path)
88
89
90################################################################################
91# Customized commands
92################################################################################
93
94
95class ONNXCommand(setuptools.Command):
96    user_options = []
97
98    def initialize_options(self):
99        pass
100
101    def finalize_options(self):
102        pass
103
104
105class create_version(ONNXCommand):
106    def run(self):
107        with open(os.path.join(SRC_DIR, 'version.py'), 'w') as f:
108            f.write(dedent('''\
109            # This file is generated by setup.py. DO NOT EDIT!
110
111            from __future__ import absolute_import
112            from __future__ import division
113            from __future__ import print_function
114            from __future__ import unicode_literals
115
116            version = '{version}'
117            git_version = '{git_version}'
118            '''.format(**dict(VersionInfo._asdict()))))
119
120
121class cmake_build(setuptools.Command):
122    """
123    Compiles everything when `python setupmnm.py build` is run using cmake.
124
125    Custom args can be passed to cmake by specifying the `CMAKE_ARGS`
126    environment variable.
127
128    The number of CPUs used by `make` can be specified by passing `-j<ncpus>`
129    to `setup.py build`.  By default all CPUs are used.
130    """
131    user_options = [
132        (str('jobs='), str('j'), str('Specifies the number of jobs to use with make'))
133    ]
134
135    built = False
136
137    def initialize_options(self):
138        self.jobs = multiprocessing.cpu_count()
139
140    def finalize_options(self):
141        self.jobs = int(self.jobs)
142
143    def run(self):
144        if cmake_build.built:
145            return
146        cmake_build.built = True
147        if not os.path.exists(CMAKE_BUILD_DIR):
148            os.makedirs(CMAKE_BUILD_DIR)
149
150        with cd(CMAKE_BUILD_DIR):
151            build_type = 'Release'
152            # configure
153            cmake_args = [
154                CMAKE,
155                '-DPYTHON_INCLUDE_DIR={}'.format(sysconfig.get_python_inc()),
156                '-DPYTHON_EXECUTABLE={}'.format(sys.executable),
157                '-DBUILD_ONNX_PYTHON=ON',
158                '-DCMAKE_EXPORT_COMPILE_COMMANDS=ON',
159                '-DONNX_NAMESPACE={}'.format(ONNX_NAMESPACE),
160                '-DPY_EXT_SUFFIX={}'.format(sysconfig.get_config_var('EXT_SUFFIX') or ''),
161            ]
162            if COVERAGE:
163                cmake_args.append('-DONNX_COVERAGE=ON')
164            if COVERAGE or DEBUG:
165                # in order to get accurate coverage information, the
166                # build needs to turn off optimizations
167                build_type = 'Debug'
168            cmake_args.append('-DCMAKE_BUILD_TYPE=%s' % build_type)
169            if WINDOWS:
170                cmake_args.extend([
171                    # we need to link with libpython on windows, so
172                    # passing python version to window in order to
173                    # find python in cmake
174                    '-DPY_VERSION={}'.format('{0}.{1}'.format(*sys.version_info[:2])),
175                    '-DONNX_USE_MSVC_STATIC_RUNTIME=ON',
176                ])
177                if 8 * struct.calcsize("P") == 64:
178                    # Temp fix for CI
179                    # TODO: need a better way to determine generator
180                    cmake_args.append('-DCMAKE_GENERATOR_PLATFORM=x64')
181            if ONNX_ML:
182                cmake_args.append('-DONNX_ML=1')
183            if ONNX_BUILD_TESTS:
184                cmake_args.append('-DONNX_BUILD_TESTS=ON')
185            if 'CMAKE_ARGS' in os.environ:
186                extra_cmake_args = shlex.split(os.environ['CMAKE_ARGS'])
187                # prevent crossfire with downstream scripts
188                del os.environ['CMAKE_ARGS']
189                log.info('Extra cmake args: {}'.format(extra_cmake_args))
190                cmake_args.extend(extra_cmake_args)
191            cmake_args.append(TOP_DIR)
192            subprocess.check_call(cmake_args)
193
194            build_args = [CMAKE, '--build', os.curdir]
195            if WINDOWS:
196                build_args.extend(['--config', build_type])
197                build_args.extend(['--', '/maxcpucount:{}'.format(self.jobs)])
198            else:
199                build_args.extend(['--', '-j', str(self.jobs)])
200            subprocess.check_call(build_args)
201
202
203class build_py(setuptools.command.build_py.build_py):
204    def run(self):
205        self.run_command('create_version')
206        self.run_command('cmake_build')
207
208        generated_python_files = \
209            glob.glob(os.path.join(CMAKE_BUILD_DIR, 'onnx', '*.py')) + \
210            glob.glob(os.path.join(CMAKE_BUILD_DIR, 'onnx', '*.pyi'))
211
212        for src in generated_python_files:
213            dst = os.path.join(
214                TOP_DIR, os.path.relpath(src, CMAKE_BUILD_DIR))
215            self.copy_file(src, dst)
216
217        return setuptools.command.build_py.build_py.run(self)
218
219
220class develop(setuptools.command.develop.develop):
221    def run(self):
222        self.run_command('build_py')
223        setuptools.command.develop.develop.run(self)
224
225
226class build_ext(setuptools.command.build_ext.build_ext):
227    def run(self):
228        self.run_command('cmake_build')
229        setuptools.command.build_ext.build_ext.run(self)
230
231    def build_extensions(self):
232        for ext in self.extensions:
233            fullname = self.get_ext_fullname(ext.name)
234            filename = os.path.basename(self.get_ext_filename(fullname))
235
236            lib_path = CMAKE_BUILD_DIR
237            if os.name == 'nt':
238                debug_lib_dir = os.path.join(lib_path, "Debug")
239                release_lib_dir = os.path.join(lib_path, "Release")
240                if os.path.exists(debug_lib_dir):
241                    lib_path = debug_lib_dir
242                elif os.path.exists(release_lib_dir):
243                    lib_path = release_lib_dir
244            src = os.path.join(lib_path, filename)
245            dst = os.path.join(os.path.realpath(self.build_lib), "onnx", filename)
246            self.copy_file(src, dst)
247
248
249class mypy_type_check(ONNXCommand):
250    description = 'Run MyPy type checker'
251
252    def run(self):
253        """Run command."""
254        onnx_script = os.path.realpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "tools/mypy-onnx.py"))
255        returncode = subprocess.call([sys.executable, onnx_script])
256        sys.exit(returncode)
257
258
259cmdclass = {
260    'create_version': create_version,
261    'cmake_build': cmake_build,
262    'build_py': build_py,
263    'develop': develop,
264    'build_ext': build_ext,
265    'typecheck': mypy_type_check,
266}
267
268################################################################################
269# Extensions
270################################################################################
271
272ext_modules = [
273    setuptools.Extension(
274        name=str('onnx.onnx_cpp2py_export'),
275        sources=[])
276]
277
278################################################################################
279# Packages
280################################################################################
281
282# no need to do fancy stuff so far
283packages = setuptools.find_packages()
284
285install_requires.extend([
286    'protobuf',
287    'numpy',
288    'six',
289    'typing>=3.6.4',
290    'typing-extensions>=3.6.2.1',
291])
292
293################################################################################
294# Test
295################################################################################
296
297setup_requires.append('pytest-runner')
298tests_require.append('pytest')
299tests_require.append('nbval')
300tests_require.append('tabulate')
301tests_require.append('typing')
302tests_require.append('typing-extensions')
303
304if sys.version_info[0] == 3:
305    # Mypy doesn't work with Python 2
306    extras_require['mypy'] = ['mypy==0.600']
307
308################################################################################
309# Final
310################################################################################
311
312setuptools.setup(
313    name="onnx",
314    version=VersionInfo.version,
315    description="Open Neural Network Exchange",
316    ext_modules=ext_modules,
317    cmdclass=cmdclass,
318    packages=packages,
319    include_package_data=True,
320    install_requires=install_requires,
321    setup_requires=setup_requires,
322    tests_require=tests_require,
323    extras_require=extras_require,
324    author='bddppq',
325    author_email='jbai@fb.com',
326    url='https://github.com/onnx/onnx',
327    entry_points={
328        'console_scripts': [
329            'check-model = onnx.bin.checker:check_model',
330            'check-node = onnx.bin.checker:check_node',
331            'backend-test-tools = onnx.backend.test.cmd_tools:main',
332        ]
333    },
334)
335