1__author__ = "Joshua Zosky"
2
3"""
4    Copyright 2015 Joshua Zosky
5    joshua.e.zosky@gmail.com
6
7    This file is part of "RetroTS".
8    "RetroTS" is free software: you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation, either version 3 of the License, or
11    (at your option) any later version.
12    "RetroTS" is distributed in the hope that it will be useful,
13    but WITHOUT ANY WARRANTY; without even the implied warranty of
14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15    GNU General Public License for more details.
16    You should have received a copy of the GNU General Public License
17    along with "RetroTS".  If not, see <http://www.gnu.org/licenses/>.
18"""
19
20from numpy import zeros, ones, nonzero, pi, argmin, sin, cos
21from numpy import size, arange, clip, histogram, r_, Inf, divide, append, delete, array
22from .zscale import z_scale
23import matplotlib.pyplot as plt
24
25
26def my_hist(x, bin_centers):
27    """
28    This frivolous yet convenient conversion from bin-edges to bin-centers is from Stack Overflow user Bas Swinckels
29    http://stackoverflow.com/questions/18065951/why-does-numpy-histogram-python-leave-off-one-element-as-compared-to-hist-in-m
30    :param x:dataset
31    :param bin_centers:bin values in a list to be moved from edges to centers
32    :return: counts = the data in bin centers ready for pyplot.bar
33    """
34    bin_edges = r_[-Inf, 0.5 * (bin_centers[:-1] + bin_centers[1:]), Inf]
35    counts, edges = histogram(x, bin_edges)
36    return counts
37
38
39def phase_base(amp_type, phasee):
40    """
41
42
43    :type phasee: object
44    :param amp_type:    if 0, it is a time-based phase estimation
45                        if 1, it is an amplitude-based phase estimation
46    :param phasee: phasee information
47    :return:
48    """
49    if amp_type == 0:
50        # Calculate the phase of the trace, with the peak to be the start of the phase
51        nptrc = len(phasee["tp_trace"])
52        phasee["phase"] = -2 * ones(size(phasee["t"]))
53        i = 0
54        j = 0
55        while i <= (nptrc - 2):
56            while phasee["t"][j] < phasee["tp_trace"][i + 1]:
57                if phasee["t"][j] >= phasee["tp_trace"][i]:
58                    # Note: Using a constant 244 period for each interval
59                    # causes slope discontinuity within a period.
60                    # One should resample period[i] so that it is
61                    # estimated at each time in phasee['t'][j],
62                    # dunno if that makes much of a difference in the end however.
63                    if j == 10975:
64                        pass
65                    phasee["phase"][j] = (
66                        phasee["t"][j] - phasee["tp_trace"][i]
67                    ) / phasee["prd"][i] + phasee["zero_phase_offset"]
68                    if phasee["phase"][j] < 0:
69                        phasee["phase"][j] = -phasee["phase"][j]
70                    if phasee["phase"][j] > 1:
71                        phasee["phase"][j] -= 1
72                j += 1
73            if i == 124:
74                pass
75            i += 1
76
77        # Remove the points flagged as unset
78        temp = nonzero(phasee["phase"] < -1)
79        phasee["phase"][temp] = 0.0
80        # Change phase to radians
81        phasee["phase"] = phasee["phase"] * 2 * pi
82    else:  # phase based on amplitude
83        # at first scale to the max
84        mxamp = max(phasee["p_trace"])
85        phasee["phase_pol"] = []
86        gR = z_scale(phasee["v"], 0, mxamp)  # Scale, per Glover 2000's paper
87        bins = arange(0.01, 1.01, 0.01) * mxamp
88        hb_value = my_hist(gR, bins)
89        # hb_value = histogram(gR, bins)
90        if phasee["show_graphs"] == 1:
91            center = (bins[:-1] + bins[1:]) / 2
92            plt.bar(center, hb_value[: len(hb_value) - 1])  # , align='center')
93            plt.show()
94        # find the polarity of each time point in v
95        i = 0
96        itp = 0
97        inp = 0
98        tp = phasee["tp_trace"][0]
99        tn = phasee["tn_trace"][0]
100        while (
101            (i <= len(phasee["v"])) and (phasee["t"][i] < tp) and (phasee["t"][i] < tn)
102        ):
103            phasee["phase_pol"].append(0)
104            i += 1
105        if tp < tn:
106            # Expiring phase (peak behind us)
107            cpol = -1
108            itp = 1
109        else:
110            # Inspiring phase (bottom behind us)
111            cpol = 1
112            inp = 1
113        phasee["phase_pol"] = zeros(
114            size(phasee["v"])
115        )  # Not sure why you would replace the
116        # list that you created 10 lines prior to this
117        # Add a fake point to tptrace and tntrace to avoid ugly if statements
118        phasee["tp_trace"] = append(phasee["tp_trace"], phasee["t"][-1])
119        phasee["tn_trace"] = append(phasee["tn_trace"], phasee["t"][-1])
120        while i < len(phasee["v"]):
121            phasee["phase_pol"][i] = cpol
122            if phasee["t"][i] == phasee["tp_trace"][itp]:
123                cpol = -1
124                itp = min((itp + 1), (len(phasee["tp_trace"]) - 1))
125            elif phasee["t"][i] == phasee["tn_trace"][inp]:
126                cpol = 1
127                inp = min((inp + 1), (len(phasee["tn_trace"]) - 1))
128            # cpol, inp, itp, i, R
129            i += 1
130        phasee["tp_trace"] = delete(phasee["tp_trace"], -1)
131        phasee["tn_trace"] = delete(phasee["tn_trace"], -1)
132        if phasee["show_graphs"] == 1:
133            # clf
134            plt.plot(phasee["t"], gR, "b")
135            ipositive = nonzero(phasee["phase_pol"] > 0)
136            ipositive = ipositive[0]
137            ipositive_x = []
138            for i in ipositive:
139                ipositive_x.append(phasee["t"][i])
140            ipositive_y = zeros(size(ipositive_x))
141            ipositive_y.fill(0.55 * mxamp)
142            plt.plot(ipositive_x, ipositive_y, "r.")
143            inegative = nonzero(phasee["phase_pol"] < 0)
144            inegative = inegative[0]
145            inegative_x = []
146            for i in inegative:
147                inegative_x.append(phasee["t"][i])
148            inegative_y = zeros(size(inegative_x))
149            inegative_y.fill(0.45 * mxamp)
150            plt.plot(inegative_x, inegative_y, "g.")
151            plt.show()
152        # Now that we have the polarity, without computing sign(dR/dt)
153        #   as in Glover et al 2000, calculate the phase per eq. 3 of that paper
154        # First the sum in the numerator
155        for i, val in enumerate(gR):
156            gR[i] = round(val / mxamp * 100) + 1
157        gR = clip(gR, 0, 99)
158        shb = sum(hb_value)
159        hbsum = []
160        hbsum.append(float(hb_value[0]) / shb)
161        for i in range(1, 100):
162            hbsum.append(hbsum[i - 1] + (float(hb_value[i]) / shb))
163        for i in range(len(phasee["t"])):
164            phasee["phase"].append(pi * hbsum[int(gR[i]) - 1] * phasee["phase_pol"][i])
165        phasee["phase"] = array(phasee["phase"])
166
167    # Time series time vector
168    phasee["time_series_time"] = arange(
169        0, (max(phasee["t"]) - 0.5 * phasee["volume_tr"]), phasee["volume_tr"]
170    )
171    # Python uses half open ranges, so we need to catch the case when the stop
172    # is evenly divisible by the step and add one more to the time series in
173    # order to match Matlab, which uses closed ranges  1 Jun 2017 [D Nielson]
174    if (max(phasee["t"]) - 0.5 * phasee["volume_tr"]) % phasee["volume_tr"] == 0:
175        phasee["time_series_time"] = append(
176            phasee["time_series_time"],
177            [phasee["time_series_time"][-1] + phasee["volume_tr"]],
178        )
179    phasee["phase_slice"] = zeros(
180        (len(phasee["time_series_time"]), phasee["number_of_slices"])
181    )
182    phasee["phase_slice_reg"] = zeros(
183        (len(phasee["time_series_time"]), 4, phasee["number_of_slices"])
184    )
185    for i_slice in range(phasee["number_of_slices"]):
186        tslc = phasee["time_series_time"] + phasee["slice_offset"][i_slice]
187        for i in range(len(phasee["time_series_time"])):
188            imin = argmin(abs(tslc[i] - phasee["t"]))
189            # mi = abs(tslc[i] - phasee['t']) # probably not needed
190            phasee["phase_slice"][i, i_slice] = phasee["phase"][imin]
191        # Make regressors for each slice
192        phasee["phase_slice_reg"][:, 0, i_slice] = sin(
193            phasee["phase_slice"][:, i_slice]
194        )
195        phasee["phase_slice_reg"][:, 1, i_slice] = cos(
196            phasee["phase_slice"][:, i_slice]
197        )
198        phasee["phase_slice_reg"][:, 2, i_slice] = sin(
199            2 * phasee["phase_slice"][:, i_slice]
200        )
201        phasee["phase_slice_reg"][:, 3, i_slice] = cos(
202            2 * phasee["phase_slice"][:, i_slice]
203        )
204
205    if phasee["quiet"] == 0 and phasee["show_graphs"] == 1:
206        print("--> Calculated phase")
207        plt.subplot(413)
208        a = divide(divide(phasee["phase"], 2), pi)
209        plt.plot(phasee["t"], divide(divide(phasee["phase"], 2), pi), "m")
210        if "phase_r" in phasee:
211            plt.plot(phasee["tR"], divide(divide(phasee["phase_r"], 2), pi), "m-.")
212        plt.subplot(414)
213        plt.plot(
214            phasee["time_series_time"],
215            phasee["phase_slice"][:, 1],
216            "ro",
217            phasee["time_series_time"],
218            phasee["phase_slice"][:, 2],
219            "bo",
220            phasee["time_series_time"],
221            phasee["phase_slice"][:, 2],
222            "b-",
223        )
224        plt.plot(phasee["t"], phasee["phase"], "k")
225        # grid on
226        # title it
227        plt.title(phasee["v_name"])
228        plt.show()
229        # Need to implement this yet
230        # if phasee['Demo']:
231        # uiwait(msgbox('Press button to resume', 'Pausing', 'modal'))
232    return phasee
233
234
235def phase_estimator(amp_phase, phase_info):
236    """
237    v_name='',
238    amp_phase=0,
239    t=[],
240    x=[],
241    iz=[],   # zero crossing (peak) locations
242    p_trace=[],
243    tp_trace=[],
244    n_trace=[],
245    tn_trace=[],
246    prd=[],
247    t_mid_prd=[],
248    p_trace_mid_prd=[],
249    phase=[],
250    rv=[],
251    rvt=[],
252    var_vector=[],
253    phys_fs=(1 / 0.025),
254    zero_phase_offset=0.5,
255    quiet=0,
256    resample_fs=(1 / 0.025),
257    f_cutoff=10,
258    fir_order=80,
259    resample_kernel='linear',
260    demo=0,
261    as_window_width=0,
262    as_percover=0,
263    as_fftwin=0,
264    sep_dups=0,
265    phasee_list=0,
266    show_graphs=0
267    """
268    """
269    Example: PhaseEstimator.phase_estimator(amp_phase, info_dictionary)
270    or PhaseEstimator.phase_estimator(v) where v is a list
271    if v is a matrix, each column is processed separately.
272    :param var_vector: column vector--list of list(s)
273    :param phys_fs: Sampling frequency
274    :param zero_phase_offset: Fraction of the period that corresponds to a phase of 0
275                                0.5 means the middle of the period, 0 means the 1st peak
276    :param quiet:
277    :param resample_fs:
278    :param frequency_cutoff:
279    :param fir_order: BC ???
280    :param resample_kernel:
281    :param demo:
282    :param as_window_width:
283    :param as_percover:
284    :param fftwin:
285    :param sep_dups:
286    :return: *_phased: phase estimation of input signal
287    """
288    phasee = dict(
289        v_name="",
290        t=[],
291        x=[],
292        iz=[],  # zero crossing (peak) locations
293        volume_tr=2,
294        p_trace=[],
295        tp_trace=[],
296        n_trace=[],
297        tn_trace=[],
298        prd=[],
299        t_mid_prd=[],
300        p_trace_mid_prd=[],
301        phase=[],
302        rv=[],
303        rvt=[],
304        var_vector=[],
305        phys_fs=(1 / 0.025),
306        zero_phase_offset=0.5,
307        quiet=0,
308        resample_fs=(1 / 0.025),
309        frequency_cutoff=10,
310        fir_order=80,
311        resample_kernel="linear",
312        demo=0,
313        as_window_width=0,
314        as_percover=0,
315        as_fftwin=0,
316        sep_dups=0,
317        phasee_list=0,
318        show_graphs=0,
319        number_of_slices=0,
320    )
321    phasee.update(phase_info)
322    if isinstance(phasee["phasee_list"], type([])):
323        return_phase_list = []
324        for phasee_column in phasee["phasee_list"]:
325            return_phase.append(phase_base(amp_phase, phasee_column))
326        return return_phase_list
327    else:
328        return_phase = phase_base(amp_phase, phasee)
329        return return_phase
330