1"""@package geometric.nifty Nifty functions, originally intended to be imported by any module within ForceBalance.
2This file was copied over from ForceBalance to geomeTRIC in order to lighten the dependencies of the latter.
3
4Table of Contents:
5- I/O formatting
6- Math: Variable manipulation, linear algebra, least squares polynomial fitting
7- Pickle: Expand Python's own pickle to accommodate writing XML etree objects
8- Commands for submitting things to the Work Queue
9- Various file and process management functions
10- Development stuff (not commonly used)
11
12Named after the mighty Sniffy Handy Nifty (King Sniffy)
13
14@author Lee-Ping Wang
15@date 2018-03-10
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import filecmp
22import itertools
23import os
24import re
25import shutil
26import sys
27from select import select
28
29import numpy as np
30from numpy.linalg import multi_dot
31
32# For Python 3 compatibility
33try:
34    from itertools import zip_longest as zip_longest
35except ImportError:
36    from itertools import izip_longest as zip_longest
37import threading
38from pickle import Pickler, Unpickler
39import tarfile
40import time
41import subprocess
42import math
43import six # For six.string_types
44from subprocess import PIPE
45from collections import OrderedDict, defaultdict
46
47#================================#
48#       Set up the logger        #
49#================================#
50if "forcebalance" in __name__:
51    # If this module is part of ForceBalance, use the package level logger
52    from .output import *
53    package="ForceBalance"
54else:
55    from logging import *
56    # Define two handlers that don't print newline characters at the end of each line
57    class RawStreamHandler(StreamHandler):
58        """
59        Exactly like StreamHandler, except no newline character is printed at the end of each message.
60        This is done in order to ensure functions in molecule.py and nifty.py work consistently
61        across multiple packages.
62        """
63        def __init__(self, stream = sys.stdout):
64            super(RawStreamHandler, self).__init__(stream)
65
66        def emit(self, record):
67            message = record.getMessage()
68            self.stream.write(message)
69            self.flush()
70
71    class RawFileHandler(FileHandler):
72        """
73        Exactly like FileHandler, except no newline character is printed at the end of each message.
74        This is done in order to ensure functions in molecule.py and nifty.py work consistently
75        across multiple packages.
76        """
77        def __init__(self, *args, **kwargs):
78            super(RawFileHandler, self).__init__(*args, **kwargs)
79
80        def emit(self, record):
81            if self.stream is None:
82                self.stream = self._open()
83            message = record.getMessage()
84            self.stream.write(message)
85            self.flush()
86
87    if "geometric" in __name__:
88        # This ensures logging behavior is consistent with the rest of geomeTRIC
89        logger = getLogger(__name__)
90        logger.setLevel(INFO)
91        package="geomeTRIC"
92    else:
93        logger = getLogger("NiftyLogger")
94        logger.setLevel(INFO)
95        handler = RawStreamHandler()
96        logger.addHandler(handler)
97        if __name__ == "__main__":
98            package = "LPW-nifty.py"
99        else:
100            package = __name__.split('.')[0]
101
102try:
103    import bz2
104    HaveBZ2 = True
105except ImportError:
106    logger.warning("bz2 module import failed (used in compressing or decompressing pickle files)\n")
107    HaveBZ2 = False
108
109try:
110    import gzip
111    HaveGZ = True
112except ImportError:
113    logger.warning("gzip module import failed (used in compressing or decompressing pickle files)\n")
114    HaveGZ = False
115
116## Boltzmann constant
117kb = 0.0083144100163
118## Q-Chem to GMX unit conversion for energy
119eqcgmx = 2625.5002
120## Q-Chem to GMX unit conversion for force
121fqcgmx = -49621.9
122# Conversion factors
123bohr2ang = 0.529177210
124ang2bohr = 1.0 / bohr2ang
125au2kcal = 627.5096080306
126kcal2au = 1.0 / au2kcal
127au2kj = 2625.5002
128kj2au = 1.0 / au2kj
129grad_au2gmx = 49614.75960959161
130grad_gmx2au = 1.0 / grad_au2gmx
131# Gradient units
132au2evang = 51.42209166566339
133evang2au = 1.0 / au2evang
134
135
136#=========================#
137#     I/O formatting      #
138#=========================#
139# These functions may be useful someday but I have not tested them
140# def bzip2(src):
141#     dest = src+'.bz2'
142#     if not os.path.exists(src):
143#         logger.error('File to be compressed does not exist')
144#         raise RuntimeError
145#     if os.path.exists(dest):
146#         logger.error('Archive to be created already exists')
147#         raise RuntimeError
148#     with open(src, 'rb') as input:
149#         with bz2.BZ2File(dest, 'wb', compresslevel=9) as output:
150#             copyfileobj(input, output)
151#     os.remove(input)
152
153# def bunzip2(src):
154#     dest = re.sub('\.bz2$', '', src)
155#     if not os.path.exists(src):
156#         logger.error('File to be decompressed does not exist')
157#         raise RuntimeError
158#     if os.path.exists(dest):
159#         logger.error('Target path for decompression already exists')
160#         raise RuntimeError
161#     with bz2.BZ2File(src, 'rb', compresslevel=9) as input:
162#         with open(dest, 'wb') as output:
163#             copyfileobj(input, output)
164#     os.remove(input)
165
166def pvec1d(vec1d, precision=1, format="e", loglevel=INFO):
167    """Printout of a 1-D vector.
168
169    @param[in] vec1d a 1-D vector
170    """
171    v2a = np.array(vec1d)
172    for i in range(v2a.shape[0]):
173        logger.log(loglevel, "%% .%i%s " % (precision, format) % v2a[i])
174    logger.log(loglevel, '\n')
175
176def astr(vec1d, precision=4):
177    """ Write an array to a string so we can use it to key a dictionary. """
178    return ' '.join([("%% .%ie " % precision % i) for i in vec1d])
179
180def pmat2d(mat2d, precision=1, format="e", loglevel=INFO):
181    """Printout of a 2-D array.
182
183    @param[in] mat2d a 2-D array
184    """
185    m2a = np.array(mat2d)
186    for i in range(m2a.shape[0]):
187        for j in range(m2a.shape[1]):
188            logger.log(loglevel, "%% .%i%s " % (precision, format) % m2a[i][j])
189        logger.log(loglevel, '\n')
190
191def grouper(iterable, n):
192    """Collect data into fixed-length chunks or blocks"""
193    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx
194    args = [iter(iterable)] * n
195    lzip = [[j for j in i if j is not None] for i in list(zip_longest(*args))]
196    return lzip
197
198def encode(l):
199    return [[len(list(group)),name] for name, group in itertools.groupby(l)]
200
201def segments(e):
202    # Takes encoded input.
203    begins = np.array([sum([k[0] for k in e][:j]) for j,i in enumerate(e) if i[1] == 1])
204    lens = np.array([i[0] for i in e if i[1] == 1])
205    return [(i, i+j) for i, j in zip(begins, lens)]
206
207def commadash(l):
208    # Formats a list like [27, 28, 29, 30, 31, 88, 89, 90, 91, 100, 136, 137, 138, 139]
209    # into '27-31,88-91,100,136-139
210    L = sorted(l)
211    if len(L) == 0:
212        return "(empty)"
213    L.append(L[-1]+1)
214    LL = [i in L for i in range(L[-1])]
215    return ','.join('%i-%i' % (i[0]+1,i[1]) if (i[1]-1 > i[0]) else '%i' % (i[0]+1) for i in segments(encode(LL)))
216
217def uncommadash(s):
218    # Takes a string like '27-31,88-91,100,136-139'
219    # and turns it into a list like [26, 27, 28, 29, 30, 87, 88, 89, 90, 99, 135, 136, 137, 138]
220    L = []
221    try:
222        for w in s.split(','):
223            ws = w.split('-')
224            a = int(ws[0])-1
225            if len(ws) == 1:
226                b = int(ws[0])
227            elif len(ws) == 2:
228                b = int(ws[1])
229            else:
230                logger.warning("Dash-separated list cannot exceed length 2\n")
231                raise
232            if a < 0 or b <= 0 or b <= a:
233                if a < 0 or b <= 0:
234                    logger.warning("Items in list cannot be zero or negative: %d %d\n" % (a, b))
235                else:
236                    logger.warning("Second number cannot be smaller than first: %d %d\n" % (a, b))
237                raise
238            newL = range(a,b)
239            if any([i in L for i in newL]):
240                logger.warning("Duplicate entries found in list\n")
241                raise
242            L += newL
243        if sorted(L) != L:
244            logger.warning("List is out of order\n")
245            raise
246    except:
247        logger.error('Invalid string for converting to list of numbers: %s\n' % s)
248        raise RuntimeError
249    return L
250
251def natural_sort(l):
252    """ Return a natural sorted list. """
253    # Convert a character to a digit or a lowercase character
254    convert = lambda text: int(text) if text.isdigit() else text.lower()
255    # Split string into "integer" and "noninteger" fields and convert each one
256    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
257    # Sort strings using these keys in descending order of importance, I guess.
258    return sorted(l, key = alphanum_key)
259
260def printcool(text,sym="#",bold=False,color=2,ansi=None,bottom='-',minwidth=50,center=True,sym2="="):
261    """Cool-looking printout for slick formatting of output.
262
263    @param[in] text The string that the printout is based upon.  This function
264    will print out the string, ANSI-colored and enclosed in the symbol
265    for example:\n
266    <tt> ################# </tt>\n
267    <tt> ### I am cool ### </tt>\n
268    <tt> ################# </tt>
269    @param[in] sym The surrounding symbol\n
270    @param[in] bold Whether to use bold print
271
272    @param[in] color The ANSI color:\n
273    1 red\n
274    2 green\n
275    3 yellow\n
276    4 blue\n
277    5 magenta\n
278    6 cyan\n
279    7 white
280
281    @param[in] bottom The symbol for the bottom bar
282
283    @param[in] minwidth The minimum width for the box, if the text is very short
284    then we insert the appropriate number of padding spaces
285
286    @return bar The bottom bar is returned for the user to print later, e.g. to mark off a 'section'
287    """
288    def newlen(l):
289        return len(re.sub(r"\x1b\[[0-9;]*m","",l))
290    text = text.split('\n')
291    width = max(minwidth,max([newlen(line) for line in text]))
292    bar = ''.join([sym2 for i in range(width + 6)])
293    bar = sym + bar + sym
294    #bar = ''.join([sym for i in range(width + 8)])
295    logger.info('\r'+bar + '\n')
296    for ln, line in enumerate(text):
297        if type(center) is list: c1 = center[ln]
298        else: c1 = center
299        if c1:
300            padleft = ' ' * (int((width - newlen(line))/2))
301        else:
302            padleft = ''
303        padright = ' '* (width - newlen(line) - len(padleft))
304        if ansi is not None:
305            ansi = str(ansi)
306            logger.info("%s| \x1b[%sm%s " % (sym, ansi, padleft)+line+" %s\x1b[0m |%s\n" % (padright, sym))
307        elif color is not None:
308            if color == 0 and bold:
309                logger.info("%s| \x1b[1m%s " % (sym, padleft) + line + " %s\x1b[0m |%s\n" % (padright, sym))
310            elif color == 0:
311                logger.info("%s| %s " % (sym, padleft)+line+" %s |%s\n" % (padright, sym))
312            else:
313                logger.info("%s| \x1b[%s9%im%s " % (sym, bold and "1;" or "", color, padleft)+line+" %s\x1b[0m |%s\n" % (padright, sym))
314            # if color == 3 or color == 7:
315            #     print "%s\x1b[40m\x1b[%s9%im%s" % (''.join([sym for i in range(3)]), bold and "1;" or "", color, padleft),line,"%s\x1b[0m%s" % (padright, ''.join([sym for i in range(3)]))
316            # else:
317            #     print "%s\x1b[%s9%im%s" % (''.join([sym for i in range(3)]), bold and "1;" or "", color, padleft),line,"%s\x1b[0m%s" % (padright, ''.join([sym for i in range(3)]))
318        else:
319            warn_press_key("Inappropriate use of printcool")
320    logger.info(bar + '\n')
321    botbar = ''.join([bottom for i in range(width + 8)])
322    return botbar + '\n'
323
324def printcool_dictionary(Dict,title="Dictionary Keys : Values",bold=False,color=2,keywidth=25,topwidth=50,center=True,leftpad=0):
325    """See documentation for printcool; this is a nice way to print out keys/values in a dictionary.
326
327    The keys in the dictionary are sorted before printing out.
328
329    @param[in] dict The dictionary to be printed
330    @param[in] title The title of the printout
331    """
332    if Dict is None: return
333    bar = printcool(title,bold=bold,color=color,minwidth=topwidth,center=center)
334    def magic_string(str):
335        # This cryptic command returns a string with the number of characters specified as a variable. :P
336        # Useful for printing nice-looking dictionaries, i guess.
337        # print "\'%%-%is\' %% '%s'" % (keywidth,str.replace("'","\\'").replace('"','\\"'))
338        return eval("\'%%-%is\' %% '%s'" % (keywidth,str.replace("'","\\'").replace('"','\\"')))
339    if isinstance(Dict, OrderedDict):
340        logger.info('\n'.join([' '*leftpad + "%s %s " % (magic_string(str(key)),str(Dict[key])) for key in Dict if Dict[key] is not None]))
341    else:
342        logger.info('\n'.join([' '*leftpad + "%s %s " % (magic_string(str(key)),str(Dict[key])) for key in sorted([i for i in Dict]) if Dict[key] is not None]))
343    logger.info("\n%s" % bar)
344
345#===============================#
346#| Math: Variable manipulation |#
347#===============================#
348def isint(word):
349    """ONLY matches integers! If you have a decimal point? None shall pass!
350
351    @param[in] word String (for instance, '123', '153.0', '2.', '-354')
352    @return answer Boolean which specifies whether the string is an integer (only +/- sign followed by digits)
353
354    """
355    try:
356        word = str(word)
357    except:
358        return False
359    return re.match('^[-+]?[0-9]+$', word)
360
361def isfloat(word):
362    """Matches ANY number; it can be a decimal, scientific notation, what have you
363    CAUTION - this will also match an integer.
364
365    @param[in] word String (for instance, '123', '153.0', '2.', '-354')
366    @return answer Boolean which specifies whether the string is any number
367
368    """
369    try: word = str(word)
370    except: return False
371    if len(word) == 0: return False
372    return re.match(r'^[-+]?[0-9]*\.?[0-9]*([eEdD][-+]?[0-9]+)?$',word)
373
374def isdecimal(word):
375    """Matches things with a decimal only; see isint and isfloat.
376
377    @param[in] word String (for instance, '123', '153.0', '2.', '-354')
378    @return answer Boolean which specifies whether the string is a number with a decimal point
379
380    """
381    try: word = str(word)
382    except: return False
383    return isfloat(word) and not isint(word)
384
385def floatornan(word):
386    """Returns a big number if we encounter NaN.
387
388    @param[in] word The string to be converted
389    @return answer The string converted to a float; if not a float, return 1e10
390    @todo I could use suggestions for making this better.
391    """
392    big = 1e10
393    if isfloat(word):
394        return float(word)
395    else:
396        logger.info("Setting %s to % .1e\n" % big)
397        return big
398
399def col(vec):
400    """
401    Given any list, array, or matrix, return a 1-column 2D array.
402
403    Input:
404    vec  = The input vector that is to be made into a column
405
406    Output:
407    A 1-column 2D array
408    """
409    return np.array(vec).reshape(-1, 1)
410
411def row(vec):
412    """Given any list, array, or matrix, return a 1-row 2D array.
413
414    @param[in] vec The input vector that is to be made into a row
415
416    @return answer A 1-row 2D array
417    """
418    return np.array(vec).reshape(1, -1)
419
420def flat(vec):
421    """Given any list, array, or matrix, return a single-index array.
422
423    @param[in] vec The data to be flattened
424    @return answer The flattened data
425    """
426    return np.array(vec).reshape(-1)
427
428def est124(val):
429    """Given any positive floating point value, return a value [124]e+xx
430    that is closest to it in the log space.
431    """
432    log = np.log10(val)
433    logint = math.floor(log)
434    logfrac = log - logint
435    log1 = 0.0
436    log2 = 0.3010299956639812
437    log4 = 0.6020599913279624
438    log10 = 1.0
439    if logfrac < 0.5*(log1+log2):
440        fac = 1.0
441    elif logfrac < 0.5*(log2+log4):
442        fac = 2.0
443    elif logfrac < 0.5*(log4+log10):
444        fac = 4.0
445    else:
446        fac = 10.0
447    return fac*10**logint
448
449def est1234568(val):
450    """Given any positive floating point value, return a value [1234568]e+xx
451    that is closest to it in the log space.  Just because I don't like seven
452    and nine.  Call me a numberist?
453    """
454    log = np.log10(val)
455    logint = math.floor(log)
456    logfrac = log - logint
457    log1 = 0.0
458    log2 = 0.3010299956639812
459    log3 = np.log10(3)
460    log4 = 0.6020599913279624
461    log5 = np.log10(5)
462    log6 = np.log10(6)
463    log8 = np.log10(8)
464    log10 = 1.0
465    if logfrac < 0.5*(log1+log2):
466        fac = 1.0
467    elif logfrac < 0.5*(log2+log3):
468        fac = 2.0
469    elif logfrac < 0.5*(log3+log4):
470        fac = 3.0
471    elif logfrac < 0.5*(log4+log5):
472        fac = 4.0
473    elif logfrac < 0.5*(log5+log6):
474        fac = 5.0
475    elif logfrac < 0.5*(log6+log8):
476        fac = 6.0
477    elif logfrac < 0.5*(log8+log10):
478        fac = 8.0
479    else:
480        fac = 10.0
481    return fac*10**logint
482
483def monotonic(arr, start, end):
484    # Make sure an array is monotonically decreasing from the start to the end.
485    a0 = arr[start]
486    i0 = start
487    if end > start:
488        i = start+1
489        while i < end:
490            if arr[i] < a0:
491                arr[i0:i+1] = np.linspace(a0, arr[i], i-i0+1)
492                a0 = arr[i]
493                i0 = i
494            i += 1
495    if end < start:
496        i = start-1
497        while i >= end:
498            if arr[i] < a0:
499                arr[i:i0+1] = np.linspace(arr[i], a0, i0-i+1)
500                a0 = arr[i]
501                i0 = i
502            i -= 1
503
504def monotonic_decreasing(arr, start=None, end=None, verbose=False):
505    """
506    Return the indices of an array corresponding to strictly monotonic
507    decreasing behavior.
508
509    Parameters
510    ----------
511    arr : numpy.ndarray
512        Input array
513    start : int
514        Starting index (first element if None)
515    end : int
516        Ending index (last element if None)
517
518    Returns
519    -------
520    indices : numpy.ndarray
521        Selected indices
522    """
523    if start is None:
524        start = 0
525    if end is None:
526        end = len(arr) - 1
527    a0 = arr[start]
528    idx = [start]
529    if verbose: logger.info("Starting @ %i : %.6f\n" % (start, arr[start]))
530    if end > start:
531        i = start+1
532        while i < end:
533            if arr[i] < a0:
534                a0 = arr[i]
535                idx.append(i)
536                if verbose: logger.info("Including  %i : %.6f\n" % (i, arr[i]))
537            else:
538                if verbose: logger.info("Excluding  %i : %.6f\n" % (i, arr[i]))
539            i += 1
540    if end < start:
541        i = start-1
542        while i >= end:
543            if arr[i] < a0:
544                a0 = arr[i]
545                idx.append(i)
546                if verbose: logger.info("Including  %i : %.6f\n" % (i, arr[i]))
547            else:
548                if verbose: logger.info("Excluding  %i : %.6f\n" % (i, arr[i]))
549            i -= 1
550    return np.array(idx)
551
552#====================================#
553#| Math: Vectors and linear algebra |#
554#====================================#
555def orthogonalize(vec1, vec2):
556    """Given two vectors vec1 and vec2, project out the component of vec1
557    that is along the vec2-direction.
558
559    @param[in] vec1 The projectee (i.e. output is some modified version of vec1)
560    @param[in] vec2 The projector (component subtracted out from vec1 is parallel to this)
561    @return answer A copy of vec1 but with the vec2-component projected out.
562    """
563    v2u = vec2/np.linalg.norm(vec2)
564    return vec1 - v2u*np.dot(vec1, v2u)
565
566def invert_svd(X,thresh=1e-12):
567
568    """
569
570    Invert a matrix using singular value decomposition.
571    @param[in] X The 2-D NumPy array containing the matrix to be inverted
572    @param[in] thresh The SVD threshold; eigenvalues below this are not inverted but set to zero
573    @return Xt The 2-D NumPy array containing the inverted matrix
574
575    """
576
577    u,s,vh = np.linalg.svd(X, full_matrices=0)
578    uh     = np.transpose(u)
579    v      = np.transpose(vh)
580    si     = s.copy()
581    for i in range(s.shape[0]):
582        if abs(s[i]) > thresh:
583            si[i] = 1./s[i]
584        else:
585            si[i] = 0.0
586    si     = np.diag(si)
587    Xt     = multi_dot([v, si, uh])
588    return Xt
589
590#==============================#
591#|    Linear least squares    |#
592#==============================#
593def get_least_squares(x, y, w = None, thresh=1e-12):
594    """
595    @code
596     __                  __
597    |                      |
598    | 1 (x0) (x0)^2 (x0)^3 |
599    | 1 (x1) (x1)^2 (x1)^3 |
600    | 1 (x2) (x2)^2 (x2)^3 |
601    | 1 (x3) (x3)^2 (x3)^3 |
602    | 1 (x4) (x4)^2 (x4)^3 |
603    |__                  __|
604
605    @endcode
606
607    @param[in] X (2-D array) An array of X-values (see above)
608    @param[in] Y (array) An array of Y-values (only used in getting the least squares coefficients)
609    @param[in] w (array) An array of weights, hopefully normalized to one.
610    @param[out] Beta The least-squares coefficients
611    @param[out] Hat The hat matrix that takes linear combinations of data y-values to give fitted y-values (weights)
612    @param[out] yfit The fitted y-values
613    @param[out] MPPI The Moore-Penrose pseudoinverse (multiply by Y to get least-squares coefficients, multiply by dY/dk to get derivatives of least-squares coefficients)
614    """
615    # X is a 'tall' matrix.
616    X = np.array(x)
617    if len(X.shape) == 1:
618        X = X[:,np.newaxis]
619    Y = col(y)
620    n_x = X.shape[0]
621    n_fit = X.shape[1]
622    if n_fit > n_x:
623        logger.warning("Argh? It seems like this problem is underdetermined!\n")
624    # Build the weight matrix.
625    if w is not None:
626        if len(w) != n_x:
627            warn_press_key("The weight array length (%i) must be the same as the number of 'X' data points (%i)!" % len(w), n_x)
628        w /= np.mean(w)
629        WH = np.diag(w**0.5)
630    else:
631        WH = np.eye(n_x)
632    # Make the Moore-Penrose Pseudoinverse.
633    # if n_fit == n_x:
634    #     MPPI = np.linalg.inv(WH*X)
635    # else:
636    # This resembles the formula (X'WX)^-1 X' W^1/2
637    MPPI = np.linalg.pinv(np.dot(WH, X))
638    Beta = multi_dot([MPPI, WH, Y])
639    Hat = multi_dot([WH, X, MPPI])
640    yfit = flat(np.dot(Hat, Y))
641    # Return three things: the least-squares coefficients, the hat matrix (turns y into yfit), and yfit
642    # We could get these all from MPPI, but I might get confused later on, so might as well do it here :P
643    return np.array(Beta).flatten(), np.array(Hat), np.array(yfit).flatten(), np.array(MPPI)
644
645#===========================================#
646#| John's statisticalInefficiency function |#
647#===========================================#
648def statisticalInefficiency(A_n, B_n=None, fast=False, mintime=3, warn=True):
649
650    """
651    Compute the (cross) statistical inefficiency of (two) timeseries.
652
653    Notes
654      The same timeseries can be used for both A_n and B_n to get the autocorrelation statistical inefficiency.
655      The fast method described in Ref [1] is used to compute g.
656
657    References
658      [1] J. D. Chodera, W. C. Swope, J. W. Pitera, C. Seok, and K. A. Dill. Use of the weighted
659      histogram analysis method for the analysis of simulated and parallel tempering simulations.
660      JCTC 3(1):26-41, 2007.
661
662    Examples
663
664    Compute statistical inefficiency of timeseries data with known correlation time.
665
666    >>> import timeseries
667    >>> A_n = timeseries.generateCorrelatedTimeseries(N=100000, tau=5.0)
668    >>> g = statisticalInefficiency(A_n, fast=True)
669
670    @param[in] A_n (required, numpy array) - A_n[n] is nth value of
671    timeseries A.  Length is deduced from vector.
672
673    @param[in] B_n (optional, numpy array) - B_n[n] is nth value of
674    timeseries B.  Length is deduced from vector.  If supplied, the
675    cross-correlation of timeseries A and B will be estimated instead of
676    the autocorrelation of timeseries A.
677
678    @param[in] fast (optional, boolean) - if True, will use faster (but
679    less accurate) method to estimate correlation time, described in
680    Ref. [1] (default: False)
681
682    @param[in] mintime (optional, int) - minimum amount of correlation
683    function to compute (default: 3) The algorithm terminates after
684    computing the correlation time out to mintime when the correlation
685    function furst goes negative.  Note that this time may need to be
686    increased if there is a strong initial negative peak in the
687    correlation function.
688
689    @return g The estimated statistical inefficiency (equal to 1 + 2
690    tau, where tau is the correlation time).  We enforce g >= 1.0.
691
692    """
693    # Create numpy copies of input arguments.
694    A_n = np.array(A_n)
695    if B_n is not None:
696        B_n = np.array(B_n)
697    else:
698        B_n = np.array(A_n)
699    # Get the length of the timeseries.
700    N = A_n.shape[0]
701    # Be sure A_n and B_n have the same dimensions.
702    if A_n.shape != B_n.shape:
703        logger.error('A_n and B_n must have same dimensions.\n')
704        raise ParameterError
705    # Initialize statistical inefficiency estimate with uncorrelated value.
706    g = 1.0
707    # Compute mean of each timeseries.
708    mu_A = A_n.mean()
709    mu_B = B_n.mean()
710    # Make temporary copies of fluctuation from mean.
711    dA_n = A_n.astype(np.float64) - mu_A
712    dB_n = B_n.astype(np.float64) - mu_B
713    # Compute estimator of covariance of (A,B) using estimator that will ensure C(0) = 1.
714    sigma2_AB = (dA_n * dB_n).mean() # standard estimator to ensure C(0) = 1
715    # Trap the case where this covariance is zero, and we cannot proceed.
716    if sigma2_AB == 0:
717        if warn:
718            logger.warning('Sample covariance sigma_AB^2 = 0 -- cannot compute statistical inefficiency\n')
719        return 1.0
720    # Accumulate the integrated correlation time by computing the normalized correlation time at
721    # increasing values of t.  Stop accumulating if the correlation function goes negative, since
722    # this is unlikely to occur unless the correlation function has decayed to the point where it
723    # is dominated by noise and indistinguishable from zero.
724    t = 1
725    increment = 1
726    while t < N-1:
727        # compute normalized fluctuation correlation function at time t
728        C = sum( dA_n[0:(N-t)]*dB_n[t:N] + dB_n[0:(N-t)]*dA_n[t:N] ) / (2.0 * float(N-t) * sigma2_AB)
729        # Terminate if the correlation function has crossed zero and we've computed the correlation
730        # function at least out to 'mintime'.
731        if (C <= 0.0) and (t > mintime):
732            break
733        # Accumulate contribution to the statistical inefficiency.
734        g += 2.0 * C * (1.0 - float(t)/float(N)) * float(increment)
735        # Increment t and the amount by which we increment t.
736        t += increment
737        # Increase the interval if "fast mode" is on.
738        if fast: increment += 1
739    # g must be at least unity
740    if g < 1.0: g = 1.0
741    # Return the computed statistical inefficiency.
742    return g
743
744def mean_stderr(ts):
745    """Return mean and standard deviation of a time series ts."""
746    return np.mean(ts), \
747      np.std(ts)*np.sqrt(statisticalInefficiency(ts, warn=False)/len(ts))
748
749# Slices a 2D array of data by column.  The new array is fed into the statisticalInefficiency function.
750def multiD_statisticalInefficiency(A_n, B_n=None, fast=False, mintime=3, warn=True):
751    n_row = A_n.shape[0]
752    n_col = A_n.shape[-1]
753    multiD_sI = np.zeros((n_row, n_col))
754    for col in range(n_col):
755        if B_n is None:
756            multiD_sI[:,col] = statisticalInefficiency(A_n[:,col], B_n, fast, mintime, warn)
757        else:
758            multiD_sI[:,col] = statisticalInefficiency(A_n[:,col], B_n[:,col], fast, mintime, warn)
759    return multiD_sI
760
761#========================================#
762#|      Loading compressed pickles      |#
763#========================================#
764
765def lp_dump(obj, fnm, protocol=0):
766    """ Write an object to a zipped pickle file specified by the path. """
767    # Safeguard against overwriting files?  Nah.
768    # if os.path.exists(fnm):
769    #     logger.error("lp_dump cannot write to an existing path")
770    #     raise IOError
771    if os.path.islink(fnm):
772        logger.warning("Trying to write to a symbolic link %s, removing it first\n" % fnm)
773        os.unlink(fnm)
774    if HaveGZ:
775        f = gzip.GzipFile(fnm, 'wb')
776    elif HaveBZ2:
777        f = bz2.BZ2File(fnm, 'wb')
778    else:
779        f = open(fnm, 'wb')
780    Pickler(f, protocol).dump(obj)
781    f.close()
782
783def lp_load(fnm):
784    """ Read an object from a bzipped file specified by the path. """
785    if not os.path.exists(fnm):
786        logger.error("lp_load cannot read from a path that doesn't exist (%s)" % fnm)
787        raise IOError
788
789    def load_uncompress():
790        logger.warning("Compressed file loader failed, attempting to read as uncompressed file\n")
791        f = open(fnm, 'rb')
792        try:
793            answer = Unpickler(f).load()
794        except UnicodeDecodeError:
795            answer = Unpickler(f, encoding='latin1').load()
796        f.close()
797        return answer
798
799    def load_bz2():
800        f = bz2.BZ2File(fnm, 'rb')
801        try:
802            answer = Unpickler(f).load()
803        except UnicodeDecodeError:
804            answer = Unpickler(f, encoding='latin1').load()
805        f.close()
806        return answer
807
808    def load_gz():
809        f = gzip.GzipFile(fnm, 'rb')
810        try:
811            answer = Unpickler(f).load()
812        except UnicodeDecodeError:
813            answer = Unpickler(f, encoding='latin1').load()
814        f.close()
815        return answer
816
817    if HaveGZ:
818        try:
819            answer = load_gz()
820        except:
821            if HaveBZ2:
822                try:
823                    answer = load_bz2()
824                except:
825                    answer = load_uncompress()
826            else:
827                answer = load_uncompress()
828    elif HaveBZ2:
829        try:
830            answer = load_bz2()
831        except:
832            answer = load_uncompress()
833    else:
834        answer = load_uncompress()
835    return answer
836
837#==============================#
838#|      Work Queue stuff      |#
839#==============================#
840try:
841    import work_queue
842except:
843    pass
844    #logger.warning("Work Queue library import fail (You can't queue up jobs using Work Queue)\n")
845
846# Global variable corresponding to the Work Queue object
847WORK_QUEUE = None
848
849# Global variable containing a mapping from target names to Work Queue task IDs
850WQIDS = defaultdict(list)
851
852def getWorkQueue():
853    global WORK_QUEUE
854    return WORK_QUEUE
855
856def getWQIds():
857    global WQIDS
858    return WQIDS
859
860def createWorkQueue(wq_port, debug=True, name=package):
861    global WORK_QUEUE
862    if debug:
863        work_queue.set_debug_flag('all')
864    WORK_QUEUE = work_queue.WorkQueue(port=wq_port, catalog=True, exclusive=False, shutdown=False)
865    WORK_QUEUE.specify_name(name)
866    #WORK_QUEUE.specify_keepalive_timeout(8640000)
867    WORK_QUEUE.specify_keepalive_interval(8640000)
868
869def destroyWorkQueue():
870    # Convenience function to destroy the Work Queue objects.
871    global WORK_QUEUE, WQIDS
872    WORK_QUEUE = None
873    WQIDS = defaultdict(list)
874
875def queue_up(wq, command, input_files, output_files, tag=None, tgt=None, verbose=True, print_time=60):
876    """
877    Submit a job to the Work Queue.
878
879    @param[in] wq (Work Queue Object)
880    @param[in] command (string) The command to run on the remote worker.
881    @param[in] input_files (list of files) A list of locations of the input files.
882    @param[in] output_files (list of files) A list of locations of the output files.
883    """
884    global WQIDS
885    task = work_queue.Task(command)
886    cwd = os.getcwd()
887    for f in input_files:
888        lf = os.path.join(cwd,f)
889        task.specify_input_file(lf,f,cache=False)
890    for f in output_files:
891        lf = os.path.join(cwd,f)
892        task.specify_output_file(lf,f,cache=False)
893    task.specify_algorithm(work_queue.WORK_QUEUE_SCHEDULE_FCFS)
894    if tag is None: tag = command
895    task.specify_tag(tag)
896    task.print_time = print_time
897    taskid = wq.submit(task)
898    if verbose:
899        logger.info("Submitting command '%s' to the Work Queue, %staskid %i\n" % (command, "tag %s, " % tag if tag != command else "", taskid))
900    if tgt is not None:
901        WQIDS[tgt.name].append(taskid)
902    else:
903        WQIDS["None"].append(taskid)
904
905def queue_up_src_dest(wq, command, input_files, output_files, tag=None, tgt=None, verbose=True, print_time=60):
906    """
907    Submit a job to the Work Queue.  This function is a bit fancier in that we can explicitly
908    specify where the input files come from, and where the output files go to.
909
910    @param[in] wq (Work Queue Object)
911    @param[in] command (string) The command to run on the remote worker.
912    @param[in] input_files (list of 2-tuples) A list of local and
913    remote locations of the input files.
914    @param[in] output_files (list of 2-tuples) A list of local and
915    remote locations of the output files.
916    """
917    global WQIDS
918    task = work_queue.Task(command)
919    for f in input_files:
920        # print f[0], f[1]
921        task.specify_input_file(f[0],f[1],cache=False)
922    for f in output_files:
923        # print f[0], f[1]
924        task.specify_output_file(f[0],f[1],cache=False)
925    task.specify_algorithm(work_queue.WORK_QUEUE_SCHEDULE_FCFS)
926    if tag is None: tag = command
927    task.specify_tag(tag)
928    task.print_time = print_time
929    taskid = wq.submit(task)
930    if verbose:
931        logger.info("Submitting command '%s' to the Work Queue, taskid %i\n" % (command, taskid))
932    if tgt is not None:
933        WQIDS[tgt.name].append(taskid)
934    else:
935        WQIDS["None"].append(taskid)
936
937def wq_wait1(wq, wait_time=10, wait_intvl=1, print_time=60, verbose=False):
938    """ This function waits ten seconds to see if a task in the Work Queue has finished. """
939    global WQIDS
940    if verbose: logger.info('---\n')
941    if wait_intvl >= wait_time:
942        wait_time = wait_intvl
943        numwaits = 1
944    else:
945        numwaits = int(wait_time/wait_intvl)
946    for sec in range(numwaits):
947        task = wq.wait(wait_intvl)
948        if task:
949            exectime = task.cmd_execution_time/1000000
950            if verbose:
951                logger.info('A job has finished!\n')
952                logger.info('Job name = ' + task.tag + 'command = ' + task.command + '\n')
953                logger.info("status = " + task.status + '\n')
954                logger.info("return_status = " + task.return_status)
955                logger.info("result = " + task.result)
956                logger.info("host = " + task.hostname + '\n')
957                logger.info("execution time = " + exectime)
958                logger.info("total_bytes_transferred = " + task.total_bytes_transferred + '\n')
959            if task.result != 0:
960                oldid = task.id
961                oldhost = task.hostname
962                tgtname = "None"
963                for tnm in WQIDS:
964                    if task.id in WQIDS[tnm]:
965                        tgtname = tnm
966                        WQIDS[tnm].remove(task.id)
967                taskid = wq.submit(task)
968                logger.warning("Task '%s' (task %i) failed on host %s (%i seconds), resubmitted: taskid %i\n" % (task.tag, oldid, oldhost, exectime, taskid))
969                WQIDS[tgtname].append(taskid)
970            else:
971                if hasattr(task, 'print_time'):
972                    print_time = task.print_time
973                if exectime > print_time: # Assume that we're only interested in printing jobs that last longer than a minute.
974                    logger.info("Task '%s' (task %i) finished successfully on host %s (%i seconds)\n" % (task.tag, task.id, task.hostname, exectime))
975                for tnm in WQIDS:
976                    if task.id in WQIDS[tnm]:
977                        WQIDS[tnm].remove(task.id)
978                del task
979
980        # LPW 2018-09-10 Updated to use stats fields from CCTools 6.2.10
981        # Please upgrade CCTools version if errors are encountered during runtime.
982        if verbose:
983            logger.info("Workers: %i init, %i idle, %i busy, %i total joined, %i total removed\n" \
984                % (wq.stats.workers_init, wq.stats.workers_idle, wq.stats.workers_busy, wq.stats.workers_joined, wq.stats.workers_removed))
985            logger.info("Tasks: %i running, %i waiting, %i dispatched, %i submitted, %i total complete\n" \
986                % (wq.stats.tasks_running, wq.stats.tasks_waiting, wq.stats.tasks_dispatched, wq.stats.tasks_submitted, wq.stats.tasks_done))
987            logger.info("Data: %i / %i kb sent/received\n" % (int(wq.stats.bytes_sent/1024), int(wq.stats.bytes_received/1024)))
988        else:
989            logger.info("\r%s : %i/%i workers busy; %i/%i jobs complete  \r" %\
990            (time.ctime(), wq.stats.workers_busy, wq.stats.workers_connected, wq.stats.tasks_done, wq.stats.tasks_submitted))
991            if time.time() - wq_wait1.t0 > 900:
992                wq_wait1.t0 = time.time()
993                logger.info('\n')
994wq_wait1.t0 = time.time()
995
996def wq_wait(wq, wait_time=10, wait_intvl=10, print_time=60, verbose=False):
997    """ This function waits until the work queue is completely empty. """
998    while not wq.empty():
999        wq_wait1(wq, wait_time=wait_time, wait_intvl=wait_intvl, print_time=print_time, verbose=verbose)
1000
1001#=====================================#
1002#| File and process management stuff |#
1003#=====================================#
1004def click():
1005    """ Stopwatch function for timing. """
1006    ans = time.time() - click.t0
1007    click.t0 = time.time()
1008    return ans
1009click.t0 = time.time()
1010
1011# Back up a file.
1012def bak(path, dest=None):
1013    oldf = path
1014    newf = None
1015    if os.path.exists(path):
1016        dnm, fnm = os.path.split(path)
1017        if dnm == '' : dnm = '.'
1018        base, ext = os.path.splitext(fnm)
1019        if dest is None:
1020            dest = dnm
1021        if not os.path.isdir(dest): os.makedirs(dest)
1022        i = 1
1023        while True:
1024            fnm = "%s_%i%s" % (base,i,ext)
1025            newf = os.path.join(dest, fnm)
1026            if not os.path.exists(newf): break
1027            i += 1
1028        logger.info("Backing up %s -> %s\n" % (oldf, newf))
1029        shutil.move(oldf,newf)
1030    return newf
1031
1032# Purpose: Given a file name and/or an extension, do one of the following:
1033# 1) If provided a file name, check the file, crash if not exist and err==True.  Return the file name.
1034# 2) If list is empty but extension is provided, check if one file exists that matches
1035# the extension.  If so, return the file name.
1036# 3) If list is still empty and err==True, then crash with an error.
1037def onefile(fnm=None, ext=None, err=False):
1038    if fnm is None and ext is None:
1039        if err:
1040            logger.error("Must provide either filename or extension to onefile()")
1041            raise RuntimeError
1042        else:
1043            return None
1044    if fnm is not None:
1045        if os.path.exists(fnm):
1046            if os.path.dirname(os.path.abspath(fnm)) != os.getcwd():
1047                fsrc = os.path.abspath(fnm)
1048                fdest = os.path.join(os.getcwd(), os.path.basename(fnm))
1049                #-----
1050                # If the file path doesn't correspond to the current directory, copy the file over
1051                # If the file exists in the current directory already and it's different, then crash.
1052                #-----
1053                if os.path.exists(fdest):
1054                    if not filecmp.cmp(fsrc, fdest):
1055                        logger.error("onefile() will not overwrite %s with %s\n" % (os.path.join(os.getcwd(), os.path.basename(fnm)),os.path.abspath(fnm)))
1056                        raise RuntimeError
1057                    else:
1058                        logger.info("\x1b[93monefile() says the files %s and %s are identical\x1b[0m\n" % (os.path.abspath(fnm), os.getcwd()))
1059                else:
1060                    logger.info("\x1b[93monefile() will copy %s to %s\x1b[0m\n" % (os.path.abspath(fnm), os.getcwd()))
1061                    shutil.copy2(fsrc, fdest)
1062            return os.path.basename(fnm)
1063        elif err==True or ext is None:
1064            logger.error("File specified by %s does not exist!" % fnm)
1065            raise RuntimeError
1066        elif ext is not None:
1067            warn_once("File specified by %s does not exist - will try to autodetect .%s extension" % (fnm, ext))
1068    answer = None
1069    cwd = os.getcwd()
1070    ls = [i for i in os.listdir(cwd) if i.endswith('.%s' % ext)]
1071    if len(ls) != 1:
1072        if err:
1073            logger.error("Cannot find a unique file with extension .%s in %s (%i found; %s)" % (ext, cwd, len(ls), ' '.join(ls)))
1074            raise RuntimeError
1075        else:
1076            warn_once("Cannot find a unique file with extension .%s in %s (%i found; %s)" %
1077                      (ext, cwd, len(ls), ' '.join(ls)), warnhash = "Found %i .%s files" % (len(ls), ext))
1078    else:
1079        answer = os.path.basename(ls[0])
1080        warn_once("Autodetected %s in %s" % (answer, cwd), warnhash = "Autodetected %s" % answer)
1081    return answer
1082
1083# Purpose: Given a file name / file list and/or an extension, do one of the following:
1084# 1) If provided a file list, check each file in the list
1085# and crash if any file does not exist.  Return the list.
1086# 2) If provided a file name, check the file and crash if the file
1087# does not exist.  Return a length-one list with the file name.
1088# 3) If list is empty but extension is provided, check for files that
1089# match the extension.  If so, append them to the list.
1090# 4) If list is still empty and err==True, then crash with an error.
1091def listfiles(fnms=None, ext=None, err=False, dnm=None):
1092    answer = []
1093    cwd = os.path.abspath(os.getcwd())
1094    if dnm is not None:
1095        os.chdir(dnm)
1096    if isinstance(fnms, list):
1097        for i in fnms:
1098            if not os.path.exists(i):
1099                logger.error('Specified %s but it does not exist' % i)
1100                raise RuntimeError
1101            answer.append(i)
1102    elif isinstance(fnms, six.string_types):
1103        if not os.path.exists(fnms):
1104            logger.error('Specified %s but it does not exist' % fnms)
1105            raise RuntimeError
1106        answer = [fnms]
1107    elif fnms is not None:
1108        logger.info(str(fnms))
1109        logger.error('First argument to listfiles must be a list, a string, or None')
1110        raise RuntimeError
1111    if answer == [] and ext is not None:
1112        answer = [os.path.basename(i) for i in os.listdir(os.getcwd()) if i.endswith('.%s' % ext)]
1113    if answer == [] and err:
1114        logger.error('listfiles function failed to come up with a file! (fnms = %s ext = %s)' % (str(fnms), str(ext)))
1115        raise RuntimeError
1116
1117    for ifnm, fnm in enumerate(answer):
1118        if os.path.dirname(os.path.abspath(fnm)) != os.getcwd():
1119            fsrc = os.path.abspath(fnm)
1120            fdest = os.path.join(os.getcwd(), os.path.basename(fnm))
1121            #-----
1122            # If the file path doesn't correspond to the current directory, copy the file over
1123            # If the file exists in the current directory already and it's different, then crash.
1124            #-----
1125            if os.path.exists(fdest):
1126                if not filecmp.cmp(fsrc, fdest):
1127                    logger.error("onefile() will not overwrite %s with %s\n" % (os.path.join(os.getcwd(), os.path.basename(fnm)),os.path.abspath(fnm)))
1128                    raise RuntimeError
1129                else:
1130                    logger.info("\x1b[93monefile() says the files %s and %s are identical\x1b[0m\n" % (os.path.abspath(fnm), os.getcwd()))
1131                    answer[ifnm] = os.path.basename(fnm)
1132            else:
1133                logger.info("\x1b[93monefile() will copy %s to %s\x1b[0m\n" % (os.path.abspath(fnm), os.getcwd()))
1134                shutil.copy2(fsrc, fdest)
1135                answer[ifnm] = os.path.basename(fnm)
1136    os.chdir(cwd)
1137    return answer
1138
1139def extract_tar(tarfnm, fnms, force=False):
1140    """
1141    Extract a list of files from .tar archive with any compression.
1142    The file is extracted to the base folder of the archive.
1143
1144    Parameters
1145    ----------
1146    tarfnm :
1147        Name of the archive file.
1148    fnms : str or list
1149        File names to be extracted.
1150    force : bool, optional
1151        If true, then force extraction of file even if they already exist on disk.
1152    """
1153    # Get path of tar file.
1154    fdir = os.path.abspath(os.path.dirname(tarfnm))
1155    # If all files exist, then return - no need to extract.
1156    if (not force) and all([os.path.exists(os.path.join(fdir, f)) for f in fnms]): return
1157    # If the tar file doesn't exist or isn't valid, do nothing.
1158    if not os.path.exists(tarfnm): return
1159    if not tarfile.is_tarfile(tarfnm): return
1160    # Check type of fnms argument.
1161    if isinstance(fnms, six.string_types): fnms = [fnms]
1162    # Load the tar file.
1163    arch = tarfile.open(tarfnm, 'r')
1164    # Extract only the files we have (to avoid an exception).
1165    all_members = arch.getmembers()
1166    all_names = [f.name for f in all_members]
1167    members = [f for f in all_members if f.name in fnms]
1168    # Extract files to the destination.
1169    arch.extractall(fdir, members=members)
1170
1171def GoInto(Dir):
1172    if os.path.exists(Dir):
1173        if os.path.isdir(Dir): pass
1174        else:
1175            logger.error("Tried to create directory %s, it exists but isn't a directory\n" % newdir)
1176            raise RuntimeError
1177    else:
1178        os.makedirs(Dir)
1179    os.chdir(Dir)
1180
1181def allsplit(Dir):
1182    # Split a directory into all directories involved.
1183    s = os.path.split(os.path.normpath(Dir))
1184    if s[1] == '' or s[1] == '.' : return []
1185    return allsplit(s[0]) + [s[1]]
1186
1187def Leave(Dir):
1188    if os.path.split(os.getcwd())[1] != Dir:
1189        logger.error("Trying to leave directory %s, but we're actually in directory %s (check your code)\n" % (Dir,os.path.split(os.getcwd())[1]))
1190        raise RuntimeError
1191    for i in range(len(allsplit(Dir))):
1192        os.chdir('..')
1193
1194# Dictionary containing specific error messages for specific missing files or file patterns
1195specific_lst = [(['mdrun','grompp','trjconv','g_energy','g_traj'], "Make sure to install GROMACS and add it to your path (or set the gmxpath option)"),
1196                (['force.mdin', 'stage.leap'], "This file is needed for setting up AMBER force matching targets"),
1197                (['conf.pdb', 'mono.pdb'], "This file is needed for setting up OpenMM condensed phase property targets"),
1198                (['liquid.xyz', 'liquid.key', 'mono.xyz', 'mono.key'], "This file is needed for setting up OpenMM condensed phase property targets"),
1199                (['dynamic', 'analyze', 'minimize', 'testgrad', 'vibrate', 'optimize', 'polarize', 'superpose'], "Make sure to install TINKER and add it to your path (or set the tinkerpath option)"),
1200                (['runcuda.sh', 'npt.py', 'npt_tinker.py'], "This file belongs in the ForceBalance source directory, not sure why it is missing"),
1201                (['input.xyz'], "This file is needed for TINKER molecular property targets"),
1202                (['.*key$', '.*xyz$'], "I am guessing this file is probably needed by TINKER"),
1203                (['.*gro$', '.*top$', '.*itp$', '.*mdp$', '.*ndx$'], "I am guessing this file is probably needed by GROMACS")
1204                ]
1205
1206# Build a dictionary mapping all of the keys in the above lists to their error messages
1207specific_dct = dict(list(itertools.chain(*[[(j,i[1]) for j in i[0]] for i in specific_lst])))
1208
1209def MissingFileInspection(fnm):
1210    fnm = os.path.split(fnm)[1]
1211    answer = ""
1212    for key in specific_dct:
1213        if answer == "":
1214            answer += "\n"
1215        if re.match(key, fnm):
1216            answer += "%s\n" % specific_dct[key]
1217    return answer
1218
1219def wopen(dest, binary=False):
1220    """ If trying to write to a symbolic link, remove it first. """
1221    if os.path.islink(dest):
1222        logger.warning("Trying to write to a symbolic link %s, removing it first\n" % dest)
1223        os.unlink(dest)
1224    if binary:
1225        return open(dest,'wb')
1226    else:
1227        return open(dest,'w')
1228
1229def LinkFile(src, dest, nosrcok = False):
1230    if os.path.abspath(src) == os.path.abspath(dest): return
1231    if os.path.exists(src):
1232        # Remove broken link
1233        if os.path.islink(dest) and not os.path.exists(dest):
1234            os.remove(dest)
1235            os.symlink(src, dest)
1236        elif os.path.exists(dest):
1237            if os.path.islink(dest): pass
1238            else:
1239                logger.error("Tried to create symbolic link %s to %s, destination exists but isn't a symbolic link\n" % (src, dest))
1240                raise RuntimeError
1241        else:
1242            os.symlink(src, dest)
1243    else:
1244        if not nosrcok:
1245            logger.error("Tried to create symbolic link %s to %s, but source file doesn't exist%s\n" % (src,dest,MissingFileInspection(src)))
1246            raise RuntimeError
1247
1248
1249def CopyFile(src, dest):
1250    if os.path.exists(src):
1251        if os.path.exists(dest):
1252            if os.path.islink(dest):
1253                logger.error("Tried to copy %s to %s, destination exists but it's a symbolic link\n" % (src, dest))
1254                raise RuntimeError
1255        else:
1256            shutil.copy2(src, dest)
1257    else:
1258        logger.error("Tried to copy %s to %s, but source file doesn't exist%s\n" % (src,dest,MissingFileInspection(src)))
1259        raise RuntimeError
1260
1261def link_dir_contents(abssrcdir, absdestdir):
1262    for fnm in os.listdir(abssrcdir):
1263        srcfnm = os.path.join(abssrcdir, fnm)
1264        destfnm = os.path.join(absdestdir, fnm)
1265        if os.path.islink(destfnm) and not os.path.exists(destfnm):
1266            os.remove(destfnm)
1267        if os.path.isfile(srcfnm) or (os.path.isdir(srcfnm) and fnm == 'IC'):
1268            if not os.path.exists(destfnm):
1269                #print "Linking %s to %s" % (srcfnm, destfnm)
1270                os.symlink(srcfnm, destfnm)
1271
1272def remove_if_exists(fnm):
1273    """ Remove the file if it exists (doesn't return an error). """
1274    if os.path.exists(fnm):
1275        os.remove(fnm)
1276
1277def which(fnm):
1278    # Get the location of a file.  Works only on UNIX-like file systems.
1279    try:
1280        return os.path.split(os.popen('which %s 2> /dev/null' % fnm).readlines()[0].strip())[0]
1281    except:
1282        return ''
1283
1284# Thanks to cesarkawakami on #python (IRC freenode) for this code.
1285class LineChunker(object):
1286    def __init__(self, callback):
1287        self.callback = callback
1288        self.buf = ""
1289
1290    def push(self, data):
1291        # Added by LPW during Py3 compatibility; ran into some trouble decoding strings such as
1292        # "a" with umlaut on top.  I guess we can ignore these for now.  For some reason,
1293        # Py2 never required decoding of data, I can simply add it to the wtring.
1294        # self.buf += data # Old Py2 code...
1295        self.buf += data.decode('utf-8')#errors='ignore')
1296        self.nomnom()
1297
1298    def close(self):
1299        if self.buf:
1300            self.callback(self.buf + "\n")
1301
1302    def nomnom(self):
1303        # Splits buffer by new line or carriage return, and passes
1304        # the splitted results onto processing.
1305        while "\n" in self.buf or "\r" in self.buf:
1306            chunk, sep, self.buf = re.split(r"(\r|\n)", self.buf, maxsplit=1)
1307            self.callback(chunk + sep)
1308
1309    def __enter__(self):
1310        return self
1311
1312    def __exit__(self, *args, **kwargs):
1313        self.close()
1314
1315def _exec(command, print_to_screen = False, outfnm = None, logfnm = None, stdin = "", print_command = True, copy_stdout = True, copy_stderr = False, persist = False, expand_cr=False, print_error=True, rbytes=1, cwd=None, **kwargs):
1316    """Runs command line using subprocess, optionally returning stdout.
1317    Options:
1318    command (required) = Name of the command you want to execute
1319    outfnm (optional) = Name of the output file name (overwritten if exists)
1320    logfnm (optional) = Name of the log file name (appended if exists)
1321    stdin (optional) = A string to be passed to stdin, as if it were typed (use newline character to mimic Enter key)
1322    print_command = Whether to print the command.
1323    copy_stdout = Copy the stdout stream; can set to False in strange situations
1324    copy_stderr = Copy the stderr stream to the stdout stream; useful for GROMACS which prints out everything to stderr (argh.)
1325    expand_cr = Whether to expand carriage returns into newlines (useful for GROMACS mdrun).
1326    print_error = Whether to print error messages on a crash. Should be true most of the time.
1327    persist = Continue execution even if the command gives a nonzero return code.
1328    rbytes = Number of bytes to read from stdout and stderr streams at a time.  GMX requires rbytes = 1 otherwise streams are interleaved.  Higher values for speed.
1329    """
1330
1331    # Dictionary of options to be passed to the Popen object.
1332    cmd_options={'shell':isinstance(command, six.string_types), 'stdin':PIPE, 'stdout':PIPE, 'stderr':PIPE, 'universal_newlines':expand_cr, 'cwd':cwd}
1333
1334    # If the current working directory is provided, the outputs will be written to there as well.
1335    if cwd is not None:
1336        if outfnm is not None:
1337            outfnm = os.path.abspath(os.path.join(cwd, outfnm))
1338        if logfnm is not None:
1339            logfnm = os.path.abspath(os.path.join(cwd, logfnm))
1340
1341    # "write to file" : Function for writing some characters to the log and/or output files.
1342    def wtf(out):
1343        if logfnm is not None:
1344            with open(logfnm,'ab+') as f:
1345                f.write(out.encode('utf-8'))
1346                f.flush()
1347        if outfnm is not None:
1348            with open(outfnm,'wb+' if wtf.first else 'ab+') as f:
1349                f.write(out.encode('utf-8'))
1350                f.flush()
1351        wtf.first = False
1352    wtf.first = True
1353
1354    # Preserve backwards compatibility; sometimes None gets passed to stdin.
1355    if stdin is None: stdin = ""
1356
1357    if print_command:
1358        logger.info("Executing process: \x1b[92m%-50s\x1b[0m%s%s%s%s\n" % (' '.join(command) if type(command) is list else command,
1359                                                               " In: %s" % cwd if cwd is not None else "",
1360                                                               " Output: %s" % outfnm if outfnm is not None else "",
1361                                                               " Append: %s" % logfnm if logfnm is not None else "",
1362                                                               (" Stdin: %s" % stdin.replace('\n','\\n')) if stdin else ""))
1363        wtf("Executing process: %s%s\n" % (command, (" Stdin: %s" % stdin.replace('\n','\\n')) if stdin else ""))
1364
1365    cmd_options.update(kwargs)
1366    p = subprocess.Popen(command, **cmd_options)
1367
1368    # Write the stdin stream to the process.
1369    p.stdin.write(stdin.encode('ascii'))
1370    p.stdin.close()
1371
1372    #===============================================================#
1373    #| Read the output streams from the process.  This is a bit    |#
1374    #| complicated because programs like GROMACS tend to print out |#
1375    #| stdout as well as stderr streams, and also carriage returns |#
1376    #| along with newline characters.                              |#
1377    #===============================================================#
1378    # stdout and stderr streams of the process.
1379    streams = [p.stdout, p.stderr]
1380    # These are functions that take chunks of lines (read) as inputs.
1381    def process_out(read):
1382        if print_to_screen: sys.stdout.write(str(read.encode('utf-8')))
1383        if copy_stdout:
1384            process_out.stdout.append(read)
1385            wtf(read)
1386    process_out.stdout = []
1387
1388    def process_err(read):
1389        if print_to_screen: sys.stderr.write(str(read.encode('utf-8')))
1390        process_err.stderr.append(read)
1391        if copy_stderr:
1392            process_out.stdout.append(read)
1393            wtf(read)
1394    process_err.stderr = []
1395    # This reads the streams one byte at a time, and passes it to the LineChunker
1396    # which splits it by either newline or carriage return.
1397    # If the stream has ended, then it is removed from the list.
1398    with LineChunker(process_out) as out_chunker, LineChunker(process_err) as err_chunker:
1399        while True:
1400            to_read, _, _ = select(streams, [], [])
1401            for fh in to_read:
1402                if fh is p.stdout:
1403                    read_nbytes = 0
1404                    read = ''.encode('utf-8')
1405                    while True:
1406                        if read_nbytes == 0:
1407                            read += fh.read(rbytes)
1408                            read_nbytes += rbytes
1409                        else:
1410                            read += fh.read(1)
1411                            read_nbytes += 1
1412                        if read_nbytes > 10+rbytes:
1413                            raise RuntimeError("Failed to decode stdout from external process.")
1414                        if not read:
1415                            streams.remove(p.stdout)
1416                            p.stdout.close()
1417                            break
1418                        else:
1419                            try:
1420                                out_chunker.push(read)
1421                                break
1422                            except UnicodeDecodeError:
1423                                pass
1424                elif fh is p.stderr:
1425                    read_nbytes = 0
1426                    read = ''.encode('utf-8')
1427                    while True:
1428                        if read_nbytes == 0:
1429                            read += fh.read(rbytes)
1430                            read_nbytes += rbytes
1431                        else:
1432                            read += fh.read(1)
1433                            read_nbytes += 1
1434                        if read_nbytes > 10+rbytes:
1435                            raise RuntimeError("Failed to decode stderr from external process.")
1436                        if not read:
1437                            streams.remove(p.stderr)
1438                            p.stderr.close()
1439                            break
1440                        else:
1441                            try:
1442                                err_chunker.push(read)
1443                                break
1444                            except UnicodeDecodeError:
1445                                pass
1446                else:
1447                    raise RuntimeError
1448            if len(streams) == 0: break
1449
1450    p.wait()
1451
1452    process_out.stdout = ''.join(process_out.stdout)
1453    process_err.stderr = ''.join(process_err.stderr)
1454
1455    _exec.returncode = p.returncode
1456    if p.returncode != 0:
1457        if process_err.stderr and print_error:
1458            logger.warning("Received an error message:\n")
1459            logger.warning("\n[====] \x1b[91mError Message\x1b[0m [====]\n")
1460            logger.warning(process_err.stderr)
1461            logger.warning("[====] \x1b[91mEnd o'Message\x1b[0m [====]\n")
1462        if persist:
1463            if print_error:
1464                logger.info("%s gave a return code of %i (it may have crashed) -- carrying on\n" % (command, p.returncode))
1465        else:
1466            # This code (commented out) would not throw an exception, but instead exit with the returncode of the crashed program.
1467            # sys.stderr.write("\x1b[1;94m%s\x1b[0m gave a return code of %i (\x1b[91mit may have crashed\x1b[0m)\n" % (command, p.returncode))
1468            # sys.exit(p.returncode)
1469            logger.error("\x1b[1;94m%s\x1b[0m gave a return code of %i (\x1b[91mit may have crashed\x1b[0m)\n\n" % (command, p.returncode))
1470            raise RuntimeError
1471
1472    # Return the output in the form of a list of lines, so we can loop over it using "for line in output".
1473    Out = process_out.stdout.split('\n')
1474    if Out[-1] == '':
1475        Out = Out[:-1]
1476    return Out
1477_exec.returncode = None
1478
1479def warn_press_key(warning, timeout=10):
1480    logger.warning(warning + '\n')
1481    if sys.stdin.isatty():
1482        logger.warning("\x1b[1;91mPress Enter or wait %i seconds (I assume no responsibility for what happens after this!)\x1b[0m\n" % timeout)
1483        try:
1484            rlist, wlist, xlist = select([sys.stdin], [], [], timeout)
1485            if rlist:
1486                sys.stdin.readline()
1487        except: pass
1488
1489def warn_once(warning, warnhash = None):
1490    """ Prints a warning but will only do so once in a given run. """
1491    if warnhash is None:
1492        warnhash = warning
1493    if warnhash in warn_once.already:
1494        return
1495    warn_once.already.add(warnhash)
1496    if type(warning) is str:
1497        logger.info(warning + '\n')
1498    elif type(warning) is list:
1499        for line in warning:
1500            logger.info(line + '\n')
1501warn_once.already = set()
1502
1503#=========================================#
1504#| Development stuff (not commonly used) |#
1505#=========================================#
1506def concurrent_map(func, data):
1507    """
1508    Similar to the bultin function map(). But spawn a thread for each argument
1509    and apply `func` concurrently.
1510
1511    Note: unlike map(), we cannot take an iterable argument. `data` should be an
1512    indexable sequence.
1513    """
1514
1515    N = len(data)
1516    result = [None] * N
1517
1518    # wrapper to dispose the result in the right slot
1519    def task_wrapper(i):
1520        result[i] = func(data[i])
1521
1522    threads = [threading.Thread(target=task_wrapper, args=(i,)) for i in range(N)]
1523    for t in threads:
1524        t.start()
1525    for t in threads:
1526        t.join()
1527
1528    return result
1529