1#!/usr/bin/env python
2
3"""
4LAMMPS Replica Exchange Molecular Dynamics (REMD) trajectories are arranged by
5replica, i.e., each trajectory is a continuous replica that records all the
6ups and downs in temperature. However, often the requirement is trajectories
7that are continuous in temperature, which is achieved by this tool.
8
9Author:
10Tanmoy Sanyal, Shell lab, Chemical Engineering, UC Santa Barbara
11Email: tanmoy dot 7989 at gmail dot com
12
13Usage
14-----
15To get detailed information about the arguments, flags, etc use:
16python reorder_remd_traj.py -h or
17python reorder_remd_traj.py --help
18
19Features of this script
20-----------------------
21a) reorder LAMMPS REMD trajectories by temperature keeping only desired frames.
22Note: this only handles LAMMPS format trajectories (i.e. .lammpstrj format)
23Trajectories can be gzipped or bz2-compressed. The trajectories are assumed to
24be named as <prefix>.%d.lammpstrj[.gz or .bz2]
25
26b) (optionally) calculate configurational weights for each frame at each
27temperature if potential energies are supplied. But this if for the canonical
28(NVT) ensemble only.
29
30Dependencies
31------------
32mpi4py
33pymbar (for getting configurational weights)
34tqdm (for printing pretty progress bars)
35StringIO (or io if in Python 3.x)
36
37"""
38
39
40
41import os, numpy as np, argparse, time, pickle
42from scipy.special import logsumexp
43from mpi4py import MPI
44
45from tqdm import tqdm
46import gzip, bz2
47try:
48    # python-2
49    from StringIO import StringIO as IOBuffer
50except ImportError:
51    # python-3
52    from io import BytesIO as IOBuffer
53
54
55
56#### INITIALIZE MPI ####
57# (note that all output on screen will be printed only on the ROOT proc)
58ROOT = 0
59comm = MPI.COMM_WORLD
60me = comm.rank # my proc id
61nproc = comm.size
62
63
64#### HELPER FUNCTIONS ####
65def _get_nearest_temp(temps, query_temp):
66    """
67    Helper function to get the nearest temp in a list
68    from a given query_temp
69
70    :param temps: list of temps.
71
72    :param query_temp: query temp
73
74    Returns:
75    idx: index of nearest temp in the list
76
77    out_temp: nearest temp from the list
78    """
79
80    if isinstance(temps, list): temps = np.array(temps)
81    return temps[np.argmin(np.abs(temps-query_temp))]
82
83
84def readwrite(trajfn, mode):
85    """
86    Helper function for input/output LAMMPS traj files.
87    Trajectories may be plain text, .gz or .bz2 compressed.
88
89    :param trajfn: name of LAMMPS traj
90
91    :param mode: "r" ("w") and "rb" ("wb") depending on read or write
92
93    Returns: file pointer
94    """
95
96    if trajfn.endswith(".gz"):
97        of = gzip.open(trajfn, mode)
98        #return gzip.GzipFile(trajfn, mode)
99    elif trajfn.endswith(".bz2"):
100        of = bz2.open(trajfn, mode)
101        #return bz2.BZ2File(trajfn, mode)
102    else:
103        of = open(trajfn, mode)
104    return of
105
106
107def get_replica_frames(logfn, temps, nswap, writefreq):
108    """
109    Get a list of frames from each replica that is
110    at a particular temp. Do this for all temps.
111
112    :param logfn: master LAMMPS log file that contains the temp
113                  swap history of all replicas
114
115    :param temps: list of all temps used in the REMD simulation.
116
117    :param nswap: swap frequency of the REMD simulation
118
119    :param writefreq: traj dump frequency in LAMMPS
120
121    Returns: master_frametuple_dict:
122             dict containing a tuple (replica #, frame #) for each temp.
123    """
124
125    n_rep = len(temps)
126    swap_history = np.loadtxt(logfn, skiprows = 3)
127    master_frametuple_dict = dict( (n, []) for n in range(n_rep) )
128
129    # walk through the replicas
130    print("Getting frames from all replicas at temperature:")
131    for n in range(n_rep):
132        print("%3.2f K" % temps[n])
133        rep_inds = [np.where(x[1:] == n)[0][0] for x in swap_history]
134
135        # case-1: when frames are dumped faster than temp. swaps
136        if writefreq <= nswap:
137            for ii, i in enumerate(rep_inds[:-1]):
138                start = int(ii * nswap / writefreq)
139                stop = int( (ii+1) * nswap / writefreq)
140                [master_frametuple_dict[n].append( (i,x) ) \
141                                        for x in range(start, stop)]
142
143        # case-2: when temps. are swapped faster than dumping frames
144        else:
145            nskip = int(writefreq / nswap)
146            [master_frametuple_dict[n].append( (i,ii) ) \
147            for ii, i in enumerate(rep_inds[0::nskip])]
148
149    return master_frametuple_dict
150
151
152def get_byte_index(rep_inds, byteindfns, intrajfns):
153    """
154    Get byte indices from (un-ordered) trajectories.
155
156    :param rep_inds: indices of replicas to process on this proc
157
158    :param byteindsfns: list of filenames that will contain the byte indices
159
160    :param intrajfns: list of (unordered) input traj filenames
161    """
162    for n in rep_inds:
163        # check if the byte indices for this traj has already been computed
164        if os.path.isfile(byteindfns[n]): continue
165
166        # extract bytes
167        fobj = readwrite(intrajfns[n], "rb")
168        byteinds = [ [0,0] ]
169
170        # place file pointer at first line
171        nframe = 0
172        first_line = fobj.readline()
173        cur_pos = fobj.tell()
174
175        # status printed only for replica read on root proc
176        # this assumes that each proc takes roughly the same time
177        if me == ROOT:
178            pb = tqdm(desc = "Reading replicas", leave = True,
179                  position = ROOT + 2*me,
180                  unit = "B/replica", unit_scale = True,
181                  unit_divisor = 1024)
182
183        # start crawling through the bytes
184        while True:
185            next_line = fobj.readline()
186            if len(next_line) == 0: break
187            # this will only work with lammpstrj traj format.
188            # this condition essentially checks periodic recurrences
189            # of the token TIMESTEP. Each time it is found,
190            # we have crawled through a frame (snapshot)
191            if next_line == first_line:
192                nframe += 1
193                byteinds.append( [nframe, cur_pos] )
194                if me == ROOT: pb.update()
195            cur_pos = fobj.tell()
196            if me == ROOT: pb.update(0)
197        if me == ROOT: pb.close()
198
199        # take care of the EOF
200        cur_pos = fobj.tell()
201        byteinds.append( [nframe+1, cur_pos] ) # dummy index for the EOF
202
203        # write to file
204        np.savetxt(byteindfns[n], np.array(byteinds), fmt = "%d")
205
206        # close the trajfile object
207        fobj.close()
208
209        return
210
211
212def write_reordered_traj(temp_inds, byte_inds, outtemps, temps,
213                         frametuple_dict, nprod, writefreq,
214                         outtrajfns, infobjs):
215    """
216    Reorders trajectories by temp. and writes them to disk
217
218    :param temp_inds: list index of temps (in the list of all temps) for which
219                      reordered trajs will be produced on this proc.
220
221    :param byte_inds: dict containing the (previously stored) byte indices
222                      for each replica file (key = replica number)
223
224    :param outtemps: list of all temps for which to produce reordered trajs.
225
226    :param temps: list of all temps used in the REMD simulation.
227
228    :param outtrajfns: list of filenames for output (ordered) trajs.
229
230    :param frametuple_dict: dict containing a tuple (replica #, frame #)
231                            for each temp.
232
233    :param nprod: number of production timesteps.
234                  Last (nprod / writefreq) frames
235                  from the end will be written to disk.
236
237    :param writefreq: traj dump frequency in LAMMPS
238
239    :param infobjs: list of file pointers to input (unordered) trajs.
240    """
241
242    nframes = int(nprod / writefreq)
243
244    for n in temp_inds:
245        # open string-buffer and file
246        buf = IOBuffer()
247        of = readwrite(outtrajfns[n], "wb")
248
249        # get frames
250        abs_temp_ind = np.argmin( abs(temps - outtemps[n]) )
251        frametuple = frametuple_dict[abs_temp_ind][-nframes:]
252
253        # write frames to buffer
254        if me == ROOT:
255            pb = tqdm(frametuple,
256                  desc = ("Buffering trajectories for writing"),
257                  leave = True, position = ROOT + 2*me,
258                  unit = 'frame/replica', unit_scale = True)
259
260            iterable = pb
261        else:
262            iterable = frametuple
263
264        for i, (rep, frame) in enumerate(iterable):
265            infobj = infobjs[rep]
266            start_ptr = int(byte_inds[rep][frame,1])
267            stop_ptr = int(byte_inds[rep][frame+1,1])
268            byte_len = stop_ptr - start_ptr
269            infobj.seek(start_ptr)
270            buf.write(infobj.read(byte_len))
271        if me == ROOT: pb.close()
272
273        # write buffer to disk
274        if me == ROOT: print("Writing buffer to file")
275        of.write(buf.getvalue())
276        of.close()
277        buf.close()
278
279    for i in infobjs: i.close()
280
281    return
282
283
284def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq,
285                       kB):
286    """
287    Gets configurational log-weights (logw) for each frame and at each temp.
288    from the REMD simulation. ONLY WRITTEN FOR THE CANONICAL (NVT) ensemble.
289
290    This weights can be used to calculate the
291    ensemble averaged value of any simulation observable X at a given temp. T :
292    <X> (T) = \sum_{k=1, ntemps} \sum_{n=1, nframes} w[idx][k,n] X[k,n]
293    where nframes is the number of frames to use from each *reordered* traj
294
295    :param enefn: ascii file (readable by numpy.loadtxt) containing an array
296                  u[r,n] of *total* potential energy for the n-th frame for
297                  the r-th replica.
298
299    :param frametuple_dict: dict containing a tuple (replica #, frame #)
300                            for each temp.
301
302    :param temps: array of temps. used in the REMD simulation
303
304    :param nprod: number of production timesteps. Last (nprod / writefreq)
305                  frames from the end will be written to disk.
306
307    :param writefreq: traj dump frequency in LAMMPS
308
309    :param kB : Boltzmann constant to set the energy scale.
310                Default is in kcal/mol
311
312    Returns: logw: dict, logw[l][k,n] gives the log weights from the
313                   n-th frame of the k-th temp. *ordered* trajectory
314                   to reweight to the l-th temp.
315
316    """
317
318    try:
319        import pymbar
320    except ImportError:
321        print("""
322              Configurational log-weight calculation requires pymbar.
323              Here are some options to install it:
324              conda install -c omnia pymbar
325              pip install --user pymbar
326              sudo pip install pymbar
327
328              To install the dev. version directly from github, use:
329              pip install pip install git+https://github.com/choderalab/pymbar.git
330              """)
331
332    u_rn = np.loadtxt(enefn)
333    ntemps = u_rn.shape[0] # number of temps.
334    nframes = int(nprod / writefreq) # number of frames at each temp.
335
336    # reorder the temps
337    u_kn = np.zeros([ntemps, nframes], float)
338    for k in range(ntemps):
339        frame_tuple = frametuple_dict[k][-nframes:]
340        for i, (rep, frame) in enumerate(frame_tuple):
341            u_kn[k, i] = u_rn[rep, frame]
342
343    # prep input for pymbar
344    #1) array of frames at each temp.
345    nframes_k = nframes * np.ones(ntemps, np.uint8)
346
347    #2) inverse temps. for chosen energy scale
348    beta_k = 1.0 / (kB * temps)
349
350    #3) get reduced energies (*ONLY FOR THE CANONICAL ENSEMBLE*)
351    u_kln = np.zeros([ntemps, ntemps, nframes], float)
352    for k in range(ntemps):
353        u_kln[k] = np.outer(beta_k, u_kn[k])
354
355    # run pymbar and extract the free energies
356    print("\nRunning pymbar...")
357    mbar = pymbar.mbar.MBAR(u_kln, nframes_k, verbose = True)
358    f_k = mbar.f_k # (1 x k array)
359
360    # calculate the log-weights
361    print("\nExtracting log-weights...")
362    log_nframes = np.log(nframes)
363    logw = dict( (k, np.zeros([ntemps, nframes], float)) for k in range(ntemps) )
364    # get log-weights to reweight to this temp.
365    for k in range(ntemps):
366        for n in range(nframes):
367            num = -beta_k[k] * u_kn[k,n]
368            denom = f_k - beta_k[k] * u_kn[k,n]
369            for l in range(ntemps):
370                logw[l][k,n] = num - logsumexp(denom) - log_nframes
371
372    return logw
373
374
375
376#### MAIN WORKFLOW ####
377if __name__ == "__main__":
378    # accept user inputs
379    parser = argparse.ArgumentParser(description = __doc__,
380             formatter_class = argparse.RawDescriptionHelpFormatter)
381
382    parser.add_argument("prefix",
383                        help = "Prefix of REMD LAMMPS trajectories.\
384                        Supply full path. Trajectories assumed to be named as \
385                        <prefix>.%%d.lammpstrj. \
386                        Can be in compressed (.gz or .bz2) format. \
387                        This is a required argument")
388
389    parser.add_argument("-logfn", "--logfn", default = "log.lammps",
390                        help = "LAMMPS log file that contains swap history \
391                        of temperatures among replicas. \
392                        Default = 'lammps.log'")
393
394    parser.add_argument("-tfn", "--tempfn", default = "temps.txt",
395                        help = "ascii file (readable by numpy.loadtxt) with \
396                        the temperatures used in the REMD simulation.")
397
398    parser.add_argument("-ns", "--nswap", type = int,
399                        help = "Swap frequency used in LAMMPS temper command")
400
401    parser.add_argument("-nw", "--nwrite", type = int, default = 1,
402                        help = "Trajectory writing frequency used \
403                        in LAMMPS dump command")
404
405    parser.add_argument("-np", "--nprod", type = int, default = 0,
406                        help = "Number of timesteps to save in the reordered\
407                        trajectories.\
408                        This should be in units of the LAMMPS timestep")
409
410    parser.add_argument("-logw", "--logw", action = 'store_true',
411                        help = "Supplying this flag \
412                        calculates *canonical* (NVT ensemble) log weights")
413
414    parser.add_argument("-e", "--enefn",
415                        help = "File that has n_replica x n_frames array\
416                        of total potential energies")
417
418    parser.add_argument("-kB", "--boltzmann_const",
419                        type = float, default = 0.001987,
420                        help = "Boltzmann constant in appropriate units. \
421                        Default is kcal/mol")
422
423    parser.add_argument("-ot", "--out_temps", nargs = '+', type = np.float64,
424                        help = "Reorder trajectories at these temperatures.\n \
425                        Default is all temperatures used in the simulation")
426
427    parser.add_argument("-od", "--outdir", default = ".",
428                        help = "All output will be saved to this directory")
429
430    # parse inputs
431    args = parser.parse_args()
432    traj_prefix = os.path.abspath(args.prefix)
433    logfn = os.path.abspath(args.logfn)
434    tempfn = os.path.abspath(args.tempfn)
435
436    nswap = args.nswap
437    writefreq = args.nwrite
438    nprod = args.nprod
439
440    enefn = args.enefn
441    if not enefn is None: enefn = os.path.abspath(enefn)
442    get_logw = args.logw
443    kB = args.boltzmann_const
444
445    out_temps = args.out_temps
446    outdir = os.path.abspath(args.outdir)
447    if not os.path.isdir(outdir):
448        if me == ROOT: os.mkdir(outdir)
449
450    # check that all input files are present (only on the ROOT proc)
451    if me == ROOT:
452        if not os.path.isfile(tempfn):
453            raise IOError("Temperature file %s not found." % tempfn)
454        elif not os.path.isfile(logfn):
455            raise IOError("LAMMPS log file %s not found." % logfn)
456        elif get_logw and not os.path.isfile(enefn):
457            raise IOError("Canonical log-weight calculation requested but\
458                          energy file %s not found" % enefn)
459
460    # get (unordered) trajectories
461    temps = np.loadtxt(tempfn)
462    ntemps = len(temps)
463    intrajfns = ["%s.%d.lammpstrj" % (traj_prefix, k) for k in range(ntemps)]
464    # check if the trajs. (or their zipped versions are present)
465    for i in range(ntemps):
466        this_intrajfn = intrajfns[i]
467        x = this_intrajfn + ".gz"
468        if os.path.isfile(this_intrajfn): continue
469        elif os.path.isfile(this_intrajfn + ".gz"):
470            intrajfns[i] = this_intrajfn + ".gz"
471        elif os.path.isfile(this_intrajfn + ".bz2"):
472            intrajfns[i] = this_intrajfn + ".bz2"
473        else:
474            if me == ROOT:
475                raise IOError("Trajectory for replica # %d missing" % i)
476
477    # set output filenames
478    outprefix = os.path.join(outdir, traj_prefix.split('/')[-1])
479    outtrajfns = ["%s.%3.2f.lammpstrj.gz" % \
480                 (outprefix, _get_nearest_temp(temps, t)) \
481                  for t in out_temps]
482    byteindfns = [os.path.join(outdir, ".byteind_%d.gz" % k) \
483                  for k in range(ntemps)]
484    frametuplefn = outprefix + '.frametuple.pickle'
485    if get_logw:
486        logwfn = outprefix + ".logw.pickle"
487
488
489    # get a list of all frames at a particular temp visited by each replica
490    # this is fast so run only on ROOT proc.
491    master_frametuple_dict = {}
492    if me == ROOT:
493        master_frametuple_dict = get_replica_frames(logfn = logfn,
494                                                    temps = temps,
495                                                    nswap = nswap,
496                                                    writefreq = writefreq)
497        # save to a pickle from the ROOT proc
498        with open(frametuplefn, 'wb') as of:
499            pickle.dump(master_frametuple_dict, of)
500
501    # broadcast to all procs
502    master_frametuple_dict = comm.bcast(master_frametuple_dict, root = ROOT)
503
504    # define a chunk of replicas  to process on each proc
505    CHUNKSIZE_1 = int(ntemps/nproc)
506    if me < nproc - 1:
507        my_rep_inds = range( (me*CHUNKSIZE_1), (me+1)*CHUNKSIZE_1 )
508    else:
509        my_rep_inds = range( (me*CHUNKSIZE_1), ntemps )
510
511    # get byte indices from replica (un-ordered) trajs. in parallel
512    get_byte_index(rep_inds = my_rep_inds,
513                   byteindfns = byteindfns,
514                   intrajfns = intrajfns)
515
516    # block until all procs have finished
517    comm.barrier()
518
519    # open all replica files for reading
520    infobjs = [readwrite(i, "rb") for i in intrajfns]
521
522    # open all byteindex files
523    byte_inds = dict( (i, np.loadtxt(fn)) for i, fn in enumerate(byteindfns) )
524
525    # define a chunk of output trajs. to process for each proc.
526    # # of reordered trajs. to write may be less than the total # of replicas
527    # which is usually equal to the requested nproc. If that is indeed the case,
528    # retire excess procs
529    n_out_temps = len(out_temps)
530    CHUNKSIZE_2 = int(n_out_temps / nproc)
531    if CHUNKSIZE_2 == 0:
532        nproc_active = n_out_temps
533        CHUNKSIZE_2 = 1
534        if me == ROOT:
535            print("\nReleasing %d excess procs" % (nproc - nproc_active))
536    else:
537        nproc_active = nproc
538    if me < nproc_active-1:
539        my_temp_inds = range( (me*CHUNKSIZE_2), (me+1)*CHUNKSIZE_1 )
540    else:
541        my_temp_inds = range( (me*CHUNKSIZE_2), n_out_temps)
542
543    # retire the excess procs
544    # dont' forget to close any open file objects
545    if me >= nproc_active:
546        for fobj in infobjs: fobj.close()
547        exit()
548
549    # write reordered trajectories to disk from active procs in parallel
550    write_reordered_traj(temp_inds = my_temp_inds,
551                         byte_inds = byte_inds,
552                         outtemps = out_temps, temps = temps,
553                         frametuple_dict = master_frametuple_dict,
554                         nprod = nprod, writefreq = writefreq,
555                         outtrajfns = outtrajfns,
556                         infobjs = infobjs)
557
558    # calculate canonical log-weights if requested
559    # usually this is very fast so retire all but the ROOT proc
560    if not get_logw: exit()
561    if not me == ROOT: exit()
562
563    logw = get_canonical_logw(enefn = enefn, temps = temps,
564                              frametuple_dict = master_frametuple_dict,
565                              nprod = nprod, writefreq = writefreq,
566                              kB = kB)
567
568
569    # save the logweights to a pickle
570    with open(logwfn, 'wb') as of:
571        pickle.dump(logw, of)
572
573
574