1# -*- coding: utf-8 -*-
2#
3# hl_api_spatial.py
4#
5# This file is part of NEST.
6#
7# Copyright (C) 2004 The NEST Initiative
8#
9# NEST is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 2 of the License, or
12# (at your option) any later version.
13#
14# NEST is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with NEST.  If not, see <http://www.gnu.org/licenses/>.
21
22"""
23Functions relating to spatial properties of nodes
24"""
25
26
27import numpy as np
28
29from ..ll_api import *
30from .. import pynestkernel as kernel
31from .hl_api_helper import *
32from .hl_api_connections import GetConnections
33from .hl_api_parallel_computing import NumProcesses, Rank
34from .hl_api_types import NodeCollection
35
36try:
37    import matplotlib as mpl
38    import matplotlib.path as mpath
39    import matplotlib.patches as mpatches
40    HAVE_MPL = True
41except ImportError:
42    HAVE_MPL = False
43
44__all__ = [
45    'CreateMask',
46    'Displacement',
47    'Distance',
48    'DumpLayerConnections',
49    'DumpLayerNodes',
50    'FindCenterElement',
51    'FindNearestElement',
52    'GetPosition',
53    'GetTargetNodes',
54    'GetTargetPositions',
55    'PlotLayer',
56    'PlotProbabilityParameter',
57    'PlotTargets',
58    'SelectNodesByMask',
59]
60
61
62def CreateMask(masktype, specs, anchor=None):
63    """
64    Create a spatial mask for connections.
65
66    Masks are used when creating connections. A mask describes the area of
67    the pool population that is searched for to connect for any given
68    node in the driver population. Several mask types are available. Examples
69    are the grid region, the rectangular, circular or doughnut region.
70
71    The command :py:func:`.CreateMask` creates a `Mask` object which may be combined
72    with other `Mask` objects using Boolean operators. The mask is specified
73    in a dictionary.
74
75    ``Mask`` objects can be passed to :py:func:`.Connect` in a connection dictionary with the key `'mask'`.
76
77    Parameters
78    ----------
79    masktype : str, ['rectangular' | 'circular' | 'doughnut' | 'elliptical']
80        for 2D masks, ['box' | 'spherical' | 'ellipsoidal] for 3D masks,
81        ['grid'] only for grid-based layers in 2D.
82        The mask name corresponds to the geometrical shape of the mask. There
83        are different types for 2- and 3-dimensional layers.
84    specs : dict
85        Dictionary specifying the parameters of the provided `masktype`,
86        see **Mask types**.
87    anchor : [tuple/list of floats | dict with the keys `'column'` and \
88        `'row'` (for grid masks only)], optional, default: None
89        By providing anchor coordinates, the location of the mask relative to
90        the driver node can be changed. The list of coordinates has a length
91        of 2 or 3 dependent on the number of dimensions.
92
93    Returns
94    -------
95    Mask:
96        Object representing the mask
97
98    See also
99    --------
100    Connect
101
102    Notes
103    -----
104    - All angles must be given in degrees.
105
106    **Mask types**
107
108    Available mask types (`masktype`) and their corresponding parameter
109    dictionaries:
110
111    * 2D free and grid-based layers
112        ::
113
114            'rectangular' :
115                {'lower_left'   : [float, float],
116                 'upper_right'  : [float, float],
117                 'azimuth_angle': float  # default:0.0}
118            #or
119            'circular' :
120                {'radius' : float}
121            #or
122            'doughnut' :
123                {'inner_radius' : float,
124                 'outer_radius' : float}
125            #or
126            'elliptical' :
127                {'major_axis' : float,
128                 'minor_axis' : float,
129                 'azimuth_angle' : float,   # default: 0.0,
130                 'anchor' : [float, float], # default: [0.0, 0.0]}
131
132    * 3D free and grid-based layers
133        ::
134
135            'box' :
136                {'lower_left'  : [float, float, float],
137                 'upper_right' : [float, float, float],
138                 'azimuth_angle: float  # default: 0.0,
139                 'polar_angle  : float  # defualt: 0.0}
140            #or
141            'spherical' :
142                {'radius' : float}
143            #or
144            'ellipsoidal' :
145                {'major_axis' : float,
146                 'minor_axis' : float,
147                 'polar_axis' : float
148                 'azimuth_angle' : float,   # default: 0.0,
149                 'polar_angle' : float,     # default: 0.0,
150                 'anchor' : [float, float, float], # default: [0.0, 0.0, 0.0]}}
151
152    * 2D grid-based layers only
153        ::
154
155            'grid' :
156                {'rows' : float,
157                 'columns' : float}
158
159        By default the top-left corner of a grid mask, i.e., the grid
160        mask element with grid index [0, 0], is aligned with the driver
161        node. It can be changed by means of the 'anchor' parameter:
162            ::
163
164                'anchor' :
165                    {'row' : float,
166                     'column' : float}
167
168    **Example**
169        ::
170
171            import nest
172
173            # create a grid-based layer
174            l = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5]))
175
176            # create a circular mask
177            m = nest.CreateMask('circular', {'radius': 0.2})
178
179            # connectivity specifications
180            conndict = {'rule': 'pairwise_bernoulli',
181                        'p': 1.0,
182                        'mask': m}
183
184            # connect layer l with itself according to the specifications
185            nest.Connect(l, l, conndict)
186    """
187    if anchor is None:
188        return sli_func('CreateMask', {masktype: specs})
189    else:
190        return sli_func('CreateMask',
191                        {masktype: specs, 'anchor': anchor})
192
193
194def GetPosition(nodes):
195    """
196    Return the spatial locations of nodes.
197
198    Parameters
199    ----------
200    nodes : NodeCollection
201        `NodeCollection` of nodes we want the positions to
202
203    Returns
204    -------
205    tuple or tuple of tuple(s):
206        Tuple of position with 2- or 3-elements or list of positions
207
208    See also
209    --------
210    Displacement: Get vector of lateral displacement between nodes.
211    Distance: Get lateral distance between nodes.
212    DumpLayerConnections: Write connectivity information to file.
213    DumpLayerNodes: Write node positions to file.
214
215    Notes
216    -----
217    - The functions :py:func:`.GetPosition`, :py:func:`.Displacement` and :py:func:`.Distance`
218      only works for nodes local to the current MPI process, if used in a
219      MPI-parallel simulation.
220
221    Example
222    -------
223        ::
224
225            import nest
226
227            # Reset kernel
228            nest.ResetKernel
229
230            # create a NodeCollection with spatial extent
231            s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5]))
232
233            # retrieve positions of all (local) nodes belonging to the population
234            pos = nest.GetPosition(s_nodes)
235
236            # retrieve positions of the first node in the NodeCollection
237            pos = nest.GetPosition(s_nodes[0])
238
239            # retrieve positions of a subset of nodes in the population
240            pos = nest.GetPosition(s_nodes[2:18])
241    """
242    if not isinstance(nodes, NodeCollection):
243        raise TypeError("nodes must be a NodeCollection with spatial extent")
244
245    return sli_func('GetPosition', nodes)
246
247
248def Displacement(from_arg, to_arg):
249    """
250    Get vector of lateral displacement from node(s)/Position(s) `from_arg`
251    to node(s) `to_arg`.
252
253    Displacement is the shortest displacement, taking into account
254    periodic boundary conditions where applicable. If explicit positions
255    are given in the `from_arg` list, they are interpreted in the `to_arg`
256    population.
257
258    - If one of `from_arg` or `to_arg` has length 1, and the other is longer,
259      the displacement from/to the single item to all other items is given.
260    - If `from_arg` and `to_arg` both have more than two elements, they have
261      to be of the same length and the displacement between each
262      pair is returned.
263
264    Parameters
265    ----------
266    from_arg : NodeCollection or tuple/list with tuple(s)/list(s) of floats
267        `NodeCollection` of node IDs or tuple/list of position(s)
268    to_arg : NodeCollection
269        `NodeCollection` of node IDs
270
271    Returns
272    -------
273    tuple:
274        Displacement vectors between pairs of nodes in `from_arg` and `to_arg`
275
276    See also
277    --------
278    Distance: Get lateral distances between nodes.
279    DumpLayerConnections: Write connectivity information to file.
280    GetPosition: Return the spatial locations of nodes.
281
282    Notes
283    -----
284    - The functions :py:func:`.GetPosition`, :py:func:`.Displacement` and :py:func:`.Distance`
285      only works for nodes local to the current MPI process, if used in a
286      MPI-parallel simulation.
287
288    **Example**
289        ::
290
291            import nest
292
293            # create a spatial population
294            s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5]))
295
296            # displacement between node 2 and 3
297            print(nest.Displacement(s_nodes[1], s_nodes[2]))
298
299            # displacment between the position (0.0., 0.0) and node 2
300            print(nest.Displacement([(0.0, 0.0)], s_nodes[1]))
301    """
302    if not isinstance(to_arg, NodeCollection):
303        raise TypeError("to_arg must be a NodeCollection")
304
305    if isinstance(from_arg, np.ndarray):
306        from_arg = (from_arg, )
307
308    if (len(from_arg) > 1 and len(to_arg) > 1 and not
309            len(from_arg) == len(to_arg)):
310        raise ValueError("to_arg and from_arg must have same size unless one have size 1.")
311
312    return sli_func('Displacement', from_arg, to_arg)
313
314
315def Distance(from_arg, to_arg):
316    """
317    Get lateral distances from node(s)/position(s) `from_arg` to node(s) `to_arg`.
318
319    The distance between two nodes is the length of its displacement.
320
321    If explicit positions are given in the `from_arg` list, they are
322    interpreted in the `to_arg` population. Distance is the shortest distance,
323    taking into account periodic boundary conditions where applicable.
324
325    - If one of `from_arg` or `to_arg` has length 1, and the other is longer,
326      the displacement from/to the single item to all other items is given.
327    - If `from_arg` and `to_arg` both have more than two elements, they have
328      to be of the same length and the distance for each pair is
329      returned.
330
331    Parameters
332    ----------
333    from_arg : NodeCollection or tuple/list with tuple(s)/list(s) of floats
334        `NodeCollection` of node IDs or tuple/list of position(s)
335    to_arg : NodeCollection
336        `NodeCollection` of node IDs
337
338    Returns
339    -------
340    tuple:
341        Distances between `from` and `to`
342
343    See also
344    --------
345    Displacement: Get vector of lateral displacements between nodes.
346    DumpLayerConnections: Write connectivity information to file.
347    GetPosition: Return the spatial locations of nodes.
348
349    Notes
350    -----
351    - The functions :py:func:`.GetPosition`, :py:func:`.Displacement` and :py:func:`.Distance`
352      only works for nodes local to the current MPI process, if used in a
353      MPI-parallel simulation.
354
355    Example
356    -------
357        ::
358
359            import nest
360
361            # create a spatial population
362            s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5]))
363
364            # distance between node 2 and 3
365            print(nest.Distance(s_nodes[1], s_nodes[2]))
366
367            # distance between the position (0.0., 0.0) and node 2
368            print(nest.Distance([(0.0, 0.0)], s_nodes[1]))
369    """
370    if not isinstance(to_arg, NodeCollection):
371        raise TypeError("to_arg must be a NodeCollection")
372
373    if isinstance(from_arg, np.ndarray):
374        from_arg = (from_arg, )
375
376    if (len(from_arg) > 1 and len(to_arg) > 1 and not
377            len(from_arg) == len(to_arg)):
378        raise ValueError("to_arg and from_arg must have same size unless one have size 1.")
379
380    return sli_func('Distance', from_arg, to_arg)
381
382
383def FindNearestElement(layer, locations, find_all=False):
384    """
385    Return the node(s) closest to the `locations` in the given `layer`.
386
387    This function works for fixed grid layer only.
388
389    * If `locations` is a single 2-element array giving a grid location, return a
390      `NodeCollection` of `layer` elements at the given location.
391    * If `locations` is a list of coordinates, the function returns a list of `NodeCollection` of the nodes at all
392      locations.
393
394    Parameters
395    ----------
396    layer : NodeCollection
397        `NodeCollection` of spatially distributed node IDs
398    locations : tuple(s)/list(s) of tuple(s)/list(s)
399        2-element list with coordinates of a single position, or list of
400        2-element list of positions
401    find_all : bool, default: False
402        If there are several nodes with same minimal distance, return only the
403        first found, if `False`.
404        If `True`, instead of returning a single `NodeCollection`, return a list of `NodeCollection`
405        containing all nodes with minimal distance.
406
407    Returns
408    -------
409    NodeCollection:
410        `NodeCollection` of node IDs if locations is a 2-element list with coordinates of a single position
411    list:
412        list of `NodeCollection` if find_all is True or locations contains more than one position
413
414    See also
415    --------
416    FindCenterElement: Return NodeCollection of node closest to center of layers.
417    GetPosition: Return the spatial locations of nodes.
418
419    Example
420    -------
421        ::
422
423            import nest
424
425            # create a spatial population
426            s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5]))
427
428            # get node ID of element closest to some location
429            nest.FindNearestElement(s_nodes, [3.0, 4.0], True)
430    """
431
432    if not isinstance(layer, NodeCollection):
433        raise TypeError("layer must be a NodeCollection")
434
435    if not len(layer) > 0:
436        raise ValueError("layer cannot be empty")
437
438    if not is_iterable(locations):
439        raise TypeError("locations must be coordinate array or list of coordinate arrays")
440
441    # Ensure locations is sequence, keeps code below simpler
442    if not is_iterable(locations[0]):
443        locations = (locations, )
444
445    result = []
446
447    for loc in locations:
448        d = Distance(np.array(loc), layer)
449
450        if not find_all:
451            dx = np.argmin(d)  # finds location of one minimum
452            result.append(layer[dx])
453        else:
454            minnode = list(layer[:1])
455            minval = d[0]
456            for idx in range(1, len(layer)):
457                if d[idx] < minval:
458                    minnode = [layer[idx]]
459                    minval = d[idx]
460                elif np.abs(d[idx] - minval) <= 1e-14 * minval:
461                    minnode.append(layer[idx])
462            result.append(minnode)
463
464    if len(result) == 1:
465        result = result[0]
466
467    return result
468
469
470def _rank_specific_filename(basename):
471    """Returns file name decorated with rank."""
472
473    if NumProcesses() == 1:
474        return basename
475    else:
476        np = NumProcesses()
477        np_digs = len(str(np - 1))  # for pretty formatting
478        rk = Rank()
479        dot = basename.find('.')
480        if dot < 0:
481            return '%s-%0*d' % (basename, np_digs, rk)
482        else:
483            return '%s-%0*d%s' % (basename[:dot], np_digs, rk, basename[dot:])
484
485
486def DumpLayerNodes(layer, outname):
487    """
488    Write `node ID` and position data of `layer` to file.
489
490    Write `node ID` and position data to `outname` file. For each node in `layer`,
491    a line with the following information is written:
492        ::
493
494            node ID x-position y-position [z-position]
495
496    If `layer` contains several `node IDs`, data for all nodes in `layer` will be written to a
497    single file.
498
499    Parameters
500    ----------
501    layer : NodeCollection
502        `NodeCollection` of spatially distributed node IDs
503    outname : str
504        Name of file to write to (existing files are overwritten)
505
506    See also
507    --------
508    DumpLayerConnections: Write connectivity information to file.
509    GetPosition: Return the spatial locations of nodes.
510
511    Notes
512    -----
513    * If calling this function from a distributed simulation, this function
514      will write to one file per MPI rank.
515    * File names are formed by adding the MPI Rank into the file name before
516      the file name suffix.
517    * Each file stores data for nodes local to that file.
518
519    Example
520    -------
521        ::
522
523            import nest
524
525            # create a spatial population
526            s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5]))
527
528            # write layer node positions to file
529            nest.DumpLayerNodes(s_nodes, 'positions.txt')
530
531    """
532    if not isinstance(layer, NodeCollection):
533        raise TypeError("layer must be a NodeCollection")
534
535    sli_func("""
536             (w) file exch DumpLayerNodes close
537             """,
538             layer, _rank_specific_filename(outname))
539
540
541def DumpLayerConnections(source_layer, target_layer, synapse_model, outname):
542    """
543    Write connectivity information to file.
544
545    This function writes connection information to file for all outgoing
546    connections from the given layers with the given synapse model.
547
548    For each connection, one line is stored, in the following format:
549        ::
550
551            source_node_id target_node_id weight delay dx dy [dz]
552
553    where (dx, dy [, dz]) is the displacement from source to target node.
554    If targets do not have positions (eg spike recorders outside any layer),
555    NaN is written for each displacement coordinate.
556
557    Parameters
558    ----------
559    source_layers : NodeCollection
560        `NodeCollection` of spatially distributed node IDs
561    target_layers : NodeCollection
562       `NodeCollection` of (spatially distributed) node IDs
563    synapse_model : str
564        NEST synapse model
565    outname : str
566        Name of file to write to (will be overwritten if it exists)
567
568    See also
569    --------
570    DumpLayerNodes: Write layer node positions to file.
571    GetPosition: Return the spatial locations of nodes.
572    GetConnections: Return connection identifiers between
573        sources and targets
574
575    Notes
576    -----
577    * If calling this function from a distributed simulation, this function
578      will write to one file per MPI rank.
579    * File names are formed by inserting
580      the MPI Rank into the file name before the file name suffix.
581    * Each file stores data for local nodes.
582
583    **Example**
584        ::
585
586            import nest
587
588            # create a spatial population
589            s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5]))
590
591            nest.Connect(s_nodes, s_nodes,
592                         {'rule': 'pairwise_bernoulli', 'p': 1.0},
593                         {'synapse_model': 'static_synapse'})
594
595            # write connectivity information to file
596            nest.DumpLayerConnections(s_nodes, s_nodes, 'static_synapse', 'conns.txt')
597    """
598    if not isinstance(source_layer, NodeCollection):
599        raise TypeError("source_layer must be a NodeCollection")
600    if not isinstance(target_layer, NodeCollection):
601        raise TypeError("target_layer must be a NodeCollection")
602
603    sli_func("""
604             /oname  Set
605             cvlit /synmod Set
606             /lyr_target Set
607             /lyr_source Set
608             oname (w) file lyr_source lyr_target synmod
609             DumpLayerConnections close
610             """,
611             source_layer, target_layer, synapse_model,
612             _rank_specific_filename(outname))
613
614
615def FindCenterElement(layer):
616    """
617    Return `NodeCollection` of node closest to center of `layer`.
618
619    Parameters
620    ----------
621    layer : NodeCollection
622        `NodeCollection` with spatially distributed node IDs
623
624    Returns
625    -------
626    NodeCollection:
627        `NodeCollection` of the node closest to the center of the `layer`, as specified by `layer`
628        parameters given in ``layer.spatial``. If several nodes are equally close to the center,
629        an arbitrary one of them is returned.
630
631    See also
632    --------
633    FindNearestElement: Return the node(s) closest to the location(s) in the given `layer`.
634    GetPosition: Return the spatial locations of nodes.
635
636    Example
637    -------
638        ::
639
640            import nest
641
642            # create a spatial population
643            s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5]))
644
645            # get NodeCollection of the element closest to the center of the layer
646            nest.FindCenterElement(s_nodes)
647    """
648
649    if not isinstance(layer, NodeCollection):
650        raise TypeError("layer must be a NodeCollection")
651    nearest_to_center = FindNearestElement(layer, layer.spatial['center'])[0]
652    index = layer.index(nearest_to_center.get('global_id'))
653    return layer[index:index+1]
654
655
656def GetTargetNodes(sources, tgt_layer, syn_model=None):
657    """
658    Obtain targets of `sources` in given `target` population.
659
660    For each neuron in `sources`, this function finds all target elements
661    in `tgt_layer`. If `syn_model` is not given (default), all targets are
662    returned, otherwise only targets of specific type.
663
664    Parameters
665    ----------
666    sources : NodeCollection
667        NodeCollection with node IDs of `sources`
668    tgt_layer : NodeCollection
669        NodeCollection with node IDs of `tgt_layer`
670    syn_model : [None | str], optional, default: None
671        Return only target positions for a given synapse model.
672
673    Returns
674    -------
675    tuple of NodeCollection:
676        Tuple of `NodeCollections` of target neurons fulfilling the given criteria, one `NodeCollection` per
677        source node ID in `sources`.
678
679    See also
680    --------
681    GetTargetPositions: Obtain positions of targets in a given target layer connected to given source.
682    GetConnections: Return connection identifiers between
683        sources and targets
684
685    Notes
686    -----
687    * For distributed simulations, this function only returns targets on the
688      local MPI process.
689
690    Example
691    -------
692        ::
693
694            import nest
695
696            # create a spatial population
697            s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[11, 11], extent=[11., 11.]))
698
699            # connectivity specifications with a mask
700            conndict = {'rule': 'pairwise_bernoulli', 'p': 1.,
701                        'mask': {'rectangular': {'lower_left' : [-2.0, -1.0],
702                                                 'upper_right': [2.0, 1.0]}}}
703
704            # connect population s_nodes with itself according to the given
705            # specifications
706            nest.Connect(s_nodes, s_nodes, conndict)
707
708            # get the node IDs of the targets of a source neuron
709            nest.GetTargetNodes(s_nodes[4], s_nodes)
710    """
711    if not isinstance(sources, NodeCollection):
712        raise TypeError("sources must be a NodeCollection.")
713
714    if not isinstance(tgt_layer, NodeCollection):
715        raise TypeError("tgt_layer must be a NodeCollection")
716
717    conns = GetConnections(sources, tgt_layer, synapse_model=syn_model)
718
719    # Re-organize conns into one list per source, containing only target node IDs.
720    src_tgt_map = dict((snode_id, []) for snode_id in sources.tolist())
721    for src, tgt in zip(conns.sources(), conns.targets()):
722        src_tgt_map[src].append(tgt)
723
724    for src in src_tgt_map.keys():
725        src_tgt_map[src] = NodeCollection(list(np.unique(src_tgt_map[src])))
726
727    # convert dict to nested list in same order as sources
728    return tuple(src_tgt_map[snode_id] for snode_id in sources.tolist())
729
730
731def GetTargetPositions(sources, tgt_layer, syn_model=None):
732    """
733    Obtain positions of targets to a given `NodeCollection` of `sources`.
734
735    For each neuron in `sources`, this function finds all target elements
736    in `tgt_layer`. If `syn_model` is not given (default), all targets are
737    returned, otherwise only targets of specific type.
738
739    Parameters
740    ----------
741    sources : NodeCollection
742        `NodeCollection` with node ID(s) of source neurons
743    tgt_layer : NodeCollection
744        `NodeCollection` of tgt_layer
745    syn_type : [None | str], optional, default: None
746        Return only target positions for a given synapse model.
747
748    Returns
749    -------
750    list of list(s) of tuple(s) of floats:
751        Positions of target neurons fulfilling the given criteria as a nested
752        list, containing one list of positions per node in sources.
753
754    See also
755    --------
756    GetTargetNodes: Obtain targets of a `NodeCollection` of sources in a given target
757        population.
758
759    Notes
760    -----
761    * For distributed simulations, this function only returns targets on the
762      local MPI process.
763
764    Example
765    -------
766        ::
767
768            import nest
769
770            # create a spatial population
771            s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[11, 11], extent=[11., 11.]))
772
773            # connectivity specifications with a mask
774            conndict = {'rule': 'pairwise_bernoulli', 'p': 1.,
775                        'mask': {'rectangular': {'lower_left' : [-2.0, -1.0],
776                                                 'upper_right': [2.0, 1.0]}}}
777
778            # connect population s_nodes with itself according to the given
779            # specifications
780            nest.Connect(s_nodes, s_nodes, conndict)
781
782            # get the positions of the targets of a source neuron
783            nest.GetTargetPositions(s_nodes[5], s_nodes)
784    """
785    if not isinstance(sources, NodeCollection):
786        raise TypeError("sources must be a NodeCollection.")
787
788    # Find positions to all nodes in target layer
789    pos_all_tgts = GetPosition(tgt_layer)
790    first_tgt_node_id = tgt_layer[0].get('global_id')
791
792    connections = GetConnections(sources, tgt_layer,
793                                 synapse_model=syn_model)
794    srcs = connections.get('source')
795    tgts = connections.get('target')
796    if isinstance(srcs, int):
797        srcs = [srcs]
798    if isinstance(tgts, int):
799        tgts = [tgts]
800
801    # Make dictionary where the keys are the source node_ids, which is mapped to a
802    # list with the positions of the targets connected to the source.
803    src_tgt_pos_map = dict((snode_id, []) for snode_id in sources.tolist())
804    for i in range(len(connections)):
805        tgt_indx = tgts[i] - first_tgt_node_id
806        src_tgt_pos_map[srcs[i]].append(pos_all_tgts[tgt_indx])
807
808    # Turn dict into list in same order as sources
809    return [src_tgt_pos_map[snode_id] for snode_id in sources.tolist()]
810
811
812def SelectNodesByMask(layer, anchor, mask_obj):
813    """
814    Obtain the node IDs inside a masked area of a spatially distributed population.
815
816    The function finds and returns all the node IDs inside a given mask of a
817    `layer`. The node IDs are returned as a `NodeCollection`. The function works on both 2-dimensional and
818    3-dimensional masks and layers. All mask types are allowed, including combined masks.
819
820    Parameters
821    ----------
822    layer : NodeCollection
823        `NodeCollection` with node IDs of the `layer` to select nodes from.
824    anchor : tuple/list of double
825        List containing center position of the layer. This is the point from
826        where we start to search.
827    mask_obj: object
828        `Mask` object specifying chosen area.
829
830    Returns
831    -------
832    NodeCollection:
833        `NodeCollection` of nodes/elements inside the mask.
834    """
835
836    if not isinstance(layer, NodeCollection):
837        raise TypeError("layer must be a NodeCollection.")
838
839    mask_datum = mask_obj._datum
840
841    node_id_list = sli_func('SelectNodesByMask',
842                            layer, anchor, mask_datum)
843
844    # When creating a NodeCollection, the input list of nodes IDs must be sorted.
845    return NodeCollection(sorted(node_id_list))
846
847
848def _draw_extent(ax, xctr, yctr, xext, yext):
849    """Draw extent and set aspect ration, limits"""
850
851    # import pyplot here and not at toplevel to avoid preventing users
852    # from changing matplotlib backend after importing nest
853    import matplotlib.pyplot as plt
854
855    # thin gray line indicating extent
856    llx, lly = xctr - xext / 2.0, yctr - yext / 2.0
857    urx, ury = llx + xext, lly + yext
858    ax.add_patch(
859        plt.Rectangle((llx, lly), xext, yext, fc='none', ec='0.5', lw=1,
860                      zorder=1))
861
862    # set limits slightly outside extent
863    ax.set(aspect='equal',
864           xlim=(llx - 0.05 * xext, urx + 0.05 * xext),
865           ylim=(lly - 0.05 * yext, ury + 0.05 * yext),
866           xticks=tuple(), yticks=tuple())
867
868
869def _shifted_positions(pos, ext):
870    """Get shifted positions corresponding to boundary conditions."""
871    return [[pos[0] + ext[0], pos[1]],
872            [pos[0] - ext[0], pos[1]],
873            [pos[0], pos[1] + ext[1]],
874            [pos[0], pos[1] - ext[1]],
875            [pos[0] + ext[0], pos[1] - ext[1]],
876            [pos[0] - ext[0], pos[1] + ext[1]],
877            [pos[0] + ext[0], pos[1] + ext[1]],
878            [pos[0] - ext[0], pos[1] - ext[1]]]
879
880
881def PlotLayer(layer, fig=None, nodecolor='b', nodesize=20):
882    """
883    Plot all nodes in a `layer`.
884
885    Parameters
886    ----------
887    layer : NodeCollection
888        `NodeCollection` of spatially distributed nodes
889    fig : [None | matplotlib.figure.Figure object], optional, default: None
890        Matplotlib figure to plot to. If not given, a new figure is
891        created.
892    nodecolor : [None | any matplotlib color], optional, default: 'b'
893        Color for nodes
894    nodesize : float, optional, default: 20
895        Marker size for nodes
896
897    Returns
898    -------
899    `matplotlib.figure.Figure` object
900
901    See also
902    --------
903    PlotProbabilityParameter: Create a plot of the connection probability and/or mask.
904    PlotTargets: Plot all targets of a given source.
905    matplotlib.figure.Figure : matplotlib Figure class
906
907    Notes
908    -----
909    * Do **not** use this function in distributed simulations.
910
911
912    Example
913    -------
914        ::
915
916            import nest
917            import matplotlib.pyplot as plt
918
919            # create a spatial population
920            s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[11, 11], extent=[11., 11.]))
921
922            # plot layer with all its nodes
923            nest.PlotLayer(s_nodes)
924            plt.show()
925    """
926
927    # import pyplot here and not at toplevel to avoid preventing users
928    # from changing matplotlib backend after importing nest
929    import matplotlib.pyplot as plt
930
931    if not HAVE_MPL:
932        raise ImportError('Matplotlib could not be imported')
933
934    if not isinstance(layer, NodeCollection):
935        raise TypeError('layer must be a NodeCollection.')
936
937    # get layer extent
938    ext = layer.spatial['extent']
939
940    if len(ext) == 2:
941        # 2D layer
942
943        # get layer extent and center, x and y
944        xext, yext = ext
945        xctr, yctr = layer.spatial['center']
946
947        # extract position information, transpose to list of x and y pos
948        xpos, ypos = zip(*GetPosition(layer))
949
950        if fig is None:
951            fig = plt.figure()
952            ax = fig.add_subplot(111)
953        else:
954            ax = fig.gca()
955
956        ax.scatter(xpos, ypos, s=nodesize, facecolor=nodecolor)
957        _draw_extent(ax, xctr, yctr, xext, yext)
958
959    elif len(ext) == 3:
960        # 3D layer
961        from mpl_toolkits.mplot3d import Axes3D
962
963        # extract position information, transpose to list of x,y,z pos
964        pos = zip(*GetPosition(layer))
965
966        if fig is None:
967            fig = plt.figure()
968            ax = fig.add_subplot(111, projection='3d')
969        else:
970            ax = fig.gca()
971
972        ax.scatter(*pos, s=nodesize, c=nodecolor)
973        plt.draw_if_interactive()
974
975    else:
976        raise ValueError("unexpected dimension of layer")
977
978    return fig
979
980
981def PlotTargets(src_nrn, tgt_layer, syn_type=None, fig=None,
982                mask=None, probability_parameter=None,
983                src_color='red', src_size=50, tgt_color='blue', tgt_size=20,
984                mask_color='yellow', probability_cmap='Greens'):
985    """
986    Plot all targets of source neuron `src_nrn` in a target layer `tgt_layer`.
987
988    Parameters
989    ----------
990    src_nrn : NodeCollection
991        `NodeCollection` of source neuron (as single-element NodeCollection)
992    tgt_layer : NodeCollection
993        `NodeCollection` of tgt_layer
994    syn_type : [None | str], optional, default: None
995        Show only targets connected with a given synapse type
996    fig : [None | matplotlib.figure.Figure object], optional, default: None
997        Matplotlib figure to plot to. If not given, a new figure is created.
998    mask : [None | dict], optional, default: None
999        Draw mask with targets; see :py:func:`.PlotProbabilityParameter` for details.
1000    probability_parameter : [None | Parameter], optional, default: None
1001        Draw connection probability with targets; see :py:func:`.PlotProbabilityParameter` for details.
1002    src_color : [None | any matplotlib color], optional, default: 'red'
1003        Color used to mark source node position
1004    src_size : float, optional, default: 50
1005        Size of source marker (see scatter for details)
1006    tgt_color : [None | any matplotlib color], optional, default: 'blue'
1007        Color used to mark target node positions
1008    tgt_size : float, optional, default: 20
1009        Size of target markers (see scatter for details)
1010    mask_color : [None | any matplotlib color], optional, default: 'red'
1011        Color used for line marking mask
1012    probability_cmap : [None | any matplotlib cmap color], optional, default: 'Greens'
1013        Color used for lines marking probability parameter.
1014
1015    Returns
1016    -------
1017    matplotlib.figure.Figure object
1018
1019    See also
1020    --------
1021    GetTargetNodes: Obtain targets of a sources in a given target layer.
1022    GetTargetPositions: Obtain positions of targets of sources in a given target layer.
1023    probability_parameter: Add indication of connection probability and mask to axes.
1024    PlotLayer: Plot all nodes in a spatially distributed population.
1025    matplotlib.pyplot.scatter : matplotlib scatter plot.
1026
1027    Notes
1028    -----
1029    * Do **not** use this function in distributed simulations.
1030
1031    **Example**
1032        ::
1033
1034            import nest
1035            import matplotlib.pyplot as plt
1036
1037            # create a spatial population
1038            s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[11, 11], extent=[11., 11.]))
1039
1040            # connectivity specifications with a mask
1041            conndict = {'rule': 'pairwise_bernoulli', 'p': 1.,
1042                        'mask': {'rectangular': {'lower_left' : [-2.0, -1.0],
1043                                                 'upper_right': [2.0, 1.0]}}}
1044
1045            # connect population s_nodes with itself according to the given
1046            # specifications
1047            nest.Connect(s_nodes, s_nodes, conndict)
1048
1049            # plot the targets of a source neuron
1050            nest.PlotTargets(s_nodes[4], s_nodes)
1051            plt.show()
1052    """
1053
1054    # import pyplot here and not at toplevel to avoid preventing users
1055    # from changing matplotlib backend after importing nest
1056    import matplotlib.pyplot as plt
1057
1058    if not HAVE_MPL:
1059        raise ImportError("Matplotlib could not be imported")
1060
1061    if not isinstance(src_nrn, NodeCollection) or len(src_nrn) != 1:
1062        raise TypeError("src_nrn must be a single element NodeCollection.")
1063    if not isinstance(tgt_layer, NodeCollection):
1064        raise TypeError("tgt_layer must be a NodeCollection.")
1065
1066    # get position of source
1067    srcpos = GetPosition(src_nrn)
1068
1069    # get layer extent
1070    ext = tgt_layer.spatial['extent']
1071
1072    if len(ext) == 2:
1073        # 2D layer
1074
1075        # get layer extent and center, x and y
1076        xext, yext = ext
1077        xctr, yctr = tgt_layer.spatial['center']
1078
1079        if fig is None:
1080            fig = plt.figure()
1081            ax = fig.add_subplot(111)
1082        else:
1083            ax = fig.gca()
1084
1085        # get positions, reorganize to x and y vectors
1086        tgtpos = GetTargetPositions(src_nrn, tgt_layer, syn_type)
1087        if tgtpos:
1088            xpos, ypos = zip(*tgtpos[0])
1089            ax.scatter(xpos, ypos, s=tgt_size, facecolor=tgt_color)
1090
1091        ax.scatter(srcpos[:1], srcpos[1:], s=src_size, facecolor=src_color, alpha=0.4, zorder=-10)
1092
1093        if mask is not None or probability_parameter is not None:
1094            edges = [xctr - xext, xctr + xext, yctr - yext, yctr + yext]
1095            PlotProbabilityParameter(src_nrn, probability_parameter, mask=mask, edges=edges, ax=ax,
1096                                     prob_cmap=probability_cmap, mask_color=mask_color)
1097
1098        _draw_extent(ax, xctr, yctr, xext, yext)
1099
1100    else:
1101        # 3D layer
1102        from mpl_toolkits.mplot3d import Axes3D
1103
1104        if fig is None:
1105            fig = plt.figure()
1106            ax = fig.add_subplot(111, projection='3d')
1107        else:
1108            ax = fig.gca()
1109
1110        # get positions, reorganize to x,y,z vectors
1111        tgtpos = GetTargetPositions(src_nrn, tgt_layer, syn_type)
1112        if tgtpos:
1113            xpos, ypos, zpos = zip(*tgtpos[0])
1114            ax.scatter3D(xpos, ypos, zpos, s=tgt_size, facecolor=tgt_color)
1115
1116        ax.scatter3D(srcpos[:1], srcpos[1:2], srcpos[2:], s=src_size, facecolor=src_color, alpha=0.4, zorder=-10)
1117
1118    plt.draw_if_interactive()
1119
1120    return fig
1121
1122
1123def _create_mask_patches(mask, periodic, extent, source_pos, face_color='yellow'):
1124    """Create Matplotlib Patch objects representing the mask"""
1125
1126    # import pyplot here and not at toplevel to avoid preventing users
1127    # from changing matplotlib backend after importing nest
1128    import matplotlib.pyplot as plt
1129    import matplotlib as mtpl
1130
1131    edge_color = 'black'
1132    alpha = 0.2
1133    line_width = 2
1134    mask_patches = []
1135
1136    if 'anchor' in mask:
1137        offs = np.array(mask['anchor'])
1138    else:
1139        offs = np.array([0., 0.])
1140
1141    if 'circular' in mask:
1142        r = mask['circular']['radius']
1143
1144        patch = plt.Circle(source_pos + offs, radius=r,
1145                           fc=face_color, ec=edge_color, alpha=alpha, lw=line_width)
1146        mask_patches.append(patch)
1147
1148        if periodic:
1149            for pos in _shifted_positions(source_pos + offs, extent):
1150                patch = plt.Circle(pos, radius=r,
1151                                   fc=face_color, ec=edge_color, alpha=alpha, lw=line_width)
1152                mask_patches.append(patch)
1153    elif 'doughnut' in mask:
1154        # Mmm... doughnut
1155        def make_doughnut_patch(pos, r_out, r_in, ec, fc, alpha):
1156            def make_circle(r):
1157                t = np.arange(0, np.pi * 2.0, 0.01)
1158                t = t.reshape((len(t), 1))
1159                x = r * np.cos(t)
1160                y = r * np.sin(t)
1161                return np.hstack((x, y))
1162            outside_verts = make_circle(r_out)[::-1]
1163            inside_verts = make_circle(r_in)
1164            codes = np.ones(len(inside_verts), dtype=mpath.Path.code_type) * mpath.Path.LINETO
1165            codes[0] = mpath.Path.MOVETO
1166            vertices = np.concatenate([outside_verts, inside_verts])
1167            vertices += pos
1168            all_codes = np.concatenate((codes, codes))
1169            path = mpath.Path(vertices, all_codes)
1170            return mpatches.PathPatch(path, fc=fc, ec=ec, alpha=alpha, lw=line_width)
1171
1172        r_in = mask['doughnut']['inner_radius']
1173        r_out = mask['doughnut']['outer_radius']
1174        pos = source_pos + offs
1175        patch = make_doughnut_patch(pos, r_in, r_out, edge_color, face_color, alpha)
1176        mask_patches.append(patch)
1177        if periodic:
1178            for pos in _shifted_positions(source_pos + offs, extent):
1179                patch = make_doughnut_patch(pos, r_in, r_out, edge_color, face_color, alpha)
1180                mask_patches.append(patch)
1181    elif 'rectangular' in mask:
1182        ll = np.array(mask['rectangular']['lower_left'])
1183        ur = np.array(mask['rectangular']['upper_right'])
1184        width = ur[0] - ll[0]
1185        height = ur[1] - ll[1]
1186        pos = source_pos + ll + offs
1187        cntr = [pos[0] + width/2, pos[1] + height/2]
1188
1189        if 'azimuth_angle' in mask['rectangular']:
1190            angle = mask['rectangular']['azimuth_angle']
1191        else:
1192            angle = 0.0
1193
1194        patch = plt.Rectangle(pos, width, height,
1195                              fc=face_color, ec=edge_color, alpha=alpha, lw=line_width)
1196        # Need to rotate about center
1197        trnsf = mtpl.transforms.Affine2D().rotate_deg_around(cntr[0], cntr[1], angle) + plt.gca().transData
1198        patch.set_transform(trnsf)
1199        mask_patches.append(patch)
1200
1201        if periodic:
1202            for pos in _shifted_positions(source_pos + ll + offs, extent):
1203                patch = plt.Rectangle(pos, width, height,
1204                                      fc=face_color, ec=edge_color, alpha=alpha, lw=line_width)
1205
1206                cntr = [pos[0] + width/2, pos[1] + height/2]
1207                # Need to rotate about center
1208                trnsf = mtpl.transforms.Affine2D().rotate_deg_around(cntr[0], cntr[1], angle) + plt.gca().transData
1209                patch.set_transform(trnsf)
1210                mask_patches.append(patch)
1211    elif 'elliptical' in mask:
1212        width = mask['elliptical']['major_axis']
1213        height = mask['elliptical']['minor_axis']
1214        if 'azimuth_angle' in mask['elliptical']:
1215            angle = mask['elliptical']['azimuth_angle']
1216        else:
1217            angle = 0.0
1218        if 'anchor' in mask['elliptical']:
1219            anchor = mask['elliptical']['anchor']
1220        else:
1221            anchor = np.array([0., 0.])
1222        patch = mpl.patches.Ellipse(source_pos + offs + anchor, width, height,
1223                                    angle=angle, fc=face_color,
1224                                    ec=edge_color, alpha=alpha, lw=line_width)
1225        mask_patches.append(patch)
1226
1227        if periodic:
1228            for pos in _shifted_positions(source_pos + offs + anchor, extent):
1229                patch = mpl.patches.Ellipse(pos, width, height, angle=angle, fc=face_color,
1230                                            ec=edge_color, alpha=alpha, lw=line_width)
1231                mask_patches.append(patch)
1232    else:
1233        raise ValueError('Mask type cannot be plotted with this version of PyNEST.')
1234    return mask_patches
1235
1236
1237def PlotProbabilityParameter(source, parameter=None, mask=None, edges=[-0.5, 0.5, -0.5, 0.5], shape=[100, 100],
1238                             ax=None, prob_cmap='Greens', mask_color='yellow'):
1239    """
1240    Create a plot of the connection probability and/or mask.
1241
1242    A probability plot is created based on a `Parameter` and a `source`. The
1243    `Parameter` should have a distance dependency. The `source` must be given
1244    as a `NodeCollection` with a single node ID. Optionally a `mask` can also be
1245    plotted.
1246
1247    Parameters
1248    ----------
1249    source : NodeCollection
1250        Single node ID `NodeCollection` to use as source.
1251    parameter : Parameter
1252        `Parameter` the probability is based on.
1253    mask : Dictionary
1254        Optional specification of a connection mask. Connections will only
1255        be made to nodes inside the mask. See :py:func:`.CreateMask` for options on
1256        how to specify the mask.
1257    edges : list/tuple
1258        List of four edges of the region to plot. The values are given as
1259        [x_min, x_max, y_min, y_max].
1260    shape : list/tuple
1261        Number of `Parameter` values to calculate in each direction.
1262    ax : matplotlib.axes.AxesSubplot,
1263        A matplotlib axes instance to plot in. If none is given,
1264        a new one is created.
1265    """
1266
1267    # import pyplot here and not at toplevel to avoid preventing users
1268    # from changing matplotlib backend after importing nest
1269    import matplotlib.pyplot as plt
1270
1271    if not HAVE_MPL:
1272        raise ImportError('Matplotlib could not be imported')
1273
1274    if parameter is None and mask is None:
1275        raise ValueError('At least one of parameter or mask must be specified')
1276    if ax is None:
1277        fig, ax = plt.subplots()
1278    ax.set_xlim(*edges[:2])
1279    ax.set_ylim(*edges[2:])
1280
1281    if parameter is not None:
1282        z = np.zeros(shape[::-1])
1283        for i, x in enumerate(np.linspace(edges[0], edges[1], shape[0])):
1284            positions = [[x, y] for y in np.linspace(edges[2], edges[3], shape[1])]
1285            values = parameter.apply(source, positions)
1286            z[:, i] = np.array(values)
1287        img = ax.imshow(np.minimum(np.maximum(z, 0.0), 1.0), extent=edges,
1288                        origin='lower', cmap=prob_cmap, vmin=0., vmax=1.)
1289        plt.colorbar(img, ax=ax, fraction=0.046, pad=0.04)
1290
1291    if mask is not None:
1292        periodic = source.spatial['edge_wrap']
1293        extent = source.spatial['extent']
1294        source_pos = GetPosition(source)
1295        patches = _create_mask_patches(mask, periodic, extent, source_pos, face_color=mask_color)
1296        for patch in patches:
1297            patch.set_zorder(0.5)
1298            ax.add_patch(patch)
1299