1# -*- coding: utf-8 -*-
2#
3# voltage_trace.py
4#
5# This file is part of NEST.
6#
7# Copyright (C) 2004 The NEST Initiative
8#
9# NEST is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 2 of the License, or
12# (at your option) any later version.
13#
14# NEST is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with NEST.  If not, see <http://www.gnu.org/licenses/>.
21
22"""
23Functions to plot voltage traces.
24"""
25
26import nest
27import numpy
28
29__all__ = [
30    'from_device',
31    'from_file',
32]
33
34
35def from_file(fname, title=None, grayscale=False):
36    """Plot voltage trace from file.
37
38    Parameters
39    ----------
40    fname : str or list
41        Filename or list of filenames to load from
42    title : str, optional
43        Plot title
44    grayscale : bool, optional
45        Plot in grayscale
46
47    Raises
48    ------
49    ValueError
50    """
51    import matplotlib.pyplot as plt
52
53    if isinstance(fname, (list, tuple)):
54        data = None
55        for f in fname:
56            if data is None:
57                data = numpy.loadtxt(f)
58            else:
59                data = numpy.concatenate((data, numpy.loadtxt(f)))
60    else:
61        data = numpy.loadtxt(fname)
62
63    if grayscale:
64        line_style = "k"
65    else:
66        line_style = ""
67
68    if len(data.shape) == 1:
69        print("INFO: only found 1 column in the file. \
70            Assuming that only one neuron was recorded.")
71        plotid = plt.plot(data, line_style)
72        plt.xlabel("Time (steps of length interval)")
73
74    elif data.shape[1] == 2:
75        print("INFO: found 2 columns in the file. Assuming \
76            them to be node ID, pot.")
77
78        plotid = []
79        data_dict = {}
80        for d in data:
81            if not d[0] in data_dict:
82                data_dict[d[0]] = [d[1]]
83            else:
84                data_dict[d[0]].append(d[1])
85
86        for d in data_dict:
87            plotid.append(
88                plt.plot(data_dict[d], line_style, label="Neuron %i" % d)
89            )
90
91        plt.xlabel("Time (steps of length interval)")
92        plt.legend()
93
94    elif data.shape[1] == 3:
95        plotid = []
96        data_dict = {}
97        g = data[0][0]
98        t = []
99        for d in data:
100            if not d[0] in data_dict:
101                data_dict[d[0]] = [d[2]]
102            else:
103                data_dict[d[0]].append(d[2])
104            if d[0] == g:
105                t.append(d[1])
106
107        for d in data_dict:
108            plotid.append(
109                plt.plot(t, data_dict[d], line_style, label="Neuron %i" % d)
110            )
111
112        plt.xlabel("Time (ms)")
113        plt.legend()
114
115    else:
116        raise ValueError("Inappropriate data shape %i!" % data.shape)
117
118    if not title:
119        title = "Membrane potential from file '%s'" % fname
120
121    plt.title(title)
122    plt.ylabel("Membrane potential (mV)")
123    plt.draw()
124
125    return plotid
126
127
128def from_device(detec, neurons=None, title=None, grayscale=False,
129                timeunit="ms"):
130    """Plot the membrane potential of a set of neurons recorded by
131    the given voltmeter or multimeter.
132
133    Parameters
134    ----------
135    detec : list
136        Global id of voltmeter or multimeter in a list, e.g. [1]
137    neurons : list, optional
138        Indices of of neurons to plot
139    title : str, optional
140        Plot title
141    grayscale : bool, optional
142        Plot in grayscale
143    timeunit : str, optional
144        Unit of time
145
146    Raises
147    ------
148    nest.kernel.NESTError
149        Description
150    """
151    import matplotlib.pyplot as plt
152
153    if len(detec) > 1:
154        raise nest.kernel.NESTError("Please provide a single voltmeter.")
155
156    type_id = nest.GetDefaults(detec.get('model'), 'type_id')
157    if type_id not in ('voltmeter', 'multimeter'):
158        raise nest.kernel.NESTError("Please provide a voltmeter or a \
159            multimeter measuring V_m.")
160    elif type_id == 'multimeter':
161        if "V_m" not in detec.get("record_from"):
162            raise nest.kernel.NESTError("Please provide a multimeter \
163                measuring V_m.")
164        elif (not detec.get("record_to") == "memory" and
165              len(detec.get("record_from")) > 1):
166            raise nest.kernel.NESTError("Please provide a multimeter \
167                measuring only V_m or record to memory!")
168
169    if detec.get("record_to") == "memory":
170
171        timefactor = 1.0
172        if not detec.get('time_in_steps'):
173            if timeunit == "s":
174                timefactor = 1000.0
175            else:
176                timeunit = "ms"
177
178        times, voltages = _from_memory(detec)
179
180        if not len(times):
181            raise nest.NESTError("No events recorded!")
182
183        if neurons is None:
184            neurons = voltages.keys()
185
186        plotids = []
187        for neuron in neurons:
188            time_values = numpy.array(times[neuron]) / timefactor
189
190            if grayscale:
191                line_style = "k"
192            else:
193                line_style = ""
194
195            try:
196                plotids.append(
197                    plt.plot(time_values, voltages[neuron],
198                             line_style, label="Neuron %i" % neuron)
199                )
200            except KeyError:
201                print("INFO: Wrong ID: {0}".format(neuron))
202
203        if not title:
204            title = "Membrane potential"
205        plt.title(title)
206
207        plt.ylabel("Membrane potential (mV)")
208
209        if nest.GetStatus(detec)[0]['time_in_steps']:
210            plt.xlabel("Steps")
211        else:
212            plt.xlabel("Time (%s)" % timeunit)
213
214        plt.legend(loc="best")
215        plt.draw()
216
217        return plotids
218
219    elif detec.get("record_to") == "ascii":
220        fname = detec.get("filenames")
221        return from_file(fname, title, grayscale)
222    else:
223        raise nest.kernel.NESTError("Provided devices neither record to \
224            ascii file, nor to memory.")
225
226
227def _from_memory(detec):
228    """Get voltage traces from memory.
229    ----------
230    detec : list
231        Global id of voltmeter or multimeter
232    """
233    import array
234
235    ev = detec.get('events')
236    potentials = ev['V_m']
237    senders = ev['senders']
238
239    v = {}
240    t = {}
241
242    if 'times' in ev:
243        times = ev['times']
244        for s, currentsender in enumerate(senders):
245            if currentsender not in v:
246                v[currentsender] = array.array('f')
247                t[currentsender] = array.array('f')
248
249            v[currentsender].append(float(potentials[s]))
250            t[currentsender].append(float(times[s]))
251    else:
252        # reconstruct the time vector, if not stored explicitly
253        origin = detec.get('origin')
254        start = detec.get('start')
255        interval = detec.get('interval')
256        senders_uniq = numpy.unique(senders)
257        num_intvls = len(senders) / len(senders_uniq)
258        times_s = origin + start + interval + \
259            interval * numpy.array(range(num_intvls))
260
261        for s, currentsender in enumerate(senders):
262            if currentsender not in v:
263                v[currentsender] = array.array('f')
264                t[currentsender] = times_s
265            v[currentsender].append(float(potentials[s]))
266
267    return t, v
268