1# coding: utf-8
2# Copyright (c) Pymatgen Development Team.
3# Distributed under the terms of the MIT License.
4
5"""
6Utilities for defects module.
7"""
8
9import itertools
10import logging
11import math
12import operator
13from collections import defaultdict
14from copy import deepcopy
15
16import numpy as np
17import pandas as pd
18from monty.dev import requires
19from monty.json import MSONable
20from numpy.linalg import norm
21from scipy.cluster.hierarchy import fcluster, linkage
22from scipy.spatial import Voronoi
23from scipy.spatial.distance import squareform
24
25from pymatgen.analysis.local_env import (
26    LocalStructOrderParams,
27    MinimumDistanceNN,
28    cn_opt_params,
29)
30from pymatgen.analysis.phase_diagram import get_facets
31from pymatgen.analysis.structure_matcher import StructureMatcher
32from pymatgen.core.periodic_table import Element, get_el_sp
33from pymatgen.core.sites import PeriodicSite
34from pymatgen.core.structure import Structure
35from pymatgen.io.vasp.outputs import Chgcar
36from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
37from pymatgen.util.coord import pbc_diff
38from pymatgen.vis.structure_vtk import StructureVis
39
40try:
41    from skimage.feature import peak_local_max
42
43    peak_local_max_found = True
44except ImportError:
45    peak_local_max_found = False
46
47__author__ = "Danny Broberg, Shyam Dwaraknath, Bharat Medasani, Nils Zimmermann, Geoffroy Hautier"
48__copyright__ = "Copyright 2014, The Materials Project"
49__version__ = "1.0"
50__maintainer__ = "Danny Broberg, Shyam Dwaraknath"
51__email__ = "dbroberg@berkeley.edu, shyamd@lbl.gov"
52__status__ = "Development"
53__date__ = "January 11, 2018"
54
55logger = logging.getLogger(__name__)
56hart_to_ev = 27.2114
57ang_to_bohr = 1.8897
58invang_to_ev = 3.80986
59kumagai_to_V = 1.809512739e2  # = Electron charge * 1e10 / VacuumPermittivity Constant
60
61motif_cn_op = {}
62for cn, di in cn_opt_params.items():  # type: ignore
63    for mot, li in di.items():
64        motif_cn_op[mot] = {"cn": int(cn), "optype": li[0]}
65        motif_cn_op[mot]["params"] = deepcopy(li[1]) if len(li) > 1 else None
66
67
68class QModel(MSONable):
69    """
70    Model for the defect charge distribution.
71    A combination of exponential tail and gaussian distribution is used
72    (see Freysoldt (2011), DOI: 10.1002/pssb.201046289 )
73    q_model(r) = q [x exp(-r/gamma) + (1-x) exp(-r^2/beta^2)]
74            without normalization constants
75    By default, gaussian distribution with 1 Bohr width is assumed.
76    If defect charge is more delocalized, exponential tail is suggested.
77    """
78
79    def __init__(self, beta=1.0, expnorm=0.0, gamma=1.0):
80        """
81        Args:
82            beta: Gaussian decay constant. Default value is 1 Bohr.
83                  When delocalized (eg. diamond), 2 Bohr is more appropriate.
84            expnorm: Weight for the exponential tail in the range of [0-1].
85                     Default is 0.0 indicating no tail .
86                     For delocalized charges ideal value is around 0.54-0.6.
87            gamma: Exponential decay constant
88        """
89        self.beta = beta
90        self.expnorm = expnorm
91        self.gamma = gamma
92
93        self.beta2 = beta * beta
94        self.gamma2 = gamma * gamma
95        if expnorm and not gamma:
96            raise ValueError("Please supply exponential decay constant.")
97
98    def rho_rec(self, g2):
99        """
100        Reciprocal space model charge value
101        for input squared reciprocal vector.
102        Args:
103            g2: Square of reciprocal vector
104
105        Returns:
106            Charge density at the reciprocal vector magnitude
107        """
108        return self.expnorm / np.sqrt(1 + self.gamma2 * g2) + (1 - self.expnorm) * np.exp(-0.25 * self.beta2 * g2)
109
110    @property
111    def rho_rec_limit0(self):
112        """
113        Reciprocal space model charge value
114        close to reciprocal vector 0 .
115        rho_rec(g->0) -> 1 + rho_rec_limit0 * g^2
116        """
117        return -2 * self.gamma2 * self.expnorm - 0.25 * self.beta2 * (1 - self.expnorm)
118
119
120def eV_to_k(energy):
121    """
122    Convert energy to reciprocal vector magnitude k via hbar*k^2/2m
123    Args:
124        a: Energy in eV.
125
126    Returns:
127        (double) Reciprocal vector magnitude (units of 1/Bohr).
128    """
129    return math.sqrt(energy / invang_to_ev) * ang_to_bohr
130
131
132def genrecip(a1, a2, a3, encut):
133    """
134    Args:
135        a1, a2, a3: lattice vectors in bohr
136        encut: energy cut off in eV
137    Returns:
138        reciprocal lattice vectors with energy less than encut
139    """
140    vol = np.dot(a1, np.cross(a2, a3))  # 1/bohr^3
141    b1 = (2 * np.pi / vol) * np.cross(a2, a3)  # units 1/bohr
142    b2 = (2 * np.pi / vol) * np.cross(a3, a1)
143    b3 = (2 * np.pi / vol) * np.cross(a1, a2)
144
145    # create list of recip space vectors that satisfy |i*b1+j*b2+k*b3|<=encut
146    G_cut = eV_to_k(encut)
147    # Figure out max in all recipricol lattice directions
148    i_max = int(math.ceil(G_cut / norm(b1)))
149    j_max = int(math.ceil(G_cut / norm(b2)))
150    k_max = int(math.ceil(G_cut / norm(b3)))
151
152    # Build index list
153    i = np.arange(-i_max, i_max)
154    j = np.arange(-j_max, j_max)
155    k = np.arange(-k_max, k_max)
156
157    # Convert index to vectors using meshgrid
158    indicies = np.array(np.meshgrid(i, j, k)).T.reshape(-1, 3)
159    # Multiply integer vectors to get recipricol space vectors
160    vecs = np.dot(indicies, [b1, b2, b3])
161    # Calculate radii of all vectors
162    radii = np.sqrt(np.einsum("ij,ij->i", vecs, vecs))
163
164    # Yield based on radii
165    for vec, r in zip(vecs, radii):
166        if r < G_cut and r != 0:
167            yield vec
168
169
170def generate_reciprocal_vectors_squared(a1, a2, a3, encut):
171    """
172    Generate reciprocal vector magnitudes within the cutoff along the specied
173    lattice vectors.
174    Args:
175        a1: Lattice vector a (in Bohrs)
176        a2: Lattice vector b (in Bohrs)
177        a3: Lattice vector c (in Bohrs)
178        encut: Reciprocal vector energy cutoff
179
180    Returns:
181        [[g1^2], [g2^2], ...] Square of reciprocal vectors (1/Bohr)^2
182        determined by a1, a2, a3 and whose magntidue is less than gcut^2.
183    """
184    for vec in genrecip(a1, a2, a3, encut):
185        yield np.dot(vec, vec)
186
187
188def closestsites(struct_blk, struct_def, pos):
189    """
190    Returns closest site to the input position
191    for both bulk and defect structures
192    Args:
193        struct_blk: Bulk structure
194        struct_def: Defect structure
195        pos: Position
196    Return: (site object, dist, index)
197    """
198    blk_close_sites = struct_blk.get_sites_in_sphere(pos, 5, include_index=True)
199    blk_close_sites.sort(key=lambda x: x[1])
200    def_close_sites = struct_def.get_sites_in_sphere(pos, 5, include_index=True)
201    def_close_sites.sort(key=lambda x: x[1])
202
203    return blk_close_sites[0], def_close_sites[0]
204
205
206class StructureMotifInterstitial:
207    """
208    Generate interstitial sites at positions
209    where the interstitialcy is coordinated by nearest neighbors
210    in a way that resembles basic structure motifs
211    (e.g., tetrahedra, octahedra).  The algorithm is called InFiT
212    (Interstitialcy Finding Tool), it was introducted by
213    Nils E. R. Zimmermann, Matthew K. Horton, Anubhav Jain,
214    and Maciej Haranczyk (Front. Mater., 4, 34, 2017),
215    and it is used by the Python Charged Defect Toolkit
216    (PyCDT: D. Broberg et al., Comput. Phys. Commun., in press, 2018).
217    """
218
219    def __init__(
220        self,
221        struct,
222        inter_elem,
223        motif_types=("tetrahedral", "octahedral"),
224        op_threshs=(0.3, 0.5),
225        dl=0.2,
226        doverlap=1,
227        facmaxdl=1.01,
228        verbose=False,
229    ):
230        """
231        Generates symmetrically distinct interstitial sites at positions
232        where the interstitial is coordinated by nearest neighbors
233        in a pattern that resembles a supported structure motif
234        (e.g., tetrahedra, octahedra).
235
236        Args:
237            struct (Structure): input structure for which symmetrically
238                distinct interstitial sites are to be found.
239            inter_elem (string): element symbol of desired interstitial.
240            motif_types ([string]): list of structure motif types that are
241                to be considered.  Permissible types are:
242                tet (tetrahedron), oct (octahedron).
243            op_threshs ([float]): threshold values for the underlying order
244                parameters to still recognize a given structural motif
245                (i.e., for an OP value >= threshold the coordination pattern
246                match is positive, for OP < threshold the match is
247                negative.
248            dl (float): grid fineness in Angstrom.  The input
249                structure is divided into a grid of dimension
250                a/dl x b/dl x c/dl along the three crystallographic
251                directions, with a, b, and c being the lengths of
252                the three lattice vectors of the input unit cell.
253            doverlap (float): distance that is considered
254                to flag an overlap between any trial interstitial site
255                and a host atom.
256            facmaxdl (float): factor to be multiplied with the maximum grid
257                width that is then used as a cutoff distance for the
258                clustering prune step.
259            verbose (bool): flag indicating whether (True) or not (False;
260                default) to print additional information to screen.
261        """
262        # Initialize interstitial finding.
263        self._structure = struct.copy()
264        self._motif_types = motif_types[:]
265        if len(self._motif_types) == 0:
266            raise RuntimeError("no motif types provided.")
267        self._op_threshs = op_threshs[:]
268        self.cn_motif_lostop = {}
269        self.target_cns = []
270        for motif in self._motif_types:
271            if motif not in list(motif_cn_op.keys()):
272                raise RuntimeError("unsupported motif type: {}.".format(motif))
273            cn = int(motif_cn_op[motif]["cn"])
274            if cn not in self.target_cns:
275                self.target_cns.append(cn)
276            if cn not in list(self.cn_motif_lostop.keys()):
277                self.cn_motif_lostop[cn] = {}
278            tmp_optype = motif_cn_op[motif]["optype"]
279            if tmp_optype == "tet_max":
280                tmp_optype = "tet"
281            if tmp_optype == "oct_max":
282                tmp_optype = "oct"
283            self.cn_motif_lostop[cn][motif] = LocalStructOrderParams(
284                [tmp_optype], parameters=[motif_cn_op[motif]["params"]], cutoff=-10.0
285            )
286        self._dl = dl
287        self._defect_sites = []
288        self._defect_types = []
289        self._defect_site_multiplicity = []
290        self._defect_cns = []
291        self._defect_opvals = []
292
293        rots, trans = SpacegroupAnalyzer(struct)._get_symmetry()
294        nbins = [
295            int(struct.lattice.a / dl),
296            int(struct.lattice.b / dl),
297            int(struct.lattice.c / dl),
298        ]
299        dls = [
300            struct.lattice.a / float(nbins[0]),
301            struct.lattice.b / float(nbins[1]),
302            struct.lattice.c / float(nbins[2]),
303        ]
304        maxdl = max(dls)
305        if verbose:
306            print("Grid size: {} {} {}".format(nbins[0], nbins[1], nbins[2]))
307            print("dls: {} {} {}".format(dls[0], dls[1], dls[2]))
308        struct_w_inter = struct.copy()
309        struct_w_inter.append(inter_elem, [0, 0, 0])
310        natoms = len(list(struct_w_inter.sites))
311        trialsites = []
312
313        # Build index list
314        i = np.arange(0, nbins[0]) + 0.5
315        j = np.arange(0, nbins[1]) + 0.5
316        k = np.arange(0, nbins[2]) + 0.5
317
318        # Convert index to vectors using meshgrid
319        indicies = np.array(np.meshgrid(i, j, k)).T.reshape(-1, 3)
320        # Multiply integer vectors to get recipricol space vectors
321        vecs = np.multiply(indicies, np.divide(1, nbins))
322
323        # Loop over trial positions that are based on a regular
324        # grid in fractional coordinate space
325        # within the unit cell.
326        for vec in vecs:
327            struct_w_inter.replace(natoms - 1, inter_elem, coords=vec, coords_are_cartesian=False)
328            if len(struct_w_inter.get_sites_in_sphere(struct_w_inter.sites[natoms - 1].coords, doverlap)) == 1:
329                neighs_images_weigths = MinimumDistanceNN(tol=0.8, cutoff=6).get_nn_info(struct_w_inter, natoms - 1)
330                neighs_images_weigths_sorted = sorted(neighs_images_weigths, key=lambda x: x["weight"], reverse=True)
331                for nsite in range(1, len(neighs_images_weigths_sorted) + 1):
332                    if nsite not in self.target_cns:
333                        continue
334
335                    allsites = [neighs_images_weigths_sorted[i]["site"] for i in range(nsite)]
336                    indices_neighs = list(range(len(allsites)))
337                    allsites.append(struct_w_inter.sites[natoms - 1])
338                    for mot, ops in self.cn_motif_lostop[nsite].items():
339                        opvals = ops.get_order_parameters(allsites, len(allsites) - 1, indices_neighs=indices_neighs)
340                        if opvals[0] > op_threshs[motif_types.index(mot)]:
341                            cns = {}
342                            for isite in range(nsite):
343                                site = neighs_images_weigths_sorted[isite]["site"]
344                                if isinstance(site.specie, Element):
345                                    elem = site.specie.symbol
346                                else:
347                                    elem = site.specie.element.symbol
348                                if elem in list(cns.keys()):
349                                    cns[elem] = cns[elem] + 1
350                                else:
351                                    cns[elem] = 1
352                            trialsites.append(
353                                {
354                                    "mtype": mot,
355                                    "opval": opvals[0],
356                                    "coords": struct_w_inter.sites[natoms - 1].coords[:],
357                                    "fracs": vec,
358                                    "cns": dict(cns),
359                                }
360                            )
361                            break
362
363        # Prune list of trial sites by clustering and find the site
364        # with the largest order parameter value in each cluster.
365        nintersites = len(trialsites)
366        unique_motifs = []
367        for ts in trialsites:
368            if ts["mtype"] not in unique_motifs:
369                unique_motifs.append(ts["mtype"])
370        labels = {}
371        connected = []
372        for i in range(nintersites):
373            connected.append([])
374            for j in range(nintersites):
375                dist, image = struct_w_inter.lattice.get_distance_and_image(
376                    trialsites[i]["fracs"], trialsites[j]["fracs"]
377                )
378                connected[i].append(bool(dist < (maxdl * facmaxdl)))
379        include = []
380        for motif in unique_motifs:
381            labels[motif] = []
382            for i, ts in enumerate(trialsites):
383                labels[motif].append(i if ts["mtype"] == motif else -1)
384            change = True
385            while change:
386                change = False
387                for i in range(nintersites - 1):
388                    if change:
389                        break
390                    if labels[motif][i] == -1:
391                        continue
392                    for j in range(i + 1, nintersites):
393                        if labels[motif][j] == -1:
394                            continue
395                        if connected[i][j] and labels[motif][i] != labels[motif][j]:
396                            if labels[motif][i] < labels[motif][j]:
397                                labels[motif][j] = labels[motif][i]
398                            else:
399                                labels[motif][i] = labels[motif][j]
400                            change = True
401                            break
402            unique_ids = []
403            for l in labels[motif]:
404                if l != -1 and l not in unique_ids:
405                    unique_ids.append(l)
406            if verbose:
407                print("unique_ids {} {}".format(motif, unique_ids))
408            for uid in unique_ids:
409                maxq = 0.0
410                imaxq = -1
411                for i in range(nintersites):
412                    if labels[motif][i] == uid:
413                        if imaxq < 0 or trialsites[i]["opval"] > maxq:
414                            imaxq = i
415                            maxq = trialsites[i]["opval"]
416                include.append(imaxq)
417
418        # Prune by symmetry.
419        multiplicity = {}
420        discard = []
421        for motif in unique_motifs:
422            discard_motif = []
423            for indi, i in enumerate(include):
424                if trialsites[i]["mtype"] != motif or i in discard_motif:
425                    continue
426                multiplicity[i] = 1
427                symposlist = [trialsites[i]["fracs"].dot(np.array(m, dtype=float)) for m in rots]
428                for t in trans:
429                    symposlist.append(trialsites[i]["fracs"] + np.array(t))
430                for indj in range(indi + 1, len(include)):
431                    j = include[indj]
432                    if trialsites[j]["mtype"] != motif or j in discard_motif:
433                        continue
434                    for sympos in symposlist:
435                        dist, image = struct.lattice.get_distance_and_image(sympos, trialsites[j]["fracs"])
436                        if dist < maxdl * facmaxdl:
437                            discard_motif.append(j)
438                            multiplicity[i] += 1
439                            break
440            for i in discard_motif:
441                if i not in discard:
442                    discard.append(i)
443
444        if verbose:
445            print(
446                "Initial trial sites: {}\nAfter clustering: {}\n"
447                "After symmetry pruning: {}".format(len(trialsites), len(include), len(include) - len(discard))
448            )
449        for i in include:
450            if i not in discard:
451                self._defect_sites.append(
452                    PeriodicSite(
453                        Element(inter_elem),
454                        trialsites[i]["fracs"],
455                        self._structure.lattice,
456                        to_unit_cell=False,
457                        coords_are_cartesian=False,
458                        properties=None,
459                    )
460                )
461                self._defect_types.append(trialsites[i]["mtype"])
462                self._defect_cns.append(trialsites[i]["cns"])
463                self._defect_site_multiplicity.append(multiplicity[i])
464                self._defect_opvals.append(trialsites[i]["opval"])
465
466    def enumerate_defectsites(self):
467        """
468        Get all defect sites.
469
470        Returns:
471            defect_sites ([PeriodicSite]): list of periodic sites
472                    representing the interstitials.
473        """
474        return self._defect_sites
475
476    def get_motif_type(self, i):
477        """
478        Get the motif type of defect with index i (e.g., "tet").
479
480        Returns:
481            motif (string): motif type.
482        """
483        return self._defect_types[i]
484
485    def get_defectsite_multiplicity(self, n):
486        """
487        Returns the symmtric multiplicity of the defect site at the index.
488        """
489        return self._defect_site_multiplicity[n]
490
491    def get_coordinating_elements_cns(self, i):
492        """
493        Get element-specific coordination numbers of defect with index i.
494
495        Returns:
496            elem_cn (dict): dictionary storing the coordination numbers (int)
497                    with string representation of elements as keys.
498                    (i.e., {elem1 (string): cn1 (int), ...}).
499        """
500        return self._defect_cns[i]
501
502    def get_op_value(self, i):
503        """
504        Get order-parameter value of defect with index i.
505
506        Returns:
507            opval (float): OP value.
508        """
509        return self._defect_opvals[i]
510
511    def make_supercells_with_defects(self, scaling_matrix):
512        """
513        Generate a sequence of supercells
514        in which each supercell contains a single interstitial,
515        except for the first supercell in the sequence
516        which is a copy of the defect-free input structure.
517
518        Args:
519            scaling_matrix (3x3 integer array): scaling matrix
520                to transform the lattice vectors.
521        Returns:
522            scs ([Structure]): sequence of supercells.
523
524        """
525        scs = []
526        sc = self._structure.copy()
527        sc.make_supercell(scaling_matrix)
528        scs.append(sc)
529        for ids, defect_site in enumerate(self._defect_sites):
530            sc_with_inter = sc.copy()
531            sc_with_inter.append(
532                defect_site.species_string,
533                defect_site.frac_coords,
534                coords_are_cartesian=False,
535                validate_proximity=False,
536                properties=None,
537            )
538            if not sc_with_inter:
539                raise RuntimeError("could not generate supercell with" " interstitial {}".format(ids + 1))
540            scs.append(sc_with_inter.copy())
541        return scs
542
543
544class TopographyAnalyzer:
545    """
546    This is a generalized module to perform topological analyses of a crystal
547    structure using Voronoi tessellations. It can be used for finding potential
548    interstitial sites. Applications including using these sites for
549    inserting additional atoms or for analyzing diffusion pathways.
550
551    Note that you typically want to do some preliminary postprocessing after
552    the initial construction. The initial construction will create a lot of
553    points, especially for determining potential insertion sites. Some helper
554    methods are available to perform aggregation and elimination of nodes. A
555    typical use is something like::
556
557        a = TopographyAnalyzer(structure, ["O"], ["P"])
558        a.cluster_nodes()
559        a.remove_collisions()
560    """
561
562    def __init__(
563        self,
564        structure,
565        framework_ions,
566        cations,
567        tol=0.0001,
568        max_cell_range=1,
569        check_volume=True,
570        constrained_c_frac=0.5,
571        thickness=0.5,
572    ):
573        """
574        Init.
575
576        Args:
577            structure (Structure): An initial structure.
578            framework_ions ([str]): A list of ions to be considered as a
579                framework. Typically, this would be all anion species. E.g.,
580                ["O", "S"].
581            cations ([str]): A list of ions to be considered as non-migrating
582                cations. E.g., if you are looking at Li3PS4 as a Li
583                conductor, Li is a mobile species. Your cations should be [
584                "P"]. The cations are used to exclude polyhedra from
585                diffusion analysis since those polyhedra are already occupied.
586            tol (float): A tolerance distance for the analysis, used to
587                determine if something are actually periodic boundary images of
588                each other. Default is usually fine.
589            max_cell_range (int): This is the range of periodic images to
590                construct the Voronoi tesselation. A value of 1 means that we
591                include all points from (x +- 1, y +- 1, z+- 1) in the
592                voronoi construction. This is because the Voronoi poly
593                extends beyond the standard unit cell because of PBC.
594                Typically, the default value of 1 works fine for most
595                structures and is fast. But for really small unit
596                cells with high symmetry, you may need to increase this to 2
597                or higher.
598            check_volume (bool): Set False when ValueError always happen after
599                tuning tolerance.
600            constrained_c_frac (float): Constraint the region where users want
601                to do Topology analysis the default value is 0.5, which is the
602                fractional coordinate of the cell
603            thickness (float): Along with constrained_c_frac, limit the
604                thickness of the regions where we want to explore. Default is
605                0.5, which is mapping all the site of the unit cell.
606
607        """
608        self.structure = structure
609        self.framework_ions = {get_el_sp(sp) for sp in framework_ions}
610        self.cations = {get_el_sp(sp) for sp in cations}
611
612        # Let us first map all sites to the standard unit cell, i.e.,
613        # 0 ≤ coordinates < 1.
614        # structure = Structure.from_sites(structure, to_unit_cell=True)
615        # lattice = structure.lattice
616
617        # We could constrain the region where we want to dope/explore by setting
618        # the value of constrained_c_frac and thickness. The default mode is
619        # mapping all sites to the standard unit cell
620        s = structure.copy()
621        constrained_sites = []
622        for i, site in enumerate(s):
623            if (
624                site.frac_coords[2] >= constrained_c_frac - thickness
625                and site.frac_coords[2] <= constrained_c_frac + thickness
626            ):
627                constrained_sites.append(site)
628        structure = Structure.from_sites(sites=constrained_sites)
629        lattice = structure.lattice
630
631        # Divide the sites into framework and non-framework sites.
632        framework = []
633        non_framework = []
634        for site in structure:
635            if self.framework_ions.intersection(site.species.keys()):
636                framework.append(site)
637            else:
638                non_framework.append(site)
639
640        # We construct a supercell series of coords. This is because the
641        # Voronoi polyhedra can extend beyond the standard unit cell. Using a
642        # range of -2, -1, 0, 1 should be fine.
643        coords = []
644        cell_range = list(range(-max_cell_range, max_cell_range + 1))
645        for shift in itertools.product(cell_range, cell_range, cell_range):
646            for site in framework:
647                shifted = site.frac_coords + shift
648                coords.append(lattice.get_cartesian_coords(shifted))
649
650        # Perform the voronoi tessellation.
651        voro = Voronoi(coords)
652
653        # Store a mapping of each voronoi node to a set of points.
654        node_points_map = defaultdict(set)
655        for pts, vs in voro.ridge_dict.items():
656            for v in vs:
657                node_points_map[v].update(pts)
658
659        logger.debug("%d total Voronoi vertices" % len(voro.vertices))
660
661        # Vnodes store all the valid voronoi polyhedra. Cation vnodes store
662        # the voronoi polyhedra that are already occupied by existing cations.
663        vnodes = []
664        cation_vnodes = []
665
666        def get_mapping(poly):
667            """
668            Helper function to check if a vornoi poly is a periodic image
669            of one of the existing voronoi polys.
670            """
671            for v in vnodes:
672                if v.is_image(poly, tol):
673                    return v
674            return None
675
676        # Filter all the voronoi polyhedra so that we only consider those
677        # which are within the unit cell.
678        for i, vertex in enumerate(voro.vertices):
679            if i == 0:
680                continue
681            fcoord = lattice.get_fractional_coords(vertex)
682            poly = VoronoiPolyhedron(lattice, fcoord, node_points_map[i], coords, i)
683            if np.all([-tol <= c < 1 + tol for c in fcoord]):
684                if len(vnodes) == 0:
685                    vnodes.append(poly)
686                else:
687                    ref = get_mapping(poly)
688                    if ref is None:
689                        vnodes.append(poly)
690
691        logger.debug("%d voronoi vertices in cell." % len(vnodes))
692
693        # Eliminate all voronoi nodes which are closest to existing cations.
694        if len(cations) > 0:
695            cation_coords = [
696                site.frac_coords for site in non_framework if self.cations.intersection(site.species.keys())
697            ]
698
699            vertex_fcoords = [v.frac_coords for v in vnodes]
700            dist_matrix = lattice.get_all_distances(cation_coords, vertex_fcoords)
701            indices = np.where(dist_matrix == np.min(dist_matrix, axis=1)[:, None])[1]
702            cation_vnodes = [v for i, v in enumerate(vnodes) if i in indices]
703            vnodes = [v for i, v in enumerate(vnodes) if i not in indices]
704
705        logger.debug("%d vertices in cell not with cation." % len(vnodes))
706        self.coords = coords
707        self.vnodes = vnodes
708        self.cation_vnodes = cation_vnodes
709        self.framework = framework
710        self.non_framework = non_framework
711        if check_volume:
712            self.check_volume()
713
714    def check_volume(self):
715        """
716        Basic check for volume of all voronoi poly sum to unit cell volume
717        Note that this does not apply after poly combination.
718        """
719        vol = sum((v.volume for v in self.vnodes)) + sum((v.volume for v in self.cation_vnodes))
720        if abs(vol - self.structure.volume) > 1e-8:
721            raise ValueError(
722                "Sum of voronoi volumes is not equal to original volume of "
723                "structure! This may lead to inaccurate results. You need to "
724                "tweak the tolerance and max_cell_range until you get a "
725                "correct mapping."
726            )
727
728    def cluster_nodes(self, tol=0.2):
729        """
730        Cluster nodes that are too close together using a tol.
731
732        Args:
733            tol (float): A distance tolerance. PBC is taken into account.
734        """
735        lattice = self.structure.lattice
736
737        vfcoords = [v.frac_coords for v in self.vnodes]
738
739        # Manually generate the distance matrix (which needs to take into
740        # account PBC.
741        dist_matrix = np.array(lattice.get_all_distances(vfcoords, vfcoords))
742        dist_matrix = (dist_matrix + dist_matrix.T) / 2
743        for i in range(len(dist_matrix)):
744            dist_matrix[i, i] = 0
745        condensed_m = squareform(dist_matrix)
746        z = linkage(condensed_m)
747        cn = fcluster(z, tol, criterion="distance")
748        merged_vnodes = []
749        for n in set(cn):
750            poly_indices = set()
751            frac_coords = []
752            for i, j in enumerate(np.where(cn == n)[0]):
753                poly_indices.update(self.vnodes[j].polyhedron_indices)
754                if i == 0:
755                    frac_coords.append(self.vnodes[j].frac_coords)
756                else:
757                    fcoords = self.vnodes[j].frac_coords
758                    # We need the image to combine the frac_coords properly.
759                    d, image = lattice.get_distance_and_image(frac_coords[0], fcoords)
760                    frac_coords.append(fcoords + image)
761            merged_vnodes.append(VoronoiPolyhedron(lattice, np.average(frac_coords, axis=0), poly_indices, self.coords))
762        self.vnodes = merged_vnodes
763        logger.debug("%d vertices after combination." % len(self.vnodes))
764
765    def remove_collisions(self, min_dist=0.5):
766        """
767        Remove vnodes that are too close to existing atoms in the structure
768
769        Args:
770            min_dist(float): The minimum distance that a vertex needs to be
771                from existing atoms.
772        """
773        vfcoords = [v.frac_coords for v in self.vnodes]
774        sfcoords = self.structure.frac_coords
775        dist_matrix = self.structure.lattice.get_all_distances(vfcoords, sfcoords)
776        all_dist = np.min(dist_matrix, axis=1)
777        new_vnodes = []
778        for i, v in enumerate(self.vnodes):
779            if all_dist[i] > min_dist:
780                new_vnodes.append(v)
781        self.vnodes = new_vnodes
782
783    def get_structure_with_nodes(self):
784        """
785        Get the modified structure with the voronoi nodes inserted. The
786        species is set as a DummySpecies X.
787        """
788        new_s = Structure.from_sites(self.structure)
789        for v in self.vnodes:
790            new_s.append("X", v.frac_coords)
791        return new_s
792
793    def print_stats(self):
794        """
795        Print stats such as the MSE dist.
796        """
797        latt = self.structure.lattice
798
799        def get_min_dist(fcoords):
800            n = len(fcoords)
801            dist = latt.get_all_distances(fcoords, fcoords)
802            all_dist = [dist[i, j] for i in range(n) for j in range(i + 1, n)]
803            return min(all_dist)
804
805        voro = [s.frac_coords for s in self.vnodes]
806        print("Min dist between voronoi vertices centers = %.4f" % get_min_dist(voro))
807
808        def get_non_framework_dist(fcoords):
809            cations = [site.frac_coords for site in self.non_framework]
810            dist_matrix = latt.get_all_distances(cations, fcoords)
811            min_dist = np.min(dist_matrix, axis=1)
812            if len(cations) != len(min_dist):
813                raise Exception("Could not calculate distance to all cations")
814            return np.linalg.norm(min_dist), min(min_dist), max(min_dist)
815
816        print(len(self.non_framework))
817        print("MSE dist voro = %s" % str(get_non_framework_dist(voro)))
818
819    def write_topology(self, fname="Topo.cif"):
820        """
821        Write topology to a file.
822
823        :param fname: Filename
824        """
825        new_s = Structure.from_sites(self.structure)
826        for v in self.vnodes:
827            new_s.append("Mg", v.frac_coords)
828        new_s.to(filename=fname)
829
830    def analyze_symmetry(self, tol):
831        """
832        :param tol: Tolerance for SpaceGroupAnalyzer
833        :return: List
834        """
835        s = Structure.from_sites(self.framework)
836        site_to_vindex = {}
837        for i, v in enumerate(self.vnodes):
838            s.append("Li", v.frac_coords)
839            site_to_vindex[s[-1]] = i
840
841        print(len(s))
842        finder = SpacegroupAnalyzer(s, tol)
843        print(finder.get_space_group_operations())
844        symm_structure = finder.get_symmetrized_structure()
845        print(len(symm_structure.equivalent_sites))
846        return [
847            [site_to_vindex[site] for site in sites]
848            for sites in symm_structure.equivalent_sites
849            if sites[0].specie.symbol == "Li"
850        ]
851
852    def vtk(self):
853        """
854        Show VTK visualization.
855        """
856        if StructureVis is None:
857            raise NotImplementedError("vtk must be present to view.")
858        lattice = self.structure.lattice
859        vis = StructureVis()
860        vis.set_structure(Structure.from_sites(self.structure))
861        for v in self.vnodes:
862            vis.add_site(PeriodicSite("K", v.frac_coords, lattice))
863            vis.add_polyhedron(
864                [PeriodicSite("S", c, lattice, coords_are_cartesian=True) for c in v.polyhedron_coords],
865                PeriodicSite("Na", v.frac_coords, lattice),
866                color="element",
867                draw_edges=True,
868                edges_color=(0, 0, 0),
869            )
870        vis.show()
871
872
873class VoronoiPolyhedron:
874    """
875    Convenience container for a voronoi point in PBC and its associated polyhedron.
876    """
877
878    def __init__(self, lattice, frac_coords, polyhedron_indices, all_coords, name=None):
879        """
880        :param lattice:
881        :param frac_coords:
882        :param polyhedron_indices:
883        :param all_coords:
884        :param name:
885        """
886        self.lattice = lattice
887        self.frac_coords = frac_coords
888        self.polyhedron_indices = polyhedron_indices
889        self.polyhedron_coords = np.array(all_coords)[list(polyhedron_indices), :]
890        self.name = name
891
892    def is_image(self, poly, tol):
893        """
894        :param poly: VoronoiPolyhedron
895        :param tol: Coordinate tolerance.
896        :return: Whether a poly is an image of the current one.
897        """
898        frac_diff = pbc_diff(poly.frac_coords, self.frac_coords)
899        if not np.allclose(frac_diff, [0, 0, 0], atol=tol):
900            return False
901        to_frac = self.lattice.get_fractional_coords
902        for c1 in self.polyhedron_coords:
903            found = False
904            for c2 in poly.polyhedron_coords:
905                d = pbc_diff(to_frac(c1), to_frac(c2))
906                if not np.allclose(d, [0, 0, 0], atol=tol):
907                    found = True
908                    break
909            if not found:
910                return False
911        return True
912
913    @property
914    def coordination(self):
915        """
916        :return: Coordination number
917        """
918        return len(self.polyhedron_indices)
919
920    @property
921    def volume(self):
922        """
923        :return: Volume
924        """
925        return calculate_vol(self.polyhedron_coords)
926
927    def __str__(self):
928        return "Voronoi polyhedron %s" % self.name
929
930
931class ChargeDensityAnalyzer(MSONable):
932    """
933    Analyzer to find potential interstitial sites based on charge density. The
934    `total` charge density is used.
935    """
936
937    def __init__(self, chgcar):
938        """
939        Initialization.
940
941        Args:
942            chgcar (pmg.Chgcar): input Chgcar object.
943        """
944        self.chgcar = chgcar
945        self.structure = chgcar.structure
946        self.extrema_coords = []  # list of frac_coords of local extrema
947        self.extrema_type = None  # "local maxima" or "local minima"
948        self._extrema_df = None  # extrema frac_coords - chg density table
949        self._charge_distribution_df = None  # frac_coords - chg density table
950
951    @classmethod
952    def from_file(cls, chgcar_filename):
953        """
954        Init from a CHGCAR.
955
956        :param chgcar_filename:
957        :return:
958        """
959        chgcar = Chgcar.from_file(chgcar_filename)
960        return cls(chgcar=chgcar)
961
962    @property
963    def charge_distribution_df(self):
964        """
965        :return: Charge distribution.
966        """
967        if self._charge_distribution_df is None:
968            return self._get_charge_distribution_df()
969        return self._charge_distribution_df
970
971    @property
972    def extrema_df(self):
973        """
974        :return: The extrema in charge density.
975        """
976        if self.extrema_type is None:
977            logger.warning("Please run ChargeDensityAnalyzer.get_local_extrema first!")
978        return self._extrema_df
979
980    def _get_charge_distribution_df(self):
981        """
982        Return a complete table of fractional coordinates - charge density.
983        """
984        # Fraction coordinates and corresponding indices
985        axis_grid = np.array([np.array(self.chgcar.get_axis_grid(i)) / self.structure.lattice.abc[i] for i in range(3)])
986        axis_index = np.array([range(len(axis_grid[i])) for i in range(3)])
987
988        data = {}
989
990        for index in itertools.product(*axis_index):
991            a, b, c = index
992            f_coords = (axis_grid[0][a], axis_grid[1][b], axis_grid[2][c])
993            data[f_coords] = self.chgcar.data["total"][a][b][c]
994
995        # Fraction coordinates - charge density table
996        df = pd.Series(data).reset_index()
997        df.columns = ["a", "b", "c", "Charge Density"]
998        self._charge_distribution_df = df
999
1000        return df
1001
1002    def _update_extrema(self, f_coords, extrema_type, threshold_frac=None, threshold_abs=None):
1003        """Update _extrema_df, extrema_type and extrema_coords"""
1004
1005        if threshold_frac is not None:
1006            if threshold_abs is not None:
1007                logger.warning("Filter can be either threshold_frac or threshold_abs!")  # Exit if both filter are set
1008                return
1009            if threshold_frac > 1 or threshold_frac < 0:
1010                raise Exception("threshold_frac range is [0, 1]!")
1011
1012        # Return empty result if coords list is empty
1013        if len(f_coords) == 0:
1014            df = pd.DataFrame({}, columns=["A", "B", "C", "Chgcar"])
1015            self._extrema_df = df
1016            self.extrema_coords = []
1017            logger.info("Find {} {}.".format(len(df), extrema_type))
1018            return
1019
1020        data = {}
1021        unit = 1 / np.array(self.chgcar.dim)  # pixel along a, b, c
1022
1023        for fc in f_coords:
1024            a, b, c = tuple(map(int, fc / unit))
1025            data[tuple(fc)] = self.chgcar.data["total"][a][b][c]
1026
1027        df = pd.Series(data).reset_index()
1028        df.columns = ["a", "b", "c", "Charge Density"]
1029        ascending = extrema_type == "local minima"
1030
1031        if threshold_abs is None:
1032            threshold_frac = threshold_frac if threshold_frac is not None else 1.0
1033            num_extrema = int(threshold_frac * len(f_coords))
1034            df = df.sort_values(by="Charge Density", ascending=ascending)[0:num_extrema]
1035            df.reset_index(drop=True, inplace=True)  # reset major index
1036        else:  # threshold_abs is set
1037            df = df.sort_values(by="Charge Density", ascending=ascending)
1038            df = df[df["Charge Density"] <= threshold_abs] if ascending else df[df["Charge Density"] >= threshold_abs]
1039
1040        extrema_coords = []
1041        for row in df.iterrows():
1042            fc = np.array(row[1]["a":"c"])
1043            extrema_coords.append(fc)
1044
1045        self._extrema_df = df
1046        self.extrema_type = extrema_type
1047        self.extrema_coords = extrema_coords
1048        logger.info("Find {} {}.".format(len(df), extrema_type))
1049
1050    @requires(
1051        peak_local_max_found,
1052        "get_local_extrema requires skimage.feature.peak_local_max module"
1053        " to be installed. Please confirm your skimage installation.",
1054    )
1055    def get_local_extrema(self, find_min=True, threshold_frac=None, threshold_abs=None):
1056        """
1057        Get all local extrema fractional coordinates in charge density,
1058        searching for local minimum by default. Note that sites are NOT grouped
1059        symmetrically.
1060
1061        Args:
1062            find_min (bool): True to find local minimum else maximum, otherwise
1063                find local maximum.
1064
1065            threshold_frac (float): optional fraction of extrema shown, which
1066                returns `threshold_frac * tot_num_extrema` extrema fractional
1067                coordinates based on highest/lowest intensity.
1068
1069                E.g. set 0.2 to show the extrema with 20% highest or lowest
1070                intensity. Value range: 0 <= threshold_frac <= 1
1071
1072                Note that threshold_abs and threshold_frac should not set in the
1073                same time.
1074
1075            threshold_abs (float): optional filter. When searching for local
1076                minima, intensity <= threshold_abs returns; when searching for
1077                local maxima, intensity >= threshold_abs returns.
1078
1079                Note that threshold_abs and threshold_frac should not set in the
1080                same time.
1081
1082        Returns:
1083            extrema_coords (list): list of fractional coordinates corresponding
1084                to local extrema.
1085        """
1086        sign, extrema_type = 1, "local maxima"
1087
1088        if find_min:
1089            sign, extrema_type = -1, "local minima"
1090
1091        # Make 3x3x3 supercell
1092        # This is a trick to resolve the periodical boundary issue.
1093        total_chg = sign * self.chgcar.data["total"]
1094        total_chg = np.tile(total_chg, reps=(3, 3, 3))
1095        coordinates = peak_local_max(total_chg, min_distance=1)
1096
1097        # Remove duplicated sites introduced by supercell.
1098        f_coords = [coord / total_chg.shape * 3 for coord in coordinates]
1099        f_coords = [f - 1 for f in f_coords if all(np.array(f) < 2) and all(np.array(f) >= 1)]
1100
1101        # Update information
1102        self._update_extrema(
1103            f_coords,
1104            extrema_type,
1105            threshold_frac=threshold_frac,
1106            threshold_abs=threshold_abs,
1107        )
1108
1109        return self.extrema_coords
1110
1111    def cluster_nodes(self, tol=0.2):
1112        """
1113        Cluster nodes that are too close together using a tol.
1114
1115        Args:
1116            tol (float): A distance tolerance. PBC is taken into account.
1117        """
1118        lattice = self.structure.lattice
1119        vf_coords = self.extrema_coords
1120
1121        if len(vf_coords) == 0:
1122            if self.extrema_type is None:
1123                logger.warning("Please run ChargeDensityAnalyzer.get_local_extrema first!")
1124                return None
1125            new_f_coords = []
1126            self._update_extrema(new_f_coords, self.extrema_type)
1127            return new_f_coords
1128
1129        # Manually generate the distance matrix (which needs to take into
1130        # account PBC.
1131        dist_matrix = np.array(lattice.get_all_distances(vf_coords, vf_coords))
1132        dist_matrix = (dist_matrix + dist_matrix.T) / 2
1133
1134        for i in range(len(dist_matrix)):
1135            dist_matrix[i, i] = 0
1136        condensed_m = squareform(dist_matrix)
1137        z = linkage(condensed_m)
1138        cn = fcluster(z, tol, criterion="distance")
1139        merged_fcoords = []
1140
1141        for n in set(cn):
1142            frac_coords = []
1143            for i, j in enumerate(np.where(cn == n)[0]):
1144                if i == 0:
1145                    frac_coords.append(self.extrema_coords[j])
1146                else:
1147                    f_coords = self.extrema_coords[j]
1148                    # We need the image to combine the frac_coords properly.
1149                    d, image = lattice.get_distance_and_image(frac_coords[0], f_coords)
1150                    frac_coords.append(f_coords + image)
1151            merged_fcoords.append(np.average(frac_coords, axis=0))
1152
1153        merged_fcoords = [f - np.floor(f) for f in merged_fcoords]
1154        merged_fcoords = [f * (np.abs(f - 1) > 1e-15) for f in merged_fcoords]
1155        # the second line for fringe cases like
1156        # np.array([ 5.0000000e-01 -4.4408921e-17  5.0000000e-01])
1157        # where the shift to [0,1) does not work due to float precision
1158        self._update_extrema(merged_fcoords, extrema_type=self.extrema_type)
1159        logger.debug("{} vertices after combination.".format(len(self.extrema_coords)))
1160        return None
1161
1162    def remove_collisions(self, min_dist=0.5):
1163        """
1164        Remove predicted sites that are too close to existing atoms in the
1165        structure.
1166
1167        Args:
1168            min_dist (float): The minimum distance (in Angstrom) that
1169                a predicted site needs to be from existing atoms. A min_dist
1170                with value <= 0 returns all sites without distance checking.
1171        """
1172        s_f_coords = self.structure.frac_coords
1173        f_coords = self.extrema_coords
1174        if len(f_coords) == 0:
1175            if self.extrema_type is None:
1176                logger.warning("Please run ChargeDensityAnalyzer.get_local_extrema first!")
1177                return None
1178            new_f_coords = []
1179            self._update_extrema(new_f_coords, self.extrema_type)
1180            return new_f_coords
1181
1182        dist_matrix = self.structure.lattice.get_all_distances(f_coords, s_f_coords)
1183        all_dist = np.min(dist_matrix, axis=1)
1184        new_f_coords = []
1185
1186        for i, f in enumerate(f_coords):
1187            if all_dist[i] > min_dist:
1188                new_f_coords.append(f)
1189        self._update_extrema(new_f_coords, self.extrema_type)
1190
1191        return new_f_coords
1192
1193    def get_structure_with_nodes(
1194        self,
1195        find_min=True,
1196        min_dist=0.5,
1197        tol=0.2,
1198        threshold_frac=None,
1199        threshold_abs=None,
1200    ):
1201        """
1202        Get the modified structure with the possible interstitial sites added.
1203        The species is set as a DummySpecies X.
1204
1205        Args:
1206            find_min (bool): True to find local minimum else maximum, otherwise
1207                find local maximum.
1208
1209            min_dist (float): The minimum distance (in Angstrom) that
1210                a predicted site needs to be from existing atoms. A min_dist
1211                with value <= 0 returns all sites without distance checking.
1212
1213            tol (float): A distance tolerance of nodes clustering that sites too
1214                closed to other predicted sites will be merged. PBC is taken
1215                into account.
1216
1217            threshold_frac (float): optional fraction of extrema, which returns
1218                `threshold_frac * tot_num_extrema` extrema fractional
1219                coordinates based on highest/lowest intensity.
1220
1221                E.g. set 0.2 to insert DummySpecies atom at the extrema with 20%
1222                highest or lowest intensity.
1223                Value range: 0 <= threshold_frac <= 1
1224
1225                Note that threshold_abs and threshold_frac should not set in the
1226                same time.
1227
1228            threshold_abs (float): optional filter. When searching for local
1229                minima, intensity <= threshold_abs returns; when searching for
1230                local maxima, intensity >= threshold_abs returns.
1231
1232                Note that threshold_abs and threshold_frac should not set in the
1233                same time.
1234
1235        Returns:
1236            structure (Structure)
1237        """
1238
1239        structure = self.structure.copy()
1240        self.get_local_extrema(
1241            find_min=find_min,
1242            threshold_frac=threshold_frac,
1243            threshold_abs=threshold_abs,
1244        )
1245
1246        self.remove_collisions(min_dist)
1247        self.cluster_nodes(tol=tol)
1248        for fc in self.extrema_coords:
1249            structure.append("X", fc)
1250
1251        return structure
1252
1253    def sort_sites_by_integrated_chg(self, r=0.4):
1254        """
1255        Get the average charge density around each local minima in the charge density
1256        and store the result in _extrema_df
1257        Args:
1258            r (float): radius of sphere around each site to evaluate the average
1259        """
1260
1261        if self.extrema_type is None:
1262            self.get_local_extrema()
1263        int_den = []
1264        for isite in self.extrema_coords:
1265            mask = self._dist_mat(isite) < r
1266            vol_sphere = self.chgcar.structure.volume * (mask.sum() / self.chgcar.ngridpts)
1267            chg_in_sphere = np.sum(self.chgcar.data["total"] * mask) / mask.size / vol_sphere
1268            int_den.append(chg_in_sphere)
1269        self._extrema_df["avg_charge_den"] = int_den
1270        self._extrema_df.sort_values(by=["avg_charge_den"], inplace=True)
1271        self._extrema_df.reset_index(drop=True, inplace=True)
1272
1273    def _dist_mat(self, pos_frac):
1274        # return a matrix that contains the distances
1275        aa = np.linspace(0, 1, len(self.chgcar.get_axis_grid(0)), endpoint=False)
1276        bb = np.linspace(0, 1, len(self.chgcar.get_axis_grid(1)), endpoint=False)
1277        cc = np.linspace(0, 1, len(self.chgcar.get_axis_grid(2)), endpoint=False)
1278        AA, BB, CC = np.meshgrid(aa, bb, cc, indexing="ij")
1279        dist_from_pos = self.chgcar.structure.lattice.get_all_distances(
1280            fcoords1=np.vstack([AA.flatten(), BB.flatten(), CC.flatten()]).T,
1281            fcoords2=pos_frac,
1282        )
1283        return dist_from_pos.reshape(AA.shape)
1284
1285
1286class ChargeInsertionAnalyzer(ChargeDensityAnalyzer):
1287    """
1288    Analyze the charge density and create new candidate structures by inserting at each charge minima
1289    The similar inserterd structures are given the same uniqueness label.
1290    This works best with AECCAR data since CHGCAR data often contains spurious local minima in the core.
1291    However you can still use CHGCAR with an appropriate max_avg_charge value.
1292
1293    Application of this for Li can be found at:
1294    J.-X. Shen et al.: npj Comput. Mater. 6, 1 (2020)
1295    https://www.nature.com/articles/s41524-020-00422-3
1296    """
1297
1298    def __init__(
1299        self,
1300        chgcar,
1301        working_ion="Li",
1302        avg_radius=0.4,
1303        max_avg_charge=1.0,
1304        clustering_tol=0.6,
1305        ltol=0.2,
1306        stol=0.3,
1307        angle_tol=5,
1308    ):
1309        """
1310        Args:
1311            chgcar: The charge density object to analyze
1312            working_ion: The working ion to be inserted
1313            avg_radius: The radius used to calculate average charge density at each site
1314            max_avg_charge: Do no consider local minmas with avg charge above this value.
1315            clustering_tol: Distance tolerance for grouping sites together
1316            ltol: StructureMatcher ltol parameter
1317            stol: StructureMatcher stol parameter
1318            angle_tol: StructureMatcher angle_tol parameter
1319        """
1320        self.working_ion = working_ion
1321        self.sm = StructureMatcher(ltol=ltol, stol=stol, angle_tol=angle_tol)
1322        self.max_avg_charge = max_avg_charge
1323        self.avg_radius = avg_radius
1324        self.clustering_tol = clustering_tol
1325
1326        super().__init__(chgcar)
1327
1328    def get_labels(self):
1329        """
1330        Populate the extrema dataframe (self._extrema_df) with the insertion structure.
1331        Then, group the sites by structure similarity.
1332        Finally store a full list of the insertion sites, with their labels as a Structure Object
1333        """
1334
1335        self.get_local_extrema()
1336
1337        if len(self._extrema_df) > 1:
1338            self.cluster_nodes(tol=self.clustering_tol)
1339
1340        self.sort_sites_by_integrated_chg(r=self.avg_radius)
1341
1342        inserted_structs = []
1343
1344        self._extrema_df = self._extrema_df[self._extrema_df.avg_charge_den <= self.max_avg_charge]
1345
1346        for itr, li_site in self._extrema_df.iterrows():
1347            if li_site["avg_charge_den"] > self.max_avg_charge:
1348                continue
1349            tmp_struct = self.chgcar.structure.copy()
1350            li_site = self._extrema_df.iloc[itr]
1351            tmp_struct.insert(
1352                0,
1353                self.working_ion,
1354                [li_site["a"], li_site["b"], li_site["c"]],
1355                properties=dict(magmom=0),
1356            )
1357            tmp_struct.sort()
1358            inserted_structs.append(tmp_struct)
1359        self._extrema_df["inserted_struct"] = inserted_structs
1360        site_labels = generic_groupby(self._extrema_df.inserted_struct, comp=self.sm.fit)
1361        self._extrema_df["site_label"] = site_labels
1362
1363        # generate the structure with only Li atoms for NN analysis
1364        self.allsites_struct = Structure(
1365            self.structure.lattice,
1366            np.repeat(self.working_ion, len(self._extrema_df)),
1367            self._extrema_df[["a", "b", "c"]].values,
1368            site_properties={"label": self._extrema_df[["site_label"]].values.flatten()},
1369        )
1370
1371
1372def generic_groupby(list_in, comp=operator.eq):
1373    """
1374    Group a list of unsortable objects
1375    Args:
1376        list_in: A list of generic objects
1377        comp: (Default value = operator.eq) The comparator
1378    Returns:
1379        [int] list of labels for the input list
1380    """
1381    list_out = [None] * len(list_in)
1382    label_num = 0
1383    for i1, ls1 in enumerate(list_out):
1384        if ls1 is not None:
1385            continue
1386        list_out[i1] = label_num
1387        for i2, ls2 in list(enumerate(list_out))[(i1 + 1) :]:  # noqa
1388            if comp(list_in[i1], list_in[i2]):
1389                if list_out[i2] is None:
1390                    list_out[i2] = list_out[i1]
1391                else:
1392                    list_out[i1] = list_out[i2]
1393                    label_num -= 1
1394        label_num += 1
1395    return list_out
1396
1397
1398def calculate_vol(coords):
1399    """
1400    Calculate volume given a set of coords.
1401
1402    :param coords: List of coords.
1403    :return: Volume
1404    """
1405    if len(coords) == 4:
1406        coords_affine = np.ones((4, 4))
1407        coords_affine[:, 0:3] = np.array(coords)
1408        return abs(np.linalg.det(coords_affine)) / 6
1409
1410    simplices = get_facets(coords, joggle=True)
1411    center = np.average(coords, axis=0)
1412    vol = 0
1413    for s in simplices:
1414        c = list(coords[i] for i in s)
1415        c.append(center)
1416        vol += calculate_vol(c)
1417    return vol
1418
1419
1420def converge(f, step, tol, max_h):
1421    """
1422    simple newton iteration based convergence function
1423    """
1424    g = f(0)
1425    dx = 10000
1426    h = step
1427    while dx > tol:
1428        g2 = f(h)
1429        dx = abs(g - g2)
1430        g = g2
1431        h += step
1432
1433        if h > max_h:
1434            raise Exception("Did not converge before {}".format(h))
1435    return g
1436
1437
1438def tune_for_gamma(lattice, epsilon):
1439    """
1440    This tunes the gamma parameter for Kumagai anisotropic
1441    Ewald calculation. Method is to find a gamma parameter which generates a similar
1442    number of reciprocal and real lattice vectors,
1443    given the suggested cut off radii by Kumagai and Oba
1444    """
1445    logger.debug("Converging for ewald parameter...")
1446    prec = 25  # a reasonable precision to tune gamma for
1447
1448    gamma = (2 * np.average(lattice.abc)) ** (-1 / 2.0)
1449    recip_set, _, real_set, _ = generate_R_and_G_vecs(gamma, prec, lattice, epsilon)
1450    recip_set = recip_set[0]
1451    real_set = real_set[0]
1452
1453    logger.debug(
1454        "First approach with gamma ={}\nProduced {} real vecs and {} recip "
1455        "vecs.".format(gamma, len(real_set), len(recip_set))
1456    )
1457
1458    while float(len(real_set)) / len(recip_set) > 1.05 or float(len(recip_set)) / len(real_set) > 1.05:
1459        gamma *= (float(len(real_set)) / float(len(recip_set))) ** 0.17
1460        logger.debug("\tNot converged...Try modifying gamma to {}.".format(gamma))
1461        recip_set, _, real_set, _ = generate_R_and_G_vecs(gamma, prec, lattice, epsilon)
1462        recip_set = recip_set[0]
1463        real_set = real_set[0]
1464        logger.debug("Now have {} real vecs and {} recip vecs.".format(len(real_set), len(recip_set)))
1465
1466    logger.debug("Converged with gamma = {}".format(gamma))
1467
1468    return gamma
1469
1470
1471def generate_R_and_G_vecs(gamma, prec_set, lattice, epsilon):
1472    """
1473    This returns a set of real and reciprocal lattice vectors
1474    (and real/recip summation values)
1475    based on a list of precision values (prec_set)
1476
1477    gamma (float): Ewald parameter
1478    prec_set (list or number): for prec values to consider (20, 25, 30 are sensible numbers)
1479    lattice: Lattice object of supercell in question
1480
1481    """
1482    if not isinstance(prec_set, list):
1483        prec_set = [prec_set]
1484
1485    [a1, a2, a3] = lattice.matrix  # Angstrom
1486    volume = lattice.volume
1487    [b1, b2, b3] = lattice.reciprocal_lattice.matrix  # 1/ Angstrom
1488    invepsilon = np.linalg.inv(epsilon)
1489    rd_epsilon = np.sqrt(np.linalg.det(epsilon))
1490
1491    # generate reciprocal vector set (for each prec_set)
1492    recip_set = [[] for prec in prec_set]
1493    recip_summation_values = [0.0 for prec in prec_set]
1494    recip_cut_set = [(2 * gamma * prec) for prec in prec_set]
1495
1496    i_max = int(math.ceil(max(recip_cut_set) / np.linalg.norm(b1)))
1497    j_max = int(math.ceil(max(recip_cut_set) / np.linalg.norm(b2)))
1498    k_max = int(math.ceil(max(recip_cut_set) / np.linalg.norm(b3)))
1499    for i in np.arange(-i_max, i_max + 1):
1500        for j in np.arange(-j_max, j_max + 1):
1501            for k in np.arange(-k_max, k_max + 1):
1502                if not i and not j and not k:
1503                    continue
1504                gvec = i * b1 + j * b2 + k * b3
1505                normgvec = np.linalg.norm(gvec)
1506                for recip_cut_ind, recip_cut in enumerate(recip_cut_set):
1507                    if normgvec <= recip_cut:
1508                        recip_set[recip_cut_ind].append(gvec)
1509
1510                        Gdotdiel = np.dot(gvec, np.dot(epsilon, gvec))
1511                        summand = math.exp(-Gdotdiel / (4 * (gamma ** 2))) / Gdotdiel
1512                        recip_summation_values[recip_cut_ind] += summand
1513
1514    recip_summation_values = np.array(recip_summation_values)
1515    recip_summation_values /= volume
1516
1517    # generate real vector set (for each prec_set)
1518    real_set = [[] for prec in prec_set]
1519    real_summation_values = [0.0 for prec in prec_set]
1520    real_cut_set = [(prec / gamma) for prec in prec_set]
1521
1522    i_max = int(math.ceil(max(real_cut_set) / np.linalg.norm(a1)))
1523    j_max = int(math.ceil(max(real_cut_set) / np.linalg.norm(a2)))
1524    k_max = int(math.ceil(max(real_cut_set) / np.linalg.norm(a3)))
1525    for i in np.arange(-i_max, i_max + 1):
1526        for j in np.arange(-j_max, j_max + 1):
1527            for k in np.arange(-k_max, k_max + 1):
1528                rvec = i * a1 + j * a2 + k * a3
1529                normrvec = np.linalg.norm(rvec)
1530                for real_cut_ind, real_cut in enumerate(real_cut_set):
1531                    if normrvec <= real_cut:
1532                        real_set[real_cut_ind].append(rvec)
1533                        if normrvec > 1e-8:
1534                            sqrt_loc_res = np.sqrt(np.dot(rvec, np.dot(invepsilon, rvec)))
1535                            nmr = math.erfc(gamma * sqrt_loc_res)
1536                            real_summation_values[real_cut_ind] += nmr / sqrt_loc_res
1537
1538    real_summation_values = np.array(real_summation_values)
1539    real_summation_values /= 4 * np.pi * rd_epsilon
1540
1541    return recip_set, recip_summation_values, real_set, real_summation_values
1542