1# coding: utf-8
2# Copyright (c) Pymatgen Development Team.
3# Distributed under the terms of the MIT License.
4
5"""
6This module implements a EnergyModel abstract class and some basic
7implementations. Basically, an EnergyModel is any model that returns an
8"energy" for any given structure.
9"""
10
11import abc
12
13from monty.json import MSONable
14
15from pymatgen.analysis.ewald import EwaldSummation
16from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
17
18__version__ = "0.1"
19
20
21class EnergyModel(MSONable, metaclass=abc.ABCMeta):
22    """
23    Abstract structure filter class.
24    """
25
26    @abc.abstractmethod
27    def get_energy(self, structure) -> float:
28        """
29        :param structure: Structure
30        :return: Energy value
31        """
32        return 0.0
33
34    @classmethod
35    def from_dict(cls, d):
36        """
37        :param d: Dict representation
38        :return: EnergyModel
39        """
40        return cls(**d["init_args"])
41
42
43class EwaldElectrostaticModel(EnergyModel):
44    """
45    Wrapper around EwaldSum to calculate the electrostatic energy.
46    """
47
48    def __init__(self, real_space_cut=None, recip_space_cut=None, eta=None, acc_factor=8.0):
49        """
50        Initializes the model. Args have the same definitions as in
51        :class:`pymatgen.analysis.ewald.EwaldSummation`.
52
53        Args:
54            real_space_cut (float): Real space cutoff radius dictating how
55                many terms are used in the real space sum. Defaults to None,
56                which means determine automagically using the formula given
57                in gulp 3.1 documentation.
58            recip_space_cut (float): Reciprocal space cutoff radius.
59                Defaults to None, which means determine automagically using
60                the formula given in gulp 3.1 documentation.
61            eta (float): Screening parameter. Defaults to None, which means
62                determine automatically.
63            acc_factor (float): No. of significant figures each sum is
64                converged to.
65        """
66        self.real_space_cut = real_space_cut
67        self.recip_space_cut = recip_space_cut
68        self.eta = eta
69        self.acc_factor = acc_factor
70
71    def get_energy(self, structure):
72        """
73        :param structure: Structure
74        :return: Energy value
75        """
76        e = EwaldSummation(
77            structure,
78            real_space_cut=self.real_space_cut,
79            recip_space_cut=self.recip_space_cut,
80            eta=self.eta,
81            acc_factor=self.acc_factor,
82        )
83        return e.total_energy
84
85    def as_dict(self):
86        """
87        :return: MSONable dict
88        """
89        return {
90            "version": __version__,
91            "@module": self.__class__.__module__,
92            "@class": self.__class__.__name__,
93            "init_args": {
94                "real_space_cut": self.real_space_cut,
95                "recip_space_cut": self.recip_space_cut,
96                "eta": self.eta,
97                "acc_factor": self.acc_factor,
98            },
99        }
100
101
102class SymmetryModel(EnergyModel):
103    """
104    Sets the energy to the -ve of the spacegroup number. Higher symmetry =>
105    lower "energy".
106
107    Args have same meaning as in
108    :class:`pymatgen.symmetry.finder.SpacegroupAnalyzer`.
109    """
110
111    def __init__(self, symprec=0.1, angle_tolerance=5):
112        """
113        Args:
114            symprec (float): Symmetry tolerance. Defaults to 0.1.
115            angle_tolerance (float): Tolerance for angles. Defaults to 5 degrees.
116        """
117        self.symprec = symprec
118        self.angle_tolerance = angle_tolerance
119
120    def get_energy(self, structure):
121        """
122        :param structure: Structure
123        :return: Energy value
124        """
125        f = SpacegroupAnalyzer(structure, symprec=self.symprec, angle_tolerance=self.angle_tolerance)
126        return -f.get_space_group_number()
127
128    def as_dict(self):
129        """
130        :return: MSONable dict
131        """
132        return {
133            "version": __version__,
134            "@module": self.__class__.__module__,
135            "@class": self.__class__.__name__,
136            "init_args": {
137                "symprec": self.symprec,
138                "angle_tolerance": self.angle_tolerance,
139            },
140        }
141
142
143class IsingModel(EnergyModel):
144    """
145    A very simple Ising model, with r^2 decay.
146    """
147
148    def __init__(self, j, max_radius):
149        """
150        Args:
151            j (float): The interaction parameter. E = J * spin1 * spin2.
152            radius (float): max_radius for the interaction.
153        """
154        self.j = j
155        self.max_radius = max_radius
156
157    def get_energy(self, structure):
158        """
159        :param structure: Structure
160        :return: Energy value
161        """
162        all_nn = structure.get_all_neighbors(r=self.max_radius)
163        energy = 0
164        for i, nns in enumerate(all_nn):
165            s1 = getattr(structure[i].specie, "spin", 0)
166            for nn in nns:
167                energy += self.j * s1 * getattr(nn.specie, "spin", 0) / (nn.nn_distance ** 2)
168        return energy
169
170    def as_dict(self):
171        """
172        :return: MSONable dict
173        """
174        return {
175            "version": __version__,
176            "@module": self.__class__.__module__,
177            "@class": self.__class__.__name__,
178            "init_args": {"j": self.j, "max_radius": self.max_radius},
179        }
180
181
182class NsitesModel(EnergyModel):
183    """
184    Sets the energy to the number of sites. More sites => higher "energy".
185    Used to rank structures from smallest number of sites to largest number
186    of sites after enumeration.
187    """
188
189    def get_energy(self, structure):
190        """
191        :param structure: Structure
192        :return: Energy value
193        """
194        return len(structure)
195
196    def as_dict(self):
197        """
198        :return: MSONable dict
199        """
200        return {
201            "version": __version__,
202            "@module": self.__class__.__module__,
203            "@class": self.__class__.__name__,
204            "init_args": {},
205        }
206