1#!/usr/bin/env python
2'''
3Get Tinker keywords for Boresch restraint
4
5> get_rot_rest.py sample.xyz sample.key
6
7Read any traj format and "ligand keyword" from Tinker key file
8Print 6 distance/angle/torsion restraints and the standard state correction
9
10'''
11
12import numpy as np
13import mdtraj as md
14import sys
15import os
16import time
17import scipy.optimize
18from scipy.spatial import distance_matrix
19from scipy.cluster.hierarchy import fcluster, leaders, linkage
20from scipy.spatial.distance import pdist
21from collections import Counter, defaultdict
22from openbabel import openbabel
23from openbabel import pybel
24
25def get_adjlist(top):
26    adjlist = defaultdict(list)
27    for bond in top.bonds:
28        a1, a2 = bond
29        if not (a1.name.startswith('H') or a2.name.startswith('H') ):
30            adjlist[a1.index].append(a2.index)
31            adjlist[a2.index].append(a1.index)
32    return adjlist
33
34def get_tab_branch(top):
35    '''
36    Given mdtraj Trajectory.Topology,
37    returns a dictionary for the number of connected heavy atoms
38    '''
39    i = 0
40    nbr = np.zeros(top.n_atoms, dtype=np.int)
41    for bond in top.bonds:
42        i += 1
43        a1, a2 = bond
44        if not (a1.name.startswith('H') or a2.name.startswith('H') ):
45            nbr[a1.index] += 1
46            nbr[a2.index] += 1
47    return nbr
48
49def error_exit(msg):
50    print("ERROR:", msg)
51    sys.exit(1)
52def write_tinker_idx(idxs):
53    idx_out = []
54    rs = []
55    for i0 in sorted(idxs):
56        if len(rs) == 0 or rs[-1][1]+1 < i0:
57            rs.append([i0, i0])
58        else:
59            rs[-1][1] = i0
60    for r in rs:
61        if r[0] == r[1]:
62            idx_out.append(r[0])
63        elif r[0] == r[1] - 1:
64            idx_out.append(r[0])
65            idx_out.append(r[1])
66        else:
67            idx_out.append(-r[0])
68            idx_out.append(r[1])
69    return idx_out
70
71def read_tinker_idx(args):
72    ''' Read tinker indices
73
74    Example:
75    ['5'] -> {5}
76    ['-5', '7', '10', '-12', '15'] -> {5, 6, 7, 10, 12, 13, 14, 15}
77
78    args: list of strings of integers
79    return: a set of indices
80
81    '''
82    _range = []
83    idxs = []
84    for a in args:
85        n = int(a)
86        if n < 0 and len(_range) == 0:
87            _range.append(-n)
88        if n > 0:
89            if len(_range) == 1:
90                idxs.extend(list(range(_range.pop(), n+1)))
91            else:
92                idxs.append(n)
93    return set(idxs)
94
95def read_ligidx(fkey):
96    '''Read tinker key file and return indices of the ligand
97    '''
98    ligand_idx = set() # ligand indices
99    with open(fkey, 'r') as fh:
100        for line in fh:
101            w = line.split()
102            if line.lower().startswith('ligand') and len(w) >= 2:
103                idx_str = line.replace(',', ' ').split()[1:]
104                ligand_idx |= read_tinker_idx(idx_str)
105    return ligand_idx
106
107
108def target_disp(weights, coord, alpha=0.1):
109    '''target function for center of mass displacement and number of atoms in the group
110    '''
111    assert len(weights) == coord.shape[0]
112    wts = np.array(weights).reshape((-1, 1))
113    wts = np.maximum(wts, 0)
114    assert sum(wts) > 0
115    wts *= 1.0/np.mean(wts)
116
117    loss = np.sum(np.abs(np.sum(wts*coord, axis=0)))
118    loss += alpha*np.sum(np.abs(np.abs(wts*coord)))
119
120    for m in np.arange(2, 20):
121        mask0 = (wts > m)
122        loss += alpha*np.sum((np.abs((wts*mask0))))
123
124    return loss
125
126def calc_coord(traj, idxs):
127    vals = np.zeros(traj.n_frames)
128    idxs = np.array(idxs)
129    NM_IN_ANG = 10.0
130    RAD_IN_DEG = 180.0/np.pi
131    if len(idxs) == 2:
132        vals = md.compute_distances(traj, idxs.reshape(1, -1))
133        vals *= NM_IN_ANG
134    elif len(idxs) == 3:
135        vals = md.compute_angles(traj, idxs.reshape(1, -1))
136        vals *= RAD_IN_DEG
137    elif len(idxs) == 4:
138        vals = md.compute_dihedrals(traj, idxs.reshape(1, -1))
139        vals *= RAD_IN_DEG
140    return vals[:, 0]
141
142def calc_idx_ortho(traj, i1, i2, idx3, nbr=None, method='long'):
143    '''
144    find the index that gives the largest ortho vector
145
146    nbr: list of nr of branched atoms
147    '''
148    DELTA_R = 0.2
149    DELTA_COS = 0.3
150
151    if nbr is None:
152        nbr = defaultdict(int)
153    t = traj
154    a1 = t.xyz[0, [i1], :] - t.xyz[0, [i2], :]
155    u1 = a1/np.linalg.norm(a1, axis=1)
156
157    vec2 = t.xyz[0, idx3, :] - t.xyz[0, [i2], :]
158    d2 = np.linalg.norm(vec2, axis=1)
159    d2p = np.abs(np.sum(vec2 * u1,axis=1))
160    d2o = np.sqrt(d2**2.0 - d2p**2.0)
161    d2cos = d2p / np.maximum(1e-5, d2)
162    if method == 'short':
163        # large angle (~90 deg), short distance
164        d2on = np.array(list(zip([nbr[_]<1 for _ in idx3], -d2cos//DELTA_COS, -d2o)) , dtype=[('n', np.int),('c', np.float), ('r', np.float)])
165    else:
166        # large ortho vector
167        d2on = np.array(list(zip([nbr[_] for _ in idx3], d2o)), dtype=[('n', np.int), ('r', np.float)])
168    i3 = idx3[np.argsort(d2on)[-1]]
169    return i3
170
171def write_xyz_from_md(traj, idx):
172    outp = '%d\nExtracted from MD\n'%(len(idx))
173    for n, i in enumerate(idx):
174        atom = traj.topology.atom(i)
175        xyz = traj.xyz[0, i, :]*10
176        outp += '%5s %12.6f %12.6f %12.6f\n'%(atom.element.symbol, xyz[0], xyz[1], xyz[2])
177    return outp
178
179def get_rotlist(traj, ligidx):
180    ftype = 'xyz'
181    ligidx_sort = sorted(ligidx)
182    outp_xyz = write_xyz_from_md(traj, ligidx_sort)
183    mymols = list([pybel.readstring(ftype, outp_xyz)])
184    mymol = mymols[0]
185    iter_bond = openbabel.OBMolBondIter(mymol.OBMol)
186    rotlist = []
187    for bond in iter_bond:
188        if bond.IsRotor():
189            i1, i2 = (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
190            rotlist.append((ligidx_sort[i1-1], ligidx_sort[i2-1]))
191    return rotlist
192
193def get_idx_pocket(traj, idx1, idx2, rcutoff=0.5):
194    R_CLUSTER = rcutoff
195    t = traj
196    nbr = get_tab_branch(t.topology)
197    sorted_i1, sorted_i2, dist1, dist2 = calc_idx_interface(traj, idx1, idx2, rcutoff=rcutoff)
198
199    if len(sorted_i1) == 0:
200        error_exit('No interface atom found within %.3f nm'%rcutoff)
201
202    prot_idx = sorted_i1
203    lig_idx = sorted_i2
204    prot_dist = pdist(t.xyz[0, prot_idx, :])
205    prot_distmat = distance_matrix(t.xyz[0, prot_idx, :],t.xyz[0, prot_idx, :])
206    Z = linkage(prot_dist, method='average')
207
208    clst_idx = fcluster(Z, R_CLUSTER, criterion='distance')
209    clst_size = Counter(clst_idx)
210    clst_size_list = sorted(clst_size.items(), key = lambda t:t[1])
211    iclstm = clst_size_list[-1][0]
212
213    flag_max = (clst_idx == iclstm)
214    iprot = prot_idx[flag_max][0]
215
216    lig_dist = np.linalg.norm(t.xyz[0, lig_idx, :] - t.xyz[0, [iprot], :], axis=1)
217    dist_nr2 = np.array(list(zip(lig_dist, [nbr[_]<=1 for _ in lig_idx])), dtype=[('r', np.float), ('n', np.int)])
218    ord_nr2 = np.argsort(dist_nr2, order=('n', 'r'))
219    ilig = sorted_i2[ord_nr2[0]]
220
221    iprot2 = calc_idx_ortho(traj, ilig, iprot, prot_idx, nbr)
222    iprot3 = calc_idx_ortho(traj, iprot, iprot2, prot_idx[prot_idx != iprot], nbr)
223
224    ilig2 = calc_idx_ortho(traj, iprot, ilig, lig_idx[lig_idx != ilig], nbr, method='short')
225    ilig3 = calc_idx_ortho(traj, ilig, ilig2, lig_idx[(lig_idx != ilig)*(lig_idx != ilig2)], nbr, method='short')
226
227    #                             -r-
228    #                        -a-      -a-
229    # iprot3 ... iprot2 -t- iprot -t- ilig -t- ilig2 ... ilig3
230    int_idxs = [[iprot, ilig]]
231    int_idxs.extend([[iprot2, iprot, ilig], [iprot, ilig, ilig2]])
232    int_idxs.extend([[iprot3, iprot2, iprot, ilig], [iprot2, iprot, ilig, ilig2], [iprot, ilig, ilig2, ilig3]])
233    #r0s = [calc_int(t.xyz[0, _, :])[0] for _ in int_idxs]
234    #print(r0s)
235    r0s = [calc_coord(t, _)[0] for _ in int_idxs]
236
237    return int_idxs, r0s
238
239def calc_idx_interface(traj, idx1, idx2, rcutoff=0.6):
240    t = traj
241    distmat = distance_matrix(t.xyz[0, idx1, :], t.xyz[0, idx2, :])
242    dist2 = np.min(distmat, axis=0)
243    dist1 = np.min(distmat, axis=1)
244
245    ord1 = np.argsort(dist1)
246    ord2 = np.argsort(dist2)
247
248    n1 = np.sum(dist1 <= rcutoff)
249    n2 = np.sum(dist2 <= rcutoff)
250
251    return idx1[ord1[:n1]], idx2[ord2[:n2]], dist1[ord1[:n1]], dist2[ord2[:n2]]
252
253def write_rest(int_idxs, r0s, k0s, fmt='tinker'):
254    rest_name = ['', '', 'restrain-distance', 'restrain-angle', 'restrain-torsion']
255    outp = ''
256    for ridx, r0, k0 in zip(int_idxs, r0s, k0s):
257        nat = len(ridx)
258        if nat >= len(rest_name):
259            continue
260        if fmt == 'tinker':
261            outp += '%s %s %.6f %.6f %.6f\n'%(rest_name[nat], ' '.join('%4d'%(_+1) for _ in ridx), k0, r0, r0)
262        else:
263            print("Format %s not supported"%fmt)
264            return ''
265    return outp
266
267def get_rottors(traj, idx1):
268    t = traj
269    rotlist = get_rotlist(t, idx1)
270    tors = []
271    adjlist = get_adjlist(traj.topology)
272    for bond in rotlist:
273        a2 = min(bond)
274        a3 = max(bond)
275        a1s = [_ for _ in adjlist[a2] if _ != a3]
276        a4s = [_ for _ in adjlist[a3] if _ != a2]
277        if len(a1s) * len(a4s) == 0:
278            error_exit("Cannot find heavy atom connected to rotable bond %d %d"%(a2, a3))
279        tors.append([a1s[0], a2, a3, a4s[0]])
280    r0s = [calc_coord(t, _)[0] for _ in tors]
281    k0s = np.zeros_like(r0s) + 0.01
282    print(write_rest(tors, r0s, k0s))
283    return tors
284
285def find_rotation_rest(fxyz, ligidx0, atomnames='CA', rcutoff=0.6, alpha=1.0, rest_rotbond=False, rotbond_only=False):
286    try:
287        t = md.load_arc(fxyz)
288    except IOError:
289        t = md.load(fxyz)
290    if rest_rotbond:
291        get_rottors(t, ligidx0)
292    if rotbond_only:
293        return
294
295    nbr = get_tab_branch(t.topology)
296
297    ligidx = np.array(sorted(list(set(ligidx0) - set(t.topology.select('name H')))))
298    if len(ligidx) == 0:
299        error_exit("No ligand heavy atoms found")
300
301    protidx0 = t.topology.select('name %s'%(atomnames))
302    protidx0 = np.array(sorted(set(protidx0) - set(ligidx)))
303
304    if len(protidx0) == 0:
305        error_exit("No ligand atoms within %.3f nm of protein %s atoms"%(rcutoff, atomnames))
306
307    int_idxs, r0s = get_idx_pocket(t, protidx0, ligidx)
308    RAD_IN_DEG = 180/np.pi
309    k0s = np.zeros_like(r0s) + 10.0/(RAD_IN_DEG)**2.0
310    k0s[0] = 10.0
311    print(write_rest(int_idxs, r0s, k0s))
312    RT = 8.314 * 298 / 4184
313    # https://doi.org/10.1021/jp0217839
314    # Eq. (14)
315    assert len(k0s) == 6
316    dgrest = RT*np.log(1662*8*np.pi**2.0*np.sqrt(np.prod(k0s)*RAD_IN_DEG**10.0) \
317                        /(r0s[0]**2.0*np.sin(r0s[1]/RAD_IN_DEG)*np.sin(r0s[2]/RAD_IN_DEG)*(2*np.pi*RT)**3))
318
319    print("#dGrest(kcal/mol) %.5f"%dgrest)
320
321
322
323def find_grp_idx(fxyz, ligidx0, atomnames='CA', rcutoff=1.2, alpha=1.0):
324    try:
325        t = md.load_arc(fxyz)
326    except IOError:
327        t = md.load(fxyz)
328    ligidx = np.array(sorted(list(set(ligidx0) - set(t.topology.select('name H')))))
329    if len(ligidx) == 0:
330        error_exit("No ligand heavy atoms found")
331
332    protidx0 = t.topology.select('name %s'%(atomnames))
333    protidx0 = np.array(sorted(set(protidx0) - set(ligidx)))
334
335    distmat = distance_matrix(t.xyz[0, ligidx, :], t.xyz[0, protidx0, :])
336    distpro = np.min(distmat, axis=0)
337    protidx = protidx0[distpro <= rcutoff]
338    if len(protidx) == 0:
339        return
340
341    imindist = np.argmin(distmat)
342    iminlig = ligidx[imindist // len(protidx0)]
343
344    com_lig = np.mean(t.xyz[0, list(ligidx), :], axis=0)
345    com_lig = t.xyz[0, [iminlig], :]
346
347
348    wts0 = np.ones(len(protidx))
349    xyz_prot = t.xyz[0, protidx, :]
350    xyz1_prot = xyz_prot - com_lig.reshape((1, -1))
351    res = scipy.optimize.minimize(target_disp, wts0, args=(xyz1_prot, alpha))
352    wts1 = np.array(res.x)
353    wts1 = np.maximum(0, wts1)
354    wts1 *= 1.0/np.sum(wts1)
355    wtm = np.mean(wts1[wts1 > 0])
356    mask1 = wts1 > 0.4*wtm
357    #print(wts1)
358    #print(np.mean(wts1))
359    #print(mask1)
360    com_p1 = (np.mean(t.xyz[0, protidx[mask1], :], axis=0)).reshape((1, -1))
361
362    xyz_lig = t.xyz[0, ligidx, :]
363    dists = np.linalg.norm(xyz_lig - com_p1, axis=1)
364    #print('DIST', dists)
365    imin = np.argmin(dists)
366
367    dcom = np.linalg.norm(com_lig - com_p1)
368    dmin = np.linalg.norm(xyz_lig[imin] - com_p1)
369
370
371    idx_tinker = write_tinker_idx([_+1 for _ in protidx[mask1]])
372    #print('#', (' '.join(['%5d'%(_) for _ in protidx[mask1]])))
373    outp = ''
374    sgrp = ''
375    for n in idx_tinker:
376        sgrp += ' %5d'%n
377        if len(sgrp) > 50 and n > 0:
378            #print('group 1 %s'%sgrp)
379            outp += ('group 1 %s\n'%sgrp)
380            sgrp = ''
381
382    outp += ('group 2 %s\n'%(' '.join(['%5d'%(_) for _ in [ligidx[imin] + 1]])))
383    outp += ("#r_0=%.3f"%(dmin*10))
384    #print('group 2 %s'%(' '.join(['%5d'%(_) for _ in [ligidx[imin] + 1]])))
385    #print("#r_0=%.3f"%(dmin*10))
386    return dmin*10, outp
387
388
389def main():
390    if len(sys.argv) <= 2:
391        print(__doc__)
392        return
393    fxyz = sys.argv[1]
394    fkey = sys.argv[2]
395    ligidx = [_-1 for _ in sorted(read_ligidx(fkey))]
396    args = sys.argv[3:]
397    if len(ligidx) == 0:
398        error_exit("ligand keyword not found")
399    #rest_rotbond = ( len(sys.argv)>3 and sys.argv[3] == 'rotatable')
400    rest_rotbond = 'rotatable' in args
401    no_orient = 'only' in args
402    #find_rotation_rest(fxyz, ligidx, "CA C N")
403    find_rotation_rest(fxyz, ligidx, "CA C N", rest_rotbond=rest_rotbond, rotbond_only=no_orient)
404
405main()
406
407