1# Copyright (C) 2002, Thomas Hamelryck (thamelry@binf.ku.dk)
2#
3# This file is part of the Biopython distribution and governed by your
4# choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
5# Please see the LICENSE file that should have been included as part of this
6# package.
7"""Base class for Residue, Chain, Model and Structure classes.
8
9It is a simple container class, with list and dictionary like properties.
10"""
11
12from collections import deque
13from copy import copy
14
15import numpy as np
16
17from Bio.PDB.PDBExceptions import PDBConstructionException
18
19
20class Entity:
21    """Basic container object for PDB heirachy.
22
23    Structure, Model, Chain and Residue are subclasses of Entity.
24    It deals with storage and lookup.
25    """
26
27    def __init__(self, id):
28        """Initialize the class."""
29        self._id = id
30        self.full_id = None
31        self.parent = None
32        self.child_list = []
33        self.child_dict = {}
34        # Dictionary that keeps additional properties
35        self.xtra = {}
36
37    # Special methods
38
39    def __len__(self):
40        """Return the number of children."""
41        return len(self.child_list)
42
43    def __getitem__(self, id):
44        """Return the child with given id."""
45        return self.child_dict[id]
46
47    def __delitem__(self, id):
48        """Remove a child."""
49        return self.detach_child(id)
50
51    def __contains__(self, id):
52        """Check if there is a child element with the given id."""
53        return id in self.child_dict
54
55    def __iter__(self):
56        """Iterate over children."""
57        yield from self.child_list
58
59    # Generic id-based comparison methods considers all parents as well as children
60    # Works for all Entities - Atoms have comparable custom operators
61    def __eq__(self, other):
62        """Test for equality. This compares full_id including the IDs of all parents."""
63        if isinstance(other, type(self)):
64            if self.parent is None:
65                return self.id == other.id
66            else:
67                return self.full_id[1:] == other.full_id[1:]
68        else:
69            return NotImplemented
70
71    def __ne__(self, other):
72        """Test for inequality."""
73        if isinstance(other, type(self)):
74            if self.parent is None:
75                return self.id != other.id
76            else:
77                return self.full_id[1:] != other.full_id[1:]
78        else:
79            return NotImplemented
80
81    def __gt__(self, other):
82        """Test greater than."""
83        if isinstance(other, type(self)):
84            if self.parent is None:
85                return self.id > other.id
86            else:
87                return self.full_id[1:] > other.full_id[1:]
88        else:
89            return NotImplemented
90
91    def __ge__(self, other):
92        """Test greater or equal."""
93        if isinstance(other, type(self)):
94            if self.parent is None:
95                return self.id >= other.id
96            else:
97                return self.full_id[1:] >= other.full_id[1:]
98        else:
99            return NotImplemented
100
101    def __lt__(self, other):
102        """Test less than."""
103        if isinstance(other, type(self)):
104            if self.parent is None:
105                return self.id < other.id
106            else:
107                return self.full_id[1:] < other.full_id[1:]
108        else:
109            return NotImplemented
110
111    def __le__(self, other):
112        """Test less or equal."""
113        if isinstance(other, type(self)):
114            if self.parent is None:
115                return self.id <= other.id
116            else:
117                return self.full_id[1:] <= other.full_id[1:]
118        else:
119            return NotImplemented
120
121    def __hash__(self):
122        """Hash method to allow uniqueness (set)."""
123        return hash(self.full_id)
124
125    # Private methods
126
127    def _reset_full_id(self):
128        """Reset the full_id (PRIVATE).
129
130        Resets the full_id of this entity and
131        recursively of all its children based on their ID.
132        """
133        for child in self:
134            try:
135                child._reset_full_id()
136            except AttributeError:
137                pass  # Atoms do not cache their full ids.
138        self.full_id = self._generate_full_id()
139
140    def _generate_full_id(self):
141        """Generate full_id (PRIVATE).
142
143        Generate the full_id of the Entity based on its
144        Id and the IDs of the parents.
145        """
146        entity_id = self.get_id()
147        parts = [entity_id]
148        parent = self.get_parent()
149        while parent is not None:
150            entity_id = parent.get_id()
151            parts.append(entity_id)
152            parent = parent.get_parent()
153        parts.reverse()
154        return tuple(parts)
155
156    # Public methods
157
158    @property
159    def id(self):
160        """Return identifier."""
161        return self._id
162
163    @id.setter
164    def id(self, value):
165        """Change the id of this entity.
166
167        This will update the child_dict of this entity's parent
168        and invalidate all cached full ids involving this entity.
169
170        @raises: ValueError
171        """
172        if value == self._id:
173            return
174        if self.parent:
175            if value in self.parent.child_dict:
176                raise ValueError(
177                    f"Cannot change id from `{self._id}` to `{value}`."
178                    f" The id `{value}` is already used for a sibling of this entity."
179                )
180            del self.parent.child_dict[self._id]
181            self.parent.child_dict[value] = self
182
183        self._id = value
184        self._reset_full_id()
185
186    def get_level(self):
187        """Return level in hierarchy.
188
189        A - atom
190        R - residue
191        C - chain
192        M - model
193        S - structure
194        """
195        return self.level
196
197    def set_parent(self, entity):
198        """Set the parent Entity object."""
199        self.parent = entity
200        self._reset_full_id()
201
202    def detach_parent(self):
203        """Detach the parent."""
204        self.parent = None
205
206    def detach_child(self, id):
207        """Remove a child."""
208        child = self.child_dict[id]
209        child.detach_parent()
210        del self.child_dict[id]
211        self.child_list.remove(child)
212
213    def add(self, entity):
214        """Add a child to the Entity."""
215        entity_id = entity.get_id()
216        if self.has_id(entity_id):
217            raise PDBConstructionException(f"{entity_id} defined twice")
218        entity.set_parent(self)
219        self.child_list.append(entity)
220        self.child_dict[entity_id] = entity
221
222    def insert(self, pos, entity):
223        """Add a child to the Entity at a specified position."""
224        entity_id = entity.get_id()
225        if self.has_id(entity_id):
226            raise PDBConstructionException(f"{entity_id} defined twice")
227        entity.set_parent(self)
228        self.child_list[pos:pos] = [entity]
229        self.child_dict[entity_id] = entity
230
231    def get_iterator(self):
232        """Return iterator over children."""
233        yield from self.child_list
234
235    def get_list(self):
236        """Return a copy of the list of children."""
237        return copy(self.child_list)
238
239    def has_id(self, id):
240        """Check if a child with given id exists."""
241        return id in self.child_dict
242
243    def get_parent(self):
244        """Return the parent Entity object."""
245        return self.parent
246
247    def get_id(self):
248        """Return the id."""
249        return self.id
250
251    def get_full_id(self):
252        """Return the full id.
253
254        The full id is a tuple containing all id's starting from
255        the top object (Structure) down to the current object. A full id for
256        a Residue object e.g. is something like:
257
258        ("1abc", 0, "A", (" ", 10, "A"))
259
260        This corresponds to:
261
262        Structure with id "1abc"
263        Model with id 0
264        Chain with id "A"
265        Residue with id (" ", 10, "A")
266
267        The Residue id indicates that the residue is not a hetero-residue
268        (or a water) because it has a blank hetero field, that its sequence
269        identifier is 10 and its insertion code "A".
270        """
271        if self.full_id is None:
272            self.full_id = self._generate_full_id()
273        return self.full_id
274
275    def transform(self, rot, tran):
276        """Apply rotation and translation to the atomic coordinates.
277
278        :param rot: A right multiplying rotation matrix
279        :type rot: 3x3 Numeric array
280
281        :param tran: the translation vector
282        :type tran: size 3 Numeric array
283
284        Examples
285        --------
286        This is an incomplete but illustrative example::
287
288            from numpy import pi, array
289            from Bio.PDB.vectors import Vector, rotmat
290            rotation = rotmat(pi, Vector(1, 0, 0))
291            translation = array((0, 0, 1), 'f')
292            entity.transform(rotation, translation)
293
294        """
295        for o in self.get_list():
296            o.transform(rot, tran)
297
298    def center_of_mass(self, geometric=False):
299        """Return the center of mass of the Entity as a numpy array.
300
301        If geometric is True, returns the center of geometry instead.
302        """
303        # Recursively iterate through children until we get all atom coordinates
304
305        if not len(self):
306            raise ValueError(f"{self} does not have children")
307
308        maybe_disordered = {"R", "C"}  # to know when to use get_unpacked_list
309        only_atom_level = {"A"}
310
311        entities = deque([self])  # start with [self] to avoid auto-unpacking
312        while True:
313            e = entities.popleft()
314            if e.level in maybe_disordered:
315                entities += e.get_unpacked_list()
316            else:
317                entities += e.child_list
318
319            elevels = {e.level for e in entities}
320            if elevels == only_atom_level:
321                break  # nothing else to unpack
322
323        coords = np.asarray([a.coord for a in entities], dtype=np.float32)
324        if geometric:
325            masses = None
326        else:
327            masses = np.asarray([a.mass for a in entities], dtype=np.float32)
328
329        return np.average(coords, axis=0, weights=masses)
330
331    def copy(self):
332        """Copy entity recursively."""
333        shallow = copy(self)
334
335        shallow.child_list = []
336        shallow.child_dict = {}
337        shallow.xtra = copy(self.xtra)
338
339        shallow.detach_parent()
340
341        for child in self.child_list:
342            shallow.add(child.copy())
343        return shallow
344
345
346class DisorderedEntityWrapper:
347    """Wrapper class to group equivalent Entities.
348
349    This class is a simple wrapper class that groups a number of equivalent
350    Entities and forwards all method calls to one of them (the currently selected
351    object). DisorderedResidue and DisorderedAtom are subclasses of this class.
352
353    E.g.: A DisorderedAtom object contains a number of Atom objects,
354    where each Atom object represents a specific position of a disordered
355    atom in the structure.
356    """
357
358    def __init__(self, id):
359        """Initialize the class."""
360        self.id = id
361        self.child_dict = {}
362        self.selected_child = None
363        self.parent = None
364
365    # Special methods
366
367    def __getattr__(self, method):
368        """Forward the method call to the selected child."""
369        if method == "__setstate__":
370            # Avoid issues with recursion when attempting deepcopy
371            raise AttributeError
372        if not hasattr(self, "selected_child"):
373            # Avoid problems with pickling
374            # Unpickling goes into infinite loop!
375            raise AttributeError
376        return getattr(self.selected_child, method)
377
378    def __getitem__(self, id):
379        """Return the child with the given id."""
380        return self.selected_child[id]
381
382    # XXX Why doesn't this forward to selected_child?
383    # (NB: setitem was here before getitem, iter, len, sub)
384    def __setitem__(self, id, child):
385        """Add a child, associated with a certain id."""
386        self.child_dict[id] = child
387
388    def __contains__(self, id):
389        """Check if the child has the given id."""
390        return id in self.selected_child
391
392    def __iter__(self):
393        """Return the number of children."""
394        return iter(self.selected_child)
395
396    def __len__(self):
397        """Return the number of children."""
398        return len(self.selected_child)
399
400    def __sub__(self, other):
401        """Subtraction with another object."""
402        return self.selected_child - other
403
404    # Sorting
405    # Directly compare the selected child
406    def __gt__(self, other):
407        """Return if child is greater than other."""
408        return self.selected_child > other
409
410    def __ge__(self, other):
411        """Return if child is greater or equal than other."""
412        return self.selected_child >= other
413
414    def __lt__(self, other):
415        """Return if child is less than other."""
416        return self.selected_child < other
417
418    def __le__(self, other):
419        """Return if child is less or equal than other."""
420        return self.selected_child <= other
421
422    # Public methods
423    def copy(self):
424        """Copy disorderd entity recursively."""
425        shallow = copy(self)
426        shallow.child_dict = {}
427        shallow.detach_parent()
428
429        for child in self.disordered_get_list():
430            shallow.disordered_add(child.copy())
431
432        return shallow
433
434    def get_id(self):
435        """Return the id."""
436        return self.id
437
438    def disordered_has_id(self, id):
439        """Check if there is an object present associated with this id."""
440        return id in self.child_dict
441
442    def detach_parent(self):
443        """Detach the parent."""
444        self.parent = None
445        for child in self.disordered_get_list():
446            child.detach_parent()
447
448    def get_parent(self):
449        """Return parent."""
450        return self.parent
451
452    def set_parent(self, parent):
453        """Set the parent for the object and its children."""
454        self.parent = parent
455        for child in self.disordered_get_list():
456            child.set_parent(parent)
457
458    def disordered_select(self, id):
459        """Select the object with given id as the currently active object.
460
461        Uncaught method calls are forwarded to the selected child object.
462        """
463        self.selected_child = self.child_dict[id]
464
465    def disordered_add(self, child):
466        """Add disordered entry.
467
468        This is implemented by DisorderedAtom and DisorderedResidue.
469        """
470        raise NotImplementedError
471
472    def disordered_remove(self, child):
473        """Remove disordered entry.
474
475        This is implemented by DisorderedAtom and DisorderedResidue.
476        """
477        raise NotImplementedError
478
479    def is_disordered(self):
480        """Return 2, indicating that this Entity is a collection of Entities."""
481        return 2
482
483    def disordered_get_id_list(self):
484        """Return a list of id's."""
485        # sort id list alphabetically
486        return sorted(self.child_dict)
487
488    def disordered_get(self, id=None):
489        """Get the child object associated with id.
490
491        If id is None, the currently selected child is returned.
492        """
493        if id is None:
494            return self.selected_child
495        return self.child_dict[id]
496
497    def disordered_get_list(self):
498        """Return list of children."""
499        return list(self.child_dict.values())
500