1from __future__ import print_function
2from graph import ProteinGraph
3from uncertainty import resolve_uncertainty
4from collections import defaultdict
5import math
6from pprint import pprint
7import sys
8
9# Use the number for R from https://en.wikipedia.org/wiki/Gas_constant
10gas_constant = 8.3144621
11Rln10 = math.log(10) * gas_constant
12T = 300.0
13Rln10_T = Rln10*T
14RT = gas_constant * T
15
16# TODO - figure out why modPkaHIP is hard-coded here!
17modPkaHIP = 6.6
18modPkaHIE = modPkaHIP
19modPkaHID = modPkaHIP
20
21def print_pc_state(pc, normal_form, out_file):
22    """Dump protein_complex state to out_file
23       normal_form - Dump the normal form part of the state."""
24    rv = pc.residue_variables
25    ie = pc.normalized_interaction_energies if normal_form else pc.interaction_energies_for_ph
26    for v_residue in rv.values():
27        for v_instance in v_residue.instances.values():
28            for w_residue in rv.values():
29                if v_residue == w_residue:
30                    continue
31                for w_instance in w_residue.instances.values():
32                    out_file.write(str((v_instance, w_instance)) + " " + str(round(ie[v_instance, w_instance],4)) + '\n')
33    keys = list(pc.residue_variables.keys())
34    for key in keys:
35        residue = pc.residue_variables[key]
36        for instance in list(residue.instances.values()):
37            if normal_form:
38                out_file.write(str(instance) + " " + str(round(instance.energyNF,4)) + "\n")
39            else:
40                out_file.write(str(instance) + " " + str(round(instance.energy_with_ph,4)) + "\n")
41    if normal_form:
42        out_file.write("Normalized constant energy: " + str(round(pc.normalized_constant_energy,4)) + "\n")
43
44def print_dg_state(dg, out_file):
45    """Dump directed graph state to out_file"""
46    out_file.write("Flow network:\nVertices:\n")
47
48    nodes = list(dg.node.keys())
49    nodes.sort()
50
51    for node in nodes:
52        out_file.write('_'.join(node)+"\n")
53
54    out_file.write("\nEdges:\n")
55
56    edges = []
57    for edge in dg.edges_iter(data="capacity"):
58        result = []
59        if isinstance(edge[0], tuple):
60            result.append('_'.join(edge[0]))
61        else:
62            result.append(edge[0])
63        if isinstance(edge[1], tuple):
64            result.append('_'.join(edge[1]))
65        else:
66            result.append(edge[1])
67
68        result.append(edge[2])
69
70        edges.append(result)
71
72    edges.sort()
73
74    for edge in edges:
75        out_file.write("(")
76        out_file.write(edge[0])
77        out_file.write(", ")
78        out_file.write(edge[1])
79        out_file.write(")= ")
80        out_file.write(str(round(edge[2],4))+"\n")
81
82
83def get_titration_curves(protein_complex, state_file=None):
84    """For each ph value:
85           Get the normal form of the protein energies.
86           Build a flow graph
87           Get the min cut of the graph
88           Find which state for each residue from the cut (labeling) and the unknown states (uncertain)
89           Use brute force or MC to resolve the uncertain states.
90           Calculate the curve value for each residue
91
92        Returns results for all residues for each ph."""
93
94    curves = defaultdict(list)
95    pg = ProteinGraph(protein_complex)
96    pH = 0.0
97    step_size = 0.1
98    end_ph = 20.0
99    steps = int(end_ph / step_size) + 1
100
101    for step in range(steps):
102        pH = step * step_size
103        print("pH", pH)
104
105        if state_file is not None:
106            state_file.write ("pH="+ str(pH)+"\n")
107            state_file.write("REGULAR ENERGIES\n")
108            protein_complex.energy_at_pH(pH)
109            print_pc_state(protein_complex, False, state_file)
110            state_file.write('\n')
111            state_file.write("NORMAL FORM ENERGIES\n")
112
113        protein_complex.normalize(pH)
114
115        if state_file is not None:
116            print_pc_state(protein_complex, True, state_file)
117            state_file.write('\n')
118
119        pg.update_graph()
120
121        if state_file is not None:
122            print_dg_state(pg.DG, state_file)
123            state_file.write('\n')
124
125        cv, s_nodes, t_nodes = pg.get_cut()
126        labeling, uncertain = pg.get_labeling_from_cut(s_nodes, t_nodes)
127        new_labeling = resolve_uncertainty(protein_complex, labeling, uncertain, verbose=True)
128
129        curve_values = get_curve_values(protein_complex, new_labeling, pH)
130        for key, value in curve_values.items():
131            curves[key].append((pH, value))
132
133    return curves
134
135def get_curve_values(protein_complex, labeling, pH):
136    """Using the given selected residue states (labeling) and pH get the
137       current curve value for all titratable residues."""
138    his_seen = set()
139    results = {}
140
141    aH = math.pow(10, -pH)
142
143    for key, residue in protein_complex.residue_variables.items():
144        name, chain, location = key
145
146        if name in ("HId", "HIe"):
147            #Do HIS stuff
148            if (chain, location) in his_seen:
149                continue
150            his_seen.add((chain, location))
151
152            if name == "HId":
153                hid_residue = residue
154                hie_residue = protein_complex.residue_variables["HIe", chain, location]
155            else:
156                hie_residue = residue
157                hid_residue = protein_complex.residue_variables["HId", chain, location]
158
159            # dge = HSP - HSE, dgd = HSP - HSD
160            class Energies:
161                pass
162            energies = Energies()
163            energies.aH = aH
164            energies.dGdref = modPkaHID*math.log(10.0)
165            energies.dGeref = modPkaHIE*math.log(10.0)
166            energies.dGe, energies.dGd = protein_complex.evaluate_energy_diff_his(hie_residue, hid_residue, labeling,
167                                                                normal_form=True)
168
169            debug_craziness = False
170            if debug_craziness:
171                print("!!! DEBUG - SETTING dGd from %g to 0.0" % energies.dGd)
172                energies.dGd = 0
173                print("!!! DEBUG - SETTING dGe from %g to 0.0" % energies.dGe)
174                energies.dGe = 0
175            else:
176                old = energies.dGd
177                energies.dGd =  energies.dGd - math.log(aH) - energies.dGdref
178                #print("Removed extra pH and pKa contributions from dGd: %g -> %g" % (old, energies.dGd))
179                old = energies.dGe
180                energies.dGe =  energies.dGe - math.log(aH) - energies.dGeref
181                #print("Removing extra pH and pKa contributions from dGe: %g -> %g" % (old, energies.dGe))
182            energies.ddG = (energies.dGe + energies.dGeref) - (energies.dGd + energies.dGdref)
183            energies.dGp = energies.dGd - energies.dGdref
184            #print(vars(energies))
185            pHSD = 1.0
186            pHSE = math.exp(-energies.ddG)
187            pHSP = energies.aH*math.exp(-energies.dGp)
188            Q = pHSD + pHSE + pHSP
189            fracHSD = pHSD/Q
190            fracHSE = pHSE/Q
191            fracHSP = pHSP/Q
192
193            # if not labeling[hie_residue].protonated and labeling[hid_residue].protonated:
194            #     titration_value = fracHSD
195            # elif labeling[hie_residue].protonated and not labeling[hid_residue].protonated:
196            #     titration_value = fracHSE
197            # elif labeling[hie_residue].protonated and labeling[hid_residue].protonated:
198            #     titration_value = fracHSP
199            # else:
200            #     errstr = "How did we get here?"
201            #     raise RuntimeError(errstr)
202            results["HIS", chain, location] = fracHSP
203            #results["HSE", chain, location] = fracHSE
204            #results["HSD", chain, location] = fracHSD
205
206        else:
207            #Do not HIS stuff
208            #energy_diff = protonated_energy - depotonated_energy
209            energy_diff = protein_complex.evaluate_energy_diff(residue, labeling, normal_form=True)
210
211            exp = -(energy_diff)
212            #Handle case where there is an unresolved bump.
213            try:
214                e_exp = math.exp(exp)
215                titration_value = e_exp/(1.0+e_exp)
216            except OverflowError:
217                titration_value = 1.0
218            results[key] = titration_value
219
220    return results
221