1from random import randint
2from typing import Dict, Tuple, Any
3
4import numpy as np
5
6from ase import Atoms
7from ase.constraints import dict2constraint
8from ase.calculators.calculator import (get_calculator_class, all_properties,
9                                        PropertyNotImplementedError,
10                                        kptdensity2monkhorstpack)
11from ase.calculators.singlepoint import SinglePointCalculator
12from ase.data import chemical_symbols, atomic_masses
13from ase.formula import Formula
14from ase.geometry import cell_to_cellpar
15from ase.io.jsonio import decode
16
17
18class FancyDict(dict):
19    """Dictionary with keys available as attributes also."""
20    def __getattr__(self, key):
21        if key not in self:
22            return dict.__getattribute__(self, key)
23        value = self[key]
24        if isinstance(value, dict):
25            return FancyDict(value)
26        return value
27
28    def __dir__(self):
29        return self.keys()  # for tab-completion
30
31
32def atoms2dict(atoms):
33    dct = {
34        'numbers': atoms.numbers,
35        'positions': atoms.positions,
36        'unique_id': '%x' % randint(16**31, 16**32 - 1)}
37    if atoms.pbc.any():
38        dct['pbc'] = atoms.pbc
39    if atoms.cell.any():
40        dct['cell'] = atoms.cell
41    if atoms.has('initial_magmoms'):
42        dct['initial_magmoms'] = atoms.get_initial_magnetic_moments()
43    if atoms.has('initial_charges'):
44        dct['initial_charges'] = atoms.get_initial_charges()
45    if atoms.has('masses'):
46        dct['masses'] = atoms.get_masses()
47    if atoms.has('tags'):
48        dct['tags'] = atoms.get_tags()
49    if atoms.has('momenta'):
50        dct['momenta'] = atoms.get_momenta()
51    if atoms.constraints:
52        dct['constraints'] = [c.todict() for c in atoms.constraints]
53    if atoms.calc is not None:
54        dct['calculator'] = atoms.calc.name.lower()
55        dct['calculator_parameters'] = atoms.calc.todict()
56        if len(atoms.calc.check_state(atoms)) == 0:
57            for prop in all_properties:
58                try:
59                    x = atoms.calc.get_property(prop, atoms, False)
60                except PropertyNotImplementedError:
61                    pass
62                else:
63                    if x is not None:
64                        dct[prop] = x
65    return dct
66
67
68class AtomsRow:
69    def __init__(self, dct):
70        if isinstance(dct, dict):
71            dct = dct.copy()
72            if 'calculator_parameters' in dct:
73                # Earlier version of ASE would encode the calculator
74                # parameter dict again and again and again ...
75                while isinstance(dct['calculator_parameters'], str):
76                    dct['calculator_parameters'] = decode(
77                        dct['calculator_parameters'])
78        else:
79            dct = atoms2dict(dct)
80        assert 'numbers' in dct
81        self._constraints = dct.pop('constraints', [])
82        self._constrained_forces = None
83        self._data = dct.pop('data', {})
84        kvp = dct.pop('key_value_pairs', {})
85        self._keys = list(kvp.keys())
86        self.__dict__.update(kvp)
87        self.__dict__.update(dct)
88        if 'cell' not in dct:
89            self.cell = np.zeros((3, 3))
90        if 'pbc' not in dct:
91            self.pbc = np.zeros(3, bool)
92
93    def __contains__(self, key):
94        return key in self.__dict__
95
96    def __iter__(self):
97        return (key for key in self.__dict__ if key[0] != '_')
98
99    def get(self, key, default=None):
100        """Return value of key if present or default if not."""
101        return getattr(self, key, default)
102
103    @property
104    def key_value_pairs(self):
105        """Return dict of key-value pairs."""
106        return dict((key, self.get(key)) for key in self._keys)
107
108    def count_atoms(self):
109        """Count atoms.
110
111        Return dict mapping chemical symbol strings to number of atoms.
112        """
113        count = {}
114        for symbol in self.symbols:
115            count[symbol] = count.get(symbol, 0) + 1
116        return count
117
118    def __getitem__(self, key):
119        return getattr(self, key)
120
121    def __setitem__(self, key, value):
122        setattr(self, key, value)
123
124    def __str__(self):
125        return '<AtomsRow: formula={0}, keys={1}>'.format(
126            self.formula, ','.join(self._keys))
127
128    @property
129    def constraints(self):
130        """List of constraints."""
131        if not isinstance(self._constraints, list):
132            # Lazy decoding:
133            cs = decode(self._constraints)
134            self._constraints = []
135            for c in cs:
136                # Convert to new format:
137                name = c.pop('__name__', None)
138                if name:
139                    c = {'name': name, 'kwargs': c}
140                if c['name'].startswith('ase'):
141                    c['name'] = c['name'].rsplit('.', 1)[1]
142                self._constraints.append(c)
143        return [dict2constraint(d) for d in self._constraints]
144
145    @property
146    def data(self):
147        """Data dict."""
148        if isinstance(self._data, str):
149            self._data = decode(self._data)  # lazy decoding
150        elif isinstance(self._data, bytes):
151            from ase.db.core import bytes_to_object
152            self._data = bytes_to_object(self._data)  # lazy decoding
153        return FancyDict(self._data)
154
155    @property
156    def natoms(self):
157        """Number of atoms."""
158        return len(self.numbers)
159
160    @property
161    def formula(self):
162        """Chemical formula string."""
163        return Formula('', _tree=[(self.symbols, 1)]).format('metal')
164
165    @property
166    def symbols(self):
167        """List of chemical symbols."""
168        return [chemical_symbols[Z] for Z in self.numbers]
169
170    @property
171    def fmax(self):
172        """Maximum atomic force."""
173        forces = self.constrained_forces
174        return (forces**2).sum(1).max()**0.5
175
176    @property
177    def constrained_forces(self):
178        """Forces after applying constraints."""
179        if self._constrained_forces is not None:
180            return self._constrained_forces
181        forces = self.forces
182        constraints = self.constraints
183        if constraints:
184            forces = forces.copy()
185            atoms = self.toatoms()
186            for constraint in constraints:
187                constraint.adjust_forces(atoms, forces)
188
189        self._constrained_forces = forces
190        return forces
191
192    @property
193    def smax(self):
194        """Maximum stress tensor component."""
195        return (self.stress**2).max()**0.5
196
197    @property
198    def mass(self):
199        """Total mass."""
200        if 'masses' in self:
201            return self.masses.sum()
202        return atomic_masses[self.numbers].sum()
203
204    @property
205    def volume(self):
206        """Volume of unit cell."""
207        if self.cell is None:
208            return None
209        vol = abs(np.linalg.det(self.cell))
210        if vol == 0.0:
211            raise AttributeError
212        return vol
213
214    @property
215    def charge(self):
216        """Total charge."""
217        charges = self.get('inital_charges')
218        if charges is None:
219            return 0.0
220        return charges.sum()
221
222    def toatoms(self, attach_calculator=False,
223                add_additional_information=False):
224        """Create Atoms object."""
225        atoms = Atoms(self.numbers,
226                      self.positions,
227                      cell=self.cell,
228                      pbc=self.pbc,
229                      magmoms=self.get('initial_magmoms'),
230                      charges=self.get('initial_charges'),
231                      tags=self.get('tags'),
232                      masses=self.get('masses'),
233                      momenta=self.get('momenta'),
234                      constraint=self.constraints)
235
236        if attach_calculator:
237            params = self.get('calculator_parameters', {})
238            atoms.calc = get_calculator_class(self.calculator)(**params)
239        else:
240            results = {}
241            for prop in all_properties:
242                if prop in self:
243                    results[prop] = self[prop]
244            if results:
245                atoms.calc = SinglePointCalculator(atoms, **results)
246                atoms.calc.name = self.get('calculator', 'unknown')
247
248        if add_additional_information:
249            atoms.info = {}
250            atoms.info['unique_id'] = self.unique_id
251            if self._keys:
252                atoms.info['key_value_pairs'] = self.key_value_pairs
253            data = self.get('data')
254            if data:
255                atoms.info['data'] = data
256
257        return atoms
258
259
260def row2dct(row,
261            key_descriptions: Dict[str, Tuple[str, str, str]] = {}
262            ) -> Dict[str, Any]:
263    """Convert row to dict of things for printing or a web-page."""
264
265    from ase.db.core import float_to_time_string, now
266
267    dct = {}
268
269    atoms = Atoms(cell=row.cell, pbc=row.pbc)
270    dct['size'] = kptdensity2monkhorstpack(atoms,
271                                           kptdensity=1.8,
272                                           even=False)
273
274    dct['cell'] = [['{:.3f}'.format(a) for a in axis] for axis in row.cell]
275    par = ['{:.3f}'.format(x) for x in cell_to_cellpar(row.cell)]
276    dct['lengths'] = par[:3]
277    dct['angles'] = par[3:]
278
279    stress = row.get('stress')
280    if stress is not None:
281        dct['stress'] = ', '.join('{0:.3f}'.format(s) for s in stress)
282
283    dct['formula'] = Formula(row.formula).format('abc')
284
285    dipole = row.get('dipole')
286    if dipole is not None:
287        dct['dipole'] = ', '.join('{0:.3f}'.format(d) for d in dipole)
288
289    data = row.get('data')
290    if data:
291        dct['data'] = ', '.join(data.keys())
292
293    constraints = row.get('constraints')
294    if constraints:
295        dct['constraints'] = ', '.join(c.__class__.__name__
296                                       for c in constraints)
297
298    keys = ({'id', 'energy', 'fmax', 'smax', 'mass', 'age'} |
299            set(key_descriptions) |
300            set(row.key_value_pairs))
301    dct['table'] = []
302    for key in keys:
303        if key == 'age':
304            age = float_to_time_string(now() - row.ctime, True)
305            dct['table'].append(('ctime', 'Age', age))
306            continue
307        value = row.get(key)
308        if value is not None:
309            if isinstance(value, float):
310                value = '{:.3f}'.format(value)
311            elif not isinstance(value, str):
312                value = str(value)
313            desc, unit = key_descriptions.get(key, ['', '', ''])[1:]
314            if unit:
315                value += ' ' + unit
316            dct['table'].append((key, desc, value))
317
318    return dct
319