1import json 2import sys 3from collections import defaultdict 4from contextlib import contextmanager 5from typing import Iterable, Iterator 6 7import ase.io 8from ase.db import connect 9from ase.db.core import convert_str_to_int_float_or_str 10from ase.db.row import row2dct 11from ase.db.table import Table, all_columns 12from ase.utils import plural 13 14 15class CLICommand: 16 """Manipulate and query ASE database. 17 18 Query is a comma-separated list of 19 selections where each selection is of the type "ID", "key" or 20 "key=value". Instead of "=", one can also use "<", "<=", ">=", ">" 21 and "!=" (these must be protected from the shell by using quotes). 22 Special keys: 23 24 * id 25 * user 26 * calculator 27 * age 28 * natoms 29 * energy 30 * magmom 31 * charge 32 33 Chemical symbols can also be used to select number of 34 specific atomic species (H, He, Li, ...). Selection examples: 35 36 calculator=nwchem 37 age<1d 38 natoms=1 39 user=alice 40 2.2<bandgap<4.1 41 Cu>=10 42 43 See also: https://wiki.fysik.dtu.dk/ase/ase/db/db.html. 44 """ 45 46 @staticmethod 47 def add_arguments(parser): 48 add = parser.add_argument 49 add('database', help='SQLite3 file, JSON file or postgres URL.') 50 add('query', nargs='*', help='Query string.') 51 add('-v', '--verbose', action='store_true', help='More output.') 52 add('-q', '--quiet', action='store_true', help='Less output.') 53 add('-n', '--count', action='store_true', 54 help='Count number of selected rows.') 55 add('-l', '--long', action='store_true', 56 help='Long description of selected row') 57 add('-i', '--insert-into', metavar='db-name', 58 help='Insert selected rows into another database.') 59 add('-a', '--add-from-file', metavar='filename', 60 help='Add configuration(s) from file. ' 61 'If the file contains more than one configuration then you can ' 62 'use the syntax filename@: to add all of them. Default is to ' 63 'only add the last.') 64 add('-k', '--add-key-value-pairs', metavar='key1=val1,key2=val2,...', 65 help='Add key-value pairs to selected rows. Values must ' 66 'be numbers or strings and keys must follow the same rules as ' 67 'keywords.') 68 add('-L', '--limit', type=int, default=-1, metavar='N', 69 help='Show only first N rows. Use --limit=0 ' 70 'to show all. Default is 20 rows when listing rows and no ' 71 'limit when --insert-into is used.') 72 add('--offset', type=int, default=0, metavar='N', 73 help='Skip first N rows. By default, no rows are skipped') 74 add('--delete', action='store_true', 75 help='Delete selected rows.') 76 add('--delete-keys', metavar='key1,key2,...', 77 help='Delete keys for selected rows.') 78 add('-y', '--yes', action='store_true', 79 help='Say yes.') 80 add('--explain', action='store_true', 81 help='Explain query plan.') 82 add('-c', '--columns', metavar='col1,col2,...', 83 help='Specify columns to show. Precede the column specification ' 84 'with a "+" in order to add columns to the default set of ' 85 'columns. Precede by a "-" to remove columns. Use "++" for all.') 86 add('-s', '--sort', metavar='column', default='id', 87 help='Sort rows using "column". Use "column-" for a descending ' 88 'sort. Default is to sort after id.') 89 add('--cut', type=int, default=35, help='Cut keywords and key-value ' 90 'columns after CUT characters. Use --cut=0 to disable cutting. ' 91 'Default is 35 characters') 92 add('-p', '--plot', metavar='x,y1,y2,...', 93 help='Example: "-p x,y": plot y row against x row. Use ' 94 '"-p a:x,y" to make a plot for each value of a.') 95 add('--csv', action='store_true', 96 help='Write comma-separated-values file.') 97 add('-w', '--open-web-browser', action='store_true', 98 help='Open results in web-browser.') 99 add('--no-lock-file', action='store_true', help="Don't use lock-files") 100 add('--analyse', action='store_true', 101 help='Gathers statistics about tables and indices to help make ' 102 'better query planning choices.') 103 add('-j', '--json', action='store_true', 104 help='Write json representation of selected row.') 105 add('-m', '--show-metadata', action='store_true', 106 help='Show metadata as json.') 107 add('--set-metadata', metavar='something.json', 108 help='Set metadata from a json file.') 109 add('--strip-data', action='store_true', 110 help='Strip data when using --insert-into.') 111 add('--progress-bar', action='store_true', 112 help='Show a progress bar when using --insert-into.') 113 add('--show-keys', action='store_true', 114 help='Show all keys.') 115 add('--show-values', metavar='key1,key2,...', 116 help='Show values for key(s).') 117 118 @staticmethod 119 def run(args): 120 main(args) 121 122 123def count_keys(db, query): 124 keys = defaultdict(int) 125 for row in db.select(query): 126 for key in row._keys: 127 keys[key] += 1 128 129 n = max(len(key) for key in keys) + 1 130 for key, number in keys.items(): 131 print('{:{}} {}'.format(key + ':', n, number)) 132 return 133 134 135def main(args): 136 verbosity = 1 - args.quiet + args.verbose 137 query = ','.join(args.query) 138 139 if args.sort.endswith('-'): 140 # Allow using "key-" instead of "-key" for reverse sorting 141 args.sort = '-' + args.sort[:-1] 142 143 if query.isdigit(): 144 query = int(query) 145 146 add_key_value_pairs = {} 147 if args.add_key_value_pairs: 148 for pair in args.add_key_value_pairs.split(','): 149 key, value = pair.split('=') 150 add_key_value_pairs[key] = convert_str_to_int_float_or_str(value) 151 152 if args.delete_keys: 153 delete_keys = args.delete_keys.split(',') 154 else: 155 delete_keys = [] 156 157 db = connect(args.database, use_lock_file=not args.no_lock_file) 158 159 def out(*args): 160 if verbosity > 0: 161 print(*args) 162 163 if args.analyse: 164 db.analyse() 165 return 166 167 if args.show_keys: 168 count_keys(db, query) 169 return 170 171 if args.show_values: 172 keys = args.show_values.split(',') 173 values = {key: defaultdict(int) for key in keys} 174 numbers = set() 175 for row in db.select(query): 176 kvp = row.key_value_pairs 177 for key in keys: 178 value = kvp.get(key) 179 if value is not None: 180 values[key][value] += 1 181 if not isinstance(value, str): 182 numbers.add(key) 183 184 n = max(len(key) for key in keys) + 1 185 for key in keys: 186 vals = values[key] 187 if key in numbers: 188 print('{:{}} [{}..{}]' 189 .format(key + ':', n, min(vals), max(vals))) 190 else: 191 print('{:{}} {}' 192 .format(key + ':', n, 193 ', '.join('{}({})'.format(v, n) 194 for v, n in vals.items()))) 195 return 196 197 if args.add_from_file: 198 filename = args.add_from_file 199 configs = ase.io.read(filename) 200 if not isinstance(configs, list): 201 configs = [configs] 202 for atoms in configs: 203 db.write(atoms, key_value_pairs=add_key_value_pairs) 204 out('Added ' + plural(len(configs), 'row')) 205 return 206 207 if args.count: 208 n = db.count(query) 209 print('%s' % plural(n, 'row')) 210 return 211 212 if args.insert_into: 213 if args.limit == -1: 214 args.limit = 0 215 216 progressbar = no_progressbar 217 length = None 218 219 if args.progress_bar: 220 # Try to import the one from click. 221 # People using ase.db will most likely have flask installed 222 # and therfore also click. 223 try: 224 from click import progressbar 225 except ImportError: 226 pass 227 else: 228 length = db.count(query) 229 230 nkvp = 0 231 nrows = 0 232 with connect(args.insert_into, 233 use_lock_file=not args.no_lock_file) as db2: 234 with progressbar(db.select(query, 235 sort=args.sort, 236 limit=args.limit, 237 offset=args.offset), 238 length=length) as rows: 239 for row in rows: 240 kvp = row.get('key_value_pairs', {}) 241 nkvp -= len(kvp) 242 kvp.update(add_key_value_pairs) 243 nkvp += len(kvp) 244 if args.strip_data: 245 db2.write(row.toatoms(), **kvp) 246 else: 247 db2.write(row, data=row.get('data'), **kvp) 248 nrows += 1 249 250 out('Added %s (%s updated)' % 251 (plural(nkvp, 'key-value pair'), 252 plural(len(add_key_value_pairs) * nrows - nkvp, 'pair'))) 253 out('Inserted %s' % plural(nrows, 'row')) 254 return 255 256 if args.limit == -1: 257 args.limit = 20 258 259 if args.explain: 260 for row in db.select(query, explain=True, 261 verbosity=verbosity, 262 limit=args.limit, offset=args.offset): 263 print(row['explain']) 264 return 265 266 if args.show_metadata: 267 print(json.dumps(db.metadata, sort_keys=True, indent=4)) 268 return 269 270 if args.set_metadata: 271 with open(args.set_metadata) as fd: 272 db.metadata = json.load(fd) 273 return 274 275 if add_key_value_pairs or delete_keys: 276 ids = [row['id'] for row in db.select(query)] 277 M = 0 278 N = 0 279 with db: 280 for id in ids: 281 m, n = db.update(id, delete_keys=delete_keys, 282 **add_key_value_pairs) 283 M += m 284 N += n 285 out('Added %s (%s updated)' % 286 (plural(M, 'key-value pair'), 287 plural(len(add_key_value_pairs) * len(ids) - M, 'pair'))) 288 out('Removed', plural(N, 'key-value pair')) 289 290 return 291 292 if args.delete: 293 ids = [row['id'] for row in db.select(query, include_data=False)] 294 if ids and not args.yes: 295 msg = 'Delete %s? (yes/No): ' % plural(len(ids), 'row') 296 if input(msg).lower() != 'yes': 297 return 298 db.delete(ids) 299 out('Deleted %s' % plural(len(ids), 'row')) 300 return 301 302 if args.plot: 303 if ':' in args.plot: 304 tags, keys = args.plot.split(':') 305 tags = tags.split(',') 306 else: 307 tags = [] 308 keys = args.plot 309 keys = keys.split(',') 310 plots = defaultdict(list) 311 X = {} 312 labels = [] 313 for row in db.select(query, sort=args.sort, include_data=False): 314 name = ','.join(str(row[tag]) for tag in tags) 315 x = row.get(keys[0]) 316 if x is not None: 317 if isinstance(x, str): 318 if x not in X: 319 X[x] = len(X) 320 labels.append(x) 321 x = X[x] 322 plots[name].append([x] + [row.get(key) for key in keys[1:]]) 323 import matplotlib.pyplot as plt 324 for name, plot in plots.items(): 325 xyy = zip(*plot) 326 x = xyy[0] 327 for y, key in zip(xyy[1:], keys[1:]): 328 plt.plot(x, y, label=name + ':' + key) 329 if X: 330 plt.xticks(range(len(labels)), labels, rotation=90) 331 plt.legend() 332 plt.show() 333 return 334 335 if args.json: 336 row = db.get(query) 337 db2 = connect(sys.stdout, 'json', use_lock_file=False) 338 kvp = row.get('key_value_pairs', {}) 339 db2.write(row, data=row.get('data'), **kvp) 340 return 341 342 if args.long: 343 row = db.get(query) 344 print(row2str(row)) 345 return 346 347 if args.open_web_browser: 348 try: 349 import flask # noqa 350 except ImportError: 351 print('Please install Flask: python3 -m pip install flask') 352 return 353 check_jsmol() 354 import ase.db.app as app 355 app.add_project(db) 356 app.app.run(host='0.0.0.0', debug=True) 357 return 358 359 columns = list(all_columns) 360 c = args.columns 361 if c and c.startswith('++'): 362 keys = set() 363 for row in db.select(query, 364 limit=args.limit, offset=args.offset, 365 include_data=False): 366 keys.update(row._keys) 367 columns.extend(keys) 368 if c[2:3] == ',': 369 c = c[3:] 370 else: 371 c = '' 372 if c: 373 if c[0] == '+': 374 c = c[1:] 375 elif c[0] != '-': 376 columns = [] 377 for col in c.split(','): 378 if col[0] == '-': 379 columns.remove(col[1:]) 380 else: 381 columns.append(col.lstrip('+')) 382 383 table = Table(db, verbosity=verbosity, cut=args.cut) 384 table.select(query, columns, args.sort, args.limit, args.offset) 385 if args.csv: 386 table.write_csv() 387 else: 388 table.write(query) 389 390 391def row2str(row) -> str: 392 t = row2dct(row) 393 S = [t['formula'] + ':', 394 'Unit cell in Ang:', 395 'axis|periodic| x| y| z|' + 396 ' length| angle'] 397 c = 1 398 fmt = (' {0}| {1}|{2[0]:>11}|{2[1]:>11}|{2[2]:>11}|' + 399 '{3:>10}|{4:>10}') 400 for p, axis, L, A in zip(row.pbc, t['cell'], t['lengths'], t['angles']): 401 S.append(fmt.format(c, [' no', 'yes'][p], axis, L, A)) 402 c += 1 403 S.append('') 404 405 if 'stress' in t: 406 S += ['Stress tensor (xx, yy, zz, zy, zx, yx) in eV/Ang^3:', 407 ' {}\n'.format(t['stress'])] 408 409 if 'dipole' in t: 410 S.append('Dipole moment in e*Ang: ({})\n'.format(t['dipole'])) 411 412 if 'constraints' in t: 413 S.append('Constraints: {}\n'.format(t['constraints'])) 414 415 if 'data' in t: 416 S.append('Data: {}\n'.format(t['data'])) 417 418 width0 = max(max(len(row[0]) for row in t['table']), 3) 419 width1 = max(max(len(row[1]) for row in t['table']), 11) 420 S.append('{:{}} | {:{}} | Value' 421 .format('Key', width0, 'Description', width1)) 422 for key, desc, value in t['table']: 423 S.append('{:{}} | {:{}} | {}' 424 .format(key, width0, desc, width1, value)) 425 return '\n'.join(S) 426 427 428@contextmanager 429def no_progressbar(iterable: Iterable, 430 length: int = None) -> Iterator[Iterable]: 431 """A do-nothing implementation.""" 432 yield iterable 433 434 435def check_jsmol(): 436 from ase.db.app import root 437 static = root / 'ase/db/static' 438 if not (static / 'jsmol/JSmol.min.js').is_file(): 439 print(f""" 440 WARNING: 441 You don't have jsmol on your system. 442 443 Download Jmol-*-binary.tar.gz from 444 https://sourceforge.net/projects/jmol/files/Jmol/, 445 extract jsmol.zip, unzip it and create a soft-link: 446 447 $ tar -xf Jmol-*-binary.tar.gz 448 $ unzip jmol-*/jsmol.zip 449 $ ln -s $PWD/jsmol {static}/jsmol 450 """, 451 file=sys.stderr) 452