1"""This module generates and does computation with molecular surfaces.
2"""
3
4from __future__ import division
5from numbers import Number
6from distutils.version import LooseVersion
7import warnings
8
9import oddt.toolkits
10import numpy as np
11from scipy.spatial import cKDTree
12
13try:
14    from skimage.morphology import ball, binary_closing
15    from skimage import __version__ as skimage_version
16    if LooseVersion(skimage_version) >= LooseVersion('0.13'):
17        from skimage.measure import marching_cubes_lewiner as marching_cubes
18    else:
19        from skimage.measure import marching_cubes
20except ImportError as e:
21    warnings.warn('scikit-image could not be imported and is required for'
22                  'generating molecular surfaces.')
23    skimage = None
24
25
26def generate_surface_marching_cubes(molecule, remove_hoh=False, scaling=1.,
27                                    probe_radius=1.4):
28    """Generates a molecular surface mesh using the marching_cubes
29    method from scikit-image. Ignores hydrogens present in the molecule.
30
31    Parameters
32    ----------
33    molecule : oddt.toolkit.Molecule object
34        Molecule for which the surface will be generated
35
36    remove_hoh : bool (default = False)
37        If True, remove waters from the molecule before generating the surface.
38        Requires molecule.protein to be set to True.
39
40    scaling : float (default = 1.0)
41        Expands the grid in which computation is done by a factor of scaling.
42        Results in a more accurate representation of the surface, but increases
43        computation time.
44
45    probe_radius : float (default = 1.4)
46        Radius of a ball used to patch up holes inside the molecule
47        resulting from some molecular distances being larger
48        (usually in protein). Basically reduces the surface to one
49        accesible by other molecules of radius smaller than probe_radius.
50
51    Returns
52    -------
53    verts : numpy array
54        Spatial coordinates for mesh vertices.
55
56    faces : numpy array
57        Faces are defined by referencing vertices from verts.
58    """
59    # Input validation
60    if not isinstance(molecule, oddt.toolkit.Molecule):
61        raise TypeError('molecule needs to be of type oddt.toolkit.Molecule')
62    if not (isinstance(probe_radius, Number) and probe_radius >= 0):
63        raise ValueError('probe_radius needs to be a positive number')
64
65    # Removing waters and hydrogens
66    atom_dict = molecule.atom_dict
67    atom_dict = atom_dict[atom_dict['atomicnum'] != 1]
68    if remove_hoh:
69        if molecule.protein is not True:
70            raise ValueError('Residue names are needed for water removal, '
71                             'molecule.protein property must be set to True')
72        no_hoh = atom_dict['resname'] != 'HOH'
73        atom_dict = atom_dict[no_hoh]
74
75    # Take a molecule's coordinates and atom radii and scale if necessary
76    coords = atom_dict['coords'] * scaling
77    radii = atom_dict['radius'] * scaling
78
79    # More input validation
80    if radii.min() < 1:
81        raise ValueError('Scaling times the radius of the smallest atom must '
82                         'be larger than 1')
83    # Create a ball for each atom in the molecule
84    ball_dict = {radius: ball(radius, dtype=bool) for radius in set(radii)}
85    ball_radii = np.array([ball_dict[radius].shape[0] for radius in radii])
86
87    # Transform the coordinates because the grid starts at (0, 0 ,0)
88    min_coords = np.min(coords, axis=0)
89    max_rad = np.max(ball_radii, axis=0)
90    adjusted = np.round(coords - min_coords + max_rad * 5).astype(np.int64)
91    offset = adjusted[0] - coords[0]
92
93    # Calculate boundries in the grid for each ball.
94    ball_coord_min = (adjusted.T - np.floor(ball_radii / 2).astype(np.int64)).T
95    ball_coord_max = (ball_coord_min.T + ball_radii).T
96
97    # Create the grid
98    grid = np.zeros(shape=ball_coord_max.max(axis=0) + int(8 * scaling), dtype=bool)
99
100    # Place balls in grid
101    for radius, coord_min, coord_max in zip(radii, ball_coord_min, ball_coord_max):
102        grid[coord_min[0]:coord_max[0],
103             coord_min[1]:coord_max[1],
104             coord_min[2]:coord_max[2]] += ball_dict[radius]
105    spacing = (1 / scaling,) * 3
106
107    # Hole-filling with morphological closing
108    grid = binary_closing(grid, ball(probe_radius * 2 * scaling))
109
110    # Marching cubes
111    verts, faces = marching_cubes(grid, level=0, spacing=spacing)[:2]
112
113    # Verts already scaled by the marching cubes function (spacing parameter)
114    # Only need to scale the offset
115    # Results in skimage version lower than 0.11 are offset by 1 in each direction
116    if LooseVersion(skimage_version) < LooseVersion('0.11'):
117        verts += 1 / scaling
118    return verts - offset / scaling, faces
119
120
121def find_surface_residues(molecule, max_dist=None, scaling=1.):
122    """Finds residues close to the molecular surface using
123    generate_surface_marching_cubes. Ignores hydrogens and
124    waters present in the molecule.
125
126    Parameters
127    ----------
128    molecule : oddt.toolkit.Molecule
129        Molecule to find surface residues in.
130
131    max_dist : array_like, numeric or None (default = None)
132        Maximum distance from the surface where residues would
133        still be considered close. If None, compares distances
134        to radii of respective atoms.
135
136    scaling : float (default = 1.0)
137        Expands the grid in which computation is done by
138        generate_surface_marching_cubes by a factor of scaling.
139        Results in a more accurate representation of the surface,
140        and therefore more accurate computation of distances
141        but increases computation time.
142
143    Returns
144    -------
145    atom_dict : numpy array
146        An atom_dict containing only the surface residues
147        from the original molecule.
148    """
149    # Input validation
150    if not isinstance(molecule, oddt.toolkit.Molecule):
151        raise TypeError('molecule needs to be of type oddt.toolkit.Molecule')
152
153    # Copy the atom_dict, remove waters
154    atom_dict = molecule.atom_dict
155    mask = (atom_dict['resname'] != 'HOH') & (atom_dict['atomicnum'] != 1)
156    atom_dict = atom_dict[mask]
157    coords = atom_dict['coords']
158    if max_dist is None:
159        max_dist = atom_dict['radius']
160
161    # More input validation
162    elif isinstance(max_dist, Number):
163        max_dist = np.repeat(max_dist, coords.shape[0])
164    else:
165        max_dist = np.array(max_dist)
166    if not np.issubdtype(max_dist.dtype, np.number):
167        raise ValueError('max_dist has to be a number or an '
168                         'array_like object containing numbers')
169    if coords.shape[0] != len(max_dist):
170        raise ValueError('max_dist doesn\'t match coords\' length')
171
172    # Marching cubes
173    verts, _ = generate_surface_marching_cubes(molecule, remove_hoh=True,
174                                               scaling=scaling, probe_radius=1.4)
175
176    # Calculate distances between atoms and the surface
177    tree_verts = cKDTree(verts)
178    mask = [bool(tree_verts.query_ball_point(point, radius))
179            for point, radius in zip(coords, max_dist)]
180    return atom_dict[np.array(mask)]
181