1""" 2 3 4""" 5 6import sys 7import inspect 8from itertools import chain 9from neo.io import get_io 10from pyNN.common import Population, PopulationView, Projection, Assembly 11 12 13class Network(object): 14 """ 15 docstring 16 """ 17 18 def __init__(self, *components): 19 self._populations = set([]) 20 self._views = set([]) 21 self._assemblies = set([]) 22 self._projections = set([]) 23 self.add(*components) 24 25 @property 26 def populations(self): 27 return frozenset(self._populations) 28 29 @property 30 def views(self): 31 return frozenset(self._views) 32 33 @property 34 def assemblies(self): 35 return frozenset(self._assemblies) 36 37 @property 38 def projections(self): 39 return frozenset(self._projections) 40 41 @property 42 def sim(self): 43 """Figure out which PyNN backend module this Network is using.""" 44 # we assume there is only one. Could be mixed if using multiple simulators 45 # at once. 46 populations_module = inspect.getmodule(list(self.populations)[0].__class__) 47 return sys.modules[".".join(populations_module.__name__.split(".")[:-1])] 48 49 def count_neurons(self): 50 return sum(population.size for population in chain(self.populations)) 51 52 def count_connections(self): 53 return sum(projection.size() for projection in chain(self.projections)) 54 55 def add(self, *components): 56 for component in components: 57 if isinstance(component, Population): 58 self._populations.add(component) 59 elif isinstance(component, PopulationView): 60 self._views.add(component) 61 self._populations.add(component.parent) 62 elif isinstance(component, Assembly): 63 self._assemblies.add(component) 64 self._populations.update(component.populations) 65 elif isinstance(component, Projection): 66 self._projections.add(component) 67 # todo: check that pre and post populations/views/assemblies have been added 68 else: 69 raise TypeError() 70 71 def get_component(self, label): 72 for obj in chain(self.populations, self.views, self.assemblies, self.projections): 73 if obj.label == label: 74 return obj 75 return None 76 77 def filter(self, cell_types=None): 78 """Return an Assembly of all components that have a cell type in the list""" 79 if cell_types is None: 80 raise NotImplementedError() 81 else: 82 if cell_types == "all": 83 return self.sim.Assembly(*(pop for pop in self.populations 84 if pop.celltype.injectable)) # or could use len(receptor_types) > 0 85 else: 86 return self.sim.Assembly(*(pop for pop in self.populations 87 if pop.celltype.__class__ in cell_types)) 88 89 def record(self, variables, to_file=None, sampling_interval=None, include_spike_source=True): 90 for obj in chain(self.populations, self.assemblies): 91 if include_spike_source or obj.injectable: # spike sources are not injectable 92 obj.record(variables, to_file=to_file, sampling_interval=sampling_interval) 93 94 def get_data(self, variables='all', gather=True, clear=False, annotations=None): 95 return [assembly.get_data(variables, gather, clear, annotations) 96 for assembly in self.assemblies] 97 98 def write_data(self, io, variables='all', gather=True, clear=False, annotations=None): 99 if isinstance(io, str): 100 io = get_io(io) 101 data = self.get_data(variables, gather, clear, annotations) 102 # if self._simulator.state.mpi_rank == 0 or gather is False: 103 if True: # tmp. Need to handle MPI 104 io.write(data) 105