1import distutils
2import os
3import platform
4import sys
5import tempfile
6import warnings
7from distutils import ccompiler
8from distutils.command.build_ext import build_ext
9from distutils.errors import CompileError, LinkError
10from distutils.sysconfig import customize_compiler
11from os.path import join
12
13import setuptools
14from setuptools import setup, Extension
15
16
17class ConvertNotebooksToDocs(distutils.cmd.Command):
18    description = "Convert the example notebooks to reStructuredText that will" \
19                  "be available in the documentation."
20
21    user_options = []
22
23    def initialize_options(self):
24        pass
25
26    def finalize_options(self):
27        pass
28
29    def run(self):
30        import nbconvert
31        from os.path import join
32
33        exporter = nbconvert.RSTExporter()
34        writer = nbconvert.writers.FilesWriter()
35
36        files = [
37            join("examples", "01_simple_usage.ipynb"),
38            join("examples", "02_advanced_usage.ipynb"),
39            join("examples", "03_preserving_global_structure.ipynb"),
40            join("examples", "04_large_data_sets.ipynb"),
41        ]
42        target_dir = join("docs", "source", "examples")
43
44        for fname in files:
45            self.announce(f"Converting {fname}...")
46            directory, nb_name = fname.split("/")
47            nb_name, _ = nb_name.split(".")
48            body, resources = exporter.from_file(fname)
49            writer.build_directory = join(target_dir, nb_name)
50            writer.write(body, resources, nb_name)
51
52
53class get_numpy_include:
54    """Helper class to determine the numpy include path
55
56    The purpose of this class is to postpone importing numpy until it is
57    actually installed, so that the ``get_include()`` method can be invoked.
58
59    """
60    def __str__(self):
61        import numpy
62        return numpy.get_include()
63
64
65def get_include_dirs():
66    """Get include dirs for the compiler."""
67    return (
68        os.path.join(sys.prefix, "include"),
69        os.path.join(sys.prefix, "Library", "include"),
70    )
71
72
73def get_library_dirs():
74    """Get library dirs for the compiler."""
75    return (
76        os.path.join(sys.prefix, "lib"),
77        os.path.join(sys.prefix, "Library", "lib"),
78    )
79
80
81def has_c_library(library, extension=".c"):
82    """Check whether a C/C++ library is available on the system to the compiler.
83
84    Parameters
85    ----------
86    library: str
87        The library we want to check for e.g. if we are interested in FFTW3, we
88        want to check for `fftw3.h`, so this parameter will be `fftw3`.
89    extension: str
90        If we want to check for a C library, the extension is `.c`, for C++
91        `.cc`, `.cpp` or `.cxx` are accepted.
92
93    Returns
94    -------
95    bool
96        Whether or not the library is available.
97
98    """
99    with tempfile.TemporaryDirectory(dir=".") as directory:
100        name = join(directory, "%s%s" % (library, extension))
101        with open(name, "w") as f:
102            f.write("#include <%s.h>\n" % library)
103            f.write("int main() {}\n")
104
105        # Get a compiler instance
106        compiler = ccompiler.new_compiler()
107        # Configure compiler to do all the platform specific things
108        customize_compiler(compiler)
109        # Add conda include dirs
110        for inc_dir in get_include_dirs():
111            compiler.add_include_dir(inc_dir)
112        assert isinstance(compiler, ccompiler.CCompiler)
113
114        try:
115            # Try to compile the file using the C compiler
116            compiler.link_executable(compiler.compile([name]), name)
117            return True
118        except (CompileError, LinkError):
119            return False
120
121
122class CythonBuildExt(build_ext):
123    def build_extensions(self):
124        # Automatically append the file extension based on language.
125        # ``cythonize`` does this for us automatically, so it's not necessary if
126        # that was run
127        for extension in extensions:
128            for idx, source in enumerate(extension.sources):
129                base, ext = os.path.splitext(source)
130                if ext == ".pyx":
131                    base += ".cpp" if extension.language == "c++" else ".c"
132                    extension.sources[idx] = base
133
134        extra_compile_args = []
135        extra_link_args = []
136
137        # Optimization compiler/linker flags are added appropriately
138        compiler = self.compiler.compiler_type
139        if compiler == "unix":
140            extra_compile_args += ["-O3"]
141        elif compiler == "msvc":
142            extra_compile_args += ["/Ox", "/fp:fast"]
143
144        if compiler == "unix" and platform.system() == "Darwin":
145            # For some reason fast math causes segfaults on linux but works on mac
146            extra_compile_args += ["-ffast-math", "-fno-associative-math"]
147
148        # Annoy specific flags
149        annoy_ext = None
150        for extension in extensions:
151            if "annoy.annoylib" in extension.name:
152                annoy_ext = extension
153        assert annoy_ext is not None, "Annoy extension not found!"
154
155        if compiler == "unix":
156            annoy_ext.extra_compile_args += ["-std=c++14"]
157            annoy_ext.extra_compile_args += ["-DANNOYLIB_MULTITHREADED_BUILD"]
158        elif compiler == "msvc":
159            annoy_ext.extra_compile_args += ["/std:c++14"]
160
161        # Set minimum deployment version for MacOS
162        if compiler == "unix" and platform.system() == "Darwin":
163            extra_compile_args += ["-mmacosx-version-min=10.12"]
164            extra_link_args += ["-stdlib=libc++", "-mmacosx-version-min=10.12"]
165
166        # We don't want the compiler to optimize for system architecture if
167        # we're building packages to be distributed by conda-forge, but if the
168        # package is being built locally, this is desired
169        if not ("AZURE_BUILD" in os.environ or "CONDA_BUILD" in os.environ):
170            if platform.machine() == "ppc64le":
171                extra_compile_args += ["-mcpu=native"]
172            if platform.machine() == "x86_64":
173                extra_compile_args += ["-march=native"]
174
175        # We will disable openmp flags if the compiler doesn"t support it. This
176        # is only really an issue with OSX clang
177        if has_c_library("omp"):
178            print("Found openmp. Compiling with openmp flags...")
179            if platform.system() == "Darwin" and compiler == "unix":
180                extra_compile_args += ["-Xpreprocessor", "-fopenmp"]
181                extra_link_args += ["-lomp"]
182            elif compiler == "unix":
183                extra_compile_args += ["-fopenmp"]
184                extra_link_args += ["-fopenmp"]
185            elif compiler == "msvc":
186                extra_compile_args += ["/openmp"]
187                extra_link_args += ["/openmp"]
188        else:
189            warnings.warn(
190                "You appear to be using a compiler which does not support "
191                "openMP, meaning that the library will not be able to run on "
192                "multiple cores. Please install/enable openMP to use multiple "
193                "cores."
194            )
195
196        for extension in self.extensions:
197            extension.extra_compile_args += extra_compile_args
198            extension.extra_link_args += extra_link_args
199
200        # Add numpy and system include directories
201        for extension in self.extensions:
202            extension.include_dirs.extend(get_include_dirs())
203            extension.include_dirs.append(get_numpy_include())
204
205        # Add numpy and system include directories
206        for extension in self.extensions:
207            extension.library_dirs.extend(get_library_dirs())
208
209        super().build_extensions()
210
211
212# Prepare the Annoy extension
213# Adapted from annoy setup.py
214# Various platform-dependent extras
215extra_compile_args = []
216extra_link_args = []
217
218annoy_path = "openTSNE/dependencies/annoy/"
219annoy = Extension(
220    "openTSNE.dependencies.annoy.annoylib",
221    [annoy_path + "annoymodule.cc"],
222    depends=[annoy_path + f for f in ["annoylib.h", "kissrandom.h", "mman.h"]],
223    language="c++",
224    extra_compile_args=extra_compile_args,
225    extra_link_args=extra_link_args,
226)
227
228# Other extensions
229extensions = [
230    Extension("openTSNE.quad_tree", ["openTSNE/quad_tree.pyx"], language="c++"),
231    Extension("openTSNE._tsne", ["openTSNE/_tsne.pyx"], language="c++"),
232    Extension("openTSNE.kl_divergence", ["openTSNE/kl_divergence.pyx"], language="c++"),
233    annoy,
234]
235
236
237# Check if we have access to FFTW3 and if so, use that implementation
238if has_c_library("fftw3"):
239    print("FFTW3 header files found. Using FFTW implementation of FFT.")
240    extension_ = Extension(
241        "openTSNE._matrix_mul.matrix_mul",
242        ["openTSNE/_matrix_mul/matrix_mul_fftw3.pyx"],
243        libraries=["fftw3"],
244        language="c++",
245    )
246    extensions.append(extension_)
247else:
248    print("FFTW3 header files not found. Using numpy implementation of FFT.")
249    extension_ = Extension(
250        "openTSNE._matrix_mul.matrix_mul",
251        ["openTSNE/_matrix_mul/matrix_mul_numpy.pyx"],
252        language="c++",
253    )
254    extensions.append(extension_)
255
256try:
257    from Cython.Build import cythonize
258    extensions = cythonize(extensions)
259except ImportError:
260    pass
261
262
263def readme():
264    with open("README.rst", encoding="utf-8") as f:
265        return f.read()
266
267
268# Read in version
269__version__: str = ""  # This is overridden by the next line
270exec(open(os.path.join("openTSNE", "version.py")).read())
271
272setup(
273    name="openTSNE",
274    description="Extensible, parallel implementations of t-SNE",
275    long_description=readme(),
276    version=__version__,
277    license="BSD-3-Clause",
278
279    author="Pavlin Poličar",
280    author_email="pavlin.g.p@gmail.com",
281    url="https://github.com/pavlin-policar/openTSNE",
282    project_urls={
283        "Documentation": "https://opentsne.readthedocs.io/",
284        "Source": "https://github.com/pavlin-policar/openTSNE",
285        "Issue Tracker": "https://github.com/pavlin-policar/openTSNE/issues",
286    },
287    classifiers=[
288        "Development Status :: 5 - Production/Stable",
289        "Intended Audience :: Science/Research",
290        "Intended Audience :: Developers",
291        "Topic :: Software Development",
292        "Topic :: Scientific/Engineering",
293        "Operating System :: Microsoft :: Windows",
294        "Operating System :: POSIX",
295        "Operating System :: Unix",
296        "Operating System :: MacOS",
297        "License :: OSI Approved",
298        "Programming Language :: Python :: 3",
299        "Topic :: Scientific/Engineering :: Artificial Intelligence",
300        "Topic :: Scientific/Engineering :: Visualization",
301        "Topic :: Software Development :: Libraries :: Python Modules",
302    ],
303
304    packages=setuptools.find_packages(include=["openTSNE", "openTSNE.*"]),
305    python_requires=">=3.6",
306    install_requires=[
307        "numpy>=1.16.6",
308        "scikit-learn>=0.20",
309        "scipy",
310    ],
311    extras_require={
312        "hnsw": "hnswlib~=0.4.0",
313        "pynndescent": "pynndescent~=0.5.0",
314    },
315    ext_modules=extensions,
316    cmdclass={"build_ext": CythonBuildExt, "convert_notebooks": ConvertNotebooksToDocs},
317)
318