1"""
2To update the scan Cython code in Theano you must
3- update the version in this file and scan_perform.py
4- call "cd theano/scan_module/; cython scan_perform.pyx; patch scan_perform.c numpy_api_changes.diff"
5
6"""
7
8from __future__ import absolute_import, print_function, division
9import errno
10import logging
11import os
12import sys
13import warnings
14
15import numpy as np
16
17import theano
18from theano import config
19from theano.compat import reload
20from theano.gof.compilelock import get_lock, release_lock
21from theano.gof import cmodule
22
23
24_logger = logging.getLogger('theano.scan_module.scan_perform')
25
26
27version = 0.296  # must match constant returned in function get_version()
28
29need_reload = False
30
31
32def try_import():
33    global scan_perform
34    sys.path[0:0] = [config.compiledir]
35    import scan_perform
36    del sys.path[0]
37
38
39def try_reload():
40    sys.path[0:0] = [config.compiledir]
41    reload(scan_perform)
42    del sys.path[0]
43
44try:
45    try_import()
46    need_reload = True
47    if version != getattr(scan_perform, '_version', None):
48        raise ImportError()
49except ImportError:
50    get_lock()
51    try:
52        # Maybe someone else already finished compiling it while we were
53        # waiting for the lock?
54        try:
55            if need_reload:
56                # The module was successfully imported earlier: we need to
57                # reload it to check if the version was updated.
58                try_reload()
59            else:
60                try_import()
61                need_reload = True
62            if version != getattr(scan_perform, '_version', None):
63                raise ImportError()
64        except ImportError:
65            if not theano.config.cxx:
66                raise ImportError("no c compiler, can't compile cython code")
67            _logger.info("Compiling C code for scan")
68            dirname = 'scan_perform'
69            cfile = os.path.join(theano.__path__[0], 'scan_module', 'c_code',
70                                 'scan_perform.c')
71            if not os.path.exists(cfile):
72                # This can happen in not normal case. We just
73                # disable the cython code. If we are here the user
74                # didn't disable the compiler, so print a warning.
75                warnings.warn(
76                    "The file scan_perform.c is not available. This do"
77                    "not happen normally. You are probably in a strange"
78                    "setup. This mean Theano can not use the cython code for "
79                    "scan. If you"
80                    "want to remove this warning, use the Theano flag"
81                    "'cxx=' (set to an empty string) to disable all c"
82                    "code generation."
83                )
84                raise ImportError("The file lazylinker_c.c is not available.")
85
86            with open(cfile) as f:
87                code = f.read()
88            loc = os.path.join(config.compiledir, dirname)
89            if not os.path.exists(loc):
90                try:
91                    os.mkdir(loc)
92                except OSError as e:
93                    assert e.errno == errno.EEXIST
94                    assert os.path.exists(loc)
95
96            preargs = ['-fwrapv', '-O2', '-fno-strict-aliasing']
97            preargs += cmodule.GCC_compiler.compile_args()
98            # Cython 19.1 always use the old NumPy interface.  So we
99            # need to manually modify the .c file to get it compiled
100            # by Theano. As by default, we tell NumPy to don't import
101            # the old interface.
102            if False:
103                # During scan cython development, it is helpful to keep the old interface, to don't manually edit the c file each time.
104                preargs.remove('-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION')
105            else:
106                numpy_ver = [int(n) for n in np.__version__.split('.')[:2]]
107                # Add add some macro to lower the number of edit
108                # needed to the c file.
109                if bool(numpy_ver >= [1, 7]):
110                    # Needed when we disable the old API, as cython
111                    # use the old interface
112                    preargs.append("-DNPY_ENSUREARRAY=NPY_ARRAY_ENSUREARRAY")
113                    preargs.append("-DNPY_ENSURECOPY=NPY_ARRAY_ENSURECOPY")
114                    preargs.append("-DNPY_ALIGNED=NPY_ARRAY_ALIGNED")
115                    preargs.append("-DNPY_WRITEABLE=NPY_ARRAY_WRITEABLE")
116                    preargs.append("-DNPY_UPDATE_ALL=NPY_ARRAY_UPDATE_ALL")
117                    preargs.append("-DNPY_C_CONTIGUOUS=NPY_ARRAY_C_CONTIGUOUS")
118                    preargs.append("-DNPY_F_CONTIGUOUS=NPY_ARRAY_F_CONTIGUOUS")
119
120            cmodule.GCC_compiler.compile_str(dirname, code, location=loc,
121                                             preargs=preargs,
122                                             hide_symbols=False)
123            # Save version into the __init__.py file.
124            init_py = os.path.join(loc, '__init__.py')
125            with open(init_py, 'w') as f:
126                f.write('_version = %s\n' % version)
127            # If we just compiled the module for the first time, then it was
128            # imported at the same time: we need to make sure we do not
129            # reload the now outdated __init__.pyc below.
130            init_pyc = os.path.join(loc, '__init__.pyc')
131            if os.path.isfile(init_pyc):
132                os.remove(init_pyc)
133            try_import()
134
135            try_reload()
136            from scan_perform import scan_perform as scan_c
137            assert (scan_perform._version ==
138                    scan_c.get_version())
139            _logger.info("New version %s", scan_perform._version)
140    finally:
141        # Release lock on compilation directory.
142        release_lock()
143
144# This is caused as cython use the old NumPy C-API but we use the new one.
145# To fix it completly, we would need to modify Cython to use the new API.
146with warnings.catch_warnings():
147    warnings.filterwarnings("ignore",
148                            message="numpy.ndarray size changed")
149    from scan_perform.scan_perform import *
150assert version == get_version()
151