1"""
2Functionality for enforcing rotational sum rules
3"""
4from sklearn.linear_model import Ridge
5import numpy as np
6from scipy.sparse import coo_matrix
7from .utilities import SparseMatrix
8
9
10def enforce_rotational_sum_rules(cs, parameters, sum_rules=None, alpha=1e-6):
11    """ Enforces rotational sum rules by projecting parameters.
12
13    Note
14    ----
15    The interface to this function might change in future releases.
16
17    Parameters
18    ----------
19    cs : ClusterSpace
20        the underlying cluster space
21    parameters : numpy.ndarray
22        parameters to be constrained
23    sum_rules : list(str)
24        type of sum rules to enforce; possible values: 'Huang', 'Born-Huang'
25    ridge_alpha : float
26        hyperparameter to the ridge regression algorithm; keyword argument
27        passed to the optimizer; larger values specify stronger regularization,
28        i.e. less correction but higher stability [default: 1e-6]
29
30    Returns
31    -------
32    numpy.ndarray
33        constrained parameters
34
35    Examples
36    --------
37    The rotational sum rules can be enforced to the parameters before
38    constructing a force constant potential as illustrated by the following
39    snippet::
40
41        cs = ClusterSpace(reference_structure, cutoffs)
42        sc = StructureContainer(cs)
43        # add structures to structure container
44        opt = Optimizer(sc.get_fit_data())
45        opt.train()
46        new_params = enforce_rotational_sum_rules(cs, opt.parameters,
47            sum_rules=['Huang', 'Born-Huang'])
48        fcp = ForceConstantPotential(cs, new_params)
49
50    """
51
52    all_sum_rules = ['Huang', 'Born-Huang']
53
54    # setup
55    parameters = parameters.copy()
56    if sum_rules is None:
57        sum_rules = all_sum_rules
58
59    # get constraint-matrix
60    M = get_rotational_constraint_matrix(cs, sum_rules)
61
62    # before fit
63    d = M.dot(parameters)
64    delta = np.linalg.norm(d)
65    print('Rotational sum-rules before, ||Ax|| = {:20.15f}'.format(delta))
66
67    # fitting
68    ridge = Ridge(alpha=alpha, fit_intercept=False, solver='sparse_cg')
69    ridge.fit(M, d)
70    parameters -= ridge.coef_
71
72    # after fit
73    d = M.dot(parameters)
74    delta = np.linalg.norm(d)
75    print('Rotational sum-rules after,  ||Ax|| = {:20.15f}'.format(delta))
76
77    return parameters
78
79
80def get_rotational_constraint_matrix(cs, sum_rules=None):
81
82    all_sum_rules = ['Huang', 'Born-Huang']
83
84    if sum_rules is None:
85        sum_rules = all_sum_rules
86
87    # setup
88    assert len(sum_rules) > 0
89    for s in sum_rules:
90        if s not in all_sum_rules:
91            raise ValueError('Sum rule {} not allowed, select from {}'.format(s, all_sum_rules))
92
93    # make orbit-parameter index map
94    params = _get_orbit_parameter_map(cs)
95    lookup = _get_fc_lookup(cs)
96
97    # append the sum rule matrices
98    Ms = []
99    args = (lookup, params, cs.atom_list, cs._prim)
100    for sum_rule in sum_rules:
101        if sum_rule == 'Huang':
102            Ms.append(_create_Huang_constraint(*args))
103        elif sum_rule == 'Born-Huang':
104            Ms.append(_create_Born_Huang_constraint(*args))
105
106    # transform and stack matrices
107    cvs_trans = cs._cvs
108    for i, M in enumerate(Ms):
109        row, col, data = [], [], []
110        for r, c, v in M.row_list():
111            row.append(r)
112            col.append(c)
113            data.append(np.float64(v))
114        M = coo_matrix((data, (row, col)), shape=M.shape)
115        M = M.dot(cvs_trans)
116        M = M.toarray()
117        Ms[i] = M
118
119    return np.vstack(Ms)
120
121
122def _get_orbit_parameter_map(cs):
123    # make orbit-parameter index map
124    params = []
125    n = 0
126    for orbit_index, orbit in enumerate(cs.orbits):
127        n_params_in_orbit = len(orbit.eigentensors)
128        params.append(list(range(n, n + n_params_in_orbit)))
129        n += n_params_in_orbit
130    return params
131
132
133def _get_fc_lookup(cs):
134    # create lookuptable for force constants
135    lookup = {}
136    for orbit_index, orbit in enumerate(cs.orbits):
137        for of in orbit.orientation_families:
138            for cluster_index, perm_index in zip(of.cluster_indices, of.permutation_indices):
139                cluster = cs._cluster_list[cluster_index]
140                perm = cs._permutations[perm_index]
141                lookup[tuple(cluster)] = [et.transpose(perm) for et in of.eigentensors], orbit_index
142    return lookup
143
144
145def _create_Huang_constraint(lookup, parameter_map, atom_list, prim):
146
147    m = SparseMatrix(3**4, parameter_map[-1][-1] + 1, 0)
148
149    def R(i, j):
150        pi = atom_list[i].pos(prim.basis, prim.cell)
151        pj = atom_list[j].pos(prim.basis, prim.cell)
152        return pi - pj
153
154    for i in range(len(prim)):
155        for j in range(len(atom_list)):
156            ets, orbit_index = lookup.get(tuple(sorted((i, j))), (None, None))
157            if ets is None:
158                continue
159            inv_perm = np.argsort(np.argsort((i, j)))
160            et_indices = parameter_map[orbit_index]
161            for et, et_index in zip(ets, et_indices):
162                et = et.transpose(inv_perm)
163                Rij = R(i, j)
164                Cij = np.einsum(et, [0, 1], Rij, [2], Rij, [3])
165                Cij -= Cij.transpose([2, 3, 0, 1])
166                for k in range(3**4):
167                    m[k, et_index] += Cij.flat[k]
168    return m
169
170
171def _create_Born_Huang_constraint(lookup, parameter_map, atom_list, prim):
172
173    constraints = []
174
175    for i in range(len(prim)):
176        m = SparseMatrix(3**3, parameter_map[-1][-1] + 1, 0)
177        for j in range(len(atom_list)):
178            ets, orbit_index = lookup.get(tuple(sorted((i, j))), (None, None))
179            if ets is None:
180                continue
181            inv_perm = np.argsort(np.argsort((i, j)))
182            et_indices = parameter_map[orbit_index]
183            R = atom_list[j].pos(prim.basis, prim.cell)
184            for et, et_index in zip(ets, et_indices):
185                et = et.transpose(inv_perm)
186                tmp = np.einsum(et, [0, 1], R, [2])
187                tmp -= tmp.transpose([0, 2, 1])
188                for k in range(3**3):
189                    m[k, et_index] += tmp.flat[k]
190        constraints.append(m)
191
192    M = SparseMatrix.vstack(*constraints)
193    return M
194