1""" 2Defines classes and functions for managing recordings (spikes, membrane 3potential etc). 4 5These classes and functions are not part of the PyNN API, and are only for 6internal use. 7 8:copyright: Copyright 2006-2021 by the PyNN team, see AUTHORS. 9:license: CeCILL, see LICENSE for details. 10 11""" 12 13import logging 14import numpy as np 15import os 16from copy import copy 17from collections import defaultdict 18from warnings import warn 19from pyNN import errors 20import neo 21from datetime import datetime 22import quantities as pq 23 24logger = logging.getLogger("PyNN") 25 26MPI_ROOT = 0 27 28 29def get_mpi_comm(): 30 try: 31 from mpi4py import MPI 32 except ImportError: 33 raise Exception( 34 "Trying to gather data without MPI installed. If you are not running a distributed simulation, this is a bug in PyNN.") 35 return MPI.COMM_WORLD, {'DOUBLE': MPI.DOUBLE, 'SUM': MPI.SUM} 36 37 38def rename_existing(filename): 39 if os.path.exists(filename): 40 os.system('mv %s %s_old' % (filename, filename)) 41 logger.warning("File %s already exists. Renaming the original file to %s_old" % 42 (filename, filename)) 43 44 45def gather_array(data): 46 # gather 1D or 2D numpy arrays 47 mpi_comm, mpi_flags = get_mpi_comm() 48 assert isinstance(data, np.ndarray) 49 assert len(data.shape) < 3 50 # first we pass the data size 51 size = data.size 52 sizes = mpi_comm.gather(size, root=MPI_ROOT) or [] 53 # now we pass the data 54 displacements = [sum(sizes[:i]) for i in range(len(sizes))] 55 gdata = np.empty(sum(sizes)) 56 mpi_comm.Gatherv([data.flatten(), size, mpi_flags['DOUBLE']], 57 [gdata, (sizes, displacements), mpi_flags['DOUBLE']], 58 root=MPI_ROOT) 59 if len(data.shape) == 1: 60 return gdata 61 else: 62 num_columns = data.shape[1] 63 return gdata.reshape((gdata.size / num_columns, num_columns)) 64 65 66def gather_dict(D, all=False): 67 # Note that if the same key exists on multiple nodes, the value from the 68 # node with the highest rank will appear in the final dict. 69 mpi_comm, mpi_flags = get_mpi_comm() 70 if all: 71 Ds = mpi_comm.allgather(D) 72 else: 73 Ds = mpi_comm.gather(D, root=MPI_ROOT) 74 if Ds: 75 for otherD in Ds: 76 D.update(otherD) 77 return D 78 79 80def gather_blocks(data, ordered=True): 81 """Gather Neo Blocks""" 82 mpi_comm, mpi_flags = get_mpi_comm() 83 assert isinstance(data, neo.Block) 84 # for now, use gather_dict, which will probably be slow. Can optimize later 85 D = {mpi_comm.rank: data} 86 D = gather_dict(D) 87 blocks = list(D.values()) 88 merged = data 89 if mpi_comm.rank == MPI_ROOT: 90 merged = blocks[0] 91 # the following business with setting sig.segment is a workaround for a bug in Neo 92 for seg in merged.segments: 93 for sig in seg.analogsignals: 94 sig.segment = seg 95 for block in blocks[1:]: 96 for seg, mseg in zip(block.segments, merged.segments): 97 for sig in seg.analogsignals: 98 sig.segment = mseg 99 merged.merge(block) 100 if ordered: 101 for segment in merged.segments: 102 ordered_spiketrains = sorted( 103 segment.spiketrains, key=lambda s: s.annotations['source_id']) 104 segment.spiketrains = ordered_spiketrains 105 return merged 106 107 108def mpi_sum(x): 109 mpi_comm, mpi_flags = get_mpi_comm() 110 if mpi_comm.size > 1: 111 return mpi_comm.allreduce(x, op=mpi_flags['SUM']) 112 else: 113 return x 114 115 116def normalize_variables_arg(variables): 117 """If variables is a single string, encapsulate it in a list.""" 118 if isinstance(variables, str) and variables != 'all': 119 return [variables] 120 else: 121 return variables 122 123 124def safe_makedirs(dir): 125 """ 126 Version of makedirs not subject to race condition when using MPI. 127 """ 128 if dir and not os.path.exists(dir): 129 try: 130 os.makedirs(dir) 131 except OSError as e: 132 if e.errno != 17: 133 raise 134 135 136def get_io(filename): 137 """ 138 Return a Neo IO instance, guessing the type based on the filename suffix. 139 """ 140 logger.debug("Creating Neo IO for filename %s" % filename) 141 dir = os.path.dirname(filename) 142 safe_makedirs(dir) 143 extension = os.path.splitext(filename)[1] 144 if extension in ('.txt', '.ras', '.v', '.gsyn'): 145 raise IOError( 146 "ASCII-based formats are not currently supported for output data. Try using the file extension '.pkl' or '.h5'") 147 elif extension in ('.h5',): 148 return neo.io.NeoHdf5IO(filename=filename) 149 elif extension in ('.pkl', '.pickle'): 150 return neo.io.PickleIO(filename=filename) 151 elif extension == '.mat': 152 return neo.io.NeoMatlabIO(filename=filename) 153 else: # function to be improved later 154 raise Exception("file extension %s not supported" % extension) 155 156 157def filter_by_variables(segment, variables): 158 """ 159 Return a new `Segment` containing only recordings of the variables given in 160 the list `variables` 161 """ 162 if variables == 'all': 163 return segment 164 else: 165 new_segment = copy(segment) # shallow copy 166 if 'spikes' not in variables: 167 new_segment.spiketrains = [] 168 new_segment.analogsignals = [sig for sig in segment.analogsignals if sig.name in variables] 169 # also need to handle Units, RecordingChannels 170 return new_segment 171 172 173def remove_duplicate_spiketrains(data): 174 for segment in data.segments: 175 spiketrains = {} 176 for spiketrain in segment.spiketrains: 177 index = spiketrain.annotations["source_index"] 178 spiketrains[index] = spiketrain 179 min_index = min(spiketrains.keys()) 180 max_index = max(spiketrains.keys()) 181 segment.spiketrains = [spiketrains[i] for i in range(min_index, max_index + 1)] 182 return data 183 184 185class DataCache(object): 186 # primitive implementation for now, storing in memory - later can consider caching to disk 187 188 def __init__(self): 189 self._data = [] 190 191 def __iter__(self): 192 return iter(self._data) 193 194 def store(self, obj): 195 if obj not in self._data: 196 logger.debug("Adding %s to cache" % obj) 197 self._data.append(obj) 198 199 def clear(self): 200 self._data = [] 201 202 203class Recorder(object): 204 """Encapsulates data and functions related to recording model variables.""" 205 206 def __init__(self, population, file=None): 207 """ 208 Create a recorder. 209 210 `population` -- the Population instance which is being recorded by the 211 recorder 212 `file` -- one of: 213 - a file-name, 214 - `None` (write to a temporary file) 215 - `False` (write to memory). 216 """ 217 self.file = file 218 self.population = population # needed for writing header information 219 self.recorded = defaultdict(set) 220 self.cache = DataCache() 221 self._simulator.state.recorders.add(self) 222 self.clear_flag = False 223 self._recording_start_time = self._simulator.state.t * pq.ms 224 self.sampling_interval = self._simulator.state.dt 225 226 def record(self, variables, ids, sampling_interval=None): 227 """ 228 Add the cells in `ids` to the sets of recorded cells for the given variables. 229 """ 230 logger.debug('Recorder.record(<%d cells>)' % len(ids)) 231 self._check_sampling_interval(sampling_interval) 232 233 ids = set([id for id in ids if id.local]) 234 for variable in normalize_variables_arg(variables): 235 if not self.population.can_record(variable): 236 raise errors.RecordingError(variable, self.population.celltype) 237 new_ids = ids.difference(self.recorded[variable]) 238 self.recorded[variable] = self.recorded[variable].union(ids) 239 self._record(variable, new_ids, sampling_interval) 240 241 def _check_sampling_interval(self, sampling_interval): 242 """ 243 Check whether record() has been called previously with a different sampling interval 244 (we exclude recording of spikes, as the sampling interval does not apply in that case) 245 """ 246 if sampling_interval is not None and sampling_interval != self.sampling_interval: 247 recorded_variables = list(self.recorded.keys()) 248 if "spikes" in recorded_variables: 249 recorded_variables.remove("spikes") 250 if len(recorded_variables) > 0: 251 raise ValueError( 252 "All neurons in a population must be recorded with the same sampling interval.") 253 254 def reset(self): 255 """Reset the list of things to be recorded.""" 256 self._reset() 257 self.recorded = defaultdict(set) 258 259 def filter_recorded(self, variable, filter_ids): 260 if filter_ids is not None: 261 return set(filter_ids).intersection(self.recorded[variable]) 262 else: 263 return self.recorded[variable] 264 265 def _get_current_segment(self, filter_ids=None, variables='all', clear=False): 266 segment = neo.Segment(name="segment%03d" % self._simulator.state.segment_counter, 267 description=self.population.describe(), 268 rec_datetime=datetime.now()) # would be nice to get the time at the start of the recording, not the end 269 variables_to_include = set(self.recorded.keys()) 270 if variables != 'all': 271 variables_to_include = variables_to_include.intersection(set(variables)) 272 for variable in variables_to_include: 273 if variable == 'spikes': 274 t_stop = self._simulator.state.t * pq.ms # must run on all MPI nodes 275 sids = sorted(self.filter_recorded('spikes', filter_ids)) 276 data = self._get_spiketimes(sids, clear=clear) 277 278 segment.spiketrains = [] 279 for id in sids: 280 times = pq.Quantity(data.get(int(id), []), pq.ms) 281 if times.size > 0 and times.max() > t_stop: 282 warn("Recorded at least one spike after t_stop") 283 times = times[times <= t_stop] 284 segment.spiketrains.append( 285 neo.SpikeTrain(times, 286 t_start=self._recording_start_time, 287 t_stop=t_stop, 288 units='ms', 289 source_population=self.population.label, 290 source_id=int(id), source_index=self.population.id_to_index(int(id))) 291 ) 292 else: 293 ids = sorted(self.filter_recorded(variable, filter_ids)) 294 signal_array = self._get_all_signals(variable, ids, clear=clear) 295 t_start = self._recording_start_time 296 t_stop = self._simulator.state.t * pq.ms 297 sampling_period = self.sampling_interval * pq.ms 298 current_time = self._simulator.state.t * pq.ms 299 mpi_node = self._simulator.state.mpi_rank # for debugging 300 if signal_array.size > 0: # may be empty if none of the recorded cells are on this MPI node 301 units = self.population.find_units(variable) 302 source_ids = np.fromiter(ids, dtype=int) 303 signal = neo.AnalogSignal( 304 signal_array, 305 units=units, 306 t_start=t_start, 307 sampling_period=sampling_period, 308 name=variable, 309 source_population=self.population.label, 310 source_ids=source_ids, 311 array_annotations={"channel_index": np.array([self.population.id_to_index(id) for id in ids])}) 312 segment.analogsignals.append(signal) 313 logger.debug("%d **** ids=%s, channels=%s", mpi_node, 314 source_ids, signal.array_annotations["channel_index"]) 315 assert segment.analogsignals[0].t_stop - \ 316 current_time - 2 * sampling_period < 1e-10 317 return segment 318 319 def get(self, variables, gather=False, filter_ids=None, clear=False, 320 annotations=None): 321 """Return the recorded data as a Neo `Block`.""" 322 variables = normalize_variables_arg(variables) 323 data = neo.Block() 324 data.segments = [filter_by_variables(segment, variables) 325 for segment in self.cache] 326 if self._simulator.state.running: # reset() has not been called, so current segment is not in cache 327 data.segments.append(self._get_current_segment( 328 filter_ids=filter_ids, variables=variables, clear=clear)) 329 data.name = self.population.label 330 data.description = self.population.describe() 331 data.rec_datetime = data.segments[0].rec_datetime 332 data.annotate(**self.metadata) 333 if annotations: 334 data.annotate(**annotations) 335 if gather and self._simulator.state.num_processes > 1: 336 data = gather_blocks(data) 337 if hasattr(self.population.celltype, "always_local") and self.population.celltype.always_local: 338 data = remove_duplicate_spiketrains(data) 339 if clear: 340 self.clear() 341 return data 342 343 def clear(self): 344 """ 345 Clear all recorded data, both from the cache and the simulator. 346 """ 347 self.cache.clear() 348 self.clear_flag = True 349 self._recording_start_time = self._simulator.state.t * pq.ms 350 self._clear_simulator() 351 352 def write(self, variables, file=None, gather=False, filter_ids=None, 353 clear=False, annotations=None): 354 """Write recorded data to a Neo IO""" 355 if isinstance(file, str): 356 file = get_io(file) 357 io = file or self.file 358 if gather is False and self._simulator.state.num_processes > 1: 359 io.filename += '.%d' % self._simulator.state.mpi_rank 360 logger.debug("Recorder is writing '%s' to file '%s' with gather=%s" % ( 361 variables, io.filename, gather)) 362 data = self.get(variables, gather, filter_ids, clear, annotations=annotations) 363 if self._simulator.state.mpi_rank == 0 or gather is False: 364 # Open the output file, if necessary and write the data 365 logger.debug("Writing data to file %s" % io) 366 io.write_block(data) 367 368 @property 369 def metadata(self): 370 metadata = { 371 'size': self.population.size, 372 'first_index': 0, 373 'last_index': len(self.population), 374 'first_id': int(self.population.first_id), 375 'last_id': int(self.population.last_id), 376 'label': self.population.label, 377 'simulator': self._simulator.name, 378 } 379 metadata.update(self.population.annotations) 380 # note that this has to run on all nodes (at least for NEST) 381 metadata['dt'] = self._simulator.state.dt 382 metadata['mpi_processes'] = self._simulator.state.num_processes 383 return metadata 384 385 def count(self, variable, gather=True, filter_ids=None): 386 """ 387 Return the number of data points for each cell, as a dict. This is mainly 388 useful for spike counts or for variable-time-step integration methods. 389 """ 390 if variable == 'spikes': 391 N = self._local_count(variable, filter_ids) 392 else: 393 raise Exception("Only implemented for spikes.") 394 if gather and self._simulator.state.num_processes > 1: 395 N = gather_dict(N) 396 return N 397 398 def store_to_cache(self, annotations=None): 399 # make sure we haven't called get with clear=True since last reset 400 # and that we did not do two resets in a row 401 if (self._simulator.state.t != 0) and (not self.clear_flag): 402 if annotations is None: 403 annotations = {} 404 segment = self._get_current_segment() 405 segment.annotate(**annotations) 406 self.cache.store(segment) 407 self.clear_flag = False 408 self._recording_start_time = 0.0 * pq.ms 409