1from ase.constraints import FixBondLengths
2from ase.calculators.tip3p import TIP3P
3from ase.calculators.tip3p import qH, sigma0, epsilon0
4
5from _gpaw import adjust_positions, adjust_momenta, calculate_forces_H2O
6from ase.calculators.calculator import Calculator, all_changes
7
8import numpy as np
9
10A = 4 * epsilon0 * sigma0**12
11B = -4 * epsilon0 * sigma0**6
12
13
14class TIP3PWaterModel(TIP3P):
15    def calculate(self, atoms=None,
16                  properties=['energy'],
17                  system_changes=all_changes):
18
19        Calculator.calculate(self, atoms, properties, system_changes)
20
21        R = self.atoms.positions.reshape((-1, 3, 3))
22        Z = self.atoms.numbers
23        pbc = self.atoms.pbc
24        diagcell = self.atoms.cell.diagonal()
25        nh2o = len(R)
26
27        assert (self.atoms.cell == np.diag(diagcell)).all(), 'not orthorhombic'
28        assert ((diagcell >= 2 * self.rc) | ~pbc).all(), 'cutoff too large'
29        if Z[0] == 8:
30            o = 0
31        else:
32            o = 2
33
34        assert o == 0
35        assert (Z[o::3] == 8).all()
36        assert (Z[(o + 1) % 3::3] == 1).all()
37        assert (Z[(o + 2) % 3::3] == 1).all()
38
39        charges = np.array([qH, qH, qH])
40        charges[o] *= -2
41
42        energy = 0.0
43        forces = np.zeros((3 * nh2o, 3))
44
45        cellmat = np.zeros((3, 3))
46        np.fill_diagonal(cellmat, diagcell)   # c code wants 3x3 cell ...
47
48        energy += calculate_forces_H2O(np.array(atoms.pbc, dtype=np.uint8),
49                                       cellmat, A, B, self.rc, self.width,
50                                       charges, self.atoms.get_positions(),
51                                       forces)
52
53        if self.pcpot:
54            e, f = self.pcpot.calculate(np.tile(charges, nh2o),
55                                        self.atoms.positions)
56            energy += e
57            forces += f
58
59        self.results['energy'] = energy
60        self.results['forces'] = forces
61
62
63class FixBondLengthsWaterModel(FixBondLengths):
64
65    def __init__(self, pairs, tolerance=1e-13, bondlengths=None,
66                 iterations=None):
67        FixBondLengths.__init__(self, pairs, tolerance=tolerance,
68                                bondlengths=bondlengths,
69                                iterations=iterations)
70
71    def initialize_bond_lengths(self, atoms):
72        bondlengths = FixBondLengths.initialize_bond_lengths(self, atoms)
73        # Make sure that the constraints are compatible with the C-code
74        assert len(self.pairs) % 3 == 0
75        masses = atoms.get_masses()
76
77        self.start = self.pairs[0][0]
78        self.end = self.pairs[-3][0]
79        self.NW = (self.end - self.start) // 3 + 1
80        assert (self.end - self.start) % 3 == 0
81        for i in range(self.NW):
82            for j in range(3):
83                assert self.pairs[i * 3 + j][0] == self.start + i * 3 + j
84                assert self.pairs[i * 3 + j][1] == (self.start +
85                                                    i * 3 + (j + 1) % 3)
86                assert abs(bondlengths[i * 3 + j] - bondlengths[j]) < 1e-6
87                assert masses[i * 3 + j + self.start] == masses[j + self.start]
88        return bondlengths
89
90    def select_indices(self, r):
91        return r[self.start:self.start + self.NW * 3, :]
92
93    def adjust_positions(self, atoms, new):
94        masses = atoms.get_masses()
95
96        if self.bondlengths is None:
97            self.bondlengths = self.initialize_bond_lengths(atoms)
98
99        return adjust_positions(self.bondlengths[:3],
100                                masses[self.start:self.start + 3],
101                                self.select_indices(atoms.get_positions()),
102                                self.select_indices(new))
103
104    def adjust_momenta(self, atoms, p):
105        masses = atoms.get_masses()
106
107        if self.bondlengths is None:
108            self.bondlengths = self.initialize_bond_lengths(atoms)
109
110        return adjust_momenta(masses[self.start:self.start + 3],
111                              self.select_indices(atoms.get_positions()),
112                              self.select_indices(p))
113
114    def index_shuffle(self, atoms, ind):
115        raise NotImplementedError
116