1# coding: utf-8
2# Copyright (c) Pymatgen Development Team.
3# Distributed under the terms of the MIT License.
4
5"""
6This module provides classes to perform analyses of
7the local environments (e.g., finding near neighbors)
8of single sites in molecules and structures based on
9bonding analysis with Lobster.
10"""
11import collections
12import copy
13import math
14import os
15
16import numpy as np
17from pymatgen.analysis.bond_valence import BVAnalyzer
18from pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder import LocalGeometryFinder
19from pymatgen.analysis.chemenv.coordination_environments.structure_environments import LightStructureEnvironments
20from pymatgen.analysis.local_env import NearNeighbors
21from pymatgen.electronic_structure.cohp import CompleteCohp
22from pymatgen.electronic_structure.core import Spin
23from pymatgen.electronic_structure.plotter import CohpPlotter
24from pymatgen.io.lobster import Charge, Icohplist
25
26__author__ = "Janine George"
27__copyright__ = "Copyright 2021, The Materials Project"
28__version__ = "1.0"
29__maintainer__ = "J. George"
30__email__ = "janinegeorge.ulfen@gmail.com"
31__status__ = "Production"
32__date__ = "February 2, 2021"
33
34
35class LobsterNeighbors(NearNeighbors):
36    """
37    This class combines capabilities from LocalEnv and ChemEnv to determine coordination environments based on
38    bonding analysis
39    """
40
41    def __init__(
42        self,
43        are_coops=False,
44        filename_ICOHP=None,
45        valences=None,
46        limits=None,
47        structure=None,
48        additional_condition=0,
49        only_bonds_to=None,
50        perc_strength_ICOHP=0.15,
51        valences_from_charges=False,
52        filename_CHARGE=None,
53        adapt_extremum_to_add_cond=False,
54    ):
55        """
56
57        Args:
58            are_coops: (Bool) if True, the file is a ICOOPLIST.lobster and not a ICOHPLIST.lobster; only tested for
59            ICOHPLIST.lobster so far
60            filename_ICOHP: (str) Path to ICOOPLIST.lobster
61            valences: (list of integers/floats) gives valence/charge for each element
62            limits: limit to decide which ICOHPs should be considered
63            structure: (Structure Object) typically constructed by: Structure.from_file("POSCAR") (Structure object
64            from pymatgen.core.structure)
65            additional_condition:   Additional condition that decides which kind of bonds will be considered
66                                    NO_ADDITIONAL_CONDITION = 0
67                                    ONLY_ANION_CATION_BONDS = 1
68                                    NO_ELEMENT_TO_SAME_ELEMENT_BONDS = 2
69                                    ONLY_ANION_CATION_BONDS_AND_NO_ELEMENT_TO_SAME_ELEMENT_BONDS = 3
70                                    ONLY_ELEMENT_TO_OXYGEN_BONDS = 4
71                                    DO_NOT_CONSIDER_ANION_CATION_BONDS=5
72                                    ONLY_CATION_CATION_BONDS=6
73            only_bonds_to: (list of str) will only consider bonds to certain elements (e.g. ["O"] for oxygen)
74            perc_strength_ICOHP: if no limits are given, this will decide which icohps will still be considered (
75            relative to
76            the strongest ICOHP)
77            valences_from_charges: if True and path to CHARGE.lobster is provided, will use Lobster charges (
78            Mulliken) instead of valences
79            filename_CHARGE: (str) Path to Charge.lobster
80            adapt_extremum_to_add_cond: (bool) will adapt the limits to only focus on the bonds determined by the
81            additional condition
82        """
83
84        self.ICOHP = Icohplist(are_coops=are_coops, filename=filename_ICOHP)
85        self.Icohpcollection = self.ICOHP.icohpcollection
86        self.structure = structure
87        self.limits = limits
88        self.only_bonds_to = only_bonds_to
89        self.adapt_extremum_to_add_cond = adapt_extremum_to_add_cond
90        self.are_coops = are_coops
91
92        if are_coops:
93            raise ValueError("Algorithm only works correctly for ICOHPLIST.lobster")
94
95        # will check if the additional condition is correctly delivered
96        if additional_condition in range(0, 7):
97            self.additional_condition = additional_condition
98        else:
99            raise ValueError("No correct additional condition")
100
101        # will read in valences, will prefer manual setting of valences
102        if valences is None:
103            if valences_from_charges and filename_CHARGE is not None:
104                chg = Charge(filename=filename_CHARGE)
105                self.valences = chg.Mulliken
106            else:
107                bv_analyzer = BVAnalyzer()
108                try:
109                    self.valences = bv_analyzer.get_valences(structure=self.structure)
110                except ValueError:
111                    self.valences = None
112                    if additional_condition in [1, 3, 5, 6]:
113                        print("Valences cannot be assigned, additional_conditions 1 and 3 and 5 and 6 will not work")
114        else:
115            self.valences = valences
116
117        if limits is None:
118            self.lowerlimit = None
119            self.upperlimit = None
120
121        else:
122            self.lowerlimit = limits[0]
123            self.upperlimit = limits[1]
124
125        # will evaluate coordination environments
126        self._evaluate_ce(
127            lowerlimit=self.lowerlimit,
128            upperlimit=self.upperlimit,
129            only_bonds_to=only_bonds_to,
130            additional_condition=self.additional_condition,
131            perc_strength_ICOHP=perc_strength_ICOHP,
132            adapt_extremum_to_add_cond=adapt_extremum_to_add_cond,
133        )
134
135    @property
136    def structures_allowed(self):
137        """
138        Boolean property: can this NearNeighbors class be used with Structure
139        objects?
140        """
141        return True
142
143    @property
144    def molecules_allowed(self):
145        """
146        Boolean property: can this NearNeighbors class be used with Molecule
147        objects?
148        """
149        return False
150
151    def get_anion_types(self):
152        """
153        will return the types of anions present in crystal structure
154        Returns:
155
156        """
157        if self.valences is None:
158            raise ValueError("No cations and anions defined")
159
160        anion_species = []
161        for site, val in zip(self.structure, self.valences):
162            if val < 0.0:
163                anion_species.append(site.specie)
164
165        return set(anion_species)
166
167    def get_nn_info(self, structure, n, use_weights=False):
168        """
169        Get coordination number, CN, of site with index n in structure.
170
171        Args:
172            structure (Structure): input structure.
173            n (integer): index of site for which to determine CN.
174            use_weights (boolean): flag indicating whether (True)
175                to use weights for computing the coordination number
176                or not (False, default: each coordinated site has equal
177                weight).
178                True is not implemented for LobsterNeighbors
179        Returns:
180            cn (integer or float): coordination number.
181        """
182        if use_weights:
183            raise ValueError("LobsterEnv cannot use weights")
184        if len(structure) != len(self.structure):
185            raise ValueError("The wrong structure was provided")
186        return self.sg_list[n]
187
188    def get_light_structure_environment(self, only_cation_environments=False, only_indices=None):
189        """
190        will return a LobsterLightStructureEnvironments object
191        if the structure only contains coordination environments smaller 13
192        Args:
193            only_cation_environments: only data for cations will be returned
194            only_indices: will only evaluate the list of isites in this list
195        Returns: LobsterLightStructureEnvironments Object
196
197        """
198
199        lgf = LocalGeometryFinder()
200        lgf.setup_structure(structure=self.structure)
201        list_ce_symbols = []
202        list_csm = []
203        list_permut = []
204        for ival, _neigh_coords in enumerate(self.list_coords):
205
206            if (len(_neigh_coords)) > 13:
207                raise ValueError("Environment cannot be determined. Number of neighbors is larger than 13.")
208            # to avoid problems if _neigh_coords is empty
209            if _neigh_coords != []:
210                lgf.setup_local_geometry(isite=ival, coords=_neigh_coords, optimization=2)
211                cncgsm = lgf.get_coordination_symmetry_measures(optimization=2)
212                list_ce_symbols.append(min(cncgsm.items(), key=lambda t: t[1]["csm_wcs_ctwcc"])[0])
213                list_csm.append(min(cncgsm.items(), key=lambda t: t[1]["csm_wcs_ctwcc"])[1]["csm_wcs_ctwcc"])
214                list_permut.append(min(cncgsm.items(), key=lambda t: t[1]["csm_wcs_ctwcc"])[1]["indices"])
215            else:
216                list_ce_symbols.append(None)
217                list_csm.append(None)
218                list_permut.append(None)
219
220        if only_indices is None:
221            if not only_cation_environments:
222                lse = LobsterLightStructureEnvironments.from_Lobster(
223                    list_ce_symbol=list_ce_symbols,
224                    list_csm=list_csm,
225                    list_permutation=list_permut,
226                    list_neighsite=self.list_neighsite,
227                    list_neighisite=self.list_neighisite,
228                    structure=self.structure,
229                    valences=self.valences,
230                )
231            else:
232                new_list_ce_symbols = []
233                new_list_csm = []
234                new_list_permut = []
235                new_list_neighsite = []
236                new_list_neighisite = []
237
238                for ival, val in enumerate(self.valences):
239
240                    if val >= 0.0:
241
242                        new_list_ce_symbols.append(list_ce_symbols[ival])
243                        new_list_csm.append(list_csm[ival])
244                        new_list_permut.append(list_permut[ival])
245                        new_list_neighisite.append(self.list_neighisite[ival])
246                        new_list_neighsite.append(self.list_neighsite[ival])
247                    else:
248                        new_list_ce_symbols.append(None)
249                        new_list_csm.append(None)
250                        new_list_permut.append([])
251                        new_list_neighisite.append([])
252                        new_list_neighsite.append([])
253
254                lse = LobsterLightStructureEnvironments.from_Lobster(
255                    list_ce_symbol=new_list_ce_symbols,
256                    list_csm=new_list_csm,
257                    list_permutation=new_list_permut,
258                    list_neighsite=new_list_neighsite,
259                    list_neighisite=new_list_neighisite,
260                    structure=self.structure,
261                    valences=self.valences,
262                )
263        else:
264            new_list_ce_symbols = []
265            new_list_csm = []
266            new_list_permut = []
267            new_list_neighsite = []
268            new_list_neighisite = []
269
270            for isite, site in enumerate(self.structure):
271
272                if isite in only_indices:
273
274                    new_list_ce_symbols.append(list_ce_symbols[isite])
275                    new_list_csm.append(list_csm[isite])
276                    new_list_permut.append(list_permut[isite])
277                    new_list_neighisite.append(self.list_neighisite[isite])
278                    new_list_neighsite.append(self.list_neighsite[isite])
279                else:
280                    new_list_ce_symbols.append(None)
281                    new_list_csm.append(None)
282                    new_list_permut.append([])
283                    new_list_neighisite.append([])
284                    new_list_neighsite.append([])
285
286            lse = LobsterLightStructureEnvironments.from_Lobster(
287                list_ce_symbol=new_list_ce_symbols,
288                list_csm=new_list_csm,
289                list_permutation=new_list_permut,
290                list_neighsite=new_list_neighsite,
291                list_neighisite=new_list_neighisite,
292                structure=self.structure,
293                valences=self.valences,
294            )
295
296        return lse
297
298    def get_info_icohps_to_neighbors(self, isites=[], onlycation_isites=True):
299        """
300        this method will return information of cohps of neighbors
301        Args:
302            isites: list of site ids, if isite==[], all isites will be used to add the icohps of the neighbors
303            onlycation_isites: will only use cations, if isite==[]
304
305
306        Returns:
307            sum of icohps of neighbors to certain sites [given by the id in structure], number of bonds to this site,
308            labels (from ICOHPLIST) for
309            these bonds
310            [the latter is useful for plotting summed COHP plots]
311        """
312
313        if self.valences is None and onlycation_isites:
314            raise ValueError("No valences are provided")
315        if isites == []:
316            if onlycation_isites:
317                isites = [i for i in range(len(self.structure)) if self.valences[i] >= 0.0]
318            else:
319                isites = list(range(len(self.structure)))
320
321        summed_icohps = 0.0
322        list_icohps = []
323        number_bonds = 0
324        labels = []
325        atoms = []
326        for ival, site in enumerate(self.structure):
327            if ival in isites:
328                for keys, icohpsum in zip(self.list_keys[ival], self.list_icohps[ival]):
329                    summed_icohps += icohpsum
330                    list_icohps.append(icohpsum)
331                    labels.append(keys)
332                    atoms.append(
333                        [
334                            self.Icohpcollection._list_atom1[int(keys) - 1],
335                            self.Icohpcollection._list_atom2[int(keys) - 1],
336                        ]
337                    )
338                    number_bonds += 1
339
340        return summed_icohps, list_icohps, number_bonds, labels, atoms
341
342    def plot_cohps_of_neighbors(
343        self,
344        path_to_COHPCAR="COHPCAR.lobster",
345        isites=[],
346        onlycation_isites=True,
347        only_bonds_to=None,
348        per_bond=False,
349        summed_spin_channels=False,
350        xlim=None,
351        ylim=[-10, 6],
352        integrated=False,
353    ):
354
355        """
356        will plot summed cohps (please be careful in the spin polarized case (plots might overlap (exactly!))
357        Args:
358            isites: list of site ids, if isite==[], all isites will be used to add the icohps of the neighbors
359            onlycation_isites: bool, will only use cations, if isite==[]
360            only_bonds_to: list of str, only anions in this list will be considered
361            per_bond: bool, will lead to a normalization of the plotted COHP per number of bond if True,
362            otherwise the sum
363            will be plotted
364            xlim: list of float, limits of x values
365            ylim: list of float, limits of y values
366            integrated: bool, if true will show integrated cohp instead of cohp
367
368        Returns:
369            plt of the cohps
370
371        """
372
373        # include COHPPlotter and plot a sum of these COHPs
374        # might include option to add Spin channels
375        # implement only_bonds_to
376        cp = CohpPlotter()
377
378        plotlabel, summed_cohp = self.get_info_cohps_to_neighbors(
379            path_to_COHPCAR,
380            isites,
381            only_bonds_to,
382            onlycation_isites,
383            per_bond,
384            summed_spin_channels=summed_spin_channels,
385        )
386
387        cp.add_cohp(plotlabel, summed_cohp)
388        plot = cp.get_plot(integrated=integrated)
389        if xlim is not None:
390            plot.xlim(xlim)
391
392        if ylim is not None:
393            plot.ylim(ylim)
394
395        return plot
396
397    def get_info_cohps_to_neighbors(
398        self,
399        path_to_COHPCAR="COHPCAR.lobster",
400        isites=[],
401        only_bonds_to=None,
402        onlycation_isites=True,
403        per_bond=True,
404        summed_spin_channels=False,
405    ):
406        """
407        will return info about the cohps from all sites mentioned in isites with neighbors
408        Args:
409            path_to_COHPCAR: str, path to COHPCAR
410            isites: list of int that indicate the number of the site
411            only_bonds_to: list of str, e.g. ["O"] to only show cohps of anything to oxygen
412            onlycation_isites: if isites=[], only cation sites will be returned
413            per_bond: will normalize per bond
414            summed_spin_channels: will sum all spin channels
415
416        Returns: label for cohp (str), CompleteCohp object which describes all cohps of the sites as given by isites
417        and the other parameters
418
419        """
420        # TODO: add options for orbital-resolved cohps
421        summed_icohps, list_icohps, number_bonds, labels, atoms = self.get_info_icohps_to_neighbors(
422            isites=isites, onlycation_isites=onlycation_isites
423        )
424
425        import tempfile
426
427        with tempfile.TemporaryDirectory() as t:
428            path = os.path.join(t, "POSCAR.vasp")
429
430            self.structure.to(filename=path, fmt="POSCAR")
431
432            if not hasattr(self, "completecohp"):
433                self.completecohp = CompleteCohp.from_file(fmt="LOBSTER", filename=path_to_COHPCAR, structure_file=path)
434
435        # will check that the number of bonds in ICOHPLIST and COHPCAR are identical
436        # further checks could be implemented
437        if len(self.Icohpcollection._list_atom1) != len(self.completecohp.bonds.keys()):
438            raise ValueError("COHPCAR and ICOHPLIST do not fit together")
439        is_spin_completecohp = Spin.down in self.completecohp.get_cohp_by_label("1").cohp
440        if self.Icohpcollection.is_spin_polarized != is_spin_completecohp:
441            raise ValueError("COHPCAR and ICOHPLIST do not fit together")
442
443        if only_bonds_to is None:
444            # sort by anion type
445            if per_bond:
446                divisor = len(labels)
447            else:
448                divisor = 1
449
450            plotlabel = self._get_plot_label(atoms, per_bond)
451            summed_cohp = self.completecohp.get_summed_cohp_by_label_list(
452                label_list=labels, divisor=divisor, summed_spin_channels=summed_spin_channels
453            )
454
455        else:
456            # TODO: check if this is okay
457            # labels of the COHPs that will be summed!
458            # iterate through labels and atoms and check which bonds can be included
459            new_labels = []
460            new_atoms = []
461            # print(labels)
462            # print(atoms)
463            for label, atompair in zip(labels, atoms):
464                # durchlaufe only_bonds_to=[] und sage ja, falls eines der Labels in atompair ist, dann speichere
465                # new_label
466                present = False
467                # print(only_bonds_to)
468                for atomtype in only_bonds_to:
469                    if atomtype in (self._split_string(atompair[0])[0], self._split_string(atompair[1])[0]):
470                        present = True
471                if present:
472                    new_labels.append(label)
473                    new_atoms.append(atompair)
474            # print(new_labels)
475            if len(new_labels) > 0:
476                if per_bond:
477                    divisor = len(new_labels)
478                else:
479                    divisor = 1
480
481                plotlabel = self._get_plot_label(new_atoms, per_bond)
482                summed_cohp = self.completecohp.get_summed_cohp_by_label_list(
483                    label_list=new_labels, divisor=divisor, summed_spin_channels=summed_spin_channels
484                )
485            else:
486                plotlabel = None
487
488                summed_cohp = None
489
490        return plotlabel, summed_cohp
491
492    def _get_plot_label(self, atoms, per_bond):
493        # count the types of bonds and append a label:
494        all_labels = []
495        for atomsnames in atoms:
496            new = [self._split_string(atomsnames[0])[0], self._split_string(atomsnames[1])[0]]
497            new.sort()
498            # print(new2)
499            string_here = new[0] + "-" + new[1]
500            all_labels.append(string_here)
501        count = collections.Counter(all_labels)
502        plotlabels = []
503        for key, item in count.items():
504            plotlabels.append(str(item) + " x " + str(key))
505        plotlabel = ", ".join(plotlabels)
506        if per_bond:
507            plotlabel = plotlabel + " (per bond)"
508        return plotlabel
509
510    def get_info_icohps_between_neighbors(self, isites=[], onlycation_isites=True):
511
512        """
513        will return infos about interactions between neighbors of a certain atom
514        Args:
515            isites: list of site ids, if isite==[], all isites will be used
516            onlycation_isites: will only use cations, if isite==[]
517
518        Returns:
519
520        """
521
522        lowerlimit = self.lowerlimit
523        upperlimit = self.upperlimit
524
525        if self.valences is None and onlycation_isites:
526            raise ValueError("No valences are provided")
527        if isites == []:
528            if onlycation_isites:
529                isites = [i for i in range(len(self.structure)) if self.valences[i] >= 0.0]
530            else:
531                isites = list(range(len(self.structure)))
532
533        summed_icohps = 0.0
534        list_icohps = []
535        number_bonds = 0
536        label_list = []
537        atoms_list = []
538        for iisite, isite in enumerate(isites):
539            for in_site, n_site in enumerate(self.list_neighsite[isite]):
540                for in_site2, n_site2 in enumerate(self.list_neighsite[isite]):
541                    if in_site < in_site2:
542                        unitcell1 = self._determine_unit_cell(n_site)
543                        unitcell2 = self._determine_unit_cell(n_site2)
544
545                        index_n_site = self._get_original_site(self.structure, n_site)
546                        index_n_site2 = self._get_original_site(self.structure, n_site2)
547
548                        if index_n_site < index_n_site2:
549                            translation = list(np.array(unitcell1) - np.array(unitcell2))
550                        elif index_n_site2 < index_n_site:
551                            translation = list(np.array(unitcell2) - np.array(unitcell1))
552                        else:
553                            translation = list(np.array(unitcell1) - np.array(unitcell2))
554
555                        icohps = self._get_icohps(
556                            icohpcollection=self.Icohpcollection,
557                            isite=index_n_site,
558                            lowerlimit=lowerlimit,
559                            upperlimit=upperlimit,
560                            only_bonds_to=self.only_bonds_to,
561                        )
562
563                        done = False
564                        for key, icohp in icohps.items():
565
566                            atomnr1 = self._get_atomnumber(icohp._atom1)
567                            atomnr2 = self._get_atomnumber(icohp._atom2)
568                            label = icohp._label
569
570                            if (index_n_site == atomnr1 and index_n_site2 == atomnr2) or (
571                                index_n_site == atomnr2 and index_n_site2 == atomnr1
572                            ):
573
574                                if atomnr1 != atomnr2:
575
576                                    if np.all(np.asarray(translation) == np.asarray(icohp._translation)):
577                                        summed_icohps += icohp.summed_icohp
578                                        list_icohps.append(icohp.summed_icohp)
579                                        number_bonds += 1
580                                        label_list.append(label)
581                                        atoms_list.append(
582                                            [
583                                                self.Icohpcollection._list_atom1[int(label) - 1],
584                                                self.Icohpcollection._list_atom2[int(label) - 1],
585                                            ]
586                                        )
587
588                                else:
589                                    if not done:
590                                        if (np.all(np.asarray(translation) == np.asarray(icohp._translation))) or (
591                                            np.all(
592                                                np.asarray(translation)
593                                                == np.asarray(
594                                                    [
595                                                        -icohp._translation[0],
596                                                        -icohp._translation[1],
597                                                        -icohp._translation[2],
598                                                    ]
599                                                )
600                                            )
601                                        ):
602                                            summed_icohps += icohp.summed_icohp
603                                            list_icohps.append(icohp.summed_icohp)
604                                            number_bonds += 1
605                                            label_list.append(label)
606                                            atoms_list.append(
607                                                [
608                                                    self.Icohpcollection._list_atom1[int(label) - 1],
609                                                    self.Icohpcollection._list_atom2[int(label) - 1],
610                                                ]
611                                            )
612                                            done = True
613
614        return summed_icohps, list_icohps, number_bonds, label_list, atoms_list
615
616    def _evaluate_ce(
617        self,
618        lowerlimit,
619        upperlimit,
620        only_bonds_to=None,
621        additional_condition=0,
622        perc_strength_ICOHP=0.15,
623        adapt_extremum_to_add_cond=False,
624    ):
625        """
626
627        Args:
628            lowerlimit: lower limit which determines the ICOHPs that are considered for the determination of the
629            neighbors
630            upperlimit: upper limit which determines the ICOHPs that are considered for the determination of the
631            neighbors
632            only_bonds_to: restricts the types of bonds that will be considered
633            additional_condition: Additional condition for the evaluation
634            perc_strength_ICOHP: will be used to determine how strong the ICOHPs (percentage*strongest ICOHP) will be
635            that are still considered for the evalulation
636            adapt_extremum_to_add_cond: will recalculate the limit based on the bonding type and not on the overall
637            extremum
638        Returns:
639
640        """
641        # get extremum
642        if lowerlimit is None and upperlimit is None:
643
644            lowerlimit, upperlimit = self._get_limit_from_extremum(
645                self.Icohpcollection,
646                percentage=perc_strength_ICOHP,
647                adapt_extremum_to_add_cond=adapt_extremum_to_add_cond,
648                additional_condition=additional_condition,
649            )
650
651        elif lowerlimit is None and upperlimit is not None:
652            raise ValueError("Please give two limits or leave them both at None")
653        elif upperlimit is None and lowerlimit is not None:
654            raise ValueError("Please give two limits or leave them both at None")
655
656        # find environments based on ICOHP values
657        list_icohps, list_keys, list_lengths, list_neighisite, list_neighsite, list_coords = self._find_environments(
658            additional_condition, lowerlimit, upperlimit, only_bonds_to
659        )
660
661        self.list_icohps = list_icohps
662        self.list_lengths = list_lengths
663        self.list_keys = list_keys
664        self.list_neighsite = list_neighsite
665        self.list_neighisite = list_neighisite
666        self.list_coords = list_coords
667
668        # make a structure graph
669        # make sure everything is relative to the given Structure and not just the atoms in the unit cell
670        self.sg_list = [
671            [
672                {
673                    "site": neighbor,
674                    "image": tuple(
675                        int(round(i))
676                        for i in (
677                            neighbor.frac_coords
678                            - self.structure[
679                                [
680                                    isite
681                                    for isite, site in enumerate(self.structure)
682                                    if neighbor.is_periodic_image(site)
683                                ][0]
684                            ].frac_coords
685                        )
686                    ),
687                    "weight": 1,
688                    "site_index": [
689                        isite for isite, site in enumerate(self.structure) if neighbor.is_periodic_image(site)
690                    ][0],
691                }
692                for neighbor in neighbors
693            ]
694            for neighbors in self.list_neighsite
695        ]
696
697    def _find_environments(self, additional_condition, lowerlimit, upperlimit, only_bonds_to):
698        """
699        will find all relevant neighbors based on certain restrictions
700        Args:
701            additional_condition (int): additional condition (see above)
702            lowerlimit (float): lower limit that tells you which ICOHPs are considered
703            upperlimit (float): upper limit that tells you which ICOHPs are considerd
704            only_bonds_to (list): list of str, e.g. ["O"] that will ensure that only bonds to "O" will be considered
705
706        Returns:
707
708        """
709        # run over structure
710        list_neighsite = []
711        list_neighisite = []
712        list_coords = []
713        list_icohps = []
714        list_lengths = []
715        list_keys = []
716        for isite, site in enumerate(self.structure):
717
718            icohps = self._get_icohps(
719                icohpcollection=self.Icohpcollection,
720                isite=isite,
721                lowerlimit=lowerlimit,
722                upperlimit=upperlimit,
723                only_bonds_to=only_bonds_to,
724            )
725
726            (
727                keys_from_ICOHPs,
728                lengths_from_ICOHPs,
729                neighbors_from_ICOHPs,
730                selected_ICOHPs,
731            ) = self._find_relevant_atoms_additional_condition(isite, icohps, additional_condition)
732
733            if len(neighbors_from_ICOHPs) > 0:
734                centralsite = self.structure.sites[isite]
735
736                neighbors_by_distance_start = self.structure.get_sites_in_sphere(
737                    pt=centralsite.coords,
738                    r=np.max(lengths_from_ICOHPs) + 0.5,
739                    include_image=True,
740                    include_index=True,
741                )
742
743                neighbors_by_distance = []
744                list_distances = []
745                index_here_list = []
746                coords = []
747                for neigh_new in sorted(neighbors_by_distance_start, key=lambda x: x[1]):
748                    site_here = neigh_new[0].to_unit_cell()
749                    index_here = neigh_new[2]
750                    index_here_list.append(index_here)
751                    cell_here = neigh_new[3]
752                    newcoords = [
753                        site_here.frac_coords[0] + float(cell_here[0]),
754                        site_here.frac_coords[1] + float(cell_here[1]),
755                        site_here.frac_coords[2] + float(cell_here[2]),
756                    ]
757                    coords.append(site_here.lattice.get_cartesian_coords(newcoords))
758
759                    # new_site = PeriodicSite(species=site_here.species_string,
760                    #                         coords=site_here.lattice.get_cartesian_coords(newcoords),
761                    #                         lattice=site_here.lattice, to_unit_cell=False, coords_are_cartesian=True)
762                    neighbors_by_distance.append(neigh_new[0])
763                    list_distances.append(neigh_new[1])
764                _list_neighsite = []
765                _list_neighisite = []
766                copied_neighbors_from_ICOHPs = copy.copy(neighbors_from_ICOHPs)
767                copied_distances_from_ICOHPs = copy.copy(lengths_from_ICOHPs)
768                _neigh_coords = []
769                _neigh_frac_coords = []
770
771                for ineigh, neigh in enumerate(neighbors_by_distance):
772                    index_here2 = index_here_list[ineigh]
773
774                    for idist, dist in enumerate(copied_distances_from_ICOHPs):
775                        if (
776                            np.isclose(dist, list_distances[ineigh], rtol=1e-4)
777                            and copied_neighbors_from_ICOHPs[idist] == index_here2
778                        ):
779                            _list_neighsite.append(neigh)
780                            _list_neighisite.append(index_here2)
781                            _neigh_coords.append(coords[ineigh])
782                            _neigh_frac_coords.append(neigh.frac_coords)
783                            del copied_distances_from_ICOHPs[idist]
784                            del copied_neighbors_from_ICOHPs[idist]
785                            break
786
787                list_neighisite.append(_list_neighisite)
788                list_neighsite.append(_list_neighsite)
789                list_lengths.append(lengths_from_ICOHPs)
790                list_keys.append(keys_from_ICOHPs)
791                list_coords.append(_neigh_coords)
792                list_icohps.append(selected_ICOHPs)
793
794            else:
795                list_neighsite.append([])
796                list_neighisite.append([])
797                list_icohps.append([])
798                list_lengths.append([])
799                list_keys.append([])
800                list_coords.append([])
801        return list_icohps, list_keys, list_lengths, list_neighisite, list_neighsite, list_coords
802
803    def _find_relevant_atoms_additional_condition(self, isite, icohps, additional_condition):
804        """
805        will find all relevant atoms that fulfill the additional_conditions
806        Args:
807            isite: number of site in structure (starts with 0)
808            icohps: icohps
809            additional_condition (int): additonal condition
810
811        Returns:
812
813        """
814        neighbors_from_ICOHPs = []
815        lengths_from_ICOHPs = []
816        icohps_from_ICOHPs = []
817        keys_from_ICOHPs = []
818
819        for key, icohp in icohps.items():
820            atomnr1 = self._get_atomnumber(icohp._atom1)
821            atomnr2 = self._get_atomnumber(icohp._atom2)
822
823            # test additional conditions
824            if additional_condition in (1, 3, 5, 6):
825                val1 = self.valences[atomnr1]
826                val2 = self.valences[atomnr2]
827
828            if additional_condition == 0:
829                # NO_ADDITIONAL_CONDITION
830                if atomnr1 == isite:
831                    neighbors_from_ICOHPs.append(atomnr2)
832                    lengths_from_ICOHPs.append(icohp._length)
833                    icohps_from_ICOHPs.append(icohp.summed_icohp)
834                    keys_from_ICOHPs.append(key)
835                elif atomnr2 == isite:
836                    neighbors_from_ICOHPs.append(atomnr1)
837                    lengths_from_ICOHPs.append(icohp._length)
838                    icohps_from_ICOHPs.append(icohp.summed_icohp)
839                    keys_from_ICOHPs.append(key)
840
841            elif additional_condition == 1:
842                # ONLY_ANION_CATION_BONDS
843                if (val1 < 0.0 < val2) or (val2 < 0.0 < val1):
844                    if atomnr1 == isite:
845                        neighbors_from_ICOHPs.append(atomnr2)
846                        lengths_from_ICOHPs.append(icohp._length)
847                        icohps_from_ICOHPs.append(icohp.summed_icohp)
848                        keys_from_ICOHPs.append(key)
849
850                    elif atomnr2 == isite:
851                        neighbors_from_ICOHPs.append(atomnr1)
852                        lengths_from_ICOHPs.append(icohp._length)
853                        icohps_from_ICOHPs.append(icohp.summed_icohp)
854                        keys_from_ICOHPs.append(key)
855
856            elif additional_condition == 2:
857                # NO_ELEMENT_TO_SAME_ELEMENT_BONDS
858                if icohp._atom1.rstrip("0123456789") != icohp._atom2.rstrip("0123456789"):
859                    if atomnr1 == isite:
860                        neighbors_from_ICOHPs.append(atomnr2)
861                        lengths_from_ICOHPs.append(icohp._length)
862                        icohps_from_ICOHPs.append(icohp.summed_icohp)
863                        keys_from_ICOHPs.append(key)
864
865                    elif atomnr2 == isite:
866                        neighbors_from_ICOHPs.append(atomnr1)
867                        lengths_from_ICOHPs.append(icohp._length)
868                        icohps_from_ICOHPs.append(icohp.summed_icohp)
869                        keys_from_ICOHPs.append(key)
870
871            elif additional_condition == 3:
872                # ONLY_ANION_CATION_BONDS_AND_NO_ELEMENT_TO_SAME_ELEMENT_BONDS = 3
873                if (val1 < 0.0 < val2) or (val2 < 0.0 < val1):
874                    if icohp._atom1.rstrip("0123456789") != icohp._atom2.rstrip("0123456789"):
875                        if atomnr1 == isite:
876                            neighbors_from_ICOHPs.append(atomnr2)
877                            lengths_from_ICOHPs.append(icohp._length)
878                            icohps_from_ICOHPs.append(icohp.summed_icohp)
879                            keys_from_ICOHPs.append(key)
880
881                        elif atomnr2 == isite:
882                            neighbors_from_ICOHPs.append(atomnr1)
883                            lengths_from_ICOHPs.append(icohp._length)
884                            icohps_from_ICOHPs.append(icohp.summed_icohp)
885                            keys_from_ICOHPs.append(key)
886
887            elif additional_condition == 4:
888                # ONLY_ELEMENT_TO_OXYGEN_BONDS = 4
889                if icohp._atom1.rstrip("0123456789") == "O" or icohp._atom2.rstrip("0123456789") == "O":
890
891                    if atomnr1 == isite:
892                        neighbors_from_ICOHPs.append(atomnr2)
893                        lengths_from_ICOHPs.append(icohp._length)
894                        icohps_from_ICOHPs.append(icohp.summed_icohp)
895                        keys_from_ICOHPs.append(key)
896
897                    elif atomnr2 == isite:
898                        neighbors_from_ICOHPs.append(atomnr1)
899                        lengths_from_ICOHPs.append(icohp._length)
900                        icohps_from_ICOHPs.append(icohp.summed_icohp)
901                        keys_from_ICOHPs.append(key)
902
903            elif additional_condition == 5:
904                # DO_NOT_CONSIDER_ANION_CATION_BONDS=5
905                if (val1 > 0.0 and val2 > 0.0) or (val1 < 0.0 and val2 < 0.0):
906                    if atomnr1 == isite:
907                        neighbors_from_ICOHPs.append(atomnr2)
908                        lengths_from_ICOHPs.append(icohp._length)
909                        icohps_from_ICOHPs.append(icohp.summed_icohp)
910                        keys_from_ICOHPs.append(key)
911
912                    elif atomnr2 == isite:
913                        neighbors_from_ICOHPs.append(atomnr1)
914                        lengths_from_ICOHPs.append(icohp._length)
915                        icohps_from_ICOHPs.append(icohp.summed_icohp)
916                        keys_from_ICOHPs.append(key)
917
918            elif additional_condition == 6:
919                # ONLY_CATION_CATION_BONDS=6
920                if val1 > 0.0 and val2 > 0.0:
921                    if atomnr1 == isite:
922                        neighbors_from_ICOHPs.append(atomnr2)
923                        lengths_from_ICOHPs.append(icohp._length)
924                        icohps_from_ICOHPs.append(icohp.summed_icohp)
925                        keys_from_ICOHPs.append(key)
926
927                    elif atomnr2 == isite:
928                        neighbors_from_ICOHPs.append(atomnr1)
929                        lengths_from_ICOHPs.append(icohp._length)
930                        icohps_from_ICOHPs.append(icohp.summed_icohp)
931                        keys_from_ICOHPs.append(key)
932
933        return keys_from_ICOHPs, lengths_from_ICOHPs, neighbors_from_ICOHPs, icohps_from_ICOHPs
934
935    @staticmethod
936    def _get_icohps(icohpcollection, isite, lowerlimit, upperlimit, only_bonds_to):
937        """
938        will return icohp dict for certain site
939        Args:
940            icohpcollection: Icohpcollection object
941            isite (int): number of a site
942            lowerlimit (float): lower limit that tells you which ICOHPs are considered
943            upperlimit (float): upper limit that tells you which ICOHPs are considerd
944            only_bonds_to (list): list of str, e.g. ["O"] that will ensure that only bonds to "O" will be considered
945
946        Returns:
947
948        """
949        icohps = icohpcollection.get_icohp_dict_of_site(
950            site=isite,
951            maxbondlength=6.0,
952            minsummedicohp=lowerlimit,
953            maxsummedicohp=upperlimit,
954            only_bonds_to=only_bonds_to,
955        )
956        return icohps
957
958    @staticmethod
959    def _get_atomnumber(atomstring):
960        """
961        will return the number of the atom within the initial POSCAR (e.g., will return 0 for "Na1")
962        Args:
963            atomstring: string such as "Na1"
964
965        Returns: integer indicating the position in the POSCAR
966
967        """
968        return int(LobsterNeighbors._split_string(atomstring)[1]) - 1
969
970    @staticmethod
971    def _split_string(s):
972        """
973        will split strings such as "Na1" in "Na" and "1" and return "1"
974        Args:
975            s (str): string
976
977        Returns:
978
979        """
980        head = s.rstrip("0123456789")
981        tail = s[len(head) :]
982        return head, tail
983
984    @staticmethod
985    def _determine_unit_cell(site):
986        """
987        based on the site it will determine the unit cell, in which this site is based
988        Args:
989            site: site object
990
991        Returns:
992
993        """
994        unitcell = []
995        for coord in site.frac_coords:
996            value = math.floor(round(coord, 4))
997            unitcell.append(value)
998
999        return unitcell
1000
1001    def _get_limit_from_extremum(
1002        self, icohpcollection, percentage=0.15, adapt_extremum_to_add_cond=False, additional_condition=0
1003    ):
1004        # TODO: adapt this to give the extremum for the correct type of bond
1005        # TODO add tests
1006        """
1007        will return limits for the evaluation of the icohp values from an icohpcollection
1008        will return -100000, min(max_icohp*0.15,-0.1)
1009        Args:
1010            icohpcollection: icohpcollection object
1011            percentage: will determine which ICOHPs will be considered (only 0.15 from the maximum value)
1012            adapt_extremum_to_add_cond: should the extrumum be adapted to the additional condition
1013            additional_condition: additional condition to determine which bonds are relevant
1014        Returns: [-100000, min(max_icohp*0.15,-0.1)]
1015
1016        """
1017        # TODO: make it work for COOPs
1018        if not adapt_extremum_to_add_cond:
1019            extremum_based = icohpcollection.extremum_icohpvalue(summed_spin_channels=True) * percentage
1020        else:
1021            if additional_condition == 0:
1022                extremum_based = icohpcollection.extremum_icohpvalue(summed_spin_channels=True) * percentage
1023            elif additional_condition == 1:
1024                # only cation anion bonds
1025                list_icohps = []
1026                for key, value in icohpcollection._icohplist.items():
1027                    atomnr1 = LobsterNeighbors._get_atomnumber(value._atom1)
1028                    atomnr2 = LobsterNeighbors._get_atomnumber(value._atom2)
1029
1030                    val1 = self.valences[atomnr1]
1031                    val2 = self.valences[atomnr2]
1032                    if (val1 < 0.0 < val2) or (val2 < 0.0 < val1):
1033                        list_icohps.append(value.summed_icohp)
1034
1035                extremum_based = min(list_icohps) * percentage
1036
1037            elif additional_condition == 2:
1038                # NO_ELEMENT_TO_SAME_ELEMENT_BONDS
1039                list_icohps = []
1040                for key, value in icohpcollection._icohplist.items():
1041                    if value._atom1.rstrip("0123456789") != value._atom2.rstrip("0123456789"):
1042                        list_icohps.append(value.summed_icohp)
1043                extremum_based = min(list_icohps) * percentage
1044
1045            elif additional_condition == 3:
1046                # ONLY_ANION_CATION_BONDS_AND_NO_ELEMENT_TO_SAME_ELEMENT_BONDS = 3
1047                list_icohps = []
1048                for key, value in icohpcollection._icohplist.items():
1049                    atomnr1 = LobsterNeighbors._get_atomnumber(value._atom1)
1050                    atomnr2 = LobsterNeighbors._get_atomnumber(value._atom2)
1051                    val1 = self.valences[atomnr1]
1052                    val2 = self.valences[atomnr2]
1053
1054                    if (val1 < 0.0 < val2) or (val2 < 0.0 < val1):
1055                        if value._atom1.rstrip("0123456789") != value._atom2.rstrip("0123456789"):
1056                            list_icohps.append(value.summed_icohp)
1057                extremum_based = min(list_icohps) * percentage
1058            elif additional_condition == 4:
1059                list_icohps = []
1060                for key, value in icohpcollection._icohplist.items():
1061                    if value._atom1.rstrip("0123456789") == "O" or value._atom2.rstrip("0123456789") == "O":
1062                        list_icohps.append(value.summed_icohp)
1063                extremum_based = min(list_icohps) * percentage
1064            elif additional_condition == 5:
1065                # DO_NOT_CONSIDER_ANION_CATION_BONDS=5
1066                list_icohps = []
1067                for key, value in icohpcollection._icohplist.items():
1068                    atomnr1 = LobsterNeighbors._get_atomnumber(value._atom1)
1069                    atomnr2 = LobsterNeighbors._get_atomnumber(value._atom2)
1070                    val1 = self.valences[atomnr1]
1071                    val2 = self.valences[atomnr2]
1072
1073                    if (val1 > 0.0 and val2 > 0.0) or (val1 < 0.0 and val2 < 0.0):
1074                        list_icohps.append(value.summed_icohp)
1075                extremum_based = min(list_icohps) * percentage
1076
1077            elif additional_condition == 6:
1078                # ONLY_CATION_CATION_BONDS=6
1079                list_icohps = []
1080                for key, value in icohpcollection._icohplist.items():
1081                    atomnr1 = LobsterNeighbors._get_atomnumber(value._atom1)
1082                    atomnr2 = LobsterNeighbors._get_atomnumber(value._atom2)
1083                    val1 = self.valences[atomnr1]
1084                    val2 = self.valences[atomnr2]
1085
1086                    if val1 > 0.0 and val2 > 0.0:
1087                        list_icohps.append(value.summed_icohp)
1088                extremum_based = min(list_icohps) * percentage
1089
1090        # if not self.are_coops:
1091        max_here = min(extremum_based, -0.1)
1092        return -100000, max_here
1093        # else:
1094        #    return extremum_based, 100000
1095
1096
1097class LobsterLightStructureEnvironments(LightStructureEnvironments):
1098    """
1099    Class to store LightStructureEnvironments based on Lobster outputs
1100    """
1101
1102    @classmethod
1103    def from_Lobster(
1104        cls, list_ce_symbol, list_csm, list_permutation, list_neighsite, list_neighisite, structure, valences=None
1105    ):
1106        """
1107        will set up a LightStructureEnvironments from Lobster
1108        Args:
1109            structure: Structure object
1110            list_ce_symbol: list of symbols for coordination environments
1111            list_csm: list of continous symmetry measures
1112            list_permutation: list of permutations
1113            list_neighsite: list of neighboring sites
1114            list_neighisite: list of neighboring isites (number of a site)
1115            valences: list of valences
1116
1117        Returns: LobsterLightStructureEnvironments
1118
1119        """
1120        strategy = None
1121        valences = valences
1122        valences_origin = "user-defined"
1123        structure = structure
1124
1125        coordination_environments = []
1126
1127        all_nbs_sites = []
1128        all_nbs_sites_indices = []
1129        neighbors_sets = []
1130        counter = 0
1131        for isite, site in enumerate(structure):
1132
1133            # all_nbs_sites_here=[]
1134            all_nbs_sites_indices_here = []
1135            # Coordination environment
1136            if list_ce_symbol is not None:
1137                ce_dict = {
1138                    "ce_symbol": list_ce_symbol[isite],
1139                    "ce_fraction": 1.0,
1140                    "csm": list_csm[isite],
1141                    "permutation": list_permutation[isite],
1142                }
1143            else:
1144                ce_dict = None
1145
1146            if list_neighisite[isite] is not None:
1147                for ineighsite, neighsite in enumerate(list_neighsite[isite]):
1148                    diff = neighsite.frac_coords - structure[list_neighisite[isite][ineighsite]].frac_coords
1149                    rounddiff = np.round(diff)
1150                    if not np.allclose(diff, rounddiff):
1151                        raise ValueError(
1152                            "Weird, differences between one site in a periodic image cell is not " "integer ..."
1153                        )
1154                    nb_image_cell = np.array(rounddiff, int)
1155
1156                    all_nbs_sites_indices_here.append(counter)
1157
1158                    all_nbs_sites.append(
1159                        {"site": neighsite, "index": list_neighisite[isite][ineighsite], "image_cell": nb_image_cell}
1160                    )
1161                    counter = counter + 1
1162
1163                all_nbs_sites_indices.append(all_nbs_sites_indices_here)
1164            else:
1165                all_nbs_sites.append({"site": None, "index": None, "image_cell": None})  # all_nbs_sites_here)
1166                all_nbs_sites_indices.append([])  # all_nbs_sites_indices_here)
1167
1168            if list_neighisite[isite] is not None:
1169                nb_set = cls.NeighborsSet(
1170                    structure=structure,
1171                    isite=isite,
1172                    all_nbs_sites=all_nbs_sites,
1173                    all_nbs_sites_indices=all_nbs_sites_indices[isite],
1174                )
1175
1176            else:
1177                nb_set = cls.NeighborsSet(structure=structure, isite=isite, all_nbs_sites=[], all_nbs_sites_indices=[])
1178
1179            coordination_environments.append([ce_dict])
1180            neighbors_sets.append([nb_set])
1181
1182        return cls(
1183            strategy=strategy,
1184            coordination_environments=coordination_environments,
1185            all_nbs_sites=all_nbs_sites,
1186            neighbors_sets=neighbors_sets,
1187            structure=structure,
1188            valences=valences,
1189            valences_origin=valences_origin,
1190        )
1191
1192    @property
1193    def uniquely_determines_coordination_environments(self):
1194        """
1195        True if the coordination environments are uniquely determined.
1196        """
1197        return True
1198
1199    def as_dict(self):
1200        """
1201        Bson-serializable dict representation of the LightStructureEnvironments object.
1202        :return: Bson-serializable dict representation of the LightStructureEnvironments object.
1203        """
1204        return {
1205            "@module": self.__class__.__module__,
1206            "@class": self.__class__.__name__,
1207            "strategy": self.strategy,
1208            "structure": self.structure.as_dict(),
1209            "coordination_environments": self.coordination_environments,
1210            "all_nbs_sites": [
1211                {
1212                    "site": nb_site["site"].as_dict(),
1213                    "index": nb_site["index"],
1214                    "image_cell": [int(ii) for ii in nb_site["image_cell"]],
1215                }
1216                for nb_site in self._all_nbs_sites
1217            ],
1218            "neighbors_sets": [
1219                [nb_set.as_dict() for nb_set in site_nb_sets] if site_nb_sets is not None else None
1220                for site_nb_sets in self.neighbors_sets
1221            ],
1222            "valences": self.valences,
1223        }
1224