1"""
2Defines classes and functions for managing recordings (spikes, membrane
3potential etc).
4
5These classes and functions are not part of the PyNN API, and are only for
6internal use.
7
8:copyright: Copyright 2006-2021 by the PyNN team, see AUTHORS.
9:license: CeCILL, see LICENSE for details.
10
11"""
12
13import logging
14import numpy as np
15import os
16from copy import copy
17from collections import defaultdict
18from warnings import warn
19from pyNN import errors
20import neo
21from datetime import datetime
22import quantities as pq
23
24logger = logging.getLogger("PyNN")
25
26MPI_ROOT = 0
27
28
29def get_mpi_comm():
30    try:
31        from mpi4py import MPI
32    except ImportError:
33        raise Exception(
34            "Trying to gather data without MPI installed. If you are not running a distributed simulation, this is a bug in PyNN.")
35    return MPI.COMM_WORLD, {'DOUBLE': MPI.DOUBLE, 'SUM': MPI.SUM}
36
37
38def rename_existing(filename):
39    if os.path.exists(filename):
40        os.system('mv %s %s_old' % (filename, filename))
41        logger.warning("File %s already exists. Renaming the original file to %s_old" %
42                       (filename, filename))
43
44
45def gather_array(data):
46    # gather 1D or 2D numpy arrays
47    mpi_comm, mpi_flags = get_mpi_comm()
48    assert isinstance(data, np.ndarray)
49    assert len(data.shape) < 3
50    # first we pass the data size
51    size = data.size
52    sizes = mpi_comm.gather(size, root=MPI_ROOT) or []
53    # now we pass the data
54    displacements = [sum(sizes[:i]) for i in range(len(sizes))]
55    gdata = np.empty(sum(sizes))
56    mpi_comm.Gatherv([data.flatten(), size, mpi_flags['DOUBLE']],
57                     [gdata, (sizes, displacements), mpi_flags['DOUBLE']],
58                     root=MPI_ROOT)
59    if len(data.shape) == 1:
60        return gdata
61    else:
62        num_columns = data.shape[1]
63        return gdata.reshape((gdata.size / num_columns, num_columns))
64
65
66def gather_dict(D, all=False):
67    # Note that if the same key exists on multiple nodes, the value from the
68    # node with the highest rank will appear in the final dict.
69    mpi_comm, mpi_flags = get_mpi_comm()
70    if all:
71        Ds = mpi_comm.allgather(D)
72    else:
73        Ds = mpi_comm.gather(D, root=MPI_ROOT)
74    if Ds:
75        for otherD in Ds:
76            D.update(otherD)
77    return D
78
79
80def gather_blocks(data, ordered=True):
81    """Gather Neo Blocks"""
82    mpi_comm, mpi_flags = get_mpi_comm()
83    assert isinstance(data, neo.Block)
84    # for now, use gather_dict, which will probably be slow. Can optimize later
85    D = {mpi_comm.rank: data}
86    D = gather_dict(D)
87    blocks = list(D.values())
88    merged = data
89    if mpi_comm.rank == MPI_ROOT:
90        merged = blocks[0]
91        # the following business with setting sig.segment is a workaround for a bug in Neo
92        for seg in merged.segments:
93            for sig in seg.analogsignals:
94                sig.segment = seg
95        for block in blocks[1:]:
96            for seg, mseg in zip(block.segments, merged.segments):
97                for sig in seg.analogsignals:
98                    sig.segment = mseg
99            merged.merge(block)
100    if ordered:
101        for segment in merged.segments:
102            ordered_spiketrains = sorted(
103                segment.spiketrains, key=lambda s: s.annotations['source_id'])
104            segment.spiketrains = ordered_spiketrains
105    return merged
106
107
108def mpi_sum(x):
109    mpi_comm, mpi_flags = get_mpi_comm()
110    if mpi_comm.size > 1:
111        return mpi_comm.allreduce(x, op=mpi_flags['SUM'])
112    else:
113        return x
114
115
116def normalize_variables_arg(variables):
117    """If variables is a single string, encapsulate it in a list."""
118    if isinstance(variables, str) and variables != 'all':
119        return [variables]
120    else:
121        return variables
122
123
124def safe_makedirs(dir):
125    """
126    Version of makedirs not subject to race condition when using MPI.
127    """
128    if dir and not os.path.exists(dir):
129        try:
130            os.makedirs(dir)
131        except OSError as e:
132            if e.errno != 17:
133                raise
134
135
136def get_io(filename):
137    """
138    Return a Neo IO instance, guessing the type based on the filename suffix.
139    """
140    logger.debug("Creating Neo IO for filename %s" % filename)
141    dir = os.path.dirname(filename)
142    safe_makedirs(dir)
143    extension = os.path.splitext(filename)[1]
144    if extension in ('.txt', '.ras', '.v', '.gsyn'):
145        raise IOError(
146            "ASCII-based formats are not currently supported for output data. Try using the file extension '.pkl' or '.h5'")
147    elif extension in ('.h5',):
148        return neo.io.NeoHdf5IO(filename=filename)
149    elif extension in ('.pkl', '.pickle'):
150        return neo.io.PickleIO(filename=filename)
151    elif extension == '.mat':
152        return neo.io.NeoMatlabIO(filename=filename)
153    else:  # function to be improved later
154        raise Exception("file extension %s not supported" % extension)
155
156
157def filter_by_variables(segment, variables):
158    """
159    Return a new `Segment` containing only recordings of the variables given in
160    the list `variables`
161    """
162    if variables == 'all':
163        return segment
164    else:
165        new_segment = copy(segment)  # shallow copy
166        if 'spikes' not in variables:
167            new_segment.spiketrains = []
168        new_segment.analogsignals = [sig for sig in segment.analogsignals if sig.name in variables]
169        # also need to handle Units, RecordingChannels
170        return new_segment
171
172
173def remove_duplicate_spiketrains(data):
174    for segment in data.segments:
175        spiketrains = {}
176        for spiketrain in segment.spiketrains:
177            index = spiketrain.annotations["source_index"]
178            spiketrains[index] = spiketrain
179        min_index = min(spiketrains.keys())
180        max_index = max(spiketrains.keys())
181        segment.spiketrains = [spiketrains[i] for i in range(min_index, max_index + 1)]
182    return data
183
184
185class DataCache(object):
186    # primitive implementation for now, storing in memory - later can consider caching to disk
187
188    def __init__(self):
189        self._data = []
190
191    def __iter__(self):
192        return iter(self._data)
193
194    def store(self, obj):
195        if obj not in self._data:
196            logger.debug("Adding %s to cache" % obj)
197            self._data.append(obj)
198
199    def clear(self):
200        self._data = []
201
202
203class Recorder(object):
204    """Encapsulates data and functions related to recording model variables."""
205
206    def __init__(self, population, file=None):
207        """
208        Create a recorder.
209
210        `population` -- the Population instance which is being recorded by the
211                        recorder
212        `file` -- one of:
213            - a file-name,
214            - `None` (write to a temporary file)
215            - `False` (write to memory).
216        """
217        self.file = file
218        self.population = population  # needed for writing header information
219        self.recorded = defaultdict(set)
220        self.cache = DataCache()
221        self._simulator.state.recorders.add(self)
222        self.clear_flag = False
223        self._recording_start_time = self._simulator.state.t * pq.ms
224        self.sampling_interval = self._simulator.state.dt
225
226    def record(self, variables, ids, sampling_interval=None):
227        """
228        Add the cells in `ids` to the sets of recorded cells for the given variables.
229        """
230        logger.debug('Recorder.record(<%d cells>)' % len(ids))
231        self._check_sampling_interval(sampling_interval)
232
233        ids = set([id for id in ids if id.local])
234        for variable in normalize_variables_arg(variables):
235            if not self.population.can_record(variable):
236                raise errors.RecordingError(variable, self.population.celltype)
237            new_ids = ids.difference(self.recorded[variable])
238            self.recorded[variable] = self.recorded[variable].union(ids)
239            self._record(variable, new_ids, sampling_interval)
240
241    def _check_sampling_interval(self, sampling_interval):
242        """
243        Check whether record() has been called previously with a different sampling interval
244        (we exclude recording of spikes, as the sampling interval does not apply in that case)
245        """
246        if sampling_interval is not None and sampling_interval != self.sampling_interval:
247            recorded_variables = list(self.recorded.keys())
248            if "spikes" in recorded_variables:
249                recorded_variables.remove("spikes")
250            if len(recorded_variables) > 0:
251                raise ValueError(
252                    "All neurons in a population must be recorded with the same sampling interval.")
253
254    def reset(self):
255        """Reset the list of things to be recorded."""
256        self._reset()
257        self.recorded = defaultdict(set)
258
259    def filter_recorded(self, variable, filter_ids):
260        if filter_ids is not None:
261            return set(filter_ids).intersection(self.recorded[variable])
262        else:
263            return self.recorded[variable]
264
265    def _get_current_segment(self, filter_ids=None, variables='all', clear=False):
266        segment = neo.Segment(name="segment%03d" % self._simulator.state.segment_counter,
267                              description=self.population.describe(),
268                              rec_datetime=datetime.now())  # would be nice to get the time at the start of the recording, not the end
269        variables_to_include = set(self.recorded.keys())
270        if variables != 'all':
271            variables_to_include = variables_to_include.intersection(set(variables))
272        for variable in variables_to_include:
273            if variable == 'spikes':
274                t_stop = self._simulator.state.t * pq.ms  # must run on all MPI nodes
275                sids = sorted(self.filter_recorded('spikes', filter_ids))
276                data = self._get_spiketimes(sids, clear=clear)
277
278                segment.spiketrains = []
279                for id in sids:
280                    times = pq.Quantity(data.get(int(id), []), pq.ms)
281                    if times.size > 0 and times.max() > t_stop:
282                        warn("Recorded at least one spike after t_stop")
283                        times = times[times <= t_stop]
284                    segment.spiketrains.append(
285                        neo.SpikeTrain(times,
286                                       t_start=self._recording_start_time,
287                                       t_stop=t_stop,
288                                       units='ms',
289                                       source_population=self.population.label,
290                                       source_id=int(id), source_index=self.population.id_to_index(int(id)))
291                    )
292            else:
293                ids = sorted(self.filter_recorded(variable, filter_ids))
294                signal_array = self._get_all_signals(variable, ids, clear=clear)
295                t_start = self._recording_start_time
296                t_stop = self._simulator.state.t * pq.ms
297                sampling_period = self.sampling_interval * pq.ms
298                current_time = self._simulator.state.t * pq.ms
299                mpi_node = self._simulator.state.mpi_rank  # for debugging
300                if signal_array.size > 0:  # may be empty if none of the recorded cells are on this MPI node
301                    units = self.population.find_units(variable)
302                    source_ids = np.fromiter(ids, dtype=int)
303                    signal = neo.AnalogSignal(
304                        signal_array,
305                        units=units,
306                        t_start=t_start,
307                        sampling_period=sampling_period,
308                        name=variable,
309                        source_population=self.population.label,
310                        source_ids=source_ids,
311                        array_annotations={"channel_index": np.array([self.population.id_to_index(id) for id in ids])})
312                    segment.analogsignals.append(signal)
313                    logger.debug("%d **** ids=%s, channels=%s", mpi_node,
314                                 source_ids, signal.array_annotations["channel_index"])
315                    assert segment.analogsignals[0].t_stop - \
316                        current_time - 2 * sampling_period < 1e-10
317        return segment
318
319    def get(self, variables, gather=False, filter_ids=None, clear=False,
320            annotations=None):
321        """Return the recorded data as a Neo `Block`."""
322        variables = normalize_variables_arg(variables)
323        data = neo.Block()
324        data.segments = [filter_by_variables(segment, variables)
325                         for segment in self.cache]
326        if self._simulator.state.running:  # reset() has not been called, so current segment is not in cache
327            data.segments.append(self._get_current_segment(
328                filter_ids=filter_ids, variables=variables, clear=clear))
329        data.name = self.population.label
330        data.description = self.population.describe()
331        data.rec_datetime = data.segments[0].rec_datetime
332        data.annotate(**self.metadata)
333        if annotations:
334            data.annotate(**annotations)
335        if gather and self._simulator.state.num_processes > 1:
336            data = gather_blocks(data)
337            if hasattr(self.population.celltype, "always_local") and self.population.celltype.always_local:
338                data = remove_duplicate_spiketrains(data)
339        if clear:
340            self.clear()
341        return data
342
343    def clear(self):
344        """
345        Clear all recorded data, both from the cache and the simulator.
346        """
347        self.cache.clear()
348        self.clear_flag = True
349        self._recording_start_time = self._simulator.state.t * pq.ms
350        self._clear_simulator()
351
352    def write(self, variables, file=None, gather=False, filter_ids=None,
353              clear=False, annotations=None):
354        """Write recorded data to a Neo IO"""
355        if isinstance(file, str):
356            file = get_io(file)
357        io = file or self.file
358        if gather is False and self._simulator.state.num_processes > 1:
359            io.filename += '.%d' % self._simulator.state.mpi_rank
360        logger.debug("Recorder is writing '%s' to file '%s' with gather=%s" % (
361            variables, io.filename, gather))
362        data = self.get(variables, gather, filter_ids, clear, annotations=annotations)
363        if self._simulator.state.mpi_rank == 0 or gather is False:
364            # Open the output file, if necessary and write the data
365            logger.debug("Writing data to file %s" % io)
366            io.write_block(data)
367
368    @property
369    def metadata(self):
370        metadata = {
371            'size': self.population.size,
372            'first_index': 0,
373            'last_index': len(self.population),
374            'first_id': int(self.population.first_id),
375            'last_id': int(self.population.last_id),
376            'label': self.population.label,
377            'simulator': self._simulator.name,
378        }
379        metadata.update(self.population.annotations)
380        # note that this has to run on all nodes (at least for NEST)
381        metadata['dt'] = self._simulator.state.dt
382        metadata['mpi_processes'] = self._simulator.state.num_processes
383        return metadata
384
385    def count(self, variable, gather=True, filter_ids=None):
386        """
387        Return the number of data points for each cell, as a dict. This is mainly
388        useful for spike counts or for variable-time-step integration methods.
389        """
390        if variable == 'spikes':
391            N = self._local_count(variable, filter_ids)
392        else:
393            raise Exception("Only implemented for spikes.")
394        if gather and self._simulator.state.num_processes > 1:
395            N = gather_dict(N)
396        return N
397
398    def store_to_cache(self, annotations=None):
399        # make sure we haven't called get with clear=True since last reset
400        # and that we did not do two resets in a row
401        if (self._simulator.state.t != 0) and (not self.clear_flag):
402            if annotations is None:
403                annotations = {}
404            segment = self._get_current_segment()
405            segment.annotate(**annotations)
406            self.cache.store(segment)
407        self.clear_flag = False
408        self._recording_start_time = 0.0 * pq.ms
409