1"""
2
3
4"""
5
6import sys
7import inspect
8from itertools import chain
9from neo.io import get_io
10from pyNN.common import Population, PopulationView, Projection, Assembly
11
12
13class Network(object):
14    """
15    docstring
16    """
17
18    def __init__(self, *components):
19        self._populations = set([])
20        self._views = set([])
21        self._assemblies = set([])
22        self._projections = set([])
23        self.add(*components)
24
25    @property
26    def populations(self):
27        return frozenset(self._populations)
28
29    @property
30    def views(self):
31        return frozenset(self._views)
32
33    @property
34    def assemblies(self):
35        return frozenset(self._assemblies)
36
37    @property
38    def projections(self):
39        return frozenset(self._projections)
40
41    @property
42    def sim(self):
43        """Figure out which PyNN backend module this Network is using."""
44        # we assume there is only one. Could be mixed if using multiple simulators
45        # at once.
46        populations_module = inspect.getmodule(list(self.populations)[0].__class__)
47        return sys.modules[".".join(populations_module.__name__.split(".")[:-1])]
48
49    def count_neurons(self):
50        return sum(population.size for population in chain(self.populations))
51
52    def count_connections(self):
53        return sum(projection.size() for projection in chain(self.projections))
54
55    def add(self, *components):
56        for component in components:
57            if isinstance(component, Population):
58                self._populations.add(component)
59            elif isinstance(component, PopulationView):
60                self._views.add(component)
61                self._populations.add(component.parent)
62            elif isinstance(component, Assembly):
63                self._assemblies.add(component)
64                self._populations.update(component.populations)
65            elif isinstance(component, Projection):
66                self._projections.add(component)
67                # todo: check that pre and post populations/views/assemblies have been added
68            else:
69                raise TypeError()
70
71    def get_component(self, label):
72        for obj in chain(self.populations, self.views, self.assemblies, self.projections):
73            if obj.label == label:
74                return obj
75        return None
76
77    def filter(self, cell_types=None):
78        """Return an Assembly of all components that have a cell type in the list"""
79        if cell_types is None:
80            raise NotImplementedError()
81        else:
82            if cell_types == "all":
83                return self.sim.Assembly(*(pop for pop in self.populations
84                                           if pop.celltype.injectable))  # or could use len(receptor_types) > 0
85            else:
86                return self.sim.Assembly(*(pop for pop in self.populations
87                                           if pop.celltype.__class__ in cell_types))
88
89    def record(self, variables, to_file=None, sampling_interval=None, include_spike_source=True):
90        for obj in chain(self.populations, self.assemblies):
91            if include_spike_source or obj.injectable:  # spike sources are not injectable
92                obj.record(variables, to_file=to_file, sampling_interval=sampling_interval)
93
94    def get_data(self, variables='all', gather=True, clear=False, annotations=None):
95        return [assembly.get_data(variables, gather, clear, annotations)
96                for assembly in self.assemblies]
97
98    def write_data(self, io, variables='all', gather=True, clear=False, annotations=None):
99        if isinstance(io, str):
100            io = get_io(io)
101        data = self.get_data(variables, gather, clear, annotations)
102        # if self._simulator.state.mpi_rank == 0 or gather is False:
103        if True:  # tmp. Need to handle MPI
104            io.write(data)
105