1from collections import defaultdict
2from numpy import inf
3
4import ase.db
5
6
7def analyze(filename, tag='results'):
8    energies = defaultdict(list)
9    mintimes = defaultdict(lambda: 999999)
10    formulas = []
11    db = ase.db.connect(filename)
12    for row in db.select(sort='formula'):
13        if row.formula not in formulas:
14            formulas.append(row.formula)
15        energies[row.formula].append(row.get('energy', inf))
16    emin = {formula: min(energies[formula]) for formula in energies}
17
18    data = defaultdict(list)
19    for row in db.select(sort='formula'):
20        if row.get('energy', inf) - emin[row.formula] < 0.01:
21            t = row.t
22            if row.n < 100:
23                nsteps = row.n
24                mintimes[row.formula] = min(mintimes[row.formula], t)
25            else:
26                nsteps = 9999
27                t = inf
28        else:
29            nsteps = 9999
30            t = inf
31        data[row.optimizer].append((nsteps, t))
32
33    print(formulas)
34
35    D = sorted(data.items(), key=lambda x: sum(y[0] for y in x[1]))
36    with open(tag + '-iterations.csv', 'w') as fd:
37        print('optimizer,' + ','.join(formulas), file=fd)
38        for o, d in D:
39            print('{:18},{}'
40                  .format(o, ','.join('{:3}'.format(x[0])
41                                      if x[0] < 100 else '   '
42                                      for x in d)),
43                  file=fd)
44
45    data = {opt: [(n, t / mintimes[f]) for (n, t), f in zip(x, formulas)]
46            for opt, x in data.items()}
47    D = sorted(data.items(), key=lambda x: sum(min(y[1], 999) for y in x[1]))
48    with open(tag + '-time.csv', 'w') as fd:
49        print('optimizer,' + ','.join(formulas), file=fd)
50        for o, d in D:
51            print('{:18},{}'
52                  .format(o, ','.join('{:8.1f}'.format(x[1])
53                                      if x[0] < 100 else '        '
54                                      for x in d)),
55                  file=fd)
56
57
58if __name__ == '__main__':
59    analyze('results.db')
60