1__author__ = "Joshua Zosky"
2
3"""
4    Copyright 2015 Joshua Zosky
5    joshua.e.zosky@gmail.com
6
7    This file is part of "RetroTS".
8    "RetroTS" is free software: you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation, either version 3 of the License, or
11    (at your option) any later version.
12    "RetroTS" is distributed in the hope that it will be useful,
13    but WITHOUT ANY WARRANTY; without even the implied warranty of
14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15    GNU General Public License for more details.
16    You should have received a copy of the GNU General Public License
17    along with "RetroTS".  If not, see <http://www.gnu.org/licenses/>.
18"""
19
20import numpy  # delete this and replace with specific functions
21from numpy import nonzero, add, subtract, divide, mean, zeros, around
22from scipy.signal import firwin, lfilter
23from scipy.interpolate import interp1d
24
25# rcr: omit new style sub-library of pylab
26from pylab import plot, subplot, show
27from .zscale import z_scale
28
29
30def rvt_from_peakfinder(r):
31    if r["demo"]:
32        quiet = 0
33        print("quiet = %s" % quiet)
34    # Calculate RVT
35    if len(r["p_trace"]) != len(r["n_trace"]):
36        dd = abs(len(r["p_trace"]) - len(r["n_trace"]))
37        if dd > 1:  # have not seen this yet, trap for it.
38            print(
39                "Error RVT_from_PeakFinder:\n"
40                "  Peak trace lengths differ by %d\n"
41                "  This is unusual, please upload data\n"
42                "  sample to afni.nimh.nih.gov" % dd
43            )
44            # keyboard
45            return
46        else:  # just a difference of 1, happens sometimes, seems ok to discard one sample
47            print(
48                "Notice RVT_from_PeakFinder:\n"
49                "   Peak trace lengths differ by %d\n"
50                "   Clipping longer trace." % dd
51            )
52            dm = min(len(r["p_trace"]), len(r["n_trace"]))
53            if len(r["p_trace"]) != dm:
54                r["p_trace"] = r["p_trace"][0:dm]
55                r["tp_trace"] = r["tp_trace"][0:dm]
56            else:
57                r["n_trace"] = r["n_trace"][0:dm]
58                r["tn_trace"] = r["tn_trace"][0:dm]
59
60    r["rv"] = subtract(r["p_trace"], r["n_trace"])
61    # NEED TO consider which starts first and
62    # whether to initialize first two values by means
63    # and also, what to do when we are left with one
64    # incomplete pair at the end
65
66    nptrc = len(r["tp_trace"])
67    r["rvt"] = r["rv"][0 : nptrc - 1] / r["prd"]
68    if r["p_trace_r"].any:
69        r["rvr"] = subtract(r["p_trace_r"], r["n_trace_r"])
70        # Debugging lines below
71        # with open('rvr.csv', 'w') as f:
72        #     for i in r['rvr']:
73        #         f.write("%s\n" % i)
74        # with open('prdR.csv', 'w') as f:
75        #     for i in r['prdR']:
76        #         f.write("%s\n" % i)
77        r["rvtr"] = numpy.ndarray(numpy.shape(r["rvr"]))
78        divide(r["rvr"], r["prdR"], r["rvtr"])
79        # Smooth RVT so that we can resample it at volume_tr later
80        fnyq = r["phys_fs"] / 2  # nyquist of physio signal
81        fcut = 2 / r["volume_tr"]  # cut below nyquist for volume_tr
82        w = float(r["frequency_cutoff"]) / float(fnyq)  # cut off frequency normalized
83        b = firwin(numtaps=(r["fir_order"] + 1), cutoff=w, window="hamming")
84        v = r["rvtr"]
85        around(v, 6, v)
86        # Debugging lines below
87        # with open('a.csv', 'w') as f:
88        #     for i in v:
89        #         f.write("%s\n" % i)
90        mv = mean(v)
91        # remove the mean
92        v = v - mv
93        # filter both ways to cancel phase shift
94        v = lfilter(b, 1, v)
95        if r["legacy_transform"] == 0:
96            v = numpy.flipud(
97                v
98            )  # Turns out these don't do anything in the MATLAB version(Might be a major problem)
99        v = lfilter(b, 1, v)
100        if r["legacy_transform"] == 0:
101            v = numpy.flipud(
102                v
103            )  # Turns out these don't do anything in the MATLAB version(Might be a major problem)
104        r["rvtrs"] = v + mv
105
106    # create RVT regressors
107    r["rvtrs_slc"] = zeros((len(r["rvt_shifts"]), len(r["time_series_time"])))
108    for i in range(0, len(r["rvt_shifts"])):
109        shf = r["rvt_shifts"][i]
110        nsamp = int(round(shf * r["phys_fs"]))
111        sind = add(list(range(0, len(r["t"]))), nsamp)
112        print(sind)
113        sind[nonzero(sind < 0)] = 0
114        sind[nonzero(sind > (len(r["t"]) - 1))] = len(r["t"]) - 1
115        rvt_shf = interp1d(
116            r["t"], r["rvtrs"][sind], r["interpolation_style"], bounds_error=True
117        )
118        rvt_shf_y = rvt_shf(r["time_series_time"])
119        if r["quiet"] == 0 and r["show_graphs"] == 1:
120           # pacify matplotlib by passing a label (to get new instance)
121           subplot(111, label='plot #%d'%i)
122           plot(r["time_series_time"], rvt_shf_y)
123        r["rvtrs_slc"][:][i] = rvt_shf_y
124
125    if r["quiet"] == 0 and r["show_graphs"] == 1:
126        print("--> Calculated RVT \n--> Created RVT regressors")
127        subplot(211)
128        plot(
129            r["t_mid_prd"], z_scale(r["rvt"], min(r["p_trace"]), max(r["p_trace"])), "k"
130        )
131        if any(r["p_trace_r"]):
132            plot(
133                r["tR"], z_scale(r["rvtrs"], min(r["p_trace"]), max(r["p_trace"])), "m"
134            )
135        show()
136        if r["demo"]:
137            # uiwait(msgbox('Press button to resume', 'Pausing', 'modal'))
138            pass
139
140    return r
141