1#!/usr/bin/env python
2# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Qiming Sun <osirpt.sun@gmail.com>
17#
18
19'''
20Some helper functions
21'''
22
23import os, sys
24import warnings
25import tempfile
26import functools
27import itertools
28import collections
29import ctypes
30import numpy
31import h5py
32from threading import Thread
33from multiprocessing import Queue, Process
34try:
35    from concurrent.futures import ThreadPoolExecutor
36except ImportError:
37    ThreadPoolExecutor = None
38
39from pyscf.lib import param
40from pyscf import __config__
41
42if h5py.version.version[:4] == '2.2.':
43    sys.stderr.write('h5py-%s is found in your environment. '
44                     'h5py-%s has bug in threading mode.\n'
45                     'Async-IO is disabled.\n' % ((h5py.version.version,)*2))
46
47c_double_p = ctypes.POINTER(ctypes.c_double)
48c_int_p = ctypes.POINTER(ctypes.c_int)
49c_null_ptr = ctypes.POINTER(ctypes.c_void_p)
50
51def load_library(libname):
52    try:
53        _loaderpath = os.path.dirname(__file__)
54        return numpy.ctypeslib.load_library(libname, _loaderpath)
55    except OSError:
56        from pyscf import __path__ as ext_modules
57        for path in ext_modules:
58            libpath = os.path.join(path, 'lib')
59            if os.path.isdir(libpath):
60                for files in os.listdir(libpath):
61                    if files.startswith(libname):
62                        return numpy.ctypeslib.load_library(libname, libpath)
63        raise
64
65#Fixme, the standard resouce module gives wrong number when objects are released
66# http://fa.bianp.net/blog/2013/different-ways-to-get-memory-consumption-or-lessons-learned-from-memory_profiler/#fn:1
67#or use slow functions as memory_profiler._get_memory did
68CLOCK_TICKS = os.sysconf("SC_CLK_TCK")
69PAGESIZE = os.sysconf("SC_PAGE_SIZE")
70def current_memory():
71    '''Return the size of used memory and allocated virtual memory (in MB)'''
72    #import resource
73    #return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1000
74    if sys.platform.startswith('linux'):
75        with open("/proc/%s/statm" % os.getpid()) as f:
76            vms, rss = [int(x)*PAGESIZE for x in f.readline().split()[:2]]
77            return rss/1e6, vms/1e6
78    else:
79        return 0, 0
80
81def num_threads(n=None):
82    '''Set the number of OMP threads.  If argument is not specified, the
83    function will return the total number of available OMP threads.
84
85    It's recommended to call this function to set OMP threads than
86    "os.environ['OMP_NUM_THREADS'] = int(n)". This is because environment
87    variables like OMP_NUM_THREADS were read when a module was imported. They
88    cannot be reset through os.environ after the module was loaded.
89
90    Examples:
91
92    >>> from pyscf import lib
93    >>> print(lib.num_threads())
94    8
95    >>> lib.num_threads(4)
96    4
97    >>> print(lib.num_threads())
98    4
99    '''
100    from pyscf.lib.numpy_helper import _np_helper
101    if n is not None:
102        _np_helper.set_omp_threads.restype = ctypes.c_int
103        threads = _np_helper.set_omp_threads(ctypes.c_int(int(n)))
104        if threads == 0:
105            warnings.warn('OpenMP is not available. '
106                          'Setting omp_threads to %s has no effects.' % n)
107        return threads
108    else:
109        _np_helper.get_omp_threads.restype = ctypes.c_int
110        return _np_helper.get_omp_threads()
111
112class with_omp_threads(object):
113    '''Using this macro to create a temporary context in which the number of
114    OpenMP threads are set to the required value. When the program exits the
115    context, the number OpenMP threads will be restored.
116
117    Args:
118        nthreads : int
119
120    Examples:
121
122    >>> from pyscf import lib
123    >>> print(lib.num_threads())
124    8
125    >>> with lib.with_omp_threads(2):
126    ...     print(lib.num_threads())
127    2
128    >>> print(lib.num_threads())
129    8
130    '''
131    def __init__(self, nthreads=None):
132        self.nthreads = nthreads
133        self.sys_threads = None
134    def __enter__(self):
135        if self.nthreads is not None and self.nthreads >= 1:
136            self.sys_threads = num_threads()
137            num_threads(self.nthreads)
138        return self
139    def __exit__(self, type, value, traceback):
140        if self.sys_threads is not None:
141            num_threads(self.sys_threads)
142
143
144def c_int_arr(m):
145    npm = numpy.array(m).flatten('C')
146    arr = (ctypes.c_int * npm.size)(*npm)
147    # cannot return LP_c_double class,
148    #Xreturn npm.ctypes.data_as(c_int_p), which destructs npm before return
149    return arr
150def f_int_arr(m):
151    npm = numpy.array(m).flatten('F')
152    arr = (ctypes.c_int * npm.size)(*npm)
153    return arr
154def c_double_arr(m):
155    npm = numpy.array(m).flatten('C')
156    arr = (ctypes.c_double * npm.size)(*npm)
157    return arr
158def f_double_arr(m):
159    npm = numpy.array(m).flatten('F')
160    arr = (ctypes.c_double * npm.size)(*npm)
161    return arr
162
163
164def member(test, x, lst):
165    for l in lst:
166        if test(x, l):
167            return True
168    return False
169
170def remove_dup(test, lst, from_end=False):
171    if test is None:
172        return set(lst)
173    else:
174        if from_end:
175            lst = list(reversed(lst))
176        seen = []
177        for l in lst:
178            if not member(test, l, seen):
179                seen.append(l)
180        return seen
181
182def remove_if(test, lst):
183    return [x for x in lst if not test(x)]
184
185def find_if(test, lst):
186    for l in lst:
187        if test(l):
188            return l
189    raise ValueError('No element of the given list matches the test condition.')
190
191def arg_first_match(test, lst):
192    for i,x in enumerate(lst):
193        if test(x):
194            return i
195    raise ValueError('No element of the given list matches the test condition.')
196
197def _balanced_partition(cum, ntasks):
198    segsize = float(cum[-1]) / ntasks
199    bounds = numpy.arange(ntasks+1) * segsize
200    displs = abs(bounds[:,None] - cum).argmin(axis=1)
201    return displs
202
203def _blocksize_partition(cum, blocksize):
204    n = len(cum) - 1
205    displs = [0]
206    if n == 0:
207        return displs
208
209    p0 = 0
210    for i in range(1, n):
211        if cum[i+1]-cum[p0] > blocksize:
212            displs.append(i)
213            p0 = i
214    displs.append(n)
215    return displs
216
217def flatten(lst):
218    '''flatten nested lists
219    x[0] + x[1] + x[2] + ...
220
221    Examples:
222
223    >>> flatten([[0, 2], [1], [[9, 8, 7]]])
224    [0, 2, 1, [9, 8, 7]]
225    '''
226    return list(itertools.chain.from_iterable(lst))
227
228def prange(start, end, step):
229    '''This function splits the number sequence between "start" and "end"
230    using uniform "step" length. It yields the boundary (start, end) for each
231    fragment.
232
233    Examples:
234
235    >>> for p0, p1 in lib.prange(0, 8, 2):
236    ...    print(p0, p1)
237    (0, 2)
238    (2, 4)
239    (4, 6)
240    (6, 8)
241    '''
242    if start < end:
243        for i in range(start, end, step):
244            yield i, min(i+step, end)
245
246def prange_tril(start, stop, blocksize):
247    '''Similar to :func:`prange`, yeilds start (p0) and end (p1) with the
248    restriction p1*(p1+1)/2-p0*(p0+1)/2 < blocksize
249
250    Examples:
251
252    >>> for p0, p1 in lib.prange_tril(0, 10, 25):
253    ...     print(p0, p1)
254    (0, 6)
255    (6, 9)
256    (9, 10)
257    '''
258    if start >= stop:
259        return []
260    idx = numpy.arange(start, stop+1)
261    cum_costs = idx*(idx+1)//2 - start*(start+1)//2
262    displs = [x+start for x in _blocksize_partition(cum_costs, blocksize)]
263    return zip(displs[:-1], displs[1:])
264
265def map_with_prefetch(func, *iterables):
266    '''
267    Apply function to an task and prefetch the next task
268    '''
269    global_import_lock = False
270    if sys.version_info < (3, 6):
271        import imp
272        global_import_lock = imp.lock_held()
273
274    if global_import_lock:
275        for task in zip(*iterables):
276            yield func(*task)
277
278    elif ThreadPoolExecutor is not None:
279        with ThreadPoolExecutor(max_workers=1) as executor:
280            future = None
281            for task in zip(*iterables):
282                if future is None:
283                    future = executor.submit(func, *task)
284                else:
285                    result = future.result()
286                    future = executor.submit(func, *task)
287                    yield result
288            if future is not None:
289                yield future.result()
290    else:
291        def func_with_buf(_output_buf, *args):
292            _output_buf[0] = func(*args)
293        with call_in_background(func_with_buf) as f_prefetch:
294            buf0, buf1 = [None], [None]
295            for istep, task in enumerate(zip(*iterables)):
296                if istep == 0:
297                    f_prefetch(buf0, *task)
298                else:
299                    buf0, buf1 = buf1, buf0
300                    f_prefetch(buf0, *task)
301                    yield buf1[0]
302        if buf0[0] is not None:
303            yield buf0[0]
304
305def index_tril_to_pair(ij):
306    '''Given tril-index ij, compute the pair indices (i,j) which satisfy
307    ij = i * (i+1) / 2 + j
308    '''
309    i = (numpy.sqrt(2*ij+.25) - .5 + 1e-7).astype(int)
310    j = ij - i*(i+1)//2
311    return i, j
312
313
314def tril_product(*iterables, **kwds):
315    '''Cartesian product in lower-triangular form for multiple indices
316
317    For a given list of indices (`iterables`), this function yields all
318    indices such that the sub-indices given by the kwarg `tril_idx` satisfy a
319    lower-triangular form.  The lower-triangular form satisfies:
320
321    .. math:: i[tril_idx[0]] >= i[tril_idx[1]] >= ... >= i[tril_idx[len(tril_idx)-1]]
322
323    Args:
324        *iterables: Variable length argument list of indices for the cartesian product
325        **kwds: Arbitrary keyword arguments.  Acceptable keywords include:
326            repeat (int): Number of times to repeat the iterables
327            tril_idx (array_like): Indices to put into lower-triangular form.
328
329    Yields:
330        product (tuple): Tuple in lower-triangular form.
331
332    Examples:
333        Specifying no `tril_idx` is equivalent to just a cartesian product.
334
335        >>> list(tril_product(range(2), repeat=2))
336        [(0, 0), (0, 1), (1, 0), (1, 1)]
337
338        We can specify only sub-indices to satisfy a lower-triangular form:
339
340        >>> list(tril_product(range(2), repeat=3, tril_idx=[1,2]))
341        [(0, 0, 0), (0, 1, 0), (0, 1, 1), (1, 0, 0), (1, 1, 0), (1, 1, 1)]
342
343        We specify all indices to satisfy a lower-triangular form, useful for iterating over
344        the symmetry unique elements of occupied/virtual orbitals in a 3-particle operator:
345
346        >>> list(tril_product(range(3), repeat=3, tril_idx=[0,1,2]))
347        [(0, 0, 0), (1, 0, 0), (1, 1, 0), (1, 1, 1), (2, 0, 0), (2, 1, 0), (2, 1, 1), (2, 2, 0), (2, 2, 1), (2, 2, 2)]
348    '''
349    repeat = kwds.get('repeat', 1)
350    tril_idx = kwds.get('tril_idx', [])
351    niterables = len(iterables) * repeat
352    ntril_idx = len(tril_idx)
353
354    assert ntril_idx <= niterables, 'Cant have a greater number of tril indices than iterables!'
355    if ntril_idx > 0:
356        assert numpy.max(tril_idx) < niterables, 'Tril index out of bounds for %d iterables! idx = %s' % \
357                                                 (niterables, tril_idx)
358    for tup in itertools.product(*iterables, repeat=repeat):
359        if ntril_idx == 0:
360            yield tup
361            continue
362
363        if all([tup[tril_idx[i]] >= tup[tril_idx[i+1]] for i in range(ntril_idx-1)]):
364            yield tup
365        else:
366            pass
367
368def square_mat_in_trilu_indices(n):
369    '''Return a n x n symmetric index matrix, in which the elements are the
370    indices of the unique elements of a tril vector
371    [0 1 3 ... ]
372    [1 2 4 ... ]
373    [3 4 5 ... ]
374    [...       ]
375    '''
376    idx = numpy.tril_indices(n)
377    tril2sq = numpy.zeros((n,n), dtype=int)
378    tril2sq[idx[0],idx[1]] = tril2sq[idx[1],idx[0]] = numpy.arange(n*(n+1)//2)
379    return tril2sq
380
381class capture_stdout(object):
382    '''redirect all stdout (c printf & python print) into a string
383
384    Examples:
385
386    >>> import os
387    >>> from pyscf import lib
388    >>> with lib.capture_stdout() as out:
389    ...     os.system('ls')
390    >>> print(out.read())
391    '''
392    #TODO: handle stderr
393    def __enter__(self):
394        sys.stdout.flush()
395        self._contents = None
396        self.old_stdout_fileno = sys.stdout.fileno()
397        self.bak_stdout_fd = os.dup(self.old_stdout_fileno)
398        self.ftmp = tempfile.NamedTemporaryFile(dir=param.TMPDIR)
399        os.dup2(self.ftmp.file.fileno(), self.old_stdout_fileno)
400        return self
401    def __exit__(self, type, value, traceback):
402        sys.stdout.flush()
403        self.ftmp.file.seek(0)
404        self._contents = self.ftmp.file.read()
405        self.ftmp.close()
406        os.dup2(self.bak_stdout_fd, self.old_stdout_fileno)
407        os.close(self.bak_stdout_fd)
408    def read(self):
409        if self._contents:
410            return self._contents
411        else:
412            sys.stdout.flush()
413            self.ftmp.file.seek(0)
414            return self.ftmp.file.read()
415ctypes_stdout = capture_stdout
416
417class quite_run(object):
418    '''capture all stdout (c printf & python print) but output nothing
419
420    Examples:
421
422    >>> import os
423    >>> from pyscf import lib
424    >>> with lib.quite_run():
425    ...     os.system('ls')
426    '''
427    def __enter__(self):
428        sys.stdout.flush()
429        #TODO: to handle the redirected stdout e.g. StringIO()
430        self.old_stdout_fileno = sys.stdout.fileno()
431        self.bak_stdout_fd = os.dup(self.old_stdout_fileno)
432        self.fnull = open(os.devnull, 'wb')
433        os.dup2(self.fnull.fileno(), self.old_stdout_fileno)
434    def __exit__(self, type, value, traceback):
435        sys.stdout.flush()
436        os.dup2(self.bak_stdout_fd, self.old_stdout_fileno)
437        self.fnull.close()
438
439
440# from pygeocoder
441# this decorator lets me use methods as both static and instance methods
442# In contrast to classmethod, when obj.function() is called, the first
443# argument is obj in omnimethod rather than obj.__class__ in classmethod
444class omnimethod(object):
445    def __init__(self, func):
446        self.func = func
447
448    def __get__(self, instance, owner):
449        return functools.partial(self.func, instance)
450
451
452SANITY_CHECK = getattr(__config__, 'SANITY_CHECK', True)
453class StreamObject(object):
454    '''For most methods, there are three stream functions to pipe computing stream:
455
456    1 ``.set_`` function to update object attributes, eg
457    ``mf = scf.RHF(mol).set(conv_tol=1e-5)`` is identical to proceed in two steps
458    ``mf = scf.RHF(mol); mf.conv_tol=1e-5``
459
460    2 ``.run`` function to execute the kenerl function (the function arguments
461    are passed to kernel function).  If keyword arguments is given, it will first
462    call ``.set`` function to update object attributes then execute the kernel
463    function.  Eg
464    ``mf = scf.RHF(mol).run(dm_init, conv_tol=1e-5)`` is identical to three steps
465    ``mf = scf.RHF(mol); mf.conv_tol=1e-5; mf.kernel(dm_init)``
466
467    3 ``.apply`` function to apply the given function/class to the current object
468    (function arguments and keyword arguments are passed to the given function).
469    Eg
470    ``mol.apply(scf.RHF).run().apply(mcscf.CASSCF, 6, 4, frozen=4)`` is identical to
471    ``mf = scf.RHF(mol); mf.kernel(); mcscf.CASSCF(mf, 6, 4, frozen=4)``
472    '''
473
474    verbose = 0
475    stdout = sys.stdout
476    _keys = set(['verbose', 'stdout'])
477
478    def kernel(self, *args, **kwargs):
479        '''
480        Kernel function is the main driver of a method.  Every method should
481        define the kernel function as the entry of the calculation.  Note the
482        return value of kernel function is not strictly defined.  It can be
483        anything related to the method (such as the energy, the wave-function,
484        the DFT mesh grids etc.).
485        '''
486        pass
487
488    def pre_kernel(self, envs):
489        '''
490        A hook to be run before the main body of kernel function is executed.
491        Internal variables are exposed to pre_kernel through the "envs"
492        dictionary.  Return value of pre_kernel function is not required.
493        '''
494        pass
495
496    def post_kernel(self, envs):
497        '''
498        A hook to be run after the main body of the kernel function.  Internal
499        variables are exposed to post_kernel through the "envs" dictionary.
500        Return value of post_kernel function is not required.
501        '''
502        pass
503
504    def run(self, *args, **kwargs):
505        '''
506        Call the kernel function of current object.  `args` will be passed
507        to kernel function.  `kwargs` will be used to update the attributes of
508        current object.  The return value of method run is the object itself.
509        This allows a series of functions/methods to be executed in pipe.
510        '''
511        self.set(**kwargs)
512        self.kernel(*args)
513        return self
514
515    def set(self, *args, **kwargs):
516        '''
517        Update the attributes of the current object.  The return value of
518        method set is the object itself.  This allows a series of
519        functions/methods to be executed in pipe.
520        '''
521        if args:
522            warnings.warn('method set() only supports keyword arguments.\n'
523                          'Arguments %s are ignored.' % args)
524        #if getattr(self, '_keys', None):
525        #    for k,v in kwargs.items():
526        #        setattr(self, k, v)
527        #        if k not in self._keys:
528        #            sys.stderr.write('Warning: %s does not have attribute %s\n'
529        #                             % (self.__class__, k))
530        #else:
531        for k,v in kwargs.items():
532            setattr(self, k, v)
533        return self
534
535    # An alias to .set method
536    __call__ = set
537
538    def apply(self, fn, *args, **kwargs):
539        '''
540        Apply the fn to rest arguments:  return fn(*args, **kwargs).  The
541        return value of method set is the object itself.  This allows a series
542        of functions/methods to be executed in pipe.
543        '''
544        return fn(self, *args, **kwargs)
545
546#    def _format_args(self, args, kwargs, kernel_kw_lst):
547#        args1 = [kwargs.pop(k, v) for k, v in kernel_kw_lst]
548#        return args + args1[len(args):], kwargs
549
550    def check_sanity(self):
551        '''
552        Check input of class/object attributes, check whether a class method is
553        overwritten.  It does not check the attributes which are prefixed with
554        "_".  The
555        return value of method set is the object itself.  This allows a series
556        of functions/methods to be executed in pipe.
557        '''
558        if (SANITY_CHECK and
559            self.verbose > 0 and  # logger.QUIET
560            getattr(self, '_keys', None)):
561            check_sanity(self, self._keys, self.stdout)
562        return self
563
564    def view(self, cls):
565        '''New view of object with the same attributes.'''
566        obj = cls.__new__(cls)
567        obj.__dict__.update(self.__dict__)
568        return obj
569
570_warn_once_registry = {}
571def check_sanity(obj, keysref, stdout=sys.stdout):
572    '''Check misinput of class attributes, check whether a class method is
573    overwritten.  It does not check the attributes which are prefixed with
574    "_".
575    '''
576    objkeys = [x for x in obj.__dict__ if not x.startswith('_')]
577    keysub = set(objkeys) - set(keysref)
578    if keysub:
579        class_attr = set(dir(obj.__class__))
580        keyin = keysub.intersection(class_attr)
581        if keyin:
582            msg = ('Overwritten attributes  %s  of %s\n' %
583                   (' '.join(keyin), obj.__class__))
584            if msg not in _warn_once_registry:
585                _warn_once_registry[msg] = 1
586                sys.stderr.write(msg)
587                if stdout is not sys.stdout:
588                    stdout.write(msg)
589        keydiff = keysub - class_attr
590        if keydiff:
591            msg = ('%s does not have attributes  %s\n' %
592                   (obj.__class__, ' '.join(keydiff)))
593            if msg not in _warn_once_registry:
594                _warn_once_registry[msg] = 1
595                sys.stderr.write(msg)
596                if stdout is not sys.stdout:
597                    stdout.write(msg)
598    return obj
599
600def with_doc(doc):
601    '''Use this decorator to add doc string for function
602
603        @with_doc(doc)
604        def fn:
605            ...
606
607    is equivalent to
608
609        fn.__doc__ = doc
610    '''
611    def fn_with_doc(fn):
612        fn.__doc__ = doc
613        return fn
614    return fn_with_doc
615
616def alias(fn, alias_name=None):
617    '''
618    The statement "fn1 = alias(fn)" in a class is equivalent to define the
619    following method in the class:
620
621    .. code-block:: python
622        def fn1(self, *args, **kwargs):
623            return self.fn(*args, **kwargs)
624
625    Using alias function instead of fn1 = fn because some methods may be
626    overloaded in the child class. Using "alias" can make sure that the
627    overloaded mehods were called when calling the aliased method.
628    '''
629    fname = fn.__name__
630    def aliased_fn(self, *args, **kwargs):
631        return getattr(self, fname)(*args, **kwargs)
632
633    if alias_name is not None:
634        aliased_fn.__name__ = alias_name
635
636    doc_str = 'An alias to method %s\n' % fname
637    if sys.version_info >= (3,):
638        from inspect import signature
639        sig = str(signature(fn))
640        if alias_name is None:
641            doc_str += 'Function Signature: %s\n' % sig
642        else:
643            doc_str += 'Function Signature: %s%s\n' % (alias_name, sig)
644    doc_str += '----------------------------------------\n\n'
645
646    if fn.__doc__ is not None:
647        doc_str += fn.__doc__
648
649    aliased_fn.__doc__ = doc_str
650    return aliased_fn
651
652def class_as_method(cls):
653    '''
654    The statement "fn1 = alias(Class)" is equivalent to:
655
656    .. code-block:: python
657        def fn1(self, *args, **kwargs):
658            return Class(self, *args, **kwargs)
659    '''
660    def fn(obj, *args, **kwargs):
661        return cls(obj, *args, **kwargs)
662    fn.__doc__ = cls.__doc__
663    fn.__name__ = cls.__name__
664    fn.__module__ = cls.__module__
665    return fn
666
667def overwrite_mro(obj, mro):
668    '''A hacky function to overwrite the __mro__ attribute'''
669    class HackMRO(type):
670        pass
671# Overwrite type.mro function so that Temp class can use the given mro
672    HackMRO.mro = lambda self: mro
673    #if sys.version_info < (3,):
674    #    class Temp(obj.__class__):
675    #        __metaclass__ = HackMRO
676    #else:
677    #    class Temp(obj.__class__, metaclass=HackMRO):
678    #        pass
679    Temp = HackMRO(obj.__class__.__name__, obj.__class__.__bases__, obj.__dict__)
680    obj = Temp()
681# Delete mro function otherwise all subclass of Temp are not able to
682# resolve the right mro
683    del(HackMRO.mro)
684    return obj
685
686def izip(*args):
687    '''python2 izip == python3 zip'''
688    if sys.version_info < (3,):
689        return itertools.izip(*args)
690    else:
691        return zip(*args)
692
693class ProcessWithReturnValue(Process):
694    def __init__(self, group=None, target=None, name=None, args=(),
695                 kwargs=None):
696        self._q = Queue()
697        self._e = None
698        def qwrap(*args, **kwargs):
699            try:
700                self._q.put(target(*args, **kwargs))
701            except BaseException as e:
702                self._e = e
703                raise e
704        Process.__init__(self, group, qwrap, name, args, kwargs)
705    def join(self):
706        Process.join(self)
707        if self._e is not None:
708            raise ProcessRuntimeError('Error on process %s:\n%s' % (self, self._e))
709        else:
710            return self._q.get()
711    get = join
712
713class ProcessRuntimeError(RuntimeError):
714    pass
715
716class ThreadWithReturnValue(Thread):
717    def __init__(self, group=None, target=None, name=None, args=(),
718                 kwargs=None):
719        self._q = Queue()
720        self._e = None
721        def qwrap(*args, **kwargs):
722            try:
723                self._q.put(target(*args, **kwargs))
724            except BaseException as e:
725                self._e = e
726                raise e
727        Thread.__init__(self, group, qwrap, name, args, kwargs)
728    def join(self):
729        Thread.join(self)
730        if self._e is not None:
731            raise ThreadRuntimeError('Error on thread %s:\n%s' % (self, self._e))
732        else:
733            # Note: If the return value of target is huge, Queue.get may raise
734            # SystemError: NULL result without error in PyObject_Call
735            # It is because return value is cached somewhere by pickle but pickle is
736            # unable to handle huge amount of data.
737            return self._q.get()
738    get = join
739
740class ThreadWithTraceBack(Thread):
741    def __init__(self, group=None, target=None, name=None, args=(),
742                 kwargs=None):
743        self._e = None
744        def qwrap(*args, **kwargs):
745            try:
746                target(*args, **kwargs)
747            except BaseException as e:
748                self._e = e
749                raise e
750        Thread.__init__(self, group, qwrap, name, args, kwargs)
751    def join(self):
752        Thread.join(self)
753        if self._e is not None:
754            raise ThreadRuntimeError('Error on thread %s:\n%s' % (self, self._e))
755
756class ThreadRuntimeError(RuntimeError):
757    pass
758
759def background_thread(func, *args, **kwargs):
760    '''applying function in background'''
761    thread = ThreadWithReturnValue(target=func, args=args, kwargs=kwargs)
762    thread.start()
763    return thread
764
765def background_process(func, *args, **kwargs):
766    '''applying function in background'''
767    thread = ProcessWithReturnValue(target=func, args=args, kwargs=kwargs)
768    thread.start()
769    return thread
770
771bg = background = bg_thread = background_thread
772bp = bg_process = background_process
773
774ASYNC_IO = getattr(__config__, 'ASYNC_IO', True)
775class call_in_background(object):
776    '''Within this macro, function(s) can be executed asynchronously (the
777    given functions are executed in background).
778
779    Attributes:
780        sync (bool): Whether to run in synchronized mode.  The default value
781            is False (asynchoronized mode).
782
783    Examples:
784
785    >>> with call_in_background(fun) as async_fun:
786    ...     async_fun(a, b)  # == fun(a, b)
787    ...     do_something_else()
788
789    >>> with call_in_background(fun1, fun2) as (afun1, afun2):
790    ...     afun2(a, b)
791    ...     do_something_else()
792    ...     afun2(a, b)
793    ...     do_something_else()
794    ...     afun1(a, b)
795    ...     do_something_else()
796    '''
797
798    def __init__(self, *fns, **kwargs):
799        self.fns = fns
800        self.executor = None
801        self.handlers = [None] * len(self.fns)
802        self.sync = kwargs.get('sync', not ASYNC_IO)
803
804    if h5py.version.version[:4] == '2.2.': # h5py-2.2.* has bug in threading mode
805        # Disable back-ground mode
806        def __enter__(self):
807            if len(self.fns) == 1:
808                return self.fns[0]
809            else:
810                return self.fns
811
812    else:
813        def __enter__(self):
814            fns = self.fns
815            handlers = self.handlers
816            ntasks = len(self.fns)
817
818            global_import_lock = False
819            if sys.version_info < (3, 6):
820                import imp
821                global_import_lock = imp.lock_held()
822
823            if self.sync or global_import_lock:
824                # Some modules like nosetests, coverage etc
825                #   python -m unittest test_xxx.py  or  nosetests test_xxx.py
826                # hang when Python multi-threading was used in the import stage due to (Python
827                # import lock) bug in the threading module.  See also
828                # https://github.com/paramiko/paramiko/issues/104
829                # https://docs.python.org/2/library/threading.html#importing-in-threaded-code
830                # Disable the asynchoronous mode for safe importing
831                def def_async_fn(i):
832                    return fns[i]
833
834            elif ThreadPoolExecutor is None: # async mode, old python
835                def def_async_fn(i):
836                    def async_fn(*args, **kwargs):
837                        if self.handlers[i] is not None:
838                            self.handlers[i].join()
839                        self.handlers[i] = ThreadWithTraceBack(target=fns[i], args=args,
840                                                               kwargs=kwargs)
841                        self.handlers[i].start()
842                        return self.handlers[i]
843                    return async_fn
844
845            else: # multiple executors in async mode, python 2.7.12 or newer
846                executor = self.executor = ThreadPoolExecutor(max_workers=ntasks)
847                def def_async_fn(i):
848                    def async_fn(*args, **kwargs):
849                        if handlers[i] is not None:
850                            try:
851                                handlers[i].result()
852                            except Exception as e:
853                                raise ThreadRuntimeError('Error on thread %s:\n%s'
854                                                         % (self, e))
855                        handlers[i] = executor.submit(fns[i], *args, **kwargs)
856                        return handlers[i]
857                    return async_fn
858
859            if len(self.fns) == 1:
860                return def_async_fn(0)
861            else:
862                return [def_async_fn(i) for i in range(ntasks)]
863
864    def __exit__(self, type, value, traceback):
865        for handler in self.handlers:
866            if handler is not None:
867                try:
868                    if ThreadPoolExecutor is None:
869                        handler.join()
870                    else:
871                        handler.result()
872                except Exception as e:
873                    raise ThreadRuntimeError('Error on thread %s:\n%s' % (self, e))
874
875        if self.executor is not None:
876            self.executor.shutdown(wait=True)
877
878
879class H5TmpFile(h5py.File):
880    '''Create and return an HDF5 temporary file.
881
882    Kwargs:
883        filename : str or None
884            If a string is given, an HDF5 file of the given filename will be
885            created. The temporary file will exist even if the H5TmpFile
886            object is released.  If nothing is specified, the HDF5 temporary
887            file will be deleted when the H5TmpFile object is released.
888
889    The return object is an h5py.File object. The file will be automatically
890    deleted when it is closed or the object is released (unless filename is
891    specified).
892
893    Examples:
894
895    >>> from pyscf import lib
896    >>> ftmp = lib.H5TmpFile()
897    '''
898    def __init__(self, filename=None, mode='a', *args, **kwargs):
899        if filename is None:
900            tmpfile = tempfile.NamedTemporaryFile(dir=param.TMPDIR)
901            filename = tmpfile.name
902        h5py.File.__init__(self, filename, mode, *args, **kwargs)
903#FIXME: Does GC flush/close the HDF5 file when releasing the resource?
904# To make HDF5 file reusable, file has to be closed or flushed
905    def __del__(self):
906        try:
907            self.close()
908        except AttributeError:  # close not defined in old h5py
909            pass
910        except ValueError:  # if close() is called twice
911            pass
912        except ImportError:  # exit program before de-referring the object
913            pass
914
915def fingerprint(a):
916    '''Fingerprint of numpy array'''
917    a = numpy.asarray(a)
918    return numpy.dot(numpy.cos(numpy.arange(a.size)), a.ravel())
919finger = fp = fingerprint
920
921
922def ndpointer(*args, **kwargs):
923    base = numpy.ctypeslib.ndpointer(*args, **kwargs)
924
925    @classmethod
926    def from_param(cls, obj):
927        if obj is None:
928            return obj
929        return base.from_param(obj)
930    return type(base.__name__, (base,), {'from_param': from_param})
931
932
933# A tag to label the derived Scanner class
934class SinglePointScanner: pass
935class GradScanner:
936    def __init__(self, g):
937        self.__dict__.update(g.__dict__)
938        self.base = g.base.as_scanner()
939    @property
940    def e_tot(self):
941        return self.base.e_tot
942    @e_tot.setter
943    def e_tot(self, x):
944        self.base.e_tot = x
945
946    @property
947    def converged(self):
948        # Some base methods like MP2 does not have the attribute converged
949        conv = getattr(self.base, 'converged', True)
950        return conv
951
952class temporary_env(object):
953    '''Within the context of this macro, the attributes of the object are
954    temporarily updated. When the program goes out of the scope of the
955    context, the original value of each attribute will be restored.
956
957    Examples:
958
959    >>> with temporary_env(lib.param, LIGHT_SPEED=15., BOHR=2.5):
960    ...     print(lib.param.LIGHT_SPEED, lib.param.BOHR)
961    15. 2.5
962    >>> print(lib.param.LIGHT_SPEED, lib.param.BOHR)
963    137.03599967994 0.52917721092
964    '''
965    def __init__(self, obj, **kwargs):
966        self.obj = obj
967
968        # Should I skip the keys which are not presented in obj?
969        #keys = [key for key in kwargs.keys() if hasattr(obj, key)]
970        #self.env_bak = [(key, getattr(obj, key, 'TO_DEL')) for key in keys]
971        #self.env_new = [(key, kwargs[key]) for key in keys]
972
973        self.env_bak = [(key, getattr(obj, key, 'TO_DEL')) for key in kwargs]
974        self.env_new = [(key, kwargs[key]) for key in kwargs]
975
976    def __enter__(self):
977        for k, v in self.env_new:
978            setattr(self.obj, k, v)
979        return self
980
981    def __exit__(self, type, value, traceback):
982        for k, v in self.env_bak:
983            if isinstance(v, str) and v == 'TO_DEL':
984                delattr(self.obj, k)
985            else:
986                setattr(self.obj, k, v)
987
988class light_speed(temporary_env):
989    '''Within the context of this macro, the environment varialbe LIGHT_SPEED
990    can be customized.
991
992    Examples:
993
994    >>> with light_speed(15.):
995    ...     print(lib.param.LIGHT_SPEED)
996    15.
997    >>> print(lib.param.LIGHT_SPEED)
998    137.03599967994
999    '''
1000    def __init__(self, c):
1001        temporary_env.__init__(self, param, LIGHT_SPEED=c)
1002        self.c = c
1003    def __enter__(self):
1004        temporary_env.__enter__(self)
1005        return self.c
1006
1007def repo_info(repo_path):
1008    '''
1009    Repo location, version, git branch and commit ID
1010    '''
1011
1012    def git_version(orig_head, head, branch):
1013        git_version = []
1014        if orig_head:
1015            git_version.append('GIT ORIG_HEAD %s' % orig_head)
1016        if branch:
1017            git_version.append('GIT HEAD (branch %s) %s' % (branch, head))
1018        elif head:
1019            git_version.append('GIT HEAD      %s' % head)
1020        return '\n'.join(git_version)
1021
1022    repo_path = os.path.abspath(repo_path)
1023
1024    if os.path.isdir(os.path.join(repo_path, '.git')):
1025        git_str = git_version(*git_info(repo_path))
1026
1027    elif os.path.isdir(os.path.abspath(os.path.join(repo_path, '..', '.git'))):
1028        repo_path = os.path.abspath(os.path.join(repo_path, '..'))
1029        git_str = git_version(*git_info(repo_path))
1030
1031    else:
1032        git_str = None
1033
1034    # TODO: Add info of BLAS, libcint, libxc, libxcfun, tblis if applicable
1035
1036    info = {'path': repo_path}
1037    if git_str:
1038        info['git'] = git_str
1039    return info
1040
1041def git_info(repo_path):
1042    orig_head = None
1043    head = None
1044    branch = None
1045    try:
1046        with open(os.path.join(repo_path, '.git', 'ORIG_HEAD'), 'r') as f:
1047            orig_head = f.read().strip()
1048    except IOError:
1049        pass
1050
1051    try:
1052        head = os.path.join(repo_path, '.git', 'HEAD')
1053        with open(head, 'r') as f:
1054            head = f.read().splitlines()[0].strip()
1055
1056        if head.startswith('ref:'):
1057            branch = os.path.basename(head)
1058            with open(os.path.join(repo_path, '.git', head.split(' ')[1]), 'r') as f:
1059                head = f.read().strip()
1060    except IOError:
1061        pass
1062    return orig_head, head, branch
1063
1064
1065def isinteger(obj):
1066    '''
1067    Check if an object is an integer.
1068    '''
1069    # A bool is also an int in python, but we don't want that.
1070    # On the other hand, numpy.bool_ is probably not a numpy.integer, but just to be sure...
1071    if isinstance(obj, (bool, numpy.bool_)):
1072        return False
1073    # These are actual ints we expect to encounter.
1074    else:
1075        return isinstance(obj, (int, numpy.integer))
1076
1077
1078def issequence(obj):
1079    '''
1080    Determine if the object provided is a sequence.
1081    '''
1082    # These are the types of sequences that we permit.
1083    # numpy.ndarray is not a subclass of collections.abc.Sequence as of version 1.19.
1084    sequence_types = (collections.abc.Sequence, numpy.ndarray)
1085    return isinstance(obj, sequence_types)
1086
1087
1088def isintsequence(obj):
1089    '''
1090    Determine if the object provided is a sequence of integers.
1091    '''
1092    if not issequence(obj):
1093        return False
1094    elif isinstance(obj, numpy.ndarray):
1095        return issubclass(obj.dtype.type, numpy.integer)
1096    else:
1097        are_ints = True
1098        for i in obj:
1099            are_ints = are_ints and isinteger(i)
1100        return are_ints
1101
1102
1103if __name__ == '__main__':
1104    for i,j in prange_tril(0, 90, 300):
1105        print(i, j, j*(j+1)//2-i*(i+1)//2)
1106