1#
2#  Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
3#  Copyright (c) 2021, Greg Landrum
4#  All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are
8# met:
9#
10#     * Redistributions of source code must retain the above copyright
11#       notice, this list of conditions and the following disclaimer.
12#     * Redistributions in binary form must reproduce the above
13#       copyright notice, this list of conditions and the following
14#       disclaimer in the documentation and/or other materials provided
15#       with the distribution.
16#     * Neither the name of Novartis Institutes for BioMedical Research Inc.
17#       nor the names of its contributors may be used to endorse or promote
18#       products derived from this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31#
32# Created by Sereina Riniker, Aug 2013
33
34import copy
35import math
36
37from numpy.lib.arraysetops import isin
38try:
39  from matplotlib import cm
40  from matplotlib.colors import LinearSegmentedColormap
41except ImportError:
42  cm = None
43except RuntimeError:
44  cm = None
45
46import numpy
47
48from rdkit import Chem
49from rdkit import DataStructs
50from rdkit import Geometry
51from rdkit.Chem import Draw
52from rdkit.Chem.Draw import rdMolDraw2D
53from rdkit.Chem import rdDepictor
54from rdkit.Chem import rdMolDescriptors as rdMD
55
56
57def GetAtomicWeightsForFingerprint(refMol, probeMol, fpFunction, metric=DataStructs.DiceSimilarity):
58  """
59    Calculates the atomic weights for the probe molecule
60    based on a fingerprint function and a metric.
61
62    Parameters:
63      refMol -- the reference molecule
64      probeMol -- the probe molecule
65      fpFunction -- the fingerprint function
66      metric -- the similarity metric
67
68    Note:
69      If fpFunction needs additional parameters, use a lambda construct
70    """
71  if hasattr(probeMol, '_fpInfo'):
72    delattr(probeMol, '_fpInfo')
73  if hasattr(refMol, '_fpInfo'):
74    delattr(refMol, '_fpInfo')
75  refFP = fpFunction(refMol, -1)
76  probeFP = fpFunction(probeMol, -1)
77  baseSimilarity = metric(refFP, probeFP)
78  # loop over atoms
79  weights = []
80  for atomId in range(probeMol.GetNumAtoms()):
81    newFP = fpFunction(probeMol, atomId)
82    newSimilarity = metric(refFP, newFP)
83    weights.append(baseSimilarity - newSimilarity)
84  if hasattr(probeMol, '_fpInfo'):
85    delattr(probeMol, '_fpInfo')
86  if hasattr(refMol, '_fpInfo'):
87    delattr(refMol, '_fpInfo')
88  return weights
89
90
91def GetAtomicWeightsForModel(probeMol, fpFunction, predictionFunction):
92  """
93    Calculates the atomic weights for the probe molecule based on
94    a fingerprint function and the prediction function of a ML model.
95
96    Parameters:
97      probeMol -- the probe molecule
98      fpFunction -- the fingerprint function
99      predictionFunction -- the prediction function of the ML model
100    """
101  if hasattr(probeMol, '_fpInfo'):
102    delattr(probeMol, '_fpInfo')
103  probeFP = fpFunction(probeMol, -1)
104  baseProba = predictionFunction(probeFP)
105  # loop over atoms
106  weights = []
107  for atomId in range(probeMol.GetNumAtoms()):
108    newFP = fpFunction(probeMol, atomId)
109    newProba = predictionFunction(newFP)
110    weights.append(baseProba - newProba)
111  if hasattr(probeMol, '_fpInfo'):
112    delattr(probeMol, '_fpInfo')
113  return weights
114
115
116def GetStandardizedWeights(weights):
117  """
118    Normalizes the weights,
119    such that the absolute maximum weight equals 1.0.
120
121    Parameters:
122      weights -- the list with the atomic weights
123    """
124  tmp = [math.fabs(w) for w in weights]
125  currentMax = max(tmp)
126  if currentMax > 0:
127    return [w / currentMax for w in weights], currentMax
128  else:
129    return weights, currentMax
130
131
132def GetSimilarityMapFromWeights(mol, weights, colorMap=None, scale=-1, size=(250, 250), sigma=None,
133                                coordScale=1.5, step=0.01, colors='k', contourLines=10, alpha=0.5,
134                                draw2d=None, **kwargs):
135  """
136    Generates the similarity map for a molecule given the atomic weights.
137
138    Parameters:
139      mol -- the molecule of interest
140      colorMap -- the matplotlib color map scheme, default is custom PiWG color map
141      scale -- the scaling: scale < 0 -> the absolute maximum weight is used as maximum scale
142                            scale = double -> this is the maximum scale
143      size -- the size of the figure
144      sigma -- the sigma for the Gaussians
145      coordScale -- scaling factor for the coordinates
146      step -- the step for calcAtomGaussian
147      colors -- color of the contour lines
148      contourLines -- if integer number N: N contour lines are drawn
149                      if list(numbers): contour lines at these numbers are drawn
150      alpha -- the alpha blending value for the contour lines
151      kwargs -- additional arguments for drawing
152    """
153  if mol.GetNumAtoms() < 2:
154    raise ValueError("too few atoms")
155  if draw2d is not None:
156    mol = rdMolDraw2D.PrepareMolForDrawing(mol, addChiralHs=False)
157    if not mol.GetNumConformers():
158      rdDepictor.Compute2DCoords(mol)
159    if sigma is None:
160      if mol.GetNumBonds() > 0:
161        bond = mol.GetBondWithIdx(0)
162        idx1 = bond.GetBeginAtomIdx()
163        idx2 = bond.GetEndAtomIdx()
164        sigma = 0.3 * (mol.GetConformer().GetAtomPosition(idx1) -
165                       mol.GetConformer().GetAtomPosition(idx2)).Length()
166      else:
167        sigma = 0.3 * (mol.GetConformer().GetAtomPosition(0) -
168                       mol.GetConformer().GetAtomPosition(1)).Length()
169      sigma = round(sigma, 2)
170    sigmas = [sigma] * mol.GetNumAtoms()
171    locs = []
172    for i in range(mol.GetNumAtoms()):
173      p = mol.GetConformer().GetAtomPosition(i)
174      locs.append(Geometry.Point2D(p.x, p.y))
175    draw2d.ClearDrawing()
176    ps = Draw.ContourParams()
177    ps.fillGrid = True
178    ps.gridResolution = 0.1
179    ps.extraGridPadding = 0.5
180    if colorMap is not None:
181      if cm is not None and isinstance(colorMap, type(cm.Blues)):
182        # it's a matplotlib colormap:
183        clrs = [tuple(x) for x in colorMap([0, 0.5, 1])]
184      else:
185        clrs = [colorMap[0], colorMap[1], colorMap[2]]
186      ps.setColourMap(clrs)
187
188    Draw.ContourAndDrawGaussians(draw2d, locs, weights, sigmas, nContours=contourLines, params=ps)
189    draw2d.drawOptions().clearBackground = False
190    draw2d.DrawMolecule(mol)
191    return draw2d
192
193  fig = Draw.MolToMPL(mol, coordScale=coordScale, size=size, **kwargs)
194  if sigma is None:
195    if mol.GetNumBonds() > 0:
196      bond = mol.GetBondWithIdx(0)
197      idx1 = bond.GetBeginAtomIdx()
198      idx2 = bond.GetEndAtomIdx()
199      sigma = 0.3 * math.sqrt(
200        sum([(mol._atomPs[idx1][i] - mol._atomPs[idx2][i])**2 for i in range(2)]))
201    else:
202      sigma = 0.3 * \
203          math.sqrt(sum([(mol._atomPs[0][i] - mol._atomPs[1][i])**2 for i in range(2)]))
204    sigma = round(sigma, 2)
205  x, y, z = Draw.calcAtomGaussians(mol, sigma, weights=weights, step=step)
206  # scaling
207  if scale <= 0.0:
208    maxScale = max(math.fabs(numpy.min(z)), math.fabs(numpy.max(z)))
209  else:
210    maxScale = scale
211  # coloring
212  if colorMap is None:
213    if cm is None:
214      raise RuntimeError("matplotlib failed to import")
215    PiYG_cmap = cm.get_cmap('PiYG', 2)
216    colorMap = LinearSegmentedColormap.from_list(
217      'PiWG', [PiYG_cmap(0), (1.0, 1.0, 1.0), PiYG_cmap(1)], N=255)
218
219  fig.axes[0].imshow(z, cmap=colorMap, interpolation='bilinear', origin='lower',
220                     extent=(0, 1, 0, 1), vmin=-maxScale, vmax=maxScale)
221  # contour lines
222  # only draw them when at least one weight is not zero
223  if len([w for w in weights if w != 0.0]):
224    contourset = fig.axes[0].contour(x, y, z, contourLines, colors=colors, alpha=alpha, **kwargs)
225    for j, c in enumerate(contourset.collections):
226      if contourset.levels[j] == 0.0:
227        c.set_linewidth(0.0)
228      elif contourset.levels[j] < 0:
229        c.set_dashes([(0, (3.0, 3.0))])
230  fig.axes[0].set_axis_off()
231  return fig
232
233
234def GetSimilarityMapForFingerprint(refMol, probeMol, fpFunction, metric=DataStructs.DiceSimilarity,
235                                   **kwargs):
236  """
237    Generates the similarity map for a given reference and probe molecule,
238    fingerprint function and similarity metric.
239
240    Parameters:
241      refMol -- the reference molecule
242      probeMol -- the probe molecule
243      fpFunction -- the fingerprint function
244      metric -- the similarity metric.
245      kwargs -- additional arguments for drawing
246    """
247  weights = GetAtomicWeightsForFingerprint(refMol, probeMol, fpFunction, metric)
248  weights, maxWeight = GetStandardizedWeights(weights)
249  fig = GetSimilarityMapFromWeights(probeMol, weights, **kwargs)
250  return fig, maxWeight
251
252
253def GetSimilarityMapForModel(probeMol, fpFunction, predictionFunction, **kwargs):
254  """
255    Generates the similarity map for a given ML model and probe molecule,
256    and fingerprint function.
257
258    Parameters:
259      probeMol -- the probe molecule
260      fpFunction -- the fingerprint function
261      predictionFunction -- the prediction function of the ML model
262      kwargs -- additional arguments for drawing
263    """
264  weights = GetAtomicWeightsForModel(probeMol, fpFunction, predictionFunction)
265  weights, maxWeight = GetStandardizedWeights(weights)
266  fig = GetSimilarityMapFromWeights(probeMol, weights, **kwargs)
267  return fig, maxWeight
268
269
270apDict = {}
271apDict['normal'] = lambda m, bits, minl, maxl, bpe, ia, **kwargs: rdMD.GetAtomPairFingerprint(
272  m, minLength=minl, maxLength=maxl, ignoreAtoms=ia, **kwargs)
273apDict['hashed'] = lambda m, bits, minl, maxl, bpe, ia, **kwargs: rdMD.GetHashedAtomPairFingerprint(
274  m, nBits=bits, minLength=minl, maxLength=maxl, ignoreAtoms=ia, **kwargs)
275apDict[
276  'bv'] = lambda m, bits, minl, maxl, bpe, ia, **kwargs: rdMD.GetHashedAtomPairFingerprintAsBitVect(
277    m, nBits=bits, minLength=minl, maxLength=maxl, nBitsPerEntry=bpe, ignoreAtoms=ia, **kwargs)
278
279
280# usage:   lambda m,i: GetAPFingerprint(m, i, fpType, nBits, minLength, maxLength, nBitsPerEntry)
281def GetAPFingerprint(mol, atomId=-1, fpType='normal', nBits=2048, minLength=1, maxLength=30,
282                     nBitsPerEntry=4, **kwargs):
283  """
284    Calculates the atom pairs fingerprint with the torsions of atomId removed.
285
286    Parameters:
287      mol -- the molecule of interest
288      atomId -- the atom to remove the pairs for (if -1, no pair is removed)
289      fpType -- the type of AP fingerprint ('normal', 'hashed', 'bv')
290      nBits -- the size of the bit vector (only for fpType='bv')
291      minLength -- the minimum path length for an atom pair
292      maxLength -- the maxmimum path length for an atom pair
293      nBitsPerEntry -- the number of bits available for each pair
294    """
295  if fpType not in ['normal', 'hashed', 'bv']:
296    raise ValueError("Unknown Atom pairs fingerprint type")
297  if atomId < 0:
298    return apDict[fpType](mol, nBits, minLength, maxLength, nBitsPerEntry, 0, **kwargs)
299  if atomId >= mol.GetNumAtoms():
300    raise ValueError("atom index greater than number of atoms")
301  return apDict[fpType](mol, nBits, minLength, maxLength, nBitsPerEntry, [atomId], **kwargs)
302
303
304ttDict = {}
305ttDict['normal'] = lambda m, bits, ts, bpe, ia, **kwargs: rdMD.GetTopologicalTorsionFingerprint(
306  m, targetSize=ts, ignoreAtoms=ia, **kwargs)
307ttDict[
308  'hashed'] = lambda m, bits, ts, bpe, ia, **kwargs: rdMD.GetHashedTopologicalTorsionFingerprint(
309    m, nBits=bits, targetSize=ts, ignoreAtoms=ia, **kwargs)
310ttDict[
311  'bv'] = lambda m, bits, ts, bpe, ia, **kwargs: rdMD.GetHashedTopologicalTorsionFingerprintAsBitVect(
312    m, nBits=bits, targetSize=ts, nBitsPerEntry=bpe, ignoreAtoms=ia, **kwargs)
313
314
315# usage:   lambda m,i: GetTTFingerprint(m, i, fpType, nBits, targetSize)
316def GetTTFingerprint(mol, atomId=-1, fpType='normal', nBits=2048, targetSize=4, nBitsPerEntry=4,
317                     **kwargs):
318  """
319    Calculates the topological torsion fingerprint with the pairs of atomId removed.
320
321    Parameters:
322      mol -- the molecule of interest
323      atomId -- the atom to remove the torsions for (if -1, no torsion is removed)
324      fpType -- the type of TT fingerprint ('normal', 'hashed', 'bv')
325      nBits -- the size of the bit vector (only for fpType='bv')
326      minLength -- the minimum path length for an atom pair
327      maxLength -- the maxmimum path length for an atom pair
328      nBitsPerEntry -- the number of bits available for each torsion
329
330    any additional keyword arguments will be passed to the fingerprinting function.
331
332    """
333  if fpType not in ['normal', 'hashed', 'bv']:
334    raise ValueError("Unknown Topological torsion fingerprint type")
335  if atomId < 0:
336    return ttDict[fpType](mol, nBits, targetSize, nBitsPerEntry, 0, **kwargs)
337  if atomId >= mol.GetNumAtoms():
338    raise ValueError("atom index greater than number of atoms")
339  return ttDict[fpType](mol, nBits, targetSize, nBitsPerEntry, [atomId], **kwargs)
340
341
342# usage:   lambda m,i: GetMorganFingerprint(m, i, radius, fpType, nBits, useFeatures)
343def GetMorganFingerprint(mol, atomId=-1, radius=2, fpType='bv', nBits=2048, useFeatures=False,
344                         **kwargs):
345  """
346    Calculates the Morgan fingerprint with the environments of atomId removed.
347
348    Parameters:
349      mol -- the molecule of interest
350      radius -- the maximum radius
351      fpType -- the type of Morgan fingerprint: 'count' or 'bv'
352      atomId -- the atom to remove the environments for (if -1, no environments is removed)
353      nBits -- the size of the bit vector (only for fpType = 'bv')
354      useFeatures -- if false: ConnectivityMorgan, if true: FeatureMorgan
355
356    any additional keyword arguments will be passed to the fingerprinting function.
357    """
358  if fpType not in ['bv', 'count']:
359    raise ValueError("Unknown Morgan fingerprint type")
360  if not hasattr(mol, '_fpInfo'):
361    info = {}
362    # get the fingerprint
363    if fpType == 'bv':
364      molFp = rdMD.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits, useFeatures=useFeatures,
365                                                 bitInfo=info, **kwargs)
366    else:
367      molFp = rdMD.GetMorganFingerprint(mol, radius, useFeatures=useFeatures, bitInfo=info,
368                                        **kwargs)
369    # construct the bit map
370    if fpType == 'bv':
371      bitmap = [DataStructs.ExplicitBitVect(nBits) for _ in range(mol.GetNumAtoms())]
372    else:
373      bitmap = [[] for _ in range(mol.GetNumAtoms())]
374    for bit, es in info.items():
375      for at1, rad in es:
376        if rad == 0:  # for radius 0
377          if fpType == 'bv':
378            bitmap[at1][bit] = 1
379          else:
380            bitmap[at1].append(bit)
381        else:  # for radii > 0
382          env = Chem.FindAtomEnvironmentOfRadiusN(mol, rad, at1)
383          amap = {}
384          Chem.PathToSubmol(mol, env, atomMap=amap)
385          for at2 in amap.keys():
386            if fpType == 'bv':
387              bitmap[at2][bit] = 1
388            else:
389              bitmap[at2].append(bit)
390    mol._fpInfo = (molFp, bitmap)
391
392  if atomId < 0:
393    return mol._fpInfo[0]
394  else:  # remove the bits of atomId
395    if atomId >= mol.GetNumAtoms():
396      raise ValueError("atom index greater than number of atoms")
397    if len(mol._fpInfo) != 2:
398      raise ValueError("_fpInfo not set")
399    if fpType == 'bv':
400      molFp = mol._fpInfo[0] ^ mol._fpInfo[1][atomId]  # xor
401    else:  # count
402      molFp = copy.deepcopy(mol._fpInfo[0])
403      # delete the bits with atomId
404      for bit in mol._fpInfo[1][atomId]:
405        molFp[bit] -= 1
406    return molFp
407
408
409# usage:   lambda m,i: GetRDKFingerprint(m, i, fpType, nBits, minPath, maxPath, nBitsPerHash)
410def GetRDKFingerprint(mol, atomId=-1, fpType='bv', nBits=2048, minPath=1, maxPath=5, nBitsPerHash=2,
411                      **kwargs):
412  """
413    Calculates the RDKit fingerprint with the paths of atomId removed.
414
415    Parameters:
416      mol -- the molecule of interest
417      atomId -- the atom to remove the paths for (if -1, no path is removed)
418      fpType -- the type of RDKit fingerprint: 'bv'
419      nBits -- the size of the bit vector
420      minPath -- minimum path length
421      maxPath -- maximum path length
422      nBitsPerHash -- number of to set per path
423    """
424  if fpType not in ['bv', '']:
425    raise ValueError("Unknown RDKit fingerprint type")
426  fpType = 'bv'
427  if not hasattr(mol, '_fpInfo'):
428    info = []  # list with bits for each atom
429    # get the fingerprint
430    molFp = Chem.RDKFingerprint(mol, fpSize=nBits, minPath=minPath, maxPath=maxPath,
431                                nBitsPerHash=nBitsPerHash, atomBits=info, **kwargs)
432    mol._fpInfo = (molFp, info)
433
434  if atomId < 0:
435    return mol._fpInfo[0]
436  else:  # remove the bits of atomId
437    if atomId >= mol.GetNumAtoms():
438      raise ValueError("atom index greater than number of atoms")
439    if len(mol._fpInfo) != 2:
440      raise ValueError("_fpInfo not set")
441    molFp = copy.deepcopy(mol._fpInfo[0])
442    molFp.UnSetBitsFromList(mol._fpInfo[1][atomId])
443    return molFp
444