1import json
2import sys
3from collections import defaultdict
4from contextlib import contextmanager
5from typing import Iterable, Iterator
6
7import ase.io
8from ase.db import connect
9from ase.db.core import convert_str_to_int_float_or_str
10from ase.db.row import row2dct
11from ase.db.table import Table, all_columns
12from ase.utils import plural
13
14
15class CLICommand:
16    """Manipulate and query ASE database.
17
18    Query is a comma-separated list of
19    selections where each selection is of the type "ID", "key" or
20    "key=value".  Instead of "=", one can also use "<", "<=", ">=", ">"
21    and  "!=" (these must be protected from the shell by using quotes).
22    Special keys:
23
24    * id
25    * user
26    * calculator
27    * age
28    * natoms
29    * energy
30    * magmom
31    * charge
32
33    Chemical symbols can also be used to select number of
34    specific atomic species (H, He, Li, ...).  Selection examples:
35
36        calculator=nwchem
37        age<1d
38        natoms=1
39        user=alice
40        2.2<bandgap<4.1
41        Cu>=10
42
43    See also: https://wiki.fysik.dtu.dk/ase/ase/db/db.html.
44    """
45
46    @staticmethod
47    def add_arguments(parser):
48        add = parser.add_argument
49        add('database', help='SQLite3 file, JSON file or postgres URL.')
50        add('query', nargs='*', help='Query string.')
51        add('-v', '--verbose', action='store_true', help='More output.')
52        add('-q', '--quiet', action='store_true', help='Less output.')
53        add('-n', '--count', action='store_true',
54            help='Count number of selected rows.')
55        add('-l', '--long', action='store_true',
56            help='Long description of selected row')
57        add('-i', '--insert-into', metavar='db-name',
58            help='Insert selected rows into another database.')
59        add('-a', '--add-from-file', metavar='filename',
60            help='Add configuration(s) from file.  '
61            'If the file contains more than one configuration then you can '
62            'use the syntax filename@: to add all of them.  Default is to '
63            'only add the last.')
64        add('-k', '--add-key-value-pairs', metavar='key1=val1,key2=val2,...',
65            help='Add key-value pairs to selected rows.  Values must '
66            'be numbers or strings and keys must follow the same rules as '
67            'keywords.')
68        add('-L', '--limit', type=int, default=-1, metavar='N',
69            help='Show only first N rows.  Use --limit=0 '
70            'to show all.  Default is 20 rows when listing rows and no '
71            'limit when --insert-into is used.')
72        add('--offset', type=int, default=0, metavar='N',
73            help='Skip first N rows.  By default, no rows are skipped')
74        add('--delete', action='store_true',
75            help='Delete selected rows.')
76        add('--delete-keys', metavar='key1,key2,...',
77            help='Delete keys for selected rows.')
78        add('-y', '--yes', action='store_true',
79            help='Say yes.')
80        add('--explain', action='store_true',
81            help='Explain query plan.')
82        add('-c', '--columns', metavar='col1,col2,...',
83            help='Specify columns to show.  Precede the column specification '
84            'with a "+" in order to add columns to the default set of '
85            'columns.  Precede by a "-" to remove columns.  Use "++" for all.')
86        add('-s', '--sort', metavar='column', default='id',
87            help='Sort rows using "column".  Use "column-" for a descending '
88            'sort.  Default is to sort after id.')
89        add('--cut', type=int, default=35, help='Cut keywords and key-value '
90            'columns after CUT characters.  Use --cut=0 to disable cutting. '
91            'Default is 35 characters')
92        add('-p', '--plot', metavar='x,y1,y2,...',
93            help='Example: "-p x,y": plot y row against x row. Use '
94            '"-p a:x,y" to make a plot for each value of a.')
95        add('--csv', action='store_true',
96            help='Write comma-separated-values file.')
97        add('-w', '--open-web-browser', action='store_true',
98            help='Open results in web-browser.')
99        add('--no-lock-file', action='store_true', help="Don't use lock-files")
100        add('--analyse', action='store_true',
101            help='Gathers statistics about tables and indices to help make '
102            'better query planning choices.')
103        add('-j', '--json', action='store_true',
104            help='Write json representation of selected row.')
105        add('-m', '--show-metadata', action='store_true',
106            help='Show metadata as json.')
107        add('--set-metadata', metavar='something.json',
108            help='Set metadata from a json file.')
109        add('--strip-data', action='store_true',
110            help='Strip data when using --insert-into.')
111        add('--progress-bar', action='store_true',
112            help='Show a progress bar when using --insert-into.')
113        add('--show-keys', action='store_true',
114            help='Show all keys.')
115        add('--show-values', metavar='key1,key2,...',
116            help='Show values for key(s).')
117
118    @staticmethod
119    def run(args):
120        main(args)
121
122
123def count_keys(db, query):
124    keys = defaultdict(int)
125    for row in db.select(query):
126        for key in row._keys:
127            keys[key] += 1
128
129    n = max(len(key) for key in keys) + 1
130    for key, number in keys.items():
131        print('{:{}} {}'.format(key + ':', n, number))
132    return
133
134
135def main(args):
136    verbosity = 1 - args.quiet + args.verbose
137    query = ','.join(args.query)
138
139    if args.sort.endswith('-'):
140        # Allow using "key-" instead of "-key" for reverse sorting
141        args.sort = '-' + args.sort[:-1]
142
143    if query.isdigit():
144        query = int(query)
145
146    add_key_value_pairs = {}
147    if args.add_key_value_pairs:
148        for pair in args.add_key_value_pairs.split(','):
149            key, value = pair.split('=')
150            add_key_value_pairs[key] = convert_str_to_int_float_or_str(value)
151
152    if args.delete_keys:
153        delete_keys = args.delete_keys.split(',')
154    else:
155        delete_keys = []
156
157    db = connect(args.database, use_lock_file=not args.no_lock_file)
158
159    def out(*args):
160        if verbosity > 0:
161            print(*args)
162
163    if args.analyse:
164        db.analyse()
165        return
166
167    if args.show_keys:
168        count_keys(db, query)
169        return
170
171    if args.show_values:
172        keys = args.show_values.split(',')
173        values = {key: defaultdict(int) for key in keys}
174        numbers = set()
175        for row in db.select(query):
176            kvp = row.key_value_pairs
177            for key in keys:
178                value = kvp.get(key)
179                if value is not None:
180                    values[key][value] += 1
181                    if not isinstance(value, str):
182                        numbers.add(key)
183
184        n = max(len(key) for key in keys) + 1
185        for key in keys:
186            vals = values[key]
187            if key in numbers:
188                print('{:{}} [{}..{}]'
189                      .format(key + ':', n, min(vals), max(vals)))
190            else:
191                print('{:{}} {}'
192                      .format(key + ':', n,
193                              ', '.join('{}({})'.format(v, n)
194                                        for v, n in vals.items())))
195        return
196
197    if args.add_from_file:
198        filename = args.add_from_file
199        configs = ase.io.read(filename)
200        if not isinstance(configs, list):
201            configs = [configs]
202        for atoms in configs:
203            db.write(atoms, key_value_pairs=add_key_value_pairs)
204        out('Added ' + plural(len(configs), 'row'))
205        return
206
207    if args.count:
208        n = db.count(query)
209        print('%s' % plural(n, 'row'))
210        return
211
212    if args.insert_into:
213        if args.limit == -1:
214            args.limit = 0
215
216        progressbar = no_progressbar
217        length = None
218
219        if args.progress_bar:
220            # Try to import the one from click.
221            # People using ase.db will most likely have flask installed
222            # and therfore also click.
223            try:
224                from click import progressbar
225            except ImportError:
226                pass
227            else:
228                length = db.count(query)
229
230        nkvp = 0
231        nrows = 0
232        with connect(args.insert_into,
233                     use_lock_file=not args.no_lock_file) as db2:
234            with progressbar(db.select(query,
235                                       sort=args.sort,
236                                       limit=args.limit,
237                                       offset=args.offset),
238                             length=length) as rows:
239                for row in rows:
240                    kvp = row.get('key_value_pairs', {})
241                    nkvp -= len(kvp)
242                    kvp.update(add_key_value_pairs)
243                    nkvp += len(kvp)
244                    if args.strip_data:
245                        db2.write(row.toatoms(), **kvp)
246                    else:
247                        db2.write(row, data=row.get('data'), **kvp)
248                    nrows += 1
249
250        out('Added %s (%s updated)' %
251            (plural(nkvp, 'key-value pair'),
252             plural(len(add_key_value_pairs) * nrows - nkvp, 'pair')))
253        out('Inserted %s' % plural(nrows, 'row'))
254        return
255
256    if args.limit == -1:
257        args.limit = 20
258
259    if args.explain:
260        for row in db.select(query, explain=True,
261                             verbosity=verbosity,
262                             limit=args.limit, offset=args.offset):
263            print(row['explain'])
264        return
265
266    if args.show_metadata:
267        print(json.dumps(db.metadata, sort_keys=True, indent=4))
268        return
269
270    if args.set_metadata:
271        with open(args.set_metadata) as fd:
272            db.metadata = json.load(fd)
273        return
274
275    if add_key_value_pairs or delete_keys:
276        ids = [row['id'] for row in db.select(query)]
277        M = 0
278        N = 0
279        with db:
280            for id in ids:
281                m, n = db.update(id, delete_keys=delete_keys,
282                                 **add_key_value_pairs)
283                M += m
284                N += n
285        out('Added %s (%s updated)' %
286            (plural(M, 'key-value pair'),
287             plural(len(add_key_value_pairs) * len(ids) - M, 'pair')))
288        out('Removed', plural(N, 'key-value pair'))
289
290        return
291
292    if args.delete:
293        ids = [row['id'] for row in db.select(query, include_data=False)]
294        if ids and not args.yes:
295            msg = 'Delete %s? (yes/No): ' % plural(len(ids), 'row')
296            if input(msg).lower() != 'yes':
297                return
298        db.delete(ids)
299        out('Deleted %s' % plural(len(ids), 'row'))
300        return
301
302    if args.plot:
303        if ':' in args.plot:
304            tags, keys = args.plot.split(':')
305            tags = tags.split(',')
306        else:
307            tags = []
308            keys = args.plot
309        keys = keys.split(',')
310        plots = defaultdict(list)
311        X = {}
312        labels = []
313        for row in db.select(query, sort=args.sort, include_data=False):
314            name = ','.join(str(row[tag]) for tag in tags)
315            x = row.get(keys[0])
316            if x is not None:
317                if isinstance(x, str):
318                    if x not in X:
319                        X[x] = len(X)
320                        labels.append(x)
321                    x = X[x]
322                plots[name].append([x] + [row.get(key) for key in keys[1:]])
323        import matplotlib.pyplot as plt
324        for name, plot in plots.items():
325            xyy = zip(*plot)
326            x = xyy[0]
327            for y, key in zip(xyy[1:], keys[1:]):
328                plt.plot(x, y, label=name + ':' + key)
329        if X:
330            plt.xticks(range(len(labels)), labels, rotation=90)
331        plt.legend()
332        plt.show()
333        return
334
335    if args.json:
336        row = db.get(query)
337        db2 = connect(sys.stdout, 'json', use_lock_file=False)
338        kvp = row.get('key_value_pairs', {})
339        db2.write(row, data=row.get('data'), **kvp)
340        return
341
342    if args.long:
343        row = db.get(query)
344        print(row2str(row))
345        return
346
347    if args.open_web_browser:
348        try:
349            import flask  # noqa
350        except ImportError:
351            print('Please install Flask: python3 -m pip install flask')
352            return
353        check_jsmol()
354        import ase.db.app as app
355        app.add_project(db)
356        app.app.run(host='0.0.0.0', debug=True)
357        return
358
359    columns = list(all_columns)
360    c = args.columns
361    if c and c.startswith('++'):
362        keys = set()
363        for row in db.select(query,
364                             limit=args.limit, offset=args.offset,
365                             include_data=False):
366            keys.update(row._keys)
367        columns.extend(keys)
368        if c[2:3] == ',':
369            c = c[3:]
370        else:
371            c = ''
372    if c:
373        if c[0] == '+':
374            c = c[1:]
375        elif c[0] != '-':
376            columns = []
377        for col in c.split(','):
378            if col[0] == '-':
379                columns.remove(col[1:])
380            else:
381                columns.append(col.lstrip('+'))
382
383    table = Table(db, verbosity=verbosity, cut=args.cut)
384    table.select(query, columns, args.sort, args.limit, args.offset)
385    if args.csv:
386        table.write_csv()
387    else:
388        table.write(query)
389
390
391def row2str(row) -> str:
392    t = row2dct(row)
393    S = [t['formula'] + ':',
394         'Unit cell in Ang:',
395         'axis|periodic|          x|          y|          z|' +
396         '    length|     angle']
397    c = 1
398    fmt = ('   {0}|     {1}|{2[0]:>11}|{2[1]:>11}|{2[2]:>11}|' +
399           '{3:>10}|{4:>10}')
400    for p, axis, L, A in zip(row.pbc, t['cell'], t['lengths'], t['angles']):
401        S.append(fmt.format(c, [' no', 'yes'][p], axis, L, A))
402        c += 1
403    S.append('')
404
405    if 'stress' in t:
406        S += ['Stress tensor (xx, yy, zz, zy, zx, yx) in eV/Ang^3:',
407              '   {}\n'.format(t['stress'])]
408
409    if 'dipole' in t:
410        S.append('Dipole moment in e*Ang: ({})\n'.format(t['dipole']))
411
412    if 'constraints' in t:
413        S.append('Constraints: {}\n'.format(t['constraints']))
414
415    if 'data' in t:
416        S.append('Data: {}\n'.format(t['data']))
417
418    width0 = max(max(len(row[0]) for row in t['table']), 3)
419    width1 = max(max(len(row[1]) for row in t['table']), 11)
420    S.append('{:{}} | {:{}} | Value'
421             .format('Key', width0, 'Description', width1))
422    for key, desc, value in t['table']:
423        S.append('{:{}} | {:{}} | {}'
424                 .format(key, width0, desc, width1, value))
425    return '\n'.join(S)
426
427
428@contextmanager
429def no_progressbar(iterable: Iterable,
430                   length: int = None) -> Iterator[Iterable]:
431    """A do-nothing implementation."""
432    yield iterable
433
434
435def check_jsmol():
436    from ase.db.app import root
437    static = root / 'ase/db/static'
438    if not (static / 'jsmol/JSmol.min.js').is_file():
439        print(f"""
440    WARNING:
441        You don't have jsmol on your system.
442
443        Download Jmol-*-binary.tar.gz from
444        https://sourceforge.net/projects/jmol/files/Jmol/,
445        extract jsmol.zip, unzip it and create a soft-link:
446
447            $ tar -xf Jmol-*-binary.tar.gz
448            $ unzip jmol-*/jsmol.zip
449            $ ln -s $PWD/jsmol {static}/jsmol
450    """,
451              file=sys.stderr)
452