1"""
2Caching mechanism for compiled functions.
3"""
4
5
6from abc import ABCMeta, abstractmethod, abstractproperty
7import contextlib
8import errno
9import hashlib
10import inspect
11import itertools
12import os
13import pickle
14import sys
15import tempfile
16import warnings
17
18from numba.misc.appdirs import AppDirs
19from numba.core.utils import add_metaclass, file_replace
20
21import numba
22from numba.core.errors import NumbaWarning
23from numba.core.base import BaseContext
24from numba.core.codegen import CodeLibrary
25from numba.core.compiler import CompileResult
26from numba.core import config, compiler
27from numba.core.serialize import dumps
28
29
30def _get_codegen(obj):
31    """
32    Returns the Codegen associated with the given object.
33    """
34    if isinstance(obj, BaseContext):
35        return obj.codegen()
36    elif isinstance(obj, CodeLibrary):
37        return obj.codegen
38    elif isinstance(obj, CompileResult):
39        return obj.target_context.codegen()
40    else:
41        raise TypeError(type(obj))
42
43
44def _cache_log(msg, *args):
45    if config.DEBUG_CACHE:
46        msg = msg % args
47        print(msg)
48
49
50@add_metaclass(ABCMeta)
51class _Cache(object):
52
53    @abstractproperty
54    def cache_path(self):
55        """
56        The base filesystem path of this cache (for example its root folder).
57        """
58
59    @abstractmethod
60    def load_overload(self, sig, target_context):
61        """
62        Load an overload for the given signature using the target context.
63        The saved object must be returned if successful, None if not found
64        in the cache.
65        """
66
67    @abstractmethod
68    def save_overload(self, sig, data):
69        """
70        Save the overload for the given signature.
71        """
72
73    @abstractmethod
74    def enable(self):
75        """
76        Enable the cache.
77        """
78
79    @abstractmethod
80    def disable(self):
81        """
82        Disable the cache.
83        """
84
85    @abstractmethod
86    def flush(self):
87        """
88        Flush the cache.
89        """
90
91
92class NullCache(_Cache):
93    @property
94    def cache_path(self):
95        return None
96
97    def load_overload(self, sig, target_context):
98        pass
99
100    def save_overload(self, sig, cres):
101        pass
102
103    def enable(self):
104        pass
105
106    def disable(self):
107        pass
108
109    def flush(self):
110        pass
111
112
113@add_metaclass(ABCMeta)
114class _CacheLocator(object):
115    """
116    A filesystem locator for caching a given function.
117    """
118
119    def ensure_cache_path(self):
120        path = self.get_cache_path()
121        try:
122            os.makedirs(path)
123        except OSError as e:
124            if e.errno != errno.EEXIST:
125                raise
126        # Ensure the directory is writable by trying to write a temporary file
127        tempfile.TemporaryFile(dir=path).close()
128
129    @abstractmethod
130    def get_cache_path(self):
131        """
132        Return the directory the function is cached in.
133        """
134
135    @abstractmethod
136    def get_source_stamp(self):
137        """
138        Get a timestamp representing the source code's freshness.
139        Can return any picklable Python object.
140        """
141
142    @abstractmethod
143    def get_disambiguator(self):
144        """
145        Get a string disambiguator for this locator's function.
146        It should allow disambiguating different but similarly-named functions.
147        """
148
149    @classmethod
150    def from_function(cls, py_func, py_file):
151        """
152        Create a locator instance for the given function located in the
153        given file.
154        """
155        raise NotImplementedError
156
157    @classmethod
158    def get_suitable_cache_subpath(cls, py_file):
159        """Given the Python file path, compute a suitable path inside the
160        cache directory.
161
162        This will reduce a file path that is too long, which can be a problem
163        on some operating system (i.e. Windows 7).
164        """
165        path = os.path.abspath(py_file)
166        subpath = os.path.dirname(path)
167        parentdir = os.path.split(subpath)[-1]
168        # Use SHA1 to reduce path length.
169        # Note: windows doesn't like long path.
170        hashed = hashlib.sha1(subpath.encode()).hexdigest()
171        # Retain parent directory name for easier debugging
172        return '_'.join([parentdir, hashed])
173
174
175class _SourceFileBackedLocatorMixin(object):
176    """
177    A cache locator mixin for functions which are backed by a well-known
178    Python source file.
179    """
180
181    def get_source_stamp(self):
182        if getattr(sys, 'frozen', False):
183            st = os.stat(sys.executable)
184        else:
185            st = os.stat(self._py_file)
186        # We use both timestamp and size as some filesystems only have second
187        # granularity.
188        return st.st_mtime, st.st_size
189
190    def get_disambiguator(self):
191        return str(self._lineno)
192
193    @classmethod
194    def from_function(cls, py_func, py_file):
195        if not os.path.exists(py_file):
196            # Perhaps a placeholder (e.g. "<ipython-XXX>")
197            return
198        self = cls(py_func, py_file)
199        try:
200            self.ensure_cache_path()
201        except OSError:
202            # Cannot ensure the cache directory exists or is writable
203            return
204        return self
205
206
207class _UserProvidedCacheLocator(_SourceFileBackedLocatorMixin, _CacheLocator):
208    """
209    A locator that always point to the user provided directory in
210    `numba.config.CACHE_DIR`
211    """
212    def __init__(self, py_func, py_file):
213        self._py_file = py_file
214        self._lineno = py_func.__code__.co_firstlineno
215        cache_subpath = self.get_suitable_cache_subpath(py_file)
216        self._cache_path = os.path.join(config.CACHE_DIR, cache_subpath)
217
218    def get_cache_path(self):
219        return self._cache_path
220
221    @classmethod
222    def from_function(cls, py_func, py_file):
223        if not config.CACHE_DIR:
224            return
225        parent = super(_UserProvidedCacheLocator, cls)
226        return parent.from_function(py_func, py_file)
227
228
229class _InTreeCacheLocator(_SourceFileBackedLocatorMixin, _CacheLocator):
230    """
231    A locator for functions backed by a regular Python module with a
232    writable __pycache__ directory.
233    """
234
235    def __init__(self, py_func, py_file):
236        self._py_file = py_file
237        self._lineno = py_func.__code__.co_firstlineno
238        self._cache_path = os.path.join(os.path.dirname(self._py_file), '__pycache__')
239
240    def get_cache_path(self):
241        return self._cache_path
242
243
244class _UserWideCacheLocator(_SourceFileBackedLocatorMixin, _CacheLocator):
245    """
246    A locator for functions backed by a regular Python module or a
247    frozen executable, cached into a user-wide cache directory.
248    """
249
250    def __init__(self, py_func, py_file):
251        self._py_file = py_file
252        self._lineno = py_func.__code__.co_firstlineno
253        appdirs = AppDirs(appname="numba", appauthor=False)
254        cache_dir = appdirs.user_cache_dir
255        cache_subpath = self.get_suitable_cache_subpath(py_file)
256        self._cache_path = os.path.join(cache_dir, cache_subpath)
257
258    def get_cache_path(self):
259        return self._cache_path
260
261    @classmethod
262    def from_function(cls, py_func, py_file):
263        if not (os.path.exists(py_file) or getattr(sys, 'frozen', False)):
264            # Perhaps a placeholder (e.g. "<ipython-XXX>")
265            # stop function exit if frozen, since it uses a temp placeholder
266            return
267        self = cls(py_func, py_file)
268        try:
269            self.ensure_cache_path()
270        except OSError:
271            # Cannot ensure the cache directory exists or is writable
272            return
273        return self
274
275
276class _IPythonCacheLocator(_CacheLocator):
277    """
278    A locator for functions entered at the IPython prompt (notebook or other).
279    """
280
281    def __init__(self, py_func, py_file):
282        self._py_file = py_file
283        # Note IPython enhances the linecache module to be able to
284        # inspect source code of functions defined on the interactive prompt.
285        source = inspect.getsource(py_func)
286        if isinstance(source, bytes):
287            self._bytes_source = source
288        else:
289            self._bytes_source = source.encode('utf-8')
290
291    def get_cache_path(self):
292        # We could also use jupyter_core.paths.jupyter_runtime_dir()
293        # In both cases this is a user-wide directory, so we need to
294        # be careful when disambiguating if we don't want too many
295        # conflicts (see below).
296        try:
297            from IPython.paths import get_ipython_cache_dir
298        except ImportError:
299            # older IPython version
300            from IPython.utils.path import get_ipython_cache_dir
301        return os.path.join(get_ipython_cache_dir(), 'numba_cache')
302
303    def get_source_stamp(self):
304        return hashlib.sha256(self._bytes_source).hexdigest()
305
306    def get_disambiguator(self):
307        # Heuristic: we don't want too many variants being saved, but
308        # we don't want similar named functions (e.g. "f") to compete
309        # for the cache, so we hash the first two lines of the function
310        # source (usually this will be the @jit decorator + the function
311        # signature).
312        firstlines = b''.join(self._bytes_source.splitlines(True)[:2])
313        return hashlib.sha256(firstlines).hexdigest()[:10]
314
315    @classmethod
316    def from_function(cls, py_func, py_file):
317        if not py_file.startswith("<ipython-"):
318            return
319        self = cls(py_func, py_file)
320        try:
321            self.ensure_cache_path()
322        except OSError:
323            # Cannot ensure the cache directory exists
324            return
325        return self
326
327
328@add_metaclass(ABCMeta)
329class _CacheImpl(object):
330    """
331    Provides the core machinery for caching.
332    - implement how to serialize and deserialize the data in the cache.
333    - control the filename of the cache.
334    - provide the cache locator
335    """
336    _locator_classes = [_UserProvidedCacheLocator,
337                        _InTreeCacheLocator,
338                        _UserWideCacheLocator,
339                        _IPythonCacheLocator]
340
341    def __init__(self, py_func):
342        self._lineno = py_func.__code__.co_firstlineno
343        # Get qualname
344        try:
345            qualname = py_func.__qualname__
346        except AttributeError:
347            qualname = py_func.__name__
348        # Find a locator
349        source_path = inspect.getfile(py_func)
350        for cls in self._locator_classes:
351            locator = cls.from_function(py_func, source_path)
352            if locator is not None:
353                break
354        else:
355            raise RuntimeError("cannot cache function %r: no locator available "
356                               "for file %r" % (qualname, source_path))
357        self._locator = locator
358        # Use filename base name as module name to avoid conflict between
359        # foo/__init__.py and foo/foo.py
360        filename = inspect.getfile(py_func)
361        modname = os.path.splitext(os.path.basename(filename))[0]
362        fullname = "%s.%s" % (modname, qualname)
363        abiflags = getattr(sys, 'abiflags', '')
364        self._filename_base = self.get_filename_base(fullname, abiflags)
365
366    def get_filename_base(self, fullname, abiflags):
367        # '<' and '>' can appear in the qualname (e.g. '<locals>') but
368        # are forbidden in Windows filenames
369        fixed_fullname = fullname.replace('<', '').replace('>', '')
370        fmt = '%s-%s.py%d%d%s'
371        return fmt % (fixed_fullname, self.locator.get_disambiguator(),
372                      sys.version_info[0], sys.version_info[1], abiflags)
373
374    @property
375    def filename_base(self):
376        return self._filename_base
377
378    @property
379    def locator(self):
380        return self._locator
381
382    @abstractmethod
383    def reduce(self, data):
384        "Returns the serialized form the data"
385        pass
386
387    @abstractmethod
388    def rebuild(self, target_context, reduced_data):
389        "Returns the de-serialized form of the *reduced_data*"
390        pass
391
392    @abstractmethod
393    def check_cachable(self, data):
394        "Returns True if the given data is cachable; otherwise, returns False."
395        pass
396
397
398class CompileResultCacheImpl(_CacheImpl):
399    """
400    Implements the logic to cache CompileResult objects.
401    """
402
403    def reduce(self, cres):
404        """
405        Returns a serialized CompileResult
406        """
407        return cres._reduce()
408
409    def rebuild(self, target_context, payload):
410        """
411        Returns the unserialized CompileResult
412        """
413        return compiler.CompileResult._rebuild(target_context, *payload)
414
415    def check_cachable(self, cres):
416        """
417        Check cachability of the given compile result.
418        """
419        cannot_cache = None
420        if any(not x.can_cache for x in cres.lifted):
421            cannot_cache = "as it uses lifted code"
422        elif cres.library.has_dynamic_globals:
423            cannot_cache = ("as it uses dynamic globals "
424                            "(such as ctypes pointers and large global arrays)")
425        if cannot_cache:
426            msg = ('Cannot cache compiled function "%s" %s'
427                   % (cres.fndesc.qualname.split('.')[-1], cannot_cache))
428            warnings.warn_explicit(msg, NumbaWarning,
429                                   self._locator._py_file, self._lineno)
430            return False
431        return True
432
433
434class CodeLibraryCacheImpl(_CacheImpl):
435    """
436    Implements the logic to cache CodeLibrary objects.
437    """
438
439    _filename_prefix = None  # must be overridden
440
441    def reduce(self, codelib):
442        """
443        Returns a serialized CodeLibrary
444        """
445        return codelib.serialize_using_object_code()
446
447    def rebuild(self, target_context, payload):
448        """
449        Returns the unserialized CodeLibrary
450        """
451        return target_context.codegen().unserialize_library(payload)
452
453    def check_cachable(self, codelib):
454        """
455        Check cachability of the given CodeLibrary.
456        """
457        return not codelib.has_dynamic_globals
458
459    def get_filename_base(self, fullname, abiflags):
460        parent = super(CodeLibraryCacheImpl, self)
461        res = parent.get_filename_base(fullname, abiflags)
462        return '-'.join([self._filename_prefix, res])
463
464
465class IndexDataCacheFile(object):
466    """
467    Implements the logic for the index file and data file used by a cache.
468    """
469    def __init__(self, cache_path, filename_base, source_stamp):
470        self._cache_path = cache_path
471        self._index_name = '%s.nbi' % (filename_base,)
472        self._index_path = os.path.join(self._cache_path, self._index_name)
473        self._data_name_pattern = '%s.{number:d}.nbc' % (filename_base,)
474        self._source_stamp = source_stamp
475        self._version = numba.__version__
476
477    def flush(self):
478        self._save_index({})
479
480    def save(self, key, data):
481        """
482        Save a new cache entry with *key* and *data*.
483        """
484        overloads = self._load_index()
485        try:
486            # If key already exists, we will overwrite the file
487            data_name = overloads[key]
488        except KeyError:
489            # Find an available name for the data file
490            existing = set(overloads.values())
491            for i in itertools.count(1):
492                data_name = self._data_name(i)
493                if data_name not in existing:
494                    break
495            overloads[key] = data_name
496            self._save_index(overloads)
497        self._save_data(data_name, data)
498
499    def load(self, key):
500        """
501        Load a cache entry with *key*.
502        """
503        overloads = self._load_index()
504        data_name = overloads.get(key)
505        if data_name is None:
506            return
507        try:
508            return self._load_data(data_name)
509        except EnvironmentError:
510            # File could have been removed while the index still refers it.
511            return
512
513    def _load_index(self):
514        """
515        Load the cache index and return it as a dictionary (possibly
516        empty if cache is empty or obsolete).
517        """
518        try:
519            with open(self._index_path, "rb") as f:
520                version = pickle.load(f)
521                data = f.read()
522        except EnvironmentError as e:
523            # Index doesn't exist yet?
524            if e.errno in (errno.ENOENT,):
525                return {}
526            raise
527        if version != self._version:
528            # This is another version.  Avoid trying to unpickling the
529            # rest of the stream, as that may fail.
530            return {}
531        stamp, overloads = pickle.loads(data)
532        _cache_log("[cache] index loaded from %r", self._index_path)
533        if stamp != self._source_stamp:
534            # Cache is not fresh.  Stale data files will be eventually
535            # overwritten, since they are numbered in incrementing order.
536            return {}
537        else:
538            return overloads
539
540    def _save_index(self, overloads):
541        data = self._source_stamp, overloads
542        data = self._dump(data)
543        with self._open_for_write(self._index_path) as f:
544            pickle.dump(self._version, f, protocol=-1)
545            f.write(data)
546        _cache_log("[cache] index saved to %r", self._index_path)
547
548    def _load_data(self, name):
549        path = self._data_path(name)
550        with open(path, "rb") as f:
551            data = f.read()
552        tup = pickle.loads(data)
553        _cache_log("[cache] data loaded from %r", path)
554        return tup
555
556    def _save_data(self, name, data):
557        data = self._dump(data)
558        path = self._data_path(name)
559        with self._open_for_write(path) as f:
560            f.write(data)
561        _cache_log("[cache] data saved to %r", path)
562
563    def _data_name(self, number):
564        return self._data_name_pattern.format(number=number)
565
566    def _data_path(self, name):
567        return os.path.join(self._cache_path, name)
568
569    def _dump(self, obj):
570        return pickle.dumps(obj, protocol=-1)
571
572    @contextlib.contextmanager
573    def _open_for_write(self, filepath):
574        """
575        Open *filepath* for writing in a race condition-free way
576        (hopefully).
577        """
578        tmpname = '%s.tmp.%d' % (filepath, os.getpid())
579        try:
580            with open(tmpname, "wb") as f:
581                yield f
582            file_replace(tmpname, filepath)
583        except Exception:
584            # In case of error, remove dangling tmp file
585            try:
586                os.unlink(tmpname)
587            except OSError:
588                pass
589            raise
590
591
592class Cache(_Cache):
593    """
594    A per-function compilation cache.  The cache saves data in separate
595    data files and maintains information in an index file.
596
597    There is one index file per function and Python version
598    ("function_name-<lineno>.pyXY.nbi") which contains a mapping of
599    signatures and architectures to data files.
600    It is prefixed by a versioning key and a timestamp of the Python source
601    file containing the function.
602
603    There is one data file ("function_name-<lineno>.pyXY.<number>.nbc")
604    per function, function signature, target architecture and Python version.
605
606    Separate index and data files per Python version avoid pickle
607    compatibility problems.
608
609    Note:
610    This contains the driver logic only.  The core logic is provided
611    by a subclass of ``_CacheImpl`` specified as *_impl_class* in the subclass.
612    """
613
614    # The following class variables must be overridden by subclass.
615    _impl_class = None
616
617    def __init__(self, py_func):
618        self._name = repr(py_func)
619        self._py_func = py_func
620        self._impl = self._impl_class(py_func)
621        self._cache_path = self._impl.locator.get_cache_path()
622        # This may be a bit strict but avoids us maintaining a magic number
623        source_stamp = self._impl.locator.get_source_stamp()
624        filename_base = self._impl.filename_base
625        self._cache_file = IndexDataCacheFile(cache_path=self._cache_path,
626                                              filename_base=filename_base,
627                                              source_stamp=source_stamp)
628        self.enable()
629
630    def __repr__(self):
631        return "<%s py_func=%r>" % (self.__class__.__name__, self._name)
632
633    @property
634    def cache_path(self):
635        return self._cache_path
636
637    def enable(self):
638        self._enabled = True
639
640    def disable(self):
641        self._enabled = False
642
643    def flush(self):
644        self._cache_file.flush()
645
646    def load_overload(self, sig, target_context):
647        """
648        Load and recreate the cached object for the given signature,
649        using the *target_context*.
650        """
651        # Refresh the context to ensure it is initialized
652        target_context.refresh()
653        with self._guard_against_spurious_io_errors():
654            return self._load_overload(sig, target_context)
655        # None returned if the `with` block swallows an exception
656
657    def _load_overload(self, sig, target_context):
658        if not self._enabled:
659            return
660        key = self._index_key(sig, _get_codegen(target_context))
661        data = self._cache_file.load(key)
662        if data is not None:
663            data = self._impl.rebuild(target_context, data)
664        return data
665
666    def save_overload(self, sig, data):
667        """
668        Save the data for the given signature in the cache.
669        """
670        with self._guard_against_spurious_io_errors():
671            self._save_overload(sig, data)
672
673    def _save_overload(self, sig, data):
674        if not self._enabled:
675            return
676        if not self._impl.check_cachable(data):
677            return
678        self._impl.locator.ensure_cache_path()
679        key = self._index_key(sig, _get_codegen(data))
680        data = self._impl.reduce(data)
681        self._cache_file.save(key, data)
682
683    @contextlib.contextmanager
684    def _guard_against_spurious_io_errors(self):
685        if os.name == 'nt':
686            # Guard against permission errors due to accessing the file
687            # from several processes (see #2028)
688            try:
689                yield
690            except EnvironmentError as e:
691                if e.errno != errno.EACCES:
692                    raise
693        else:
694            # No such conditions under non-Windows OSes
695            yield
696
697    def _index_key(self, sig, codegen):
698        """
699        Compute index key for the given signature and codegen.
700        It includes a description of the OS, target architecture and hashes of
701        the bytecode for the function and, if the function has a __closure__,
702        a hash of the cell_contents.
703        """
704        codebytes = self._py_func.__code__.co_code
705        if self._py_func.__closure__ is not None:
706            cvars = tuple([x.cell_contents for x in self._py_func.__closure__])
707            cvarbytes = dumps(cvars)
708        else:
709            cvarbytes = b''
710
711        hasher = lambda x: hashlib.sha256(x).hexdigest()
712        return (sig, codegen.magic_tuple(), (hasher(codebytes),
713                                             hasher(cvarbytes),))
714
715
716class FunctionCache(Cache):
717    """
718    Implements Cache that saves and loads CompileResult objects.
719    """
720    _impl_class = CompileResultCacheImpl
721
722
723# Remember used cache filename prefixes.
724_lib_cache_prefixes = set([''])
725
726
727def make_library_cache(prefix):
728    """
729    Create a Cache class for additional compilation features to cache their
730    result for reuse.  The cache is saved in filename pattern like
731    in ``FunctionCache`` but with additional *prefix* as specified.
732    """
733    # avoid cache prefix reuse
734    assert prefix not in _lib_cache_prefixes
735    _lib_cache_prefixes.add(prefix)
736
737    class CustomCodeLibraryCacheImpl(CodeLibraryCacheImpl):
738        _filename_prefix = prefix
739
740    class LibraryCache(Cache):
741        """
742        Implements Cache that saves and loads CodeLibrary objects for additional
743        feature for the specified python function.
744        """
745        _impl_class = CustomCodeLibraryCacheImpl
746
747    return LibraryCache
748
749