1#!/usr/bin/env python 2 3""" 4LAMMPS Replica Exchange Molecular Dynamics (REMD) trajectories are arranged by 5replica, i.e., each trajectory is a continuous replica that records all the 6ups and downs in temperature. However, often the requirement is trajectories 7that are continuous in temperature, which is achieved by this tool. 8 9Author: 10Tanmoy Sanyal, Shell lab, Chemical Engineering, UC Santa Barbara 11Email: tanmoy dot 7989 at gmail dot com 12 13Usage 14----- 15To get detailed information about the arguments, flags, etc use: 16python reorder_remd_traj.py -h or 17python reorder_remd_traj.py --help 18 19Features of this script 20----------------------- 21a) reorder LAMMPS REMD trajectories by temperature keeping only desired frames. 22Note: this only handles LAMMPS format trajectories (i.e. .lammpstrj format) 23Trajectories can be gzipped or bz2-compressed. The trajectories are assumed to 24be named as <prefix>.%d.lammpstrj[.gz or .bz2] 25 26b) (optionally) calculate configurational weights for each frame at each 27temperature if potential energies are supplied. But this if for the canonical 28(NVT) ensemble only. 29 30Dependencies 31------------ 32mpi4py 33pymbar (for getting configurational weights) 34tqdm (for printing pretty progress bars) 35StringIO (or io if in Python 3.x) 36 37""" 38 39 40 41import os, numpy as np, argparse, time, pickle 42from scipy.special import logsumexp 43from mpi4py import MPI 44 45from tqdm import tqdm 46import gzip, bz2 47try: 48 # python-2 49 from StringIO import StringIO as IOBuffer 50except ImportError: 51 # python-3 52 from io import BytesIO as IOBuffer 53 54 55 56#### INITIALIZE MPI #### 57# (note that all output on screen will be printed only on the ROOT proc) 58ROOT = 0 59comm = MPI.COMM_WORLD 60me = comm.rank # my proc id 61nproc = comm.size 62 63 64#### HELPER FUNCTIONS #### 65def _get_nearest_temp(temps, query_temp): 66 """ 67 Helper function to get the nearest temp in a list 68 from a given query_temp 69 70 :param temps: list of temps. 71 72 :param query_temp: query temp 73 74 Returns: 75 idx: index of nearest temp in the list 76 77 out_temp: nearest temp from the list 78 """ 79 80 if isinstance(temps, list): temps = np.array(temps) 81 return temps[np.argmin(np.abs(temps-query_temp))] 82 83 84def readwrite(trajfn, mode): 85 """ 86 Helper function for input/output LAMMPS traj files. 87 Trajectories may be plain text, .gz or .bz2 compressed. 88 89 :param trajfn: name of LAMMPS traj 90 91 :param mode: "r" ("w") and "rb" ("wb") depending on read or write 92 93 Returns: file pointer 94 """ 95 96 if trajfn.endswith(".gz"): 97 of = gzip.open(trajfn, mode) 98 #return gzip.GzipFile(trajfn, mode) 99 elif trajfn.endswith(".bz2"): 100 of = bz2.open(trajfn, mode) 101 #return bz2.BZ2File(trajfn, mode) 102 else: 103 of = open(trajfn, mode) 104 return of 105 106 107def get_replica_frames(logfn, temps, nswap, writefreq): 108 """ 109 Get a list of frames from each replica that is 110 at a particular temp. Do this for all temps. 111 112 :param logfn: master LAMMPS log file that contains the temp 113 swap history of all replicas 114 115 :param temps: list of all temps used in the REMD simulation. 116 117 :param nswap: swap frequency of the REMD simulation 118 119 :param writefreq: traj dump frequency in LAMMPS 120 121 Returns: master_frametuple_dict: 122 dict containing a tuple (replica #, frame #) for each temp. 123 """ 124 125 n_rep = len(temps) 126 swap_history = np.loadtxt(logfn, skiprows = 3) 127 master_frametuple_dict = dict( (n, []) for n in range(n_rep) ) 128 129 # walk through the replicas 130 print("Getting frames from all replicas at temperature:") 131 for n in range(n_rep): 132 print("%3.2f K" % temps[n]) 133 rep_inds = [np.where(x[1:] == n)[0][0] for x in swap_history] 134 135 # case-1: when frames are dumped faster than temp. swaps 136 if writefreq <= nswap: 137 for ii, i in enumerate(rep_inds[:-1]): 138 start = int(ii * nswap / writefreq) 139 stop = int( (ii+1) * nswap / writefreq) 140 [master_frametuple_dict[n].append( (i,x) ) \ 141 for x in range(start, stop)] 142 143 # case-2: when temps. are swapped faster than dumping frames 144 else: 145 nskip = int(writefreq / nswap) 146 [master_frametuple_dict[n].append( (i,ii) ) \ 147 for ii, i in enumerate(rep_inds[0::nskip])] 148 149 return master_frametuple_dict 150 151 152def get_byte_index(rep_inds, byteindfns, intrajfns): 153 """ 154 Get byte indices from (un-ordered) trajectories. 155 156 :param rep_inds: indices of replicas to process on this proc 157 158 :param byteindsfns: list of filenames that will contain the byte indices 159 160 :param intrajfns: list of (unordered) input traj filenames 161 """ 162 for n in rep_inds: 163 # check if the byte indices for this traj has already been computed 164 if os.path.isfile(byteindfns[n]): continue 165 166 # extract bytes 167 fobj = readwrite(intrajfns[n], "rb") 168 byteinds = [ [0,0] ] 169 170 # place file pointer at first line 171 nframe = 0 172 first_line = fobj.readline() 173 cur_pos = fobj.tell() 174 175 # status printed only for replica read on root proc 176 # this assumes that each proc takes roughly the same time 177 if me == ROOT: 178 pb = tqdm(desc = "Reading replicas", leave = True, 179 position = ROOT + 2*me, 180 unit = "B/replica", unit_scale = True, 181 unit_divisor = 1024) 182 183 # start crawling through the bytes 184 while True: 185 next_line = fobj.readline() 186 if len(next_line) == 0: break 187 # this will only work with lammpstrj traj format. 188 # this condition essentially checks periodic recurrences 189 # of the token TIMESTEP. Each time it is found, 190 # we have crawled through a frame (snapshot) 191 if next_line == first_line: 192 nframe += 1 193 byteinds.append( [nframe, cur_pos] ) 194 if me == ROOT: pb.update() 195 cur_pos = fobj.tell() 196 if me == ROOT: pb.update(0) 197 if me == ROOT: pb.close() 198 199 # take care of the EOF 200 cur_pos = fobj.tell() 201 byteinds.append( [nframe+1, cur_pos] ) # dummy index for the EOF 202 203 # write to file 204 np.savetxt(byteindfns[n], np.array(byteinds), fmt = "%d") 205 206 # close the trajfile object 207 fobj.close() 208 209 return 210 211 212def write_reordered_traj(temp_inds, byte_inds, outtemps, temps, 213 frametuple_dict, nprod, writefreq, 214 outtrajfns, infobjs): 215 """ 216 Reorders trajectories by temp. and writes them to disk 217 218 :param temp_inds: list index of temps (in the list of all temps) for which 219 reordered trajs will be produced on this proc. 220 221 :param byte_inds: dict containing the (previously stored) byte indices 222 for each replica file (key = replica number) 223 224 :param outtemps: list of all temps for which to produce reordered trajs. 225 226 :param temps: list of all temps used in the REMD simulation. 227 228 :param outtrajfns: list of filenames for output (ordered) trajs. 229 230 :param frametuple_dict: dict containing a tuple (replica #, frame #) 231 for each temp. 232 233 :param nprod: number of production timesteps. 234 Last (nprod / writefreq) frames 235 from the end will be written to disk. 236 237 :param writefreq: traj dump frequency in LAMMPS 238 239 :param infobjs: list of file pointers to input (unordered) trajs. 240 """ 241 242 nframes = int(nprod / writefreq) 243 244 for n in temp_inds: 245 # open string-buffer and file 246 buf = IOBuffer() 247 of = readwrite(outtrajfns[n], "wb") 248 249 # get frames 250 abs_temp_ind = np.argmin( abs(temps - outtemps[n]) ) 251 frametuple = frametuple_dict[abs_temp_ind][-nframes:] 252 253 # write frames to buffer 254 if me == ROOT: 255 pb = tqdm(frametuple, 256 desc = ("Buffering trajectories for writing"), 257 leave = True, position = ROOT + 2*me, 258 unit = 'frame/replica', unit_scale = True) 259 260 iterable = pb 261 else: 262 iterable = frametuple 263 264 for i, (rep, frame) in enumerate(iterable): 265 infobj = infobjs[rep] 266 start_ptr = int(byte_inds[rep][frame,1]) 267 stop_ptr = int(byte_inds[rep][frame+1,1]) 268 byte_len = stop_ptr - start_ptr 269 infobj.seek(start_ptr) 270 buf.write(infobj.read(byte_len)) 271 if me == ROOT: pb.close() 272 273 # write buffer to disk 274 if me == ROOT: print("Writing buffer to file") 275 of.write(buf.getvalue()) 276 of.close() 277 buf.close() 278 279 for i in infobjs: i.close() 280 281 return 282 283 284def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq, 285 kB): 286 """ 287 Gets configurational log-weights (logw) for each frame and at each temp. 288 from the REMD simulation. ONLY WRITTEN FOR THE CANONICAL (NVT) ensemble. 289 290 This weights can be used to calculate the 291 ensemble averaged value of any simulation observable X at a given temp. T : 292 <X> (T) = \sum_{k=1, ntemps} \sum_{n=1, nframes} w[idx][k,n] X[k,n] 293 where nframes is the number of frames to use from each *reordered* traj 294 295 :param enefn: ascii file (readable by numpy.loadtxt) containing an array 296 u[r,n] of *total* potential energy for the n-th frame for 297 the r-th replica. 298 299 :param frametuple_dict: dict containing a tuple (replica #, frame #) 300 for each temp. 301 302 :param temps: array of temps. used in the REMD simulation 303 304 :param nprod: number of production timesteps. Last (nprod / writefreq) 305 frames from the end will be written to disk. 306 307 :param writefreq: traj dump frequency in LAMMPS 308 309 :param kB : Boltzmann constant to set the energy scale. 310 Default is in kcal/mol 311 312 Returns: logw: dict, logw[l][k,n] gives the log weights from the 313 n-th frame of the k-th temp. *ordered* trajectory 314 to reweight to the l-th temp. 315 316 """ 317 318 try: 319 import pymbar 320 except ImportError: 321 print(""" 322 Configurational log-weight calculation requires pymbar. 323 Here are some options to install it: 324 conda install -c omnia pymbar 325 pip install --user pymbar 326 sudo pip install pymbar 327 328 To install the dev. version directly from github, use: 329 pip install pip install git+https://github.com/choderalab/pymbar.git 330 """) 331 332 u_rn = np.loadtxt(enefn) 333 ntemps = u_rn.shape[0] # number of temps. 334 nframes = int(nprod / writefreq) # number of frames at each temp. 335 336 # reorder the temps 337 u_kn = np.zeros([ntemps, nframes], float) 338 for k in range(ntemps): 339 frame_tuple = frametuple_dict[k][-nframes:] 340 for i, (rep, frame) in enumerate(frame_tuple): 341 u_kn[k, i] = u_rn[rep, frame] 342 343 # prep input for pymbar 344 #1) array of frames at each temp. 345 nframes_k = nframes * np.ones(ntemps, np.uint8) 346 347 #2) inverse temps. for chosen energy scale 348 beta_k = 1.0 / (kB * temps) 349 350 #3) get reduced energies (*ONLY FOR THE CANONICAL ENSEMBLE*) 351 u_kln = np.zeros([ntemps, ntemps, nframes], float) 352 for k in range(ntemps): 353 u_kln[k] = np.outer(beta_k, u_kn[k]) 354 355 # run pymbar and extract the free energies 356 print("\nRunning pymbar...") 357 mbar = pymbar.mbar.MBAR(u_kln, nframes_k, verbose = True) 358 f_k = mbar.f_k # (1 x k array) 359 360 # calculate the log-weights 361 print("\nExtracting log-weights...") 362 log_nframes = np.log(nframes) 363 logw = dict( (k, np.zeros([ntemps, nframes], float)) for k in range(ntemps) ) 364 # get log-weights to reweight to this temp. 365 for k in range(ntemps): 366 for n in range(nframes): 367 num = -beta_k[k] * u_kn[k,n] 368 denom = f_k - beta_k[k] * u_kn[k,n] 369 for l in range(ntemps): 370 logw[l][k,n] = num - logsumexp(denom) - log_nframes 371 372 return logw 373 374 375 376#### MAIN WORKFLOW #### 377if __name__ == "__main__": 378 # accept user inputs 379 parser = argparse.ArgumentParser(description = __doc__, 380 formatter_class = argparse.RawDescriptionHelpFormatter) 381 382 parser.add_argument("prefix", 383 help = "Prefix of REMD LAMMPS trajectories.\ 384 Supply full path. Trajectories assumed to be named as \ 385 <prefix>.%%d.lammpstrj. \ 386 Can be in compressed (.gz or .bz2) format. \ 387 This is a required argument") 388 389 parser.add_argument("-logfn", "--logfn", default = "log.lammps", 390 help = "LAMMPS log file that contains swap history \ 391 of temperatures among replicas. \ 392 Default = 'lammps.log'") 393 394 parser.add_argument("-tfn", "--tempfn", default = "temps.txt", 395 help = "ascii file (readable by numpy.loadtxt) with \ 396 the temperatures used in the REMD simulation.") 397 398 parser.add_argument("-ns", "--nswap", type = int, 399 help = "Swap frequency used in LAMMPS temper command") 400 401 parser.add_argument("-nw", "--nwrite", type = int, default = 1, 402 help = "Trajectory writing frequency used \ 403 in LAMMPS dump command") 404 405 parser.add_argument("-np", "--nprod", type = int, default = 0, 406 help = "Number of timesteps to save in the reordered\ 407 trajectories.\ 408 This should be in units of the LAMMPS timestep") 409 410 parser.add_argument("-logw", "--logw", action = 'store_true', 411 help = "Supplying this flag \ 412 calculates *canonical* (NVT ensemble) log weights") 413 414 parser.add_argument("-e", "--enefn", 415 help = "File that has n_replica x n_frames array\ 416 of total potential energies") 417 418 parser.add_argument("-kB", "--boltzmann_const", 419 type = float, default = 0.001987, 420 help = "Boltzmann constant in appropriate units. \ 421 Default is kcal/mol") 422 423 parser.add_argument("-ot", "--out_temps", nargs = '+', type = np.float64, 424 help = "Reorder trajectories at these temperatures.\n \ 425 Default is all temperatures used in the simulation") 426 427 parser.add_argument("-od", "--outdir", default = ".", 428 help = "All output will be saved to this directory") 429 430 # parse inputs 431 args = parser.parse_args() 432 traj_prefix = os.path.abspath(args.prefix) 433 logfn = os.path.abspath(args.logfn) 434 tempfn = os.path.abspath(args.tempfn) 435 436 nswap = args.nswap 437 writefreq = args.nwrite 438 nprod = args.nprod 439 440 enefn = args.enefn 441 if not enefn is None: enefn = os.path.abspath(enefn) 442 get_logw = args.logw 443 kB = args.boltzmann_const 444 445 out_temps = args.out_temps 446 outdir = os.path.abspath(args.outdir) 447 if not os.path.isdir(outdir): 448 if me == ROOT: os.mkdir(outdir) 449 450 # check that all input files are present (only on the ROOT proc) 451 if me == ROOT: 452 if not os.path.isfile(tempfn): 453 raise IOError("Temperature file %s not found." % tempfn) 454 elif not os.path.isfile(logfn): 455 raise IOError("LAMMPS log file %s not found." % logfn) 456 elif get_logw and not os.path.isfile(enefn): 457 raise IOError("Canonical log-weight calculation requested but\ 458 energy file %s not found" % enefn) 459 460 # get (unordered) trajectories 461 temps = np.loadtxt(tempfn) 462 ntemps = len(temps) 463 intrajfns = ["%s.%d.lammpstrj" % (traj_prefix, k) for k in range(ntemps)] 464 # check if the trajs. (or their zipped versions are present) 465 for i in range(ntemps): 466 this_intrajfn = intrajfns[i] 467 x = this_intrajfn + ".gz" 468 if os.path.isfile(this_intrajfn): continue 469 elif os.path.isfile(this_intrajfn + ".gz"): 470 intrajfns[i] = this_intrajfn + ".gz" 471 elif os.path.isfile(this_intrajfn + ".bz2"): 472 intrajfns[i] = this_intrajfn + ".bz2" 473 else: 474 if me == ROOT: 475 raise IOError("Trajectory for replica # %d missing" % i) 476 477 # set output filenames 478 outprefix = os.path.join(outdir, traj_prefix.split('/')[-1]) 479 outtrajfns = ["%s.%3.2f.lammpstrj.gz" % \ 480 (outprefix, _get_nearest_temp(temps, t)) \ 481 for t in out_temps] 482 byteindfns = [os.path.join(outdir, ".byteind_%d.gz" % k) \ 483 for k in range(ntemps)] 484 frametuplefn = outprefix + '.frametuple.pickle' 485 if get_logw: 486 logwfn = outprefix + ".logw.pickle" 487 488 489 # get a list of all frames at a particular temp visited by each replica 490 # this is fast so run only on ROOT proc. 491 master_frametuple_dict = {} 492 if me == ROOT: 493 master_frametuple_dict = get_replica_frames(logfn = logfn, 494 temps = temps, 495 nswap = nswap, 496 writefreq = writefreq) 497 # save to a pickle from the ROOT proc 498 with open(frametuplefn, 'wb') as of: 499 pickle.dump(master_frametuple_dict, of) 500 501 # broadcast to all procs 502 master_frametuple_dict = comm.bcast(master_frametuple_dict, root = ROOT) 503 504 # define a chunk of replicas to process on each proc 505 CHUNKSIZE_1 = int(ntemps/nproc) 506 if me < nproc - 1: 507 my_rep_inds = range( (me*CHUNKSIZE_1), (me+1)*CHUNKSIZE_1 ) 508 else: 509 my_rep_inds = range( (me*CHUNKSIZE_1), ntemps ) 510 511 # get byte indices from replica (un-ordered) trajs. in parallel 512 get_byte_index(rep_inds = my_rep_inds, 513 byteindfns = byteindfns, 514 intrajfns = intrajfns) 515 516 # block until all procs have finished 517 comm.barrier() 518 519 # open all replica files for reading 520 infobjs = [readwrite(i, "rb") for i in intrajfns] 521 522 # open all byteindex files 523 byte_inds = dict( (i, np.loadtxt(fn)) for i, fn in enumerate(byteindfns) ) 524 525 # define a chunk of output trajs. to process for each proc. 526 # # of reordered trajs. to write may be less than the total # of replicas 527 # which is usually equal to the requested nproc. If that is indeed the case, 528 # retire excess procs 529 n_out_temps = len(out_temps) 530 CHUNKSIZE_2 = int(n_out_temps / nproc) 531 if CHUNKSIZE_2 == 0: 532 nproc_active = n_out_temps 533 CHUNKSIZE_2 = 1 534 if me == ROOT: 535 print("\nReleasing %d excess procs" % (nproc - nproc_active)) 536 else: 537 nproc_active = nproc 538 if me < nproc_active-1: 539 my_temp_inds = range( (me*CHUNKSIZE_2), (me+1)*CHUNKSIZE_1 ) 540 else: 541 my_temp_inds = range( (me*CHUNKSIZE_2), n_out_temps) 542 543 # retire the excess procs 544 # dont' forget to close any open file objects 545 if me >= nproc_active: 546 for fobj in infobjs: fobj.close() 547 exit() 548 549 # write reordered trajectories to disk from active procs in parallel 550 write_reordered_traj(temp_inds = my_temp_inds, 551 byte_inds = byte_inds, 552 outtemps = out_temps, temps = temps, 553 frametuple_dict = master_frametuple_dict, 554 nprod = nprod, writefreq = writefreq, 555 outtrajfns = outtrajfns, 556 infobjs = infobjs) 557 558 # calculate canonical log-weights if requested 559 # usually this is very fast so retire all but the ROOT proc 560 if not get_logw: exit() 561 if not me == ROOT: exit() 562 563 logw = get_canonical_logw(enefn = enefn, temps = temps, 564 frametuple_dict = master_frametuple_dict, 565 nprod = nprod, writefreq = writefreq, 566 kB = kB) 567 568 569 # save the logweights to a pickle 570 with open(logwfn, 'wb') as of: 571 pickle.dump(logw, of) 572 573 574