1import math
2import sys
3
4# read FILE with CVs and weights
5FILENAME_ = sys.argv[1]
6# number of CVs for FES
7NCV_ = int(sys.argv[2])
8# read minimum, maximum and number of bins for FES grid
9gmin = []; gmax = []; nbin = []
10for i in range(0, NCV_):
11    i0 = 3*i + 3
12    gmin.append(float(sys.argv[i0]))
13    gmax.append(float(sys.argv[i0+1]))
14    nbin.append(int(sys.argv[i0+2]))
15# read KBT_
16KBT_ = float(sys.argv[3*NCV_+3])
17# block size
18BSIZE_ = int(sys.argv[-1])
19
20def get_indexes_from_index(index, nbin):
21    indexes = []
22    # get first index
23    indexes.append(index%nbin[0])
24    # loop
25    kk = index
26    for i in range(1, len(nbin)-1):
27        kk = ( kk - indexes[i-1] ) / nbin[i-1]
28        indexes.append(kk%nbin[i])
29    if(len(nbin)>=2):
30      indexes.append( ( kk - indexes[len(nbin)-2] ) / nbin[len(nbin) -2] )
31    return indexes
32
33def get_indexes_from_cvs(cvs, gmin, dx):
34    keys = []
35    for i in range(0, len(cvs)):
36        keys.append(int( round( ( cvs[i] - gmin[i] ) / dx[i] ) ))
37    return tuple(keys)
38
39def get_points(key, gmin, dx):
40    xs = []
41    for i in range(0, len(key)):
42        xs.append(gmin[i] + float(key[i]) * dx[i])
43    return xs
44
45# define bin size
46dx = []
47for i in range(0, NCV_):
48    dx.append( (gmax[i]-gmin[i])/float(nbin[i]-1) )
49
50# total numbers of bins
51nbins = 1
52for i in range(0, len(nbin)): nbins *= nbin[i]
53
54# read file and store lists
55cv_list=[]; w_list=[]
56for lines in open(FILENAME_, "r").readlines():
57    riga = lines.strip().split()
58    # check format
59    if(len(riga)!=NCV_ and len(riga)!=NCV_+1):
60      print (FILENAME_,"is in the wrong format!")
61      exit()
62    # read CVs
63    cvs = []
64    for i in range(0, NCV_): cvs.append(float(riga[i]))
65    # get indexes
66    key = get_indexes_from_cvs(cvs, gmin, dx)
67    # read weight, if present
68    if(len(riga)==NCV_+1):
69      w = float(riga[NCV_])
70    else: w = 1.0
71    # store into lists
72    cv_list.append(key)
73    w_list.append(w)
74
75# total number of data points
76ndata = len(cv_list)
77# number of blocks
78nblock = int(ndata/BSIZE_)
79
80# prepare histo dictionaries
81histo_ave = {} ; histo_ave2 = {};
82
83# cycle on blocks
84for iblock in range(0, nblock):
85    # define range in CV
86    i0 = iblock * BSIZE_
87    i1 = i0 + BSIZE_
88    # build histo
89    histo = {}
90    for i in range(i0, i1):
91        if cv_list[i] in histo: histo[cv_list[i]] += w_list[i]
92        else:                   histo[cv_list[i]]  = w_list[i]
93    # calculate average histo in block
94    for key in histo: histo[key] /= float(BSIZE_)
95    # add to global histo dictionary
96    for key in histo:
97        if key in histo_ave:
98           histo_ave[key]   += histo[key]
99           histo_ave2[key]  += histo[key] * histo[key]
100        else:
101           histo_ave[key]   = histo[key]
102           histo_ave2[key]  = histo[key] * histo[key]
103
104# print out fes and error
105log = open("fes."+str(BSIZE_)+".dat", "w")
106# this is needed to add a blank line
107xs_old = []
108for i in range(0, nbins):
109    # get the indexes in the multi-dimensional grid
110    key = tuple(get_indexes_from_index(i, nbin))
111    # get CV values for that grid point
112    xs = get_points(key, gmin, dx)
113    # add a blank line for gnuplot
114    if(i == 0):
115      xs_old = xs[:]
116    else:
117      flag = 0
118      for j in range(1,len(xs)):
119          if(xs[j] != xs_old[j]):
120            flag = 1
121            xs_old = xs[:]
122      if (flag == 1): log.write("\n")
123    # print value of CVs
124    for x in xs:
125        log.write("%12.6lf " % x)
126    # calculate fes
127    nb = float(nblock)
128    if key in histo_ave:
129       # average and variance
130       aveh = histo_ave[key] / nb
131       s2h  = (histo_ave2[key]/nb-aveh*aveh) * nb / ( nb - 1.0 )
132       # error
133       errh = math.sqrt( s2h / nb )
134       # free energy and error
135       fes = -KBT_ * math.log(aveh)
136       errf = KBT_ / aveh * errh
137       # printout
138       log.write("   %12.6lf %12.6lf\n" % (fes, errf))
139    else:
140       log.write("       Infinity\n")
141log.close()
142