1#!/usr/bin/env python 2''' 3Get Tinker keywords for Boresch restraint 4 5> get_rot_rest.py sample.xyz sample.key 6 7Read any traj format and "ligand keyword" from Tinker key file 8Print 6 distance/angle/torsion restraints and the standard state correction 9 10''' 11 12import numpy as np 13import mdtraj as md 14import sys 15import os 16import time 17import scipy.optimize 18from scipy.spatial import distance_matrix 19from scipy.cluster.hierarchy import fcluster, leaders, linkage 20from scipy.spatial.distance import pdist 21from collections import Counter, defaultdict 22from openbabel import openbabel 23from openbabel import pybel 24 25def get_adjlist(top): 26 adjlist = defaultdict(list) 27 for bond in top.bonds: 28 a1, a2 = bond 29 if not (a1.name.startswith('H') or a2.name.startswith('H') ): 30 adjlist[a1.index].append(a2.index) 31 adjlist[a2.index].append(a1.index) 32 return adjlist 33 34def get_tab_branch(top): 35 ''' 36 Given mdtraj Trajectory.Topology, 37 returns a dictionary for the number of connected heavy atoms 38 ''' 39 i = 0 40 nbr = np.zeros(top.n_atoms, dtype=np.int) 41 for bond in top.bonds: 42 i += 1 43 a1, a2 = bond 44 if not (a1.name.startswith('H') or a2.name.startswith('H') ): 45 nbr[a1.index] += 1 46 nbr[a2.index] += 1 47 return nbr 48 49def error_exit(msg): 50 print("ERROR:", msg) 51 sys.exit(1) 52def write_tinker_idx(idxs): 53 idx_out = [] 54 rs = [] 55 for i0 in sorted(idxs): 56 if len(rs) == 0 or rs[-1][1]+1 < i0: 57 rs.append([i0, i0]) 58 else: 59 rs[-1][1] = i0 60 for r in rs: 61 if r[0] == r[1]: 62 idx_out.append(r[0]) 63 elif r[0] == r[1] - 1: 64 idx_out.append(r[0]) 65 idx_out.append(r[1]) 66 else: 67 idx_out.append(-r[0]) 68 idx_out.append(r[1]) 69 return idx_out 70 71def read_tinker_idx(args): 72 ''' Read tinker indices 73 74 Example: 75 ['5'] -> {5} 76 ['-5', '7', '10', '-12', '15'] -> {5, 6, 7, 10, 12, 13, 14, 15} 77 78 args: list of strings of integers 79 return: a set of indices 80 81 ''' 82 _range = [] 83 idxs = [] 84 for a in args: 85 n = int(a) 86 if n < 0 and len(_range) == 0: 87 _range.append(-n) 88 if n > 0: 89 if len(_range) == 1: 90 idxs.extend(list(range(_range.pop(), n+1))) 91 else: 92 idxs.append(n) 93 return set(idxs) 94 95def read_ligidx(fkey): 96 '''Read tinker key file and return indices of the ligand 97 ''' 98 ligand_idx = set() # ligand indices 99 with open(fkey, 'r') as fh: 100 for line in fh: 101 w = line.split() 102 if line.lower().startswith('ligand') and len(w) >= 2: 103 idx_str = line.replace(',', ' ').split()[1:] 104 ligand_idx |= read_tinker_idx(idx_str) 105 return ligand_idx 106 107 108def target_disp(weights, coord, alpha=0.1): 109 '''target function for center of mass displacement and number of atoms in the group 110 ''' 111 assert len(weights) == coord.shape[0] 112 wts = np.array(weights).reshape((-1, 1)) 113 wts = np.maximum(wts, 0) 114 assert sum(wts) > 0 115 wts *= 1.0/np.mean(wts) 116 117 loss = np.sum(np.abs(np.sum(wts*coord, axis=0))) 118 loss += alpha*np.sum(np.abs(np.abs(wts*coord))) 119 120 for m in np.arange(2, 20): 121 mask0 = (wts > m) 122 loss += alpha*np.sum((np.abs((wts*mask0)))) 123 124 return loss 125 126def calc_coord(traj, idxs): 127 vals = np.zeros(traj.n_frames) 128 idxs = np.array(idxs) 129 NM_IN_ANG = 10.0 130 RAD_IN_DEG = 180.0/np.pi 131 if len(idxs) == 2: 132 vals = md.compute_distances(traj, idxs.reshape(1, -1)) 133 vals *= NM_IN_ANG 134 elif len(idxs) == 3: 135 vals = md.compute_angles(traj, idxs.reshape(1, -1)) 136 vals *= RAD_IN_DEG 137 elif len(idxs) == 4: 138 vals = md.compute_dihedrals(traj, idxs.reshape(1, -1)) 139 vals *= RAD_IN_DEG 140 return vals[:, 0] 141 142def calc_idx_ortho(traj, i1, i2, idx3, nbr=None, method='long'): 143 ''' 144 find the index that gives the largest ortho vector 145 146 nbr: list of nr of branched atoms 147 ''' 148 DELTA_R = 0.2 149 DELTA_COS = 0.3 150 151 if nbr is None: 152 nbr = defaultdict(int) 153 t = traj 154 a1 = t.xyz[0, [i1], :] - t.xyz[0, [i2], :] 155 u1 = a1/np.linalg.norm(a1, axis=1) 156 157 vec2 = t.xyz[0, idx3, :] - t.xyz[0, [i2], :] 158 d2 = np.linalg.norm(vec2, axis=1) 159 d2p = np.abs(np.sum(vec2 * u1,axis=1)) 160 d2o = np.sqrt(d2**2.0 - d2p**2.0) 161 d2cos = d2p / np.maximum(1e-5, d2) 162 if method == 'short': 163 # large angle (~90 deg), short distance 164 d2on = np.array(list(zip([nbr[_]<1 for _ in idx3], -d2cos//DELTA_COS, -d2o)) , dtype=[('n', np.int),('c', np.float), ('r', np.float)]) 165 else: 166 # large ortho vector 167 d2on = np.array(list(zip([nbr[_] for _ in idx3], d2o)), dtype=[('n', np.int), ('r', np.float)]) 168 i3 = idx3[np.argsort(d2on)[-1]] 169 return i3 170 171def write_xyz_from_md(traj, idx): 172 outp = '%d\nExtracted from MD\n'%(len(idx)) 173 for n, i in enumerate(idx): 174 atom = traj.topology.atom(i) 175 xyz = traj.xyz[0, i, :]*10 176 outp += '%5s %12.6f %12.6f %12.6f\n'%(atom.element.symbol, xyz[0], xyz[1], xyz[2]) 177 return outp 178 179def get_rotlist(traj, ligidx): 180 ftype = 'xyz' 181 ligidx_sort = sorted(ligidx) 182 outp_xyz = write_xyz_from_md(traj, ligidx_sort) 183 mymols = list([pybel.readstring(ftype, outp_xyz)]) 184 mymol = mymols[0] 185 iter_bond = openbabel.OBMolBondIter(mymol.OBMol) 186 rotlist = [] 187 for bond in iter_bond: 188 if bond.IsRotor(): 189 i1, i2 = (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) 190 rotlist.append((ligidx_sort[i1-1], ligidx_sort[i2-1])) 191 return rotlist 192 193def get_idx_pocket(traj, idx1, idx2, rcutoff=0.5): 194 R_CLUSTER = rcutoff 195 t = traj 196 nbr = get_tab_branch(t.topology) 197 sorted_i1, sorted_i2, dist1, dist2 = calc_idx_interface(traj, idx1, idx2, rcutoff=rcutoff) 198 199 if len(sorted_i1) == 0: 200 error_exit('No interface atom found within %.3f nm'%rcutoff) 201 202 prot_idx = sorted_i1 203 lig_idx = sorted_i2 204 prot_dist = pdist(t.xyz[0, prot_idx, :]) 205 prot_distmat = distance_matrix(t.xyz[0, prot_idx, :],t.xyz[0, prot_idx, :]) 206 Z = linkage(prot_dist, method='average') 207 208 clst_idx = fcluster(Z, R_CLUSTER, criterion='distance') 209 clst_size = Counter(clst_idx) 210 clst_size_list = sorted(clst_size.items(), key = lambda t:t[1]) 211 iclstm = clst_size_list[-1][0] 212 213 flag_max = (clst_idx == iclstm) 214 iprot = prot_idx[flag_max][0] 215 216 lig_dist = np.linalg.norm(t.xyz[0, lig_idx, :] - t.xyz[0, [iprot], :], axis=1) 217 dist_nr2 = np.array(list(zip(lig_dist, [nbr[_]<=1 for _ in lig_idx])), dtype=[('r', np.float), ('n', np.int)]) 218 ord_nr2 = np.argsort(dist_nr2, order=('n', 'r')) 219 ilig = sorted_i2[ord_nr2[0]] 220 221 iprot2 = calc_idx_ortho(traj, ilig, iprot, prot_idx, nbr) 222 iprot3 = calc_idx_ortho(traj, iprot, iprot2, prot_idx[prot_idx != iprot], nbr) 223 224 ilig2 = calc_idx_ortho(traj, iprot, ilig, lig_idx[lig_idx != ilig], nbr, method='short') 225 ilig3 = calc_idx_ortho(traj, ilig, ilig2, lig_idx[(lig_idx != ilig)*(lig_idx != ilig2)], nbr, method='short') 226 227 # -r- 228 # -a- -a- 229 # iprot3 ... iprot2 -t- iprot -t- ilig -t- ilig2 ... ilig3 230 int_idxs = [[iprot, ilig]] 231 int_idxs.extend([[iprot2, iprot, ilig], [iprot, ilig, ilig2]]) 232 int_idxs.extend([[iprot3, iprot2, iprot, ilig], [iprot2, iprot, ilig, ilig2], [iprot, ilig, ilig2, ilig3]]) 233 #r0s = [calc_int(t.xyz[0, _, :])[0] for _ in int_idxs] 234 #print(r0s) 235 r0s = [calc_coord(t, _)[0] for _ in int_idxs] 236 237 return int_idxs, r0s 238 239def calc_idx_interface(traj, idx1, idx2, rcutoff=0.6): 240 t = traj 241 distmat = distance_matrix(t.xyz[0, idx1, :], t.xyz[0, idx2, :]) 242 dist2 = np.min(distmat, axis=0) 243 dist1 = np.min(distmat, axis=1) 244 245 ord1 = np.argsort(dist1) 246 ord2 = np.argsort(dist2) 247 248 n1 = np.sum(dist1 <= rcutoff) 249 n2 = np.sum(dist2 <= rcutoff) 250 251 return idx1[ord1[:n1]], idx2[ord2[:n2]], dist1[ord1[:n1]], dist2[ord2[:n2]] 252 253def write_rest(int_idxs, r0s, k0s, fmt='tinker'): 254 rest_name = ['', '', 'restrain-distance', 'restrain-angle', 'restrain-torsion'] 255 outp = '' 256 for ridx, r0, k0 in zip(int_idxs, r0s, k0s): 257 nat = len(ridx) 258 if nat >= len(rest_name): 259 continue 260 if fmt == 'tinker': 261 outp += '%s %s %.6f %.6f %.6f\n'%(rest_name[nat], ' '.join('%4d'%(_+1) for _ in ridx), k0, r0, r0) 262 else: 263 print("Format %s not supported"%fmt) 264 return '' 265 return outp 266 267def get_rottors(traj, idx1): 268 t = traj 269 rotlist = get_rotlist(t, idx1) 270 tors = [] 271 adjlist = get_adjlist(traj.topology) 272 for bond in rotlist: 273 a2 = min(bond) 274 a3 = max(bond) 275 a1s = [_ for _ in adjlist[a2] if _ != a3] 276 a4s = [_ for _ in adjlist[a3] if _ != a2] 277 if len(a1s) * len(a4s) == 0: 278 error_exit("Cannot find heavy atom connected to rotable bond %d %d"%(a2, a3)) 279 tors.append([a1s[0], a2, a3, a4s[0]]) 280 r0s = [calc_coord(t, _)[0] for _ in tors] 281 k0s = np.zeros_like(r0s) + 0.01 282 print(write_rest(tors, r0s, k0s)) 283 return tors 284 285def find_rotation_rest(fxyz, ligidx0, atomnames='CA', rcutoff=0.6, alpha=1.0, rest_rotbond=False, rotbond_only=False): 286 try: 287 t = md.load_arc(fxyz) 288 except IOError: 289 t = md.load(fxyz) 290 if rest_rotbond: 291 get_rottors(t, ligidx0) 292 if rotbond_only: 293 return 294 295 nbr = get_tab_branch(t.topology) 296 297 ligidx = np.array(sorted(list(set(ligidx0) - set(t.topology.select('name H'))))) 298 if len(ligidx) == 0: 299 error_exit("No ligand heavy atoms found") 300 301 protidx0 = t.topology.select('name %s'%(atomnames)) 302 protidx0 = np.array(sorted(set(protidx0) - set(ligidx))) 303 304 if len(protidx0) == 0: 305 error_exit("No ligand atoms within %.3f nm of protein %s atoms"%(rcutoff, atomnames)) 306 307 int_idxs, r0s = get_idx_pocket(t, protidx0, ligidx) 308 RAD_IN_DEG = 180/np.pi 309 k0s = np.zeros_like(r0s) + 10.0/(RAD_IN_DEG)**2.0 310 k0s[0] = 10.0 311 print(write_rest(int_idxs, r0s, k0s)) 312 RT = 8.314 * 298 / 4184 313 # https://doi.org/10.1021/jp0217839 314 # Eq. (14) 315 assert len(k0s) == 6 316 dgrest = RT*np.log(1662*8*np.pi**2.0*np.sqrt(np.prod(k0s)*RAD_IN_DEG**10.0) \ 317 /(r0s[0]**2.0*np.sin(r0s[1]/RAD_IN_DEG)*np.sin(r0s[2]/RAD_IN_DEG)*(2*np.pi*RT)**3)) 318 319 print("#dGrest(kcal/mol) %.5f"%dgrest) 320 321 322 323def find_grp_idx(fxyz, ligidx0, atomnames='CA', rcutoff=1.2, alpha=1.0): 324 try: 325 t = md.load_arc(fxyz) 326 except IOError: 327 t = md.load(fxyz) 328 ligidx = np.array(sorted(list(set(ligidx0) - set(t.topology.select('name H'))))) 329 if len(ligidx) == 0: 330 error_exit("No ligand heavy atoms found") 331 332 protidx0 = t.topology.select('name %s'%(atomnames)) 333 protidx0 = np.array(sorted(set(protidx0) - set(ligidx))) 334 335 distmat = distance_matrix(t.xyz[0, ligidx, :], t.xyz[0, protidx0, :]) 336 distpro = np.min(distmat, axis=0) 337 protidx = protidx0[distpro <= rcutoff] 338 if len(protidx) == 0: 339 return 340 341 imindist = np.argmin(distmat) 342 iminlig = ligidx[imindist // len(protidx0)] 343 344 com_lig = np.mean(t.xyz[0, list(ligidx), :], axis=0) 345 com_lig = t.xyz[0, [iminlig], :] 346 347 348 wts0 = np.ones(len(protidx)) 349 xyz_prot = t.xyz[0, protidx, :] 350 xyz1_prot = xyz_prot - com_lig.reshape((1, -1)) 351 res = scipy.optimize.minimize(target_disp, wts0, args=(xyz1_prot, alpha)) 352 wts1 = np.array(res.x) 353 wts1 = np.maximum(0, wts1) 354 wts1 *= 1.0/np.sum(wts1) 355 wtm = np.mean(wts1[wts1 > 0]) 356 mask1 = wts1 > 0.4*wtm 357 #print(wts1) 358 #print(np.mean(wts1)) 359 #print(mask1) 360 com_p1 = (np.mean(t.xyz[0, protidx[mask1], :], axis=0)).reshape((1, -1)) 361 362 xyz_lig = t.xyz[0, ligidx, :] 363 dists = np.linalg.norm(xyz_lig - com_p1, axis=1) 364 #print('DIST', dists) 365 imin = np.argmin(dists) 366 367 dcom = np.linalg.norm(com_lig - com_p1) 368 dmin = np.linalg.norm(xyz_lig[imin] - com_p1) 369 370 371 idx_tinker = write_tinker_idx([_+1 for _ in protidx[mask1]]) 372 #print('#', (' '.join(['%5d'%(_) for _ in protidx[mask1]]))) 373 outp = '' 374 sgrp = '' 375 for n in idx_tinker: 376 sgrp += ' %5d'%n 377 if len(sgrp) > 50 and n > 0: 378 #print('group 1 %s'%sgrp) 379 outp += ('group 1 %s\n'%sgrp) 380 sgrp = '' 381 382 outp += ('group 2 %s\n'%(' '.join(['%5d'%(_) for _ in [ligidx[imin] + 1]]))) 383 outp += ("#r_0=%.3f"%(dmin*10)) 384 #print('group 2 %s'%(' '.join(['%5d'%(_) for _ in [ligidx[imin] + 1]]))) 385 #print("#r_0=%.3f"%(dmin*10)) 386 return dmin*10, outp 387 388 389def main(): 390 if len(sys.argv) <= 2: 391 print(__doc__) 392 return 393 fxyz = sys.argv[1] 394 fkey = sys.argv[2] 395 ligidx = [_-1 for _ in sorted(read_ligidx(fkey))] 396 args = sys.argv[3:] 397 if len(ligidx) == 0: 398 error_exit("ligand keyword not found") 399 #rest_rotbond = ( len(sys.argv)>3 and sys.argv[3] == 'rotatable') 400 rest_rotbond = 'rotatable' in args 401 no_orient = 'only' in args 402 #find_rotation_rest(fxyz, ligidx, "CA C N") 403 find_rotation_rest(fxyz, ligidx, "CA C N", rest_rotbond=rest_rotbond, rotbond_only=no_orient) 404 405main() 406 407