1import numpy as np
2import pytest
3
4from ase.atoms import Atoms
5from ase.build import bulk
6from ase.calculators.calculator import all_changes
7from ase.calculators.lj import LennardJones
8from ase.spacegroup.symmetrize import FixSymmetry, check_symmetry, is_subgroup
9from ase.optimize.precon.lbfgs import PreconLBFGS
10from ase.constraints import UnitCellFilter, ExpCellFilter
11
12spglib = pytest.importorskip('spglib')
13
14
15class NoisyLennardJones(LennardJones):
16    def __init__(self, *args, rng=None, **kwargs):
17        self.rng = rng
18        LennardJones.__init__(self, *args, **kwargs)
19
20    def calculate(self, atoms=None, properties=['energy'],
21                  system_changes=all_changes):
22        LennardJones.calculate(self, atoms, properties, system_changes)
23        if 'forces' in self.results:
24            self.results['forces'] += 1e-4 * self.rng.normal(
25                size=self.results['forces'].shape, )
26        if 'stress' in self.results:
27            self.results['stress'] += 1e-4 * self.rng.normal(
28                size=self.results['stress'].shape, )
29
30
31def setup_cell():
32    # setup an bcc Al cell
33    at_init = bulk('Al', 'bcc', a=2 / np.sqrt(3), cubic=True)
34
35    F = np.eye(3)
36    for k in range(3):
37        l = list(range(3))
38        l.remove(k)
39        (i, j) = l
40        R = np.eye(3)
41        theta = 0.1 * (k + 1)
42        R[i, i] = np.cos(theta)
43        R[j, j] = np.cos(theta)
44        R[i, j] = np.sin(theta)
45        R[j, i] = -np.sin(theta)
46        F = np.dot(F, R)
47    at_rot = at_init.copy()
48    at_rot.set_cell(at_rot.cell @ F, True)
49    return at_init, at_rot
50
51
52def symmetrized_optimisation(at_init, filter):
53    rng = np.random.RandomState(1)
54    at = at_init.copy()
55    at.calc = NoisyLennardJones(rng=rng)
56
57    at_cell = filter(at)
58    print("Initial Energy", at.get_potential_energy(), at.get_volume())
59    with PreconLBFGS(at_cell, precon=None) as dyn:
60        dyn.run(steps=300, fmax=0.001)
61        print("n_steps", dyn.get_number_of_steps())
62    print("Final Energy", at.get_potential_energy(), at.get_volume())
63    print("Final forces\n", at.get_forces())
64    print("Final stress\n", at.get_stress())
65
66    print("initial symmetry at 1e-6")
67    di = check_symmetry(at_init, 1.0e-6, verbose=True)
68    print("final symmetry at 1e-6")
69    df = check_symmetry(at, 1.0e-6, verbose=True)
70    return di, df
71
72
73@pytest.fixture(params=[UnitCellFilter, ExpCellFilter])
74def filter(request):
75    return request.param
76
77
78@pytest.mark.filterwarnings('ignore:ASE Atoms-like input is deprecated')
79@pytest.mark.filterwarnings('ignore:Armijo linesearch failed')
80def test_no_symmetrization(filter):
81    print("NO SYM")
82    at_init, at_rot = setup_cell()
83    at_unsym = at_init.copy()
84    di, df = symmetrized_optimisation(at_unsym, filter)
85    assert di["number"] == 229 and not is_subgroup(sub_data=di, sup_data=df)
86
87
88@pytest.mark.filterwarnings('ignore:ASE Atoms-like input is deprecated')
89@pytest.mark.filterwarnings('ignore:Armijo linesearch failed')
90def test_no_sym_rotated(filter):
91    print("NO SYM ROT")
92    at_init, at_rot = setup_cell()
93    at_unsym_rot = at_rot.copy()
94    di, df = symmetrized_optimisation(at_unsym_rot, filter)
95    assert di["number"] == 229 and not is_subgroup(sub_data=di, sup_data=df)
96
97
98@pytest.mark.filterwarnings('ignore:ASE Atoms-like input is deprecated')
99@pytest.mark.filterwarnings('ignore:Armijo linesearch failed')
100def test_sym_adj_cell(filter):
101    print("SYM POS+CELL")
102    at_init, at_rot = setup_cell()
103    at_sym_3 = at_init.copy()
104    at_sym_3.set_constraint(
105        FixSymmetry(at_sym_3, adjust_positions=True, adjust_cell=True))
106    di, df = symmetrized_optimisation(at_sym_3, filter)
107    assert di["number"] == 229 and is_subgroup(sub_data=di, sup_data=df)
108
109
110@pytest.mark.filterwarnings('ignore:ASE Atoms-like input is deprecated')
111@pytest.mark.filterwarnings('ignore:Armijo linesearch failed')
112def test_sym_rot_adj_cell(filter):
113    print("SYM POS+CELL ROT")
114    at_init, at_rot = setup_cell()
115    at_sym_3_rot = at_init.copy()
116    at_sym_3_rot.set_constraint(
117        FixSymmetry(at_sym_3_rot, adjust_positions=True, adjust_cell=True))
118    di, df = symmetrized_optimisation(at_sym_3_rot, filter)
119    assert di["number"] == 229 and is_subgroup(sub_data=di, sup_data=df)
120
121
122@pytest.mark.filterwarnings('ignore:ASE Atoms-like input is deprecated')
123def test_fix_symmetry_shuffle_indices():
124    atoms = Atoms('AlFeAl6', cell=[6] * 3,
125                  positions=[[0, 0, 0], [2.9, 2.9, 2.9], [0, 0, 3], [0, 3, 0],
126                             [0, 3, 3], [3, 0, 0], [3, 0, 3], [3, 3, 0]], pbc=True)
127    atoms.set_constraint(FixSymmetry(atoms))
128    at_permut = atoms[[0, 2, 3, 4, 5, 6, 7, 1]]
129    pos0 = atoms.get_positions()
130
131    def perturb(atoms, pos0, at_i, dpos):
132        positions = pos0.copy()
133        positions[at_i] += dpos
134        atoms.set_positions(positions)
135        new_p = atoms.get_positions()
136        return pos0[at_i] - new_p[at_i]
137
138    dp1 = perturb(atoms, pos0, 1, (0.0, 0.1, -0.1))
139    dp2 = perturb(atoms, pos0, 2, (0.0, 0.1, -0.1))
140    pos0 = at_permut.get_positions()
141    permut_dp1 = perturb(at_permut, pos0, 7, (0.0, 0.1, -0.1))
142    permut_dp2 = perturb(at_permut, pos0, 1, (0.0, 0.1, -0.1))
143    assert np.max(np.abs(dp1 - permut_dp1)) < 1.0e-10
144    assert np.max(np.abs(dp2 - permut_dp2)) < 1.0e-10
145