1# 2# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 3# Copyright (c) 2021, Greg Landrum 4# All rights reserved. 5# 6# Redistribution and use in source and binary forms, with or without 7# modification, are permitted provided that the following conditions are 8# met: 9# 10# * Redistributions of source code must retain the above copyright 11# notice, this list of conditions and the following disclaimer. 12# * Redistributions in binary form must reproduce the above 13# copyright notice, this list of conditions and the following 14# disclaimer in the documentation and/or other materials provided 15# with the distribution. 16# * Neither the name of Novartis Institutes for BioMedical Research Inc. 17# nor the names of its contributors may be used to endorse or promote 18# products derived from this software without specific prior written permission. 19# 20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31# 32# Created by Sereina Riniker, Aug 2013 33 34import copy 35import math 36 37from numpy.lib.arraysetops import isin 38try: 39 from matplotlib import cm 40 from matplotlib.colors import LinearSegmentedColormap 41except ImportError: 42 cm = None 43except RuntimeError: 44 cm = None 45 46import numpy 47 48from rdkit import Chem 49from rdkit import DataStructs 50from rdkit import Geometry 51from rdkit.Chem import Draw 52from rdkit.Chem.Draw import rdMolDraw2D 53from rdkit.Chem import rdDepictor 54from rdkit.Chem import rdMolDescriptors as rdMD 55 56 57def GetAtomicWeightsForFingerprint(refMol, probeMol, fpFunction, metric=DataStructs.DiceSimilarity): 58 """ 59 Calculates the atomic weights for the probe molecule 60 based on a fingerprint function and a metric. 61 62 Parameters: 63 refMol -- the reference molecule 64 probeMol -- the probe molecule 65 fpFunction -- the fingerprint function 66 metric -- the similarity metric 67 68 Note: 69 If fpFunction needs additional parameters, use a lambda construct 70 """ 71 if hasattr(probeMol, '_fpInfo'): 72 delattr(probeMol, '_fpInfo') 73 if hasattr(refMol, '_fpInfo'): 74 delattr(refMol, '_fpInfo') 75 refFP = fpFunction(refMol, -1) 76 probeFP = fpFunction(probeMol, -1) 77 baseSimilarity = metric(refFP, probeFP) 78 # loop over atoms 79 weights = [] 80 for atomId in range(probeMol.GetNumAtoms()): 81 newFP = fpFunction(probeMol, atomId) 82 newSimilarity = metric(refFP, newFP) 83 weights.append(baseSimilarity - newSimilarity) 84 if hasattr(probeMol, '_fpInfo'): 85 delattr(probeMol, '_fpInfo') 86 if hasattr(refMol, '_fpInfo'): 87 delattr(refMol, '_fpInfo') 88 return weights 89 90 91def GetAtomicWeightsForModel(probeMol, fpFunction, predictionFunction): 92 """ 93 Calculates the atomic weights for the probe molecule based on 94 a fingerprint function and the prediction function of a ML model. 95 96 Parameters: 97 probeMol -- the probe molecule 98 fpFunction -- the fingerprint function 99 predictionFunction -- the prediction function of the ML model 100 """ 101 if hasattr(probeMol, '_fpInfo'): 102 delattr(probeMol, '_fpInfo') 103 probeFP = fpFunction(probeMol, -1) 104 baseProba = predictionFunction(probeFP) 105 # loop over atoms 106 weights = [] 107 for atomId in range(probeMol.GetNumAtoms()): 108 newFP = fpFunction(probeMol, atomId) 109 newProba = predictionFunction(newFP) 110 weights.append(baseProba - newProba) 111 if hasattr(probeMol, '_fpInfo'): 112 delattr(probeMol, '_fpInfo') 113 return weights 114 115 116def GetStandardizedWeights(weights): 117 """ 118 Normalizes the weights, 119 such that the absolute maximum weight equals 1.0. 120 121 Parameters: 122 weights -- the list with the atomic weights 123 """ 124 tmp = [math.fabs(w) for w in weights] 125 currentMax = max(tmp) 126 if currentMax > 0: 127 return [w / currentMax for w in weights], currentMax 128 else: 129 return weights, currentMax 130 131 132def GetSimilarityMapFromWeights(mol, weights, colorMap=None, scale=-1, size=(250, 250), sigma=None, 133 coordScale=1.5, step=0.01, colors='k', contourLines=10, alpha=0.5, 134 draw2d=None, **kwargs): 135 """ 136 Generates the similarity map for a molecule given the atomic weights. 137 138 Parameters: 139 mol -- the molecule of interest 140 colorMap -- the matplotlib color map scheme, default is custom PiWG color map 141 scale -- the scaling: scale < 0 -> the absolute maximum weight is used as maximum scale 142 scale = double -> this is the maximum scale 143 size -- the size of the figure 144 sigma -- the sigma for the Gaussians 145 coordScale -- scaling factor for the coordinates 146 step -- the step for calcAtomGaussian 147 colors -- color of the contour lines 148 contourLines -- if integer number N: N contour lines are drawn 149 if list(numbers): contour lines at these numbers are drawn 150 alpha -- the alpha blending value for the contour lines 151 kwargs -- additional arguments for drawing 152 """ 153 if mol.GetNumAtoms() < 2: 154 raise ValueError("too few atoms") 155 if draw2d is not None: 156 mol = rdMolDraw2D.PrepareMolForDrawing(mol, addChiralHs=False) 157 if not mol.GetNumConformers(): 158 rdDepictor.Compute2DCoords(mol) 159 if sigma is None: 160 if mol.GetNumBonds() > 0: 161 bond = mol.GetBondWithIdx(0) 162 idx1 = bond.GetBeginAtomIdx() 163 idx2 = bond.GetEndAtomIdx() 164 sigma = 0.3 * (mol.GetConformer().GetAtomPosition(idx1) - 165 mol.GetConformer().GetAtomPosition(idx2)).Length() 166 else: 167 sigma = 0.3 * (mol.GetConformer().GetAtomPosition(0) - 168 mol.GetConformer().GetAtomPosition(1)).Length() 169 sigma = round(sigma, 2) 170 sigmas = [sigma] * mol.GetNumAtoms() 171 locs = [] 172 for i in range(mol.GetNumAtoms()): 173 p = mol.GetConformer().GetAtomPosition(i) 174 locs.append(Geometry.Point2D(p.x, p.y)) 175 draw2d.ClearDrawing() 176 ps = Draw.ContourParams() 177 ps.fillGrid = True 178 ps.gridResolution = 0.1 179 ps.extraGridPadding = 0.5 180 if colorMap is not None: 181 if cm is not None and isinstance(colorMap, type(cm.Blues)): 182 # it's a matplotlib colormap: 183 clrs = [tuple(x) for x in colorMap([0, 0.5, 1])] 184 else: 185 clrs = [colorMap[0], colorMap[1], colorMap[2]] 186 ps.setColourMap(clrs) 187 188 Draw.ContourAndDrawGaussians(draw2d, locs, weights, sigmas, nContours=contourLines, params=ps) 189 draw2d.drawOptions().clearBackground = False 190 draw2d.DrawMolecule(mol) 191 return draw2d 192 193 fig = Draw.MolToMPL(mol, coordScale=coordScale, size=size, **kwargs) 194 if sigma is None: 195 if mol.GetNumBonds() > 0: 196 bond = mol.GetBondWithIdx(0) 197 idx1 = bond.GetBeginAtomIdx() 198 idx2 = bond.GetEndAtomIdx() 199 sigma = 0.3 * math.sqrt( 200 sum([(mol._atomPs[idx1][i] - mol._atomPs[idx2][i])**2 for i in range(2)])) 201 else: 202 sigma = 0.3 * \ 203 math.sqrt(sum([(mol._atomPs[0][i] - mol._atomPs[1][i])**2 for i in range(2)])) 204 sigma = round(sigma, 2) 205 x, y, z = Draw.calcAtomGaussians(mol, sigma, weights=weights, step=step) 206 # scaling 207 if scale <= 0.0: 208 maxScale = max(math.fabs(numpy.min(z)), math.fabs(numpy.max(z))) 209 else: 210 maxScale = scale 211 # coloring 212 if colorMap is None: 213 if cm is None: 214 raise RuntimeError("matplotlib failed to import") 215 PiYG_cmap = cm.get_cmap('PiYG', 2) 216 colorMap = LinearSegmentedColormap.from_list( 217 'PiWG', [PiYG_cmap(0), (1.0, 1.0, 1.0), PiYG_cmap(1)], N=255) 218 219 fig.axes[0].imshow(z, cmap=colorMap, interpolation='bilinear', origin='lower', 220 extent=(0, 1, 0, 1), vmin=-maxScale, vmax=maxScale) 221 # contour lines 222 # only draw them when at least one weight is not zero 223 if len([w for w in weights if w != 0.0]): 224 contourset = fig.axes[0].contour(x, y, z, contourLines, colors=colors, alpha=alpha, **kwargs) 225 for j, c in enumerate(contourset.collections): 226 if contourset.levels[j] == 0.0: 227 c.set_linewidth(0.0) 228 elif contourset.levels[j] < 0: 229 c.set_dashes([(0, (3.0, 3.0))]) 230 fig.axes[0].set_axis_off() 231 return fig 232 233 234def GetSimilarityMapForFingerprint(refMol, probeMol, fpFunction, metric=DataStructs.DiceSimilarity, 235 **kwargs): 236 """ 237 Generates the similarity map for a given reference and probe molecule, 238 fingerprint function and similarity metric. 239 240 Parameters: 241 refMol -- the reference molecule 242 probeMol -- the probe molecule 243 fpFunction -- the fingerprint function 244 metric -- the similarity metric. 245 kwargs -- additional arguments for drawing 246 """ 247 weights = GetAtomicWeightsForFingerprint(refMol, probeMol, fpFunction, metric) 248 weights, maxWeight = GetStandardizedWeights(weights) 249 fig = GetSimilarityMapFromWeights(probeMol, weights, **kwargs) 250 return fig, maxWeight 251 252 253def GetSimilarityMapForModel(probeMol, fpFunction, predictionFunction, **kwargs): 254 """ 255 Generates the similarity map for a given ML model and probe molecule, 256 and fingerprint function. 257 258 Parameters: 259 probeMol -- the probe molecule 260 fpFunction -- the fingerprint function 261 predictionFunction -- the prediction function of the ML model 262 kwargs -- additional arguments for drawing 263 """ 264 weights = GetAtomicWeightsForModel(probeMol, fpFunction, predictionFunction) 265 weights, maxWeight = GetStandardizedWeights(weights) 266 fig = GetSimilarityMapFromWeights(probeMol, weights, **kwargs) 267 return fig, maxWeight 268 269 270apDict = {} 271apDict['normal'] = lambda m, bits, minl, maxl, bpe, ia, **kwargs: rdMD.GetAtomPairFingerprint( 272 m, minLength=minl, maxLength=maxl, ignoreAtoms=ia, **kwargs) 273apDict['hashed'] = lambda m, bits, minl, maxl, bpe, ia, **kwargs: rdMD.GetHashedAtomPairFingerprint( 274 m, nBits=bits, minLength=minl, maxLength=maxl, ignoreAtoms=ia, **kwargs) 275apDict[ 276 'bv'] = lambda m, bits, minl, maxl, bpe, ia, **kwargs: rdMD.GetHashedAtomPairFingerprintAsBitVect( 277 m, nBits=bits, minLength=minl, maxLength=maxl, nBitsPerEntry=bpe, ignoreAtoms=ia, **kwargs) 278 279 280# usage: lambda m,i: GetAPFingerprint(m, i, fpType, nBits, minLength, maxLength, nBitsPerEntry) 281def GetAPFingerprint(mol, atomId=-1, fpType='normal', nBits=2048, minLength=1, maxLength=30, 282 nBitsPerEntry=4, **kwargs): 283 """ 284 Calculates the atom pairs fingerprint with the torsions of atomId removed. 285 286 Parameters: 287 mol -- the molecule of interest 288 atomId -- the atom to remove the pairs for (if -1, no pair is removed) 289 fpType -- the type of AP fingerprint ('normal', 'hashed', 'bv') 290 nBits -- the size of the bit vector (only for fpType='bv') 291 minLength -- the minimum path length for an atom pair 292 maxLength -- the maxmimum path length for an atom pair 293 nBitsPerEntry -- the number of bits available for each pair 294 """ 295 if fpType not in ['normal', 'hashed', 'bv']: 296 raise ValueError("Unknown Atom pairs fingerprint type") 297 if atomId < 0: 298 return apDict[fpType](mol, nBits, minLength, maxLength, nBitsPerEntry, 0, **kwargs) 299 if atomId >= mol.GetNumAtoms(): 300 raise ValueError("atom index greater than number of atoms") 301 return apDict[fpType](mol, nBits, minLength, maxLength, nBitsPerEntry, [atomId], **kwargs) 302 303 304ttDict = {} 305ttDict['normal'] = lambda m, bits, ts, bpe, ia, **kwargs: rdMD.GetTopologicalTorsionFingerprint( 306 m, targetSize=ts, ignoreAtoms=ia, **kwargs) 307ttDict[ 308 'hashed'] = lambda m, bits, ts, bpe, ia, **kwargs: rdMD.GetHashedTopologicalTorsionFingerprint( 309 m, nBits=bits, targetSize=ts, ignoreAtoms=ia, **kwargs) 310ttDict[ 311 'bv'] = lambda m, bits, ts, bpe, ia, **kwargs: rdMD.GetHashedTopologicalTorsionFingerprintAsBitVect( 312 m, nBits=bits, targetSize=ts, nBitsPerEntry=bpe, ignoreAtoms=ia, **kwargs) 313 314 315# usage: lambda m,i: GetTTFingerprint(m, i, fpType, nBits, targetSize) 316def GetTTFingerprint(mol, atomId=-1, fpType='normal', nBits=2048, targetSize=4, nBitsPerEntry=4, 317 **kwargs): 318 """ 319 Calculates the topological torsion fingerprint with the pairs of atomId removed. 320 321 Parameters: 322 mol -- the molecule of interest 323 atomId -- the atom to remove the torsions for (if -1, no torsion is removed) 324 fpType -- the type of TT fingerprint ('normal', 'hashed', 'bv') 325 nBits -- the size of the bit vector (only for fpType='bv') 326 minLength -- the minimum path length for an atom pair 327 maxLength -- the maxmimum path length for an atom pair 328 nBitsPerEntry -- the number of bits available for each torsion 329 330 any additional keyword arguments will be passed to the fingerprinting function. 331 332 """ 333 if fpType not in ['normal', 'hashed', 'bv']: 334 raise ValueError("Unknown Topological torsion fingerprint type") 335 if atomId < 0: 336 return ttDict[fpType](mol, nBits, targetSize, nBitsPerEntry, 0, **kwargs) 337 if atomId >= mol.GetNumAtoms(): 338 raise ValueError("atom index greater than number of atoms") 339 return ttDict[fpType](mol, nBits, targetSize, nBitsPerEntry, [atomId], **kwargs) 340 341 342# usage: lambda m,i: GetMorganFingerprint(m, i, radius, fpType, nBits, useFeatures) 343def GetMorganFingerprint(mol, atomId=-1, radius=2, fpType='bv', nBits=2048, useFeatures=False, 344 **kwargs): 345 """ 346 Calculates the Morgan fingerprint with the environments of atomId removed. 347 348 Parameters: 349 mol -- the molecule of interest 350 radius -- the maximum radius 351 fpType -- the type of Morgan fingerprint: 'count' or 'bv' 352 atomId -- the atom to remove the environments for (if -1, no environments is removed) 353 nBits -- the size of the bit vector (only for fpType = 'bv') 354 useFeatures -- if false: ConnectivityMorgan, if true: FeatureMorgan 355 356 any additional keyword arguments will be passed to the fingerprinting function. 357 """ 358 if fpType not in ['bv', 'count']: 359 raise ValueError("Unknown Morgan fingerprint type") 360 if not hasattr(mol, '_fpInfo'): 361 info = {} 362 # get the fingerprint 363 if fpType == 'bv': 364 molFp = rdMD.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits, useFeatures=useFeatures, 365 bitInfo=info, **kwargs) 366 else: 367 molFp = rdMD.GetMorganFingerprint(mol, radius, useFeatures=useFeatures, bitInfo=info, 368 **kwargs) 369 # construct the bit map 370 if fpType == 'bv': 371 bitmap = [DataStructs.ExplicitBitVect(nBits) for _ in range(mol.GetNumAtoms())] 372 else: 373 bitmap = [[] for _ in range(mol.GetNumAtoms())] 374 for bit, es in info.items(): 375 for at1, rad in es: 376 if rad == 0: # for radius 0 377 if fpType == 'bv': 378 bitmap[at1][bit] = 1 379 else: 380 bitmap[at1].append(bit) 381 else: # for radii > 0 382 env = Chem.FindAtomEnvironmentOfRadiusN(mol, rad, at1) 383 amap = {} 384 Chem.PathToSubmol(mol, env, atomMap=amap) 385 for at2 in amap.keys(): 386 if fpType == 'bv': 387 bitmap[at2][bit] = 1 388 else: 389 bitmap[at2].append(bit) 390 mol._fpInfo = (molFp, bitmap) 391 392 if atomId < 0: 393 return mol._fpInfo[0] 394 else: # remove the bits of atomId 395 if atomId >= mol.GetNumAtoms(): 396 raise ValueError("atom index greater than number of atoms") 397 if len(mol._fpInfo) != 2: 398 raise ValueError("_fpInfo not set") 399 if fpType == 'bv': 400 molFp = mol._fpInfo[0] ^ mol._fpInfo[1][atomId] # xor 401 else: # count 402 molFp = copy.deepcopy(mol._fpInfo[0]) 403 # delete the bits with atomId 404 for bit in mol._fpInfo[1][atomId]: 405 molFp[bit] -= 1 406 return molFp 407 408 409# usage: lambda m,i: GetRDKFingerprint(m, i, fpType, nBits, minPath, maxPath, nBitsPerHash) 410def GetRDKFingerprint(mol, atomId=-1, fpType='bv', nBits=2048, minPath=1, maxPath=5, nBitsPerHash=2, 411 **kwargs): 412 """ 413 Calculates the RDKit fingerprint with the paths of atomId removed. 414 415 Parameters: 416 mol -- the molecule of interest 417 atomId -- the atom to remove the paths for (if -1, no path is removed) 418 fpType -- the type of RDKit fingerprint: 'bv' 419 nBits -- the size of the bit vector 420 minPath -- minimum path length 421 maxPath -- maximum path length 422 nBitsPerHash -- number of to set per path 423 """ 424 if fpType not in ['bv', '']: 425 raise ValueError("Unknown RDKit fingerprint type") 426 fpType = 'bv' 427 if not hasattr(mol, '_fpInfo'): 428 info = [] # list with bits for each atom 429 # get the fingerprint 430 molFp = Chem.RDKFingerprint(mol, fpSize=nBits, minPath=minPath, maxPath=maxPath, 431 nBitsPerHash=nBitsPerHash, atomBits=info, **kwargs) 432 mol._fpInfo = (molFp, info) 433 434 if atomId < 0: 435 return mol._fpInfo[0] 436 else: # remove the bits of atomId 437 if atomId >= mol.GetNumAtoms(): 438 raise ValueError("atom index greater than number of atoms") 439 if len(mol._fpInfo) != 2: 440 raise ValueError("_fpInfo not set") 441 molFp = copy.deepcopy(mol._fpInfo[0]) 442 molFp.UnSetBitsFromList(mol._fpInfo[1][atomId]) 443 return molFp 444