1#!/usr/bin/env python 2# Copyright 2014-2020 The PySCF Developers. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16# Author: Qiming Sun <osirpt.sun@gmail.com> 17# 18 19''' 20Some helper functions 21''' 22 23import os, sys 24import warnings 25import tempfile 26import functools 27import itertools 28import collections 29import ctypes 30import numpy 31import h5py 32from threading import Thread 33from multiprocessing import Queue, Process 34try: 35 from concurrent.futures import ThreadPoolExecutor 36except ImportError: 37 ThreadPoolExecutor = None 38 39from pyscf.lib import param 40from pyscf import __config__ 41 42if h5py.version.version[:4] == '2.2.': 43 sys.stderr.write('h5py-%s is found in your environment. ' 44 'h5py-%s has bug in threading mode.\n' 45 'Async-IO is disabled.\n' % ((h5py.version.version,)*2)) 46 47c_double_p = ctypes.POINTER(ctypes.c_double) 48c_int_p = ctypes.POINTER(ctypes.c_int) 49c_null_ptr = ctypes.POINTER(ctypes.c_void_p) 50 51def load_library(libname): 52 try: 53 _loaderpath = os.path.dirname(__file__) 54 return numpy.ctypeslib.load_library(libname, _loaderpath) 55 except OSError: 56 from pyscf import __path__ as ext_modules 57 for path in ext_modules: 58 libpath = os.path.join(path, 'lib') 59 if os.path.isdir(libpath): 60 for files in os.listdir(libpath): 61 if files.startswith(libname): 62 return numpy.ctypeslib.load_library(libname, libpath) 63 raise 64 65#Fixme, the standard resouce module gives wrong number when objects are released 66# http://fa.bianp.net/blog/2013/different-ways-to-get-memory-consumption-or-lessons-learned-from-memory_profiler/#fn:1 67#or use slow functions as memory_profiler._get_memory did 68CLOCK_TICKS = os.sysconf("SC_CLK_TCK") 69PAGESIZE = os.sysconf("SC_PAGE_SIZE") 70def current_memory(): 71 '''Return the size of used memory and allocated virtual memory (in MB)''' 72 #import resource 73 #return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1000 74 if sys.platform.startswith('linux'): 75 with open("/proc/%s/statm" % os.getpid()) as f: 76 vms, rss = [int(x)*PAGESIZE for x in f.readline().split()[:2]] 77 return rss/1e6, vms/1e6 78 else: 79 return 0, 0 80 81def num_threads(n=None): 82 '''Set the number of OMP threads. If argument is not specified, the 83 function will return the total number of available OMP threads. 84 85 It's recommended to call this function to set OMP threads than 86 "os.environ['OMP_NUM_THREADS'] = int(n)". This is because environment 87 variables like OMP_NUM_THREADS were read when a module was imported. They 88 cannot be reset through os.environ after the module was loaded. 89 90 Examples: 91 92 >>> from pyscf import lib 93 >>> print(lib.num_threads()) 94 8 95 >>> lib.num_threads(4) 96 4 97 >>> print(lib.num_threads()) 98 4 99 ''' 100 from pyscf.lib.numpy_helper import _np_helper 101 if n is not None: 102 _np_helper.set_omp_threads.restype = ctypes.c_int 103 threads = _np_helper.set_omp_threads(ctypes.c_int(int(n))) 104 if threads == 0: 105 warnings.warn('OpenMP is not available. ' 106 'Setting omp_threads to %s has no effects.' % n) 107 return threads 108 else: 109 _np_helper.get_omp_threads.restype = ctypes.c_int 110 return _np_helper.get_omp_threads() 111 112class with_omp_threads(object): 113 '''Using this macro to create a temporary context in which the number of 114 OpenMP threads are set to the required value. When the program exits the 115 context, the number OpenMP threads will be restored. 116 117 Args: 118 nthreads : int 119 120 Examples: 121 122 >>> from pyscf import lib 123 >>> print(lib.num_threads()) 124 8 125 >>> with lib.with_omp_threads(2): 126 ... print(lib.num_threads()) 127 2 128 >>> print(lib.num_threads()) 129 8 130 ''' 131 def __init__(self, nthreads=None): 132 self.nthreads = nthreads 133 self.sys_threads = None 134 def __enter__(self): 135 if self.nthreads is not None and self.nthreads >= 1: 136 self.sys_threads = num_threads() 137 num_threads(self.nthreads) 138 return self 139 def __exit__(self, type, value, traceback): 140 if self.sys_threads is not None: 141 num_threads(self.sys_threads) 142 143 144def c_int_arr(m): 145 npm = numpy.array(m).flatten('C') 146 arr = (ctypes.c_int * npm.size)(*npm) 147 # cannot return LP_c_double class, 148 #Xreturn npm.ctypes.data_as(c_int_p), which destructs npm before return 149 return arr 150def f_int_arr(m): 151 npm = numpy.array(m).flatten('F') 152 arr = (ctypes.c_int * npm.size)(*npm) 153 return arr 154def c_double_arr(m): 155 npm = numpy.array(m).flatten('C') 156 arr = (ctypes.c_double * npm.size)(*npm) 157 return arr 158def f_double_arr(m): 159 npm = numpy.array(m).flatten('F') 160 arr = (ctypes.c_double * npm.size)(*npm) 161 return arr 162 163 164def member(test, x, lst): 165 for l in lst: 166 if test(x, l): 167 return True 168 return False 169 170def remove_dup(test, lst, from_end=False): 171 if test is None: 172 return set(lst) 173 else: 174 if from_end: 175 lst = list(reversed(lst)) 176 seen = [] 177 for l in lst: 178 if not member(test, l, seen): 179 seen.append(l) 180 return seen 181 182def remove_if(test, lst): 183 return [x for x in lst if not test(x)] 184 185def find_if(test, lst): 186 for l in lst: 187 if test(l): 188 return l 189 raise ValueError('No element of the given list matches the test condition.') 190 191def arg_first_match(test, lst): 192 for i,x in enumerate(lst): 193 if test(x): 194 return i 195 raise ValueError('No element of the given list matches the test condition.') 196 197def _balanced_partition(cum, ntasks): 198 segsize = float(cum[-1]) / ntasks 199 bounds = numpy.arange(ntasks+1) * segsize 200 displs = abs(bounds[:,None] - cum).argmin(axis=1) 201 return displs 202 203def _blocksize_partition(cum, blocksize): 204 n = len(cum) - 1 205 displs = [0] 206 if n == 0: 207 return displs 208 209 p0 = 0 210 for i in range(1, n): 211 if cum[i+1]-cum[p0] > blocksize: 212 displs.append(i) 213 p0 = i 214 displs.append(n) 215 return displs 216 217def flatten(lst): 218 '''flatten nested lists 219 x[0] + x[1] + x[2] + ... 220 221 Examples: 222 223 >>> flatten([[0, 2], [1], [[9, 8, 7]]]) 224 [0, 2, 1, [9, 8, 7]] 225 ''' 226 return list(itertools.chain.from_iterable(lst)) 227 228def prange(start, end, step): 229 '''This function splits the number sequence between "start" and "end" 230 using uniform "step" length. It yields the boundary (start, end) for each 231 fragment. 232 233 Examples: 234 235 >>> for p0, p1 in lib.prange(0, 8, 2): 236 ... print(p0, p1) 237 (0, 2) 238 (2, 4) 239 (4, 6) 240 (6, 8) 241 ''' 242 if start < end: 243 for i in range(start, end, step): 244 yield i, min(i+step, end) 245 246def prange_tril(start, stop, blocksize): 247 '''Similar to :func:`prange`, yeilds start (p0) and end (p1) with the 248 restriction p1*(p1+1)/2-p0*(p0+1)/2 < blocksize 249 250 Examples: 251 252 >>> for p0, p1 in lib.prange_tril(0, 10, 25): 253 ... print(p0, p1) 254 (0, 6) 255 (6, 9) 256 (9, 10) 257 ''' 258 if start >= stop: 259 return [] 260 idx = numpy.arange(start, stop+1) 261 cum_costs = idx*(idx+1)//2 - start*(start+1)//2 262 displs = [x+start for x in _blocksize_partition(cum_costs, blocksize)] 263 return zip(displs[:-1], displs[1:]) 264 265def map_with_prefetch(func, *iterables): 266 ''' 267 Apply function to an task and prefetch the next task 268 ''' 269 global_import_lock = False 270 if sys.version_info < (3, 6): 271 import imp 272 global_import_lock = imp.lock_held() 273 274 if global_import_lock: 275 for task in zip(*iterables): 276 yield func(*task) 277 278 elif ThreadPoolExecutor is not None: 279 with ThreadPoolExecutor(max_workers=1) as executor: 280 future = None 281 for task in zip(*iterables): 282 if future is None: 283 future = executor.submit(func, *task) 284 else: 285 result = future.result() 286 future = executor.submit(func, *task) 287 yield result 288 if future is not None: 289 yield future.result() 290 else: 291 def func_with_buf(_output_buf, *args): 292 _output_buf[0] = func(*args) 293 with call_in_background(func_with_buf) as f_prefetch: 294 buf0, buf1 = [None], [None] 295 for istep, task in enumerate(zip(*iterables)): 296 if istep == 0: 297 f_prefetch(buf0, *task) 298 else: 299 buf0, buf1 = buf1, buf0 300 f_prefetch(buf0, *task) 301 yield buf1[0] 302 if buf0[0] is not None: 303 yield buf0[0] 304 305def index_tril_to_pair(ij): 306 '''Given tril-index ij, compute the pair indices (i,j) which satisfy 307 ij = i * (i+1) / 2 + j 308 ''' 309 i = (numpy.sqrt(2*ij+.25) - .5 + 1e-7).astype(int) 310 j = ij - i*(i+1)//2 311 return i, j 312 313 314def tril_product(*iterables, **kwds): 315 '''Cartesian product in lower-triangular form for multiple indices 316 317 For a given list of indices (`iterables`), this function yields all 318 indices such that the sub-indices given by the kwarg `tril_idx` satisfy a 319 lower-triangular form. The lower-triangular form satisfies: 320 321 .. math:: i[tril_idx[0]] >= i[tril_idx[1]] >= ... >= i[tril_idx[len(tril_idx)-1]] 322 323 Args: 324 *iterables: Variable length argument list of indices for the cartesian product 325 **kwds: Arbitrary keyword arguments. Acceptable keywords include: 326 repeat (int): Number of times to repeat the iterables 327 tril_idx (array_like): Indices to put into lower-triangular form. 328 329 Yields: 330 product (tuple): Tuple in lower-triangular form. 331 332 Examples: 333 Specifying no `tril_idx` is equivalent to just a cartesian product. 334 335 >>> list(tril_product(range(2), repeat=2)) 336 [(0, 0), (0, 1), (1, 0), (1, 1)] 337 338 We can specify only sub-indices to satisfy a lower-triangular form: 339 340 >>> list(tril_product(range(2), repeat=3, tril_idx=[1,2])) 341 [(0, 0, 0), (0, 1, 0), (0, 1, 1), (1, 0, 0), (1, 1, 0), (1, 1, 1)] 342 343 We specify all indices to satisfy a lower-triangular form, useful for iterating over 344 the symmetry unique elements of occupied/virtual orbitals in a 3-particle operator: 345 346 >>> list(tril_product(range(3), repeat=3, tril_idx=[0,1,2])) 347 [(0, 0, 0), (1, 0, 0), (1, 1, 0), (1, 1, 1), (2, 0, 0), (2, 1, 0), (2, 1, 1), (2, 2, 0), (2, 2, 1), (2, 2, 2)] 348 ''' 349 repeat = kwds.get('repeat', 1) 350 tril_idx = kwds.get('tril_idx', []) 351 niterables = len(iterables) * repeat 352 ntril_idx = len(tril_idx) 353 354 assert ntril_idx <= niterables, 'Cant have a greater number of tril indices than iterables!' 355 if ntril_idx > 0: 356 assert numpy.max(tril_idx) < niterables, 'Tril index out of bounds for %d iterables! idx = %s' % \ 357 (niterables, tril_idx) 358 for tup in itertools.product(*iterables, repeat=repeat): 359 if ntril_idx == 0: 360 yield tup 361 continue 362 363 if all([tup[tril_idx[i]] >= tup[tril_idx[i+1]] for i in range(ntril_idx-1)]): 364 yield tup 365 else: 366 pass 367 368def square_mat_in_trilu_indices(n): 369 '''Return a n x n symmetric index matrix, in which the elements are the 370 indices of the unique elements of a tril vector 371 [0 1 3 ... ] 372 [1 2 4 ... ] 373 [3 4 5 ... ] 374 [... ] 375 ''' 376 idx = numpy.tril_indices(n) 377 tril2sq = numpy.zeros((n,n), dtype=int) 378 tril2sq[idx[0],idx[1]] = tril2sq[idx[1],idx[0]] = numpy.arange(n*(n+1)//2) 379 return tril2sq 380 381class capture_stdout(object): 382 '''redirect all stdout (c printf & python print) into a string 383 384 Examples: 385 386 >>> import os 387 >>> from pyscf import lib 388 >>> with lib.capture_stdout() as out: 389 ... os.system('ls') 390 >>> print(out.read()) 391 ''' 392 #TODO: handle stderr 393 def __enter__(self): 394 sys.stdout.flush() 395 self._contents = None 396 self.old_stdout_fileno = sys.stdout.fileno() 397 self.bak_stdout_fd = os.dup(self.old_stdout_fileno) 398 self.ftmp = tempfile.NamedTemporaryFile(dir=param.TMPDIR) 399 os.dup2(self.ftmp.file.fileno(), self.old_stdout_fileno) 400 return self 401 def __exit__(self, type, value, traceback): 402 sys.stdout.flush() 403 self.ftmp.file.seek(0) 404 self._contents = self.ftmp.file.read() 405 self.ftmp.close() 406 os.dup2(self.bak_stdout_fd, self.old_stdout_fileno) 407 os.close(self.bak_stdout_fd) 408 def read(self): 409 if self._contents: 410 return self._contents 411 else: 412 sys.stdout.flush() 413 self.ftmp.file.seek(0) 414 return self.ftmp.file.read() 415ctypes_stdout = capture_stdout 416 417class quite_run(object): 418 '''capture all stdout (c printf & python print) but output nothing 419 420 Examples: 421 422 >>> import os 423 >>> from pyscf import lib 424 >>> with lib.quite_run(): 425 ... os.system('ls') 426 ''' 427 def __enter__(self): 428 sys.stdout.flush() 429 #TODO: to handle the redirected stdout e.g. StringIO() 430 self.old_stdout_fileno = sys.stdout.fileno() 431 self.bak_stdout_fd = os.dup(self.old_stdout_fileno) 432 self.fnull = open(os.devnull, 'wb') 433 os.dup2(self.fnull.fileno(), self.old_stdout_fileno) 434 def __exit__(self, type, value, traceback): 435 sys.stdout.flush() 436 os.dup2(self.bak_stdout_fd, self.old_stdout_fileno) 437 self.fnull.close() 438 439 440# from pygeocoder 441# this decorator lets me use methods as both static and instance methods 442# In contrast to classmethod, when obj.function() is called, the first 443# argument is obj in omnimethod rather than obj.__class__ in classmethod 444class omnimethod(object): 445 def __init__(self, func): 446 self.func = func 447 448 def __get__(self, instance, owner): 449 return functools.partial(self.func, instance) 450 451 452SANITY_CHECK = getattr(__config__, 'SANITY_CHECK', True) 453class StreamObject(object): 454 '''For most methods, there are three stream functions to pipe computing stream: 455 456 1 ``.set_`` function to update object attributes, eg 457 ``mf = scf.RHF(mol).set(conv_tol=1e-5)`` is identical to proceed in two steps 458 ``mf = scf.RHF(mol); mf.conv_tol=1e-5`` 459 460 2 ``.run`` function to execute the kenerl function (the function arguments 461 are passed to kernel function). If keyword arguments is given, it will first 462 call ``.set`` function to update object attributes then execute the kernel 463 function. Eg 464 ``mf = scf.RHF(mol).run(dm_init, conv_tol=1e-5)`` is identical to three steps 465 ``mf = scf.RHF(mol); mf.conv_tol=1e-5; mf.kernel(dm_init)`` 466 467 3 ``.apply`` function to apply the given function/class to the current object 468 (function arguments and keyword arguments are passed to the given function). 469 Eg 470 ``mol.apply(scf.RHF).run().apply(mcscf.CASSCF, 6, 4, frozen=4)`` is identical to 471 ``mf = scf.RHF(mol); mf.kernel(); mcscf.CASSCF(mf, 6, 4, frozen=4)`` 472 ''' 473 474 verbose = 0 475 stdout = sys.stdout 476 _keys = set(['verbose', 'stdout']) 477 478 def kernel(self, *args, **kwargs): 479 ''' 480 Kernel function is the main driver of a method. Every method should 481 define the kernel function as the entry of the calculation. Note the 482 return value of kernel function is not strictly defined. It can be 483 anything related to the method (such as the energy, the wave-function, 484 the DFT mesh grids etc.). 485 ''' 486 pass 487 488 def pre_kernel(self, envs): 489 ''' 490 A hook to be run before the main body of kernel function is executed. 491 Internal variables are exposed to pre_kernel through the "envs" 492 dictionary. Return value of pre_kernel function is not required. 493 ''' 494 pass 495 496 def post_kernel(self, envs): 497 ''' 498 A hook to be run after the main body of the kernel function. Internal 499 variables are exposed to post_kernel through the "envs" dictionary. 500 Return value of post_kernel function is not required. 501 ''' 502 pass 503 504 def run(self, *args, **kwargs): 505 ''' 506 Call the kernel function of current object. `args` will be passed 507 to kernel function. `kwargs` will be used to update the attributes of 508 current object. The return value of method run is the object itself. 509 This allows a series of functions/methods to be executed in pipe. 510 ''' 511 self.set(**kwargs) 512 self.kernel(*args) 513 return self 514 515 def set(self, *args, **kwargs): 516 ''' 517 Update the attributes of the current object. The return value of 518 method set is the object itself. This allows a series of 519 functions/methods to be executed in pipe. 520 ''' 521 if args: 522 warnings.warn('method set() only supports keyword arguments.\n' 523 'Arguments %s are ignored.' % args) 524 #if getattr(self, '_keys', None): 525 # for k,v in kwargs.items(): 526 # setattr(self, k, v) 527 # if k not in self._keys: 528 # sys.stderr.write('Warning: %s does not have attribute %s\n' 529 # % (self.__class__, k)) 530 #else: 531 for k,v in kwargs.items(): 532 setattr(self, k, v) 533 return self 534 535 # An alias to .set method 536 __call__ = set 537 538 def apply(self, fn, *args, **kwargs): 539 ''' 540 Apply the fn to rest arguments: return fn(*args, **kwargs). The 541 return value of method set is the object itself. This allows a series 542 of functions/methods to be executed in pipe. 543 ''' 544 return fn(self, *args, **kwargs) 545 546# def _format_args(self, args, kwargs, kernel_kw_lst): 547# args1 = [kwargs.pop(k, v) for k, v in kernel_kw_lst] 548# return args + args1[len(args):], kwargs 549 550 def check_sanity(self): 551 ''' 552 Check input of class/object attributes, check whether a class method is 553 overwritten. It does not check the attributes which are prefixed with 554 "_". The 555 return value of method set is the object itself. This allows a series 556 of functions/methods to be executed in pipe. 557 ''' 558 if (SANITY_CHECK and 559 self.verbose > 0 and # logger.QUIET 560 getattr(self, '_keys', None)): 561 check_sanity(self, self._keys, self.stdout) 562 return self 563 564 def view(self, cls): 565 '''New view of object with the same attributes.''' 566 obj = cls.__new__(cls) 567 obj.__dict__.update(self.__dict__) 568 return obj 569 570_warn_once_registry = {} 571def check_sanity(obj, keysref, stdout=sys.stdout): 572 '''Check misinput of class attributes, check whether a class method is 573 overwritten. It does not check the attributes which are prefixed with 574 "_". 575 ''' 576 objkeys = [x for x in obj.__dict__ if not x.startswith('_')] 577 keysub = set(objkeys) - set(keysref) 578 if keysub: 579 class_attr = set(dir(obj.__class__)) 580 keyin = keysub.intersection(class_attr) 581 if keyin: 582 msg = ('Overwritten attributes %s of %s\n' % 583 (' '.join(keyin), obj.__class__)) 584 if msg not in _warn_once_registry: 585 _warn_once_registry[msg] = 1 586 sys.stderr.write(msg) 587 if stdout is not sys.stdout: 588 stdout.write(msg) 589 keydiff = keysub - class_attr 590 if keydiff: 591 msg = ('%s does not have attributes %s\n' % 592 (obj.__class__, ' '.join(keydiff))) 593 if msg not in _warn_once_registry: 594 _warn_once_registry[msg] = 1 595 sys.stderr.write(msg) 596 if stdout is not sys.stdout: 597 stdout.write(msg) 598 return obj 599 600def with_doc(doc): 601 '''Use this decorator to add doc string for function 602 603 @with_doc(doc) 604 def fn: 605 ... 606 607 is equivalent to 608 609 fn.__doc__ = doc 610 ''' 611 def fn_with_doc(fn): 612 fn.__doc__ = doc 613 return fn 614 return fn_with_doc 615 616def alias(fn, alias_name=None): 617 ''' 618 The statement "fn1 = alias(fn)" in a class is equivalent to define the 619 following method in the class: 620 621 .. code-block:: python 622 def fn1(self, *args, **kwargs): 623 return self.fn(*args, **kwargs) 624 625 Using alias function instead of fn1 = fn because some methods may be 626 overloaded in the child class. Using "alias" can make sure that the 627 overloaded mehods were called when calling the aliased method. 628 ''' 629 fname = fn.__name__ 630 def aliased_fn(self, *args, **kwargs): 631 return getattr(self, fname)(*args, **kwargs) 632 633 if alias_name is not None: 634 aliased_fn.__name__ = alias_name 635 636 doc_str = 'An alias to method %s\n' % fname 637 if sys.version_info >= (3,): 638 from inspect import signature 639 sig = str(signature(fn)) 640 if alias_name is None: 641 doc_str += 'Function Signature: %s\n' % sig 642 else: 643 doc_str += 'Function Signature: %s%s\n' % (alias_name, sig) 644 doc_str += '----------------------------------------\n\n' 645 646 if fn.__doc__ is not None: 647 doc_str += fn.__doc__ 648 649 aliased_fn.__doc__ = doc_str 650 return aliased_fn 651 652def class_as_method(cls): 653 ''' 654 The statement "fn1 = alias(Class)" is equivalent to: 655 656 .. code-block:: python 657 def fn1(self, *args, **kwargs): 658 return Class(self, *args, **kwargs) 659 ''' 660 def fn(obj, *args, **kwargs): 661 return cls(obj, *args, **kwargs) 662 fn.__doc__ = cls.__doc__ 663 fn.__name__ = cls.__name__ 664 fn.__module__ = cls.__module__ 665 return fn 666 667def overwrite_mro(obj, mro): 668 '''A hacky function to overwrite the __mro__ attribute''' 669 class HackMRO(type): 670 pass 671# Overwrite type.mro function so that Temp class can use the given mro 672 HackMRO.mro = lambda self: mro 673 #if sys.version_info < (3,): 674 # class Temp(obj.__class__): 675 # __metaclass__ = HackMRO 676 #else: 677 # class Temp(obj.__class__, metaclass=HackMRO): 678 # pass 679 Temp = HackMRO(obj.__class__.__name__, obj.__class__.__bases__, obj.__dict__) 680 obj = Temp() 681# Delete mro function otherwise all subclass of Temp are not able to 682# resolve the right mro 683 del(HackMRO.mro) 684 return obj 685 686def izip(*args): 687 '''python2 izip == python3 zip''' 688 if sys.version_info < (3,): 689 return itertools.izip(*args) 690 else: 691 return zip(*args) 692 693class ProcessWithReturnValue(Process): 694 def __init__(self, group=None, target=None, name=None, args=(), 695 kwargs=None): 696 self._q = Queue() 697 self._e = None 698 def qwrap(*args, **kwargs): 699 try: 700 self._q.put(target(*args, **kwargs)) 701 except BaseException as e: 702 self._e = e 703 raise e 704 Process.__init__(self, group, qwrap, name, args, kwargs) 705 def join(self): 706 Process.join(self) 707 if self._e is not None: 708 raise ProcessRuntimeError('Error on process %s:\n%s' % (self, self._e)) 709 else: 710 return self._q.get() 711 get = join 712 713class ProcessRuntimeError(RuntimeError): 714 pass 715 716class ThreadWithReturnValue(Thread): 717 def __init__(self, group=None, target=None, name=None, args=(), 718 kwargs=None): 719 self._q = Queue() 720 self._e = None 721 def qwrap(*args, **kwargs): 722 try: 723 self._q.put(target(*args, **kwargs)) 724 except BaseException as e: 725 self._e = e 726 raise e 727 Thread.__init__(self, group, qwrap, name, args, kwargs) 728 def join(self): 729 Thread.join(self) 730 if self._e is not None: 731 raise ThreadRuntimeError('Error on thread %s:\n%s' % (self, self._e)) 732 else: 733 # Note: If the return value of target is huge, Queue.get may raise 734 # SystemError: NULL result without error in PyObject_Call 735 # It is because return value is cached somewhere by pickle but pickle is 736 # unable to handle huge amount of data. 737 return self._q.get() 738 get = join 739 740class ThreadWithTraceBack(Thread): 741 def __init__(self, group=None, target=None, name=None, args=(), 742 kwargs=None): 743 self._e = None 744 def qwrap(*args, **kwargs): 745 try: 746 target(*args, **kwargs) 747 except BaseException as e: 748 self._e = e 749 raise e 750 Thread.__init__(self, group, qwrap, name, args, kwargs) 751 def join(self): 752 Thread.join(self) 753 if self._e is not None: 754 raise ThreadRuntimeError('Error on thread %s:\n%s' % (self, self._e)) 755 756class ThreadRuntimeError(RuntimeError): 757 pass 758 759def background_thread(func, *args, **kwargs): 760 '''applying function in background''' 761 thread = ThreadWithReturnValue(target=func, args=args, kwargs=kwargs) 762 thread.start() 763 return thread 764 765def background_process(func, *args, **kwargs): 766 '''applying function in background''' 767 thread = ProcessWithReturnValue(target=func, args=args, kwargs=kwargs) 768 thread.start() 769 return thread 770 771bg = background = bg_thread = background_thread 772bp = bg_process = background_process 773 774ASYNC_IO = getattr(__config__, 'ASYNC_IO', True) 775class call_in_background(object): 776 '''Within this macro, function(s) can be executed asynchronously (the 777 given functions are executed in background). 778 779 Attributes: 780 sync (bool): Whether to run in synchronized mode. The default value 781 is False (asynchoronized mode). 782 783 Examples: 784 785 >>> with call_in_background(fun) as async_fun: 786 ... async_fun(a, b) # == fun(a, b) 787 ... do_something_else() 788 789 >>> with call_in_background(fun1, fun2) as (afun1, afun2): 790 ... afun2(a, b) 791 ... do_something_else() 792 ... afun2(a, b) 793 ... do_something_else() 794 ... afun1(a, b) 795 ... do_something_else() 796 ''' 797 798 def __init__(self, *fns, **kwargs): 799 self.fns = fns 800 self.executor = None 801 self.handlers = [None] * len(self.fns) 802 self.sync = kwargs.get('sync', not ASYNC_IO) 803 804 if h5py.version.version[:4] == '2.2.': # h5py-2.2.* has bug in threading mode 805 # Disable back-ground mode 806 def __enter__(self): 807 if len(self.fns) == 1: 808 return self.fns[0] 809 else: 810 return self.fns 811 812 else: 813 def __enter__(self): 814 fns = self.fns 815 handlers = self.handlers 816 ntasks = len(self.fns) 817 818 global_import_lock = False 819 if sys.version_info < (3, 6): 820 import imp 821 global_import_lock = imp.lock_held() 822 823 if self.sync or global_import_lock: 824 # Some modules like nosetests, coverage etc 825 # python -m unittest test_xxx.py or nosetests test_xxx.py 826 # hang when Python multi-threading was used in the import stage due to (Python 827 # import lock) bug in the threading module. See also 828 # https://github.com/paramiko/paramiko/issues/104 829 # https://docs.python.org/2/library/threading.html#importing-in-threaded-code 830 # Disable the asynchoronous mode for safe importing 831 def def_async_fn(i): 832 return fns[i] 833 834 elif ThreadPoolExecutor is None: # async mode, old python 835 def def_async_fn(i): 836 def async_fn(*args, **kwargs): 837 if self.handlers[i] is not None: 838 self.handlers[i].join() 839 self.handlers[i] = ThreadWithTraceBack(target=fns[i], args=args, 840 kwargs=kwargs) 841 self.handlers[i].start() 842 return self.handlers[i] 843 return async_fn 844 845 else: # multiple executors in async mode, python 2.7.12 or newer 846 executor = self.executor = ThreadPoolExecutor(max_workers=ntasks) 847 def def_async_fn(i): 848 def async_fn(*args, **kwargs): 849 if handlers[i] is not None: 850 try: 851 handlers[i].result() 852 except Exception as e: 853 raise ThreadRuntimeError('Error on thread %s:\n%s' 854 % (self, e)) 855 handlers[i] = executor.submit(fns[i], *args, **kwargs) 856 return handlers[i] 857 return async_fn 858 859 if len(self.fns) == 1: 860 return def_async_fn(0) 861 else: 862 return [def_async_fn(i) for i in range(ntasks)] 863 864 def __exit__(self, type, value, traceback): 865 for handler in self.handlers: 866 if handler is not None: 867 try: 868 if ThreadPoolExecutor is None: 869 handler.join() 870 else: 871 handler.result() 872 except Exception as e: 873 raise ThreadRuntimeError('Error on thread %s:\n%s' % (self, e)) 874 875 if self.executor is not None: 876 self.executor.shutdown(wait=True) 877 878 879class H5TmpFile(h5py.File): 880 '''Create and return an HDF5 temporary file. 881 882 Kwargs: 883 filename : str or None 884 If a string is given, an HDF5 file of the given filename will be 885 created. The temporary file will exist even if the H5TmpFile 886 object is released. If nothing is specified, the HDF5 temporary 887 file will be deleted when the H5TmpFile object is released. 888 889 The return object is an h5py.File object. The file will be automatically 890 deleted when it is closed or the object is released (unless filename is 891 specified). 892 893 Examples: 894 895 >>> from pyscf import lib 896 >>> ftmp = lib.H5TmpFile() 897 ''' 898 def __init__(self, filename=None, mode='a', *args, **kwargs): 899 if filename is None: 900 tmpfile = tempfile.NamedTemporaryFile(dir=param.TMPDIR) 901 filename = tmpfile.name 902 h5py.File.__init__(self, filename, mode, *args, **kwargs) 903#FIXME: Does GC flush/close the HDF5 file when releasing the resource? 904# To make HDF5 file reusable, file has to be closed or flushed 905 def __del__(self): 906 try: 907 self.close() 908 except AttributeError: # close not defined in old h5py 909 pass 910 except ValueError: # if close() is called twice 911 pass 912 except ImportError: # exit program before de-referring the object 913 pass 914 915def fingerprint(a): 916 '''Fingerprint of numpy array''' 917 a = numpy.asarray(a) 918 return numpy.dot(numpy.cos(numpy.arange(a.size)), a.ravel()) 919finger = fp = fingerprint 920 921 922def ndpointer(*args, **kwargs): 923 base = numpy.ctypeslib.ndpointer(*args, **kwargs) 924 925 @classmethod 926 def from_param(cls, obj): 927 if obj is None: 928 return obj 929 return base.from_param(obj) 930 return type(base.__name__, (base,), {'from_param': from_param}) 931 932 933# A tag to label the derived Scanner class 934class SinglePointScanner: pass 935class GradScanner: 936 def __init__(self, g): 937 self.__dict__.update(g.__dict__) 938 self.base = g.base.as_scanner() 939 @property 940 def e_tot(self): 941 return self.base.e_tot 942 @e_tot.setter 943 def e_tot(self, x): 944 self.base.e_tot = x 945 946 @property 947 def converged(self): 948 # Some base methods like MP2 does not have the attribute converged 949 conv = getattr(self.base, 'converged', True) 950 return conv 951 952class temporary_env(object): 953 '''Within the context of this macro, the attributes of the object are 954 temporarily updated. When the program goes out of the scope of the 955 context, the original value of each attribute will be restored. 956 957 Examples: 958 959 >>> with temporary_env(lib.param, LIGHT_SPEED=15., BOHR=2.5): 960 ... print(lib.param.LIGHT_SPEED, lib.param.BOHR) 961 15. 2.5 962 >>> print(lib.param.LIGHT_SPEED, lib.param.BOHR) 963 137.03599967994 0.52917721092 964 ''' 965 def __init__(self, obj, **kwargs): 966 self.obj = obj 967 968 # Should I skip the keys which are not presented in obj? 969 #keys = [key for key in kwargs.keys() if hasattr(obj, key)] 970 #self.env_bak = [(key, getattr(obj, key, 'TO_DEL')) for key in keys] 971 #self.env_new = [(key, kwargs[key]) for key in keys] 972 973 self.env_bak = [(key, getattr(obj, key, 'TO_DEL')) for key in kwargs] 974 self.env_new = [(key, kwargs[key]) for key in kwargs] 975 976 def __enter__(self): 977 for k, v in self.env_new: 978 setattr(self.obj, k, v) 979 return self 980 981 def __exit__(self, type, value, traceback): 982 for k, v in self.env_bak: 983 if isinstance(v, str) and v == 'TO_DEL': 984 delattr(self.obj, k) 985 else: 986 setattr(self.obj, k, v) 987 988class light_speed(temporary_env): 989 '''Within the context of this macro, the environment varialbe LIGHT_SPEED 990 can be customized. 991 992 Examples: 993 994 >>> with light_speed(15.): 995 ... print(lib.param.LIGHT_SPEED) 996 15. 997 >>> print(lib.param.LIGHT_SPEED) 998 137.03599967994 999 ''' 1000 def __init__(self, c): 1001 temporary_env.__init__(self, param, LIGHT_SPEED=c) 1002 self.c = c 1003 def __enter__(self): 1004 temporary_env.__enter__(self) 1005 return self.c 1006 1007def repo_info(repo_path): 1008 ''' 1009 Repo location, version, git branch and commit ID 1010 ''' 1011 1012 def git_version(orig_head, head, branch): 1013 git_version = [] 1014 if orig_head: 1015 git_version.append('GIT ORIG_HEAD %s' % orig_head) 1016 if branch: 1017 git_version.append('GIT HEAD (branch %s) %s' % (branch, head)) 1018 elif head: 1019 git_version.append('GIT HEAD %s' % head) 1020 return '\n'.join(git_version) 1021 1022 repo_path = os.path.abspath(repo_path) 1023 1024 if os.path.isdir(os.path.join(repo_path, '.git')): 1025 git_str = git_version(*git_info(repo_path)) 1026 1027 elif os.path.isdir(os.path.abspath(os.path.join(repo_path, '..', '.git'))): 1028 repo_path = os.path.abspath(os.path.join(repo_path, '..')) 1029 git_str = git_version(*git_info(repo_path)) 1030 1031 else: 1032 git_str = None 1033 1034 # TODO: Add info of BLAS, libcint, libxc, libxcfun, tblis if applicable 1035 1036 info = {'path': repo_path} 1037 if git_str: 1038 info['git'] = git_str 1039 return info 1040 1041def git_info(repo_path): 1042 orig_head = None 1043 head = None 1044 branch = None 1045 try: 1046 with open(os.path.join(repo_path, '.git', 'ORIG_HEAD'), 'r') as f: 1047 orig_head = f.read().strip() 1048 except IOError: 1049 pass 1050 1051 try: 1052 head = os.path.join(repo_path, '.git', 'HEAD') 1053 with open(head, 'r') as f: 1054 head = f.read().splitlines()[0].strip() 1055 1056 if head.startswith('ref:'): 1057 branch = os.path.basename(head) 1058 with open(os.path.join(repo_path, '.git', head.split(' ')[1]), 'r') as f: 1059 head = f.read().strip() 1060 except IOError: 1061 pass 1062 return orig_head, head, branch 1063 1064 1065def isinteger(obj): 1066 ''' 1067 Check if an object is an integer. 1068 ''' 1069 # A bool is also an int in python, but we don't want that. 1070 # On the other hand, numpy.bool_ is probably not a numpy.integer, but just to be sure... 1071 if isinstance(obj, (bool, numpy.bool_)): 1072 return False 1073 # These are actual ints we expect to encounter. 1074 else: 1075 return isinstance(obj, (int, numpy.integer)) 1076 1077 1078def issequence(obj): 1079 ''' 1080 Determine if the object provided is a sequence. 1081 ''' 1082 # These are the types of sequences that we permit. 1083 # numpy.ndarray is not a subclass of collections.abc.Sequence as of version 1.19. 1084 sequence_types = (collections.abc.Sequence, numpy.ndarray) 1085 return isinstance(obj, sequence_types) 1086 1087 1088def isintsequence(obj): 1089 ''' 1090 Determine if the object provided is a sequence of integers. 1091 ''' 1092 if not issequence(obj): 1093 return False 1094 elif isinstance(obj, numpy.ndarray): 1095 return issubclass(obj.dtype.type, numpy.integer) 1096 else: 1097 are_ints = True 1098 for i in obj: 1099 are_ints = are_ints and isinteger(i) 1100 return are_ints 1101 1102 1103if __name__ == '__main__': 1104 for i,j in prange_tril(0, 90, 300): 1105 print(i, j, j*(j+1)//2-i*(i+1)//2) 1106