1# -*- coding: utf-8 -*- 2# 3# hl_api_spatial.py 4# 5# This file is part of NEST. 6# 7# Copyright (C) 2004 The NEST Initiative 8# 9# NEST is free software: you can redistribute it and/or modify 10# it under the terms of the GNU General Public License as published by 11# the Free Software Foundation, either version 2 of the License, or 12# (at your option) any later version. 13# 14# NEST is distributed in the hope that it will be useful, 15# but WITHOUT ANY WARRANTY; without even the implied warranty of 16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17# GNU General Public License for more details. 18# 19# You should have received a copy of the GNU General Public License 20# along with NEST. If not, see <http://www.gnu.org/licenses/>. 21 22""" 23Functions relating to spatial properties of nodes 24""" 25 26 27import numpy as np 28 29from ..ll_api import * 30from .. import pynestkernel as kernel 31from .hl_api_helper import * 32from .hl_api_connections import GetConnections 33from .hl_api_parallel_computing import NumProcesses, Rank 34from .hl_api_types import NodeCollection 35 36try: 37 import matplotlib as mpl 38 import matplotlib.path as mpath 39 import matplotlib.patches as mpatches 40 HAVE_MPL = True 41except ImportError: 42 HAVE_MPL = False 43 44__all__ = [ 45 'CreateMask', 46 'Displacement', 47 'Distance', 48 'DumpLayerConnections', 49 'DumpLayerNodes', 50 'FindCenterElement', 51 'FindNearestElement', 52 'GetPosition', 53 'GetTargetNodes', 54 'GetTargetPositions', 55 'PlotLayer', 56 'PlotProbabilityParameter', 57 'PlotTargets', 58 'SelectNodesByMask', 59] 60 61 62def CreateMask(masktype, specs, anchor=None): 63 """ 64 Create a spatial mask for connections. 65 66 Masks are used when creating connections. A mask describes the area of 67 the pool population that is searched for to connect for any given 68 node in the driver population. Several mask types are available. Examples 69 are the grid region, the rectangular, circular or doughnut region. 70 71 The command :py:func:`.CreateMask` creates a `Mask` object which may be combined 72 with other `Mask` objects using Boolean operators. The mask is specified 73 in a dictionary. 74 75 ``Mask`` objects can be passed to :py:func:`.Connect` in a connection dictionary with the key `'mask'`. 76 77 Parameters 78 ---------- 79 masktype : str, ['rectangular' | 'circular' | 'doughnut' | 'elliptical'] 80 for 2D masks, ['box' | 'spherical' | 'ellipsoidal] for 3D masks, 81 ['grid'] only for grid-based layers in 2D. 82 The mask name corresponds to the geometrical shape of the mask. There 83 are different types for 2- and 3-dimensional layers. 84 specs : dict 85 Dictionary specifying the parameters of the provided `masktype`, 86 see **Mask types**. 87 anchor : [tuple/list of floats | dict with the keys `'column'` and \ 88 `'row'` (for grid masks only)], optional, default: None 89 By providing anchor coordinates, the location of the mask relative to 90 the driver node can be changed. The list of coordinates has a length 91 of 2 or 3 dependent on the number of dimensions. 92 93 Returns 94 ------- 95 Mask: 96 Object representing the mask 97 98 See also 99 -------- 100 Connect 101 102 Notes 103 ----- 104 - All angles must be given in degrees. 105 106 **Mask types** 107 108 Available mask types (`masktype`) and their corresponding parameter 109 dictionaries: 110 111 * 2D free and grid-based layers 112 :: 113 114 'rectangular' : 115 {'lower_left' : [float, float], 116 'upper_right' : [float, float], 117 'azimuth_angle': float # default:0.0} 118 #or 119 'circular' : 120 {'radius' : float} 121 #or 122 'doughnut' : 123 {'inner_radius' : float, 124 'outer_radius' : float} 125 #or 126 'elliptical' : 127 {'major_axis' : float, 128 'minor_axis' : float, 129 'azimuth_angle' : float, # default: 0.0, 130 'anchor' : [float, float], # default: [0.0, 0.0]} 131 132 * 3D free and grid-based layers 133 :: 134 135 'box' : 136 {'lower_left' : [float, float, float], 137 'upper_right' : [float, float, float], 138 'azimuth_angle: float # default: 0.0, 139 'polar_angle : float # defualt: 0.0} 140 #or 141 'spherical' : 142 {'radius' : float} 143 #or 144 'ellipsoidal' : 145 {'major_axis' : float, 146 'minor_axis' : float, 147 'polar_axis' : float 148 'azimuth_angle' : float, # default: 0.0, 149 'polar_angle' : float, # default: 0.0, 150 'anchor' : [float, float, float], # default: [0.0, 0.0, 0.0]}} 151 152 * 2D grid-based layers only 153 :: 154 155 'grid' : 156 {'rows' : float, 157 'columns' : float} 158 159 By default the top-left corner of a grid mask, i.e., the grid 160 mask element with grid index [0, 0], is aligned with the driver 161 node. It can be changed by means of the 'anchor' parameter: 162 :: 163 164 'anchor' : 165 {'row' : float, 166 'column' : float} 167 168 **Example** 169 :: 170 171 import nest 172 173 # create a grid-based layer 174 l = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5])) 175 176 # create a circular mask 177 m = nest.CreateMask('circular', {'radius': 0.2}) 178 179 # connectivity specifications 180 conndict = {'rule': 'pairwise_bernoulli', 181 'p': 1.0, 182 'mask': m} 183 184 # connect layer l with itself according to the specifications 185 nest.Connect(l, l, conndict) 186 """ 187 if anchor is None: 188 return sli_func('CreateMask', {masktype: specs}) 189 else: 190 return sli_func('CreateMask', 191 {masktype: specs, 'anchor': anchor}) 192 193 194def GetPosition(nodes): 195 """ 196 Return the spatial locations of nodes. 197 198 Parameters 199 ---------- 200 nodes : NodeCollection 201 `NodeCollection` of nodes we want the positions to 202 203 Returns 204 ------- 205 tuple or tuple of tuple(s): 206 Tuple of position with 2- or 3-elements or list of positions 207 208 See also 209 -------- 210 Displacement: Get vector of lateral displacement between nodes. 211 Distance: Get lateral distance between nodes. 212 DumpLayerConnections: Write connectivity information to file. 213 DumpLayerNodes: Write node positions to file. 214 215 Notes 216 ----- 217 - The functions :py:func:`.GetPosition`, :py:func:`.Displacement` and :py:func:`.Distance` 218 only works for nodes local to the current MPI process, if used in a 219 MPI-parallel simulation. 220 221 Example 222 ------- 223 :: 224 225 import nest 226 227 # Reset kernel 228 nest.ResetKernel 229 230 # create a NodeCollection with spatial extent 231 s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5])) 232 233 # retrieve positions of all (local) nodes belonging to the population 234 pos = nest.GetPosition(s_nodes) 235 236 # retrieve positions of the first node in the NodeCollection 237 pos = nest.GetPosition(s_nodes[0]) 238 239 # retrieve positions of a subset of nodes in the population 240 pos = nest.GetPosition(s_nodes[2:18]) 241 """ 242 if not isinstance(nodes, NodeCollection): 243 raise TypeError("nodes must be a NodeCollection with spatial extent") 244 245 return sli_func('GetPosition', nodes) 246 247 248def Displacement(from_arg, to_arg): 249 """ 250 Get vector of lateral displacement from node(s)/Position(s) `from_arg` 251 to node(s) `to_arg`. 252 253 Displacement is the shortest displacement, taking into account 254 periodic boundary conditions where applicable. If explicit positions 255 are given in the `from_arg` list, they are interpreted in the `to_arg` 256 population. 257 258 - If one of `from_arg` or `to_arg` has length 1, and the other is longer, 259 the displacement from/to the single item to all other items is given. 260 - If `from_arg` and `to_arg` both have more than two elements, they have 261 to be of the same length and the displacement between each 262 pair is returned. 263 264 Parameters 265 ---------- 266 from_arg : NodeCollection or tuple/list with tuple(s)/list(s) of floats 267 `NodeCollection` of node IDs or tuple/list of position(s) 268 to_arg : NodeCollection 269 `NodeCollection` of node IDs 270 271 Returns 272 ------- 273 tuple: 274 Displacement vectors between pairs of nodes in `from_arg` and `to_arg` 275 276 See also 277 -------- 278 Distance: Get lateral distances between nodes. 279 DumpLayerConnections: Write connectivity information to file. 280 GetPosition: Return the spatial locations of nodes. 281 282 Notes 283 ----- 284 - The functions :py:func:`.GetPosition`, :py:func:`.Displacement` and :py:func:`.Distance` 285 only works for nodes local to the current MPI process, if used in a 286 MPI-parallel simulation. 287 288 **Example** 289 :: 290 291 import nest 292 293 # create a spatial population 294 s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5])) 295 296 # displacement between node 2 and 3 297 print(nest.Displacement(s_nodes[1], s_nodes[2])) 298 299 # displacment between the position (0.0., 0.0) and node 2 300 print(nest.Displacement([(0.0, 0.0)], s_nodes[1])) 301 """ 302 if not isinstance(to_arg, NodeCollection): 303 raise TypeError("to_arg must be a NodeCollection") 304 305 if isinstance(from_arg, np.ndarray): 306 from_arg = (from_arg, ) 307 308 if (len(from_arg) > 1 and len(to_arg) > 1 and not 309 len(from_arg) == len(to_arg)): 310 raise ValueError("to_arg and from_arg must have same size unless one have size 1.") 311 312 return sli_func('Displacement', from_arg, to_arg) 313 314 315def Distance(from_arg, to_arg): 316 """ 317 Get lateral distances from node(s)/position(s) `from_arg` to node(s) `to_arg`. 318 319 The distance between two nodes is the length of its displacement. 320 321 If explicit positions are given in the `from_arg` list, they are 322 interpreted in the `to_arg` population. Distance is the shortest distance, 323 taking into account periodic boundary conditions where applicable. 324 325 - If one of `from_arg` or `to_arg` has length 1, and the other is longer, 326 the displacement from/to the single item to all other items is given. 327 - If `from_arg` and `to_arg` both have more than two elements, they have 328 to be of the same length and the distance for each pair is 329 returned. 330 331 Parameters 332 ---------- 333 from_arg : NodeCollection or tuple/list with tuple(s)/list(s) of floats 334 `NodeCollection` of node IDs or tuple/list of position(s) 335 to_arg : NodeCollection 336 `NodeCollection` of node IDs 337 338 Returns 339 ------- 340 tuple: 341 Distances between `from` and `to` 342 343 See also 344 -------- 345 Displacement: Get vector of lateral displacements between nodes. 346 DumpLayerConnections: Write connectivity information to file. 347 GetPosition: Return the spatial locations of nodes. 348 349 Notes 350 ----- 351 - The functions :py:func:`.GetPosition`, :py:func:`.Displacement` and :py:func:`.Distance` 352 only works for nodes local to the current MPI process, if used in a 353 MPI-parallel simulation. 354 355 Example 356 ------- 357 :: 358 359 import nest 360 361 # create a spatial population 362 s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5])) 363 364 # distance between node 2 and 3 365 print(nest.Distance(s_nodes[1], s_nodes[2])) 366 367 # distance between the position (0.0., 0.0) and node 2 368 print(nest.Distance([(0.0, 0.0)], s_nodes[1])) 369 """ 370 if not isinstance(to_arg, NodeCollection): 371 raise TypeError("to_arg must be a NodeCollection") 372 373 if isinstance(from_arg, np.ndarray): 374 from_arg = (from_arg, ) 375 376 if (len(from_arg) > 1 and len(to_arg) > 1 and not 377 len(from_arg) == len(to_arg)): 378 raise ValueError("to_arg and from_arg must have same size unless one have size 1.") 379 380 return sli_func('Distance', from_arg, to_arg) 381 382 383def FindNearestElement(layer, locations, find_all=False): 384 """ 385 Return the node(s) closest to the `locations` in the given `layer`. 386 387 This function works for fixed grid layer only. 388 389 * If `locations` is a single 2-element array giving a grid location, return a 390 `NodeCollection` of `layer` elements at the given location. 391 * If `locations` is a list of coordinates, the function returns a list of `NodeCollection` of the nodes at all 392 locations. 393 394 Parameters 395 ---------- 396 layer : NodeCollection 397 `NodeCollection` of spatially distributed node IDs 398 locations : tuple(s)/list(s) of tuple(s)/list(s) 399 2-element list with coordinates of a single position, or list of 400 2-element list of positions 401 find_all : bool, default: False 402 If there are several nodes with same minimal distance, return only the 403 first found, if `False`. 404 If `True`, instead of returning a single `NodeCollection`, return a list of `NodeCollection` 405 containing all nodes with minimal distance. 406 407 Returns 408 ------- 409 NodeCollection: 410 `NodeCollection` of node IDs if locations is a 2-element list with coordinates of a single position 411 list: 412 list of `NodeCollection` if find_all is True or locations contains more than one position 413 414 See also 415 -------- 416 FindCenterElement: Return NodeCollection of node closest to center of layers. 417 GetPosition: Return the spatial locations of nodes. 418 419 Example 420 ------- 421 :: 422 423 import nest 424 425 # create a spatial population 426 s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5])) 427 428 # get node ID of element closest to some location 429 nest.FindNearestElement(s_nodes, [3.0, 4.0], True) 430 """ 431 432 if not isinstance(layer, NodeCollection): 433 raise TypeError("layer must be a NodeCollection") 434 435 if not len(layer) > 0: 436 raise ValueError("layer cannot be empty") 437 438 if not is_iterable(locations): 439 raise TypeError("locations must be coordinate array or list of coordinate arrays") 440 441 # Ensure locations is sequence, keeps code below simpler 442 if not is_iterable(locations[0]): 443 locations = (locations, ) 444 445 result = [] 446 447 for loc in locations: 448 d = Distance(np.array(loc), layer) 449 450 if not find_all: 451 dx = np.argmin(d) # finds location of one minimum 452 result.append(layer[dx]) 453 else: 454 minnode = list(layer[:1]) 455 minval = d[0] 456 for idx in range(1, len(layer)): 457 if d[idx] < minval: 458 minnode = [layer[idx]] 459 minval = d[idx] 460 elif np.abs(d[idx] - minval) <= 1e-14 * minval: 461 minnode.append(layer[idx]) 462 result.append(minnode) 463 464 if len(result) == 1: 465 result = result[0] 466 467 return result 468 469 470def _rank_specific_filename(basename): 471 """Returns file name decorated with rank.""" 472 473 if NumProcesses() == 1: 474 return basename 475 else: 476 np = NumProcesses() 477 np_digs = len(str(np - 1)) # for pretty formatting 478 rk = Rank() 479 dot = basename.find('.') 480 if dot < 0: 481 return '%s-%0*d' % (basename, np_digs, rk) 482 else: 483 return '%s-%0*d%s' % (basename[:dot], np_digs, rk, basename[dot:]) 484 485 486def DumpLayerNodes(layer, outname): 487 """ 488 Write `node ID` and position data of `layer` to file. 489 490 Write `node ID` and position data to `outname` file. For each node in `layer`, 491 a line with the following information is written: 492 :: 493 494 node ID x-position y-position [z-position] 495 496 If `layer` contains several `node IDs`, data for all nodes in `layer` will be written to a 497 single file. 498 499 Parameters 500 ---------- 501 layer : NodeCollection 502 `NodeCollection` of spatially distributed node IDs 503 outname : str 504 Name of file to write to (existing files are overwritten) 505 506 See also 507 -------- 508 DumpLayerConnections: Write connectivity information to file. 509 GetPosition: Return the spatial locations of nodes. 510 511 Notes 512 ----- 513 * If calling this function from a distributed simulation, this function 514 will write to one file per MPI rank. 515 * File names are formed by adding the MPI Rank into the file name before 516 the file name suffix. 517 * Each file stores data for nodes local to that file. 518 519 Example 520 ------- 521 :: 522 523 import nest 524 525 # create a spatial population 526 s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5])) 527 528 # write layer node positions to file 529 nest.DumpLayerNodes(s_nodes, 'positions.txt') 530 531 """ 532 if not isinstance(layer, NodeCollection): 533 raise TypeError("layer must be a NodeCollection") 534 535 sli_func(""" 536 (w) file exch DumpLayerNodes close 537 """, 538 layer, _rank_specific_filename(outname)) 539 540 541def DumpLayerConnections(source_layer, target_layer, synapse_model, outname): 542 """ 543 Write connectivity information to file. 544 545 This function writes connection information to file for all outgoing 546 connections from the given layers with the given synapse model. 547 548 For each connection, one line is stored, in the following format: 549 :: 550 551 source_node_id target_node_id weight delay dx dy [dz] 552 553 where (dx, dy [, dz]) is the displacement from source to target node. 554 If targets do not have positions (eg spike recorders outside any layer), 555 NaN is written for each displacement coordinate. 556 557 Parameters 558 ---------- 559 source_layers : NodeCollection 560 `NodeCollection` of spatially distributed node IDs 561 target_layers : NodeCollection 562 `NodeCollection` of (spatially distributed) node IDs 563 synapse_model : str 564 NEST synapse model 565 outname : str 566 Name of file to write to (will be overwritten if it exists) 567 568 See also 569 -------- 570 DumpLayerNodes: Write layer node positions to file. 571 GetPosition: Return the spatial locations of nodes. 572 GetConnections: Return connection identifiers between 573 sources and targets 574 575 Notes 576 ----- 577 * If calling this function from a distributed simulation, this function 578 will write to one file per MPI rank. 579 * File names are formed by inserting 580 the MPI Rank into the file name before the file name suffix. 581 * Each file stores data for local nodes. 582 583 **Example** 584 :: 585 586 import nest 587 588 # create a spatial population 589 s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5])) 590 591 nest.Connect(s_nodes, s_nodes, 592 {'rule': 'pairwise_bernoulli', 'p': 1.0}, 593 {'synapse_model': 'static_synapse'}) 594 595 # write connectivity information to file 596 nest.DumpLayerConnections(s_nodes, s_nodes, 'static_synapse', 'conns.txt') 597 """ 598 if not isinstance(source_layer, NodeCollection): 599 raise TypeError("source_layer must be a NodeCollection") 600 if not isinstance(target_layer, NodeCollection): 601 raise TypeError("target_layer must be a NodeCollection") 602 603 sli_func(""" 604 /oname Set 605 cvlit /synmod Set 606 /lyr_target Set 607 /lyr_source Set 608 oname (w) file lyr_source lyr_target synmod 609 DumpLayerConnections close 610 """, 611 source_layer, target_layer, synapse_model, 612 _rank_specific_filename(outname)) 613 614 615def FindCenterElement(layer): 616 """ 617 Return `NodeCollection` of node closest to center of `layer`. 618 619 Parameters 620 ---------- 621 layer : NodeCollection 622 `NodeCollection` with spatially distributed node IDs 623 624 Returns 625 ------- 626 NodeCollection: 627 `NodeCollection` of the node closest to the center of the `layer`, as specified by `layer` 628 parameters given in ``layer.spatial``. If several nodes are equally close to the center, 629 an arbitrary one of them is returned. 630 631 See also 632 -------- 633 FindNearestElement: Return the node(s) closest to the location(s) in the given `layer`. 634 GetPosition: Return the spatial locations of nodes. 635 636 Example 637 ------- 638 :: 639 640 import nest 641 642 # create a spatial population 643 s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[5, 5])) 644 645 # get NodeCollection of the element closest to the center of the layer 646 nest.FindCenterElement(s_nodes) 647 """ 648 649 if not isinstance(layer, NodeCollection): 650 raise TypeError("layer must be a NodeCollection") 651 nearest_to_center = FindNearestElement(layer, layer.spatial['center'])[0] 652 index = layer.index(nearest_to_center.get('global_id')) 653 return layer[index:index+1] 654 655 656def GetTargetNodes(sources, tgt_layer, syn_model=None): 657 """ 658 Obtain targets of `sources` in given `target` population. 659 660 For each neuron in `sources`, this function finds all target elements 661 in `tgt_layer`. If `syn_model` is not given (default), all targets are 662 returned, otherwise only targets of specific type. 663 664 Parameters 665 ---------- 666 sources : NodeCollection 667 NodeCollection with node IDs of `sources` 668 tgt_layer : NodeCollection 669 NodeCollection with node IDs of `tgt_layer` 670 syn_model : [None | str], optional, default: None 671 Return only target positions for a given synapse model. 672 673 Returns 674 ------- 675 tuple of NodeCollection: 676 Tuple of `NodeCollections` of target neurons fulfilling the given criteria, one `NodeCollection` per 677 source node ID in `sources`. 678 679 See also 680 -------- 681 GetTargetPositions: Obtain positions of targets in a given target layer connected to given source. 682 GetConnections: Return connection identifiers between 683 sources and targets 684 685 Notes 686 ----- 687 * For distributed simulations, this function only returns targets on the 688 local MPI process. 689 690 Example 691 ------- 692 :: 693 694 import nest 695 696 # create a spatial population 697 s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[11, 11], extent=[11., 11.])) 698 699 # connectivity specifications with a mask 700 conndict = {'rule': 'pairwise_bernoulli', 'p': 1., 701 'mask': {'rectangular': {'lower_left' : [-2.0, -1.0], 702 'upper_right': [2.0, 1.0]}}} 703 704 # connect population s_nodes with itself according to the given 705 # specifications 706 nest.Connect(s_nodes, s_nodes, conndict) 707 708 # get the node IDs of the targets of a source neuron 709 nest.GetTargetNodes(s_nodes[4], s_nodes) 710 """ 711 if not isinstance(sources, NodeCollection): 712 raise TypeError("sources must be a NodeCollection.") 713 714 if not isinstance(tgt_layer, NodeCollection): 715 raise TypeError("tgt_layer must be a NodeCollection") 716 717 conns = GetConnections(sources, tgt_layer, synapse_model=syn_model) 718 719 # Re-organize conns into one list per source, containing only target node IDs. 720 src_tgt_map = dict((snode_id, []) for snode_id in sources.tolist()) 721 for src, tgt in zip(conns.sources(), conns.targets()): 722 src_tgt_map[src].append(tgt) 723 724 for src in src_tgt_map.keys(): 725 src_tgt_map[src] = NodeCollection(list(np.unique(src_tgt_map[src]))) 726 727 # convert dict to nested list in same order as sources 728 return tuple(src_tgt_map[snode_id] for snode_id in sources.tolist()) 729 730 731def GetTargetPositions(sources, tgt_layer, syn_model=None): 732 """ 733 Obtain positions of targets to a given `NodeCollection` of `sources`. 734 735 For each neuron in `sources`, this function finds all target elements 736 in `tgt_layer`. If `syn_model` is not given (default), all targets are 737 returned, otherwise only targets of specific type. 738 739 Parameters 740 ---------- 741 sources : NodeCollection 742 `NodeCollection` with node ID(s) of source neurons 743 tgt_layer : NodeCollection 744 `NodeCollection` of tgt_layer 745 syn_type : [None | str], optional, default: None 746 Return only target positions for a given synapse model. 747 748 Returns 749 ------- 750 list of list(s) of tuple(s) of floats: 751 Positions of target neurons fulfilling the given criteria as a nested 752 list, containing one list of positions per node in sources. 753 754 See also 755 -------- 756 GetTargetNodes: Obtain targets of a `NodeCollection` of sources in a given target 757 population. 758 759 Notes 760 ----- 761 * For distributed simulations, this function only returns targets on the 762 local MPI process. 763 764 Example 765 ------- 766 :: 767 768 import nest 769 770 # create a spatial population 771 s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[11, 11], extent=[11., 11.])) 772 773 # connectivity specifications with a mask 774 conndict = {'rule': 'pairwise_bernoulli', 'p': 1., 775 'mask': {'rectangular': {'lower_left' : [-2.0, -1.0], 776 'upper_right': [2.0, 1.0]}}} 777 778 # connect population s_nodes with itself according to the given 779 # specifications 780 nest.Connect(s_nodes, s_nodes, conndict) 781 782 # get the positions of the targets of a source neuron 783 nest.GetTargetPositions(s_nodes[5], s_nodes) 784 """ 785 if not isinstance(sources, NodeCollection): 786 raise TypeError("sources must be a NodeCollection.") 787 788 # Find positions to all nodes in target layer 789 pos_all_tgts = GetPosition(tgt_layer) 790 first_tgt_node_id = tgt_layer[0].get('global_id') 791 792 connections = GetConnections(sources, tgt_layer, 793 synapse_model=syn_model) 794 srcs = connections.get('source') 795 tgts = connections.get('target') 796 if isinstance(srcs, int): 797 srcs = [srcs] 798 if isinstance(tgts, int): 799 tgts = [tgts] 800 801 # Make dictionary where the keys are the source node_ids, which is mapped to a 802 # list with the positions of the targets connected to the source. 803 src_tgt_pos_map = dict((snode_id, []) for snode_id in sources.tolist()) 804 for i in range(len(connections)): 805 tgt_indx = tgts[i] - first_tgt_node_id 806 src_tgt_pos_map[srcs[i]].append(pos_all_tgts[tgt_indx]) 807 808 # Turn dict into list in same order as sources 809 return [src_tgt_pos_map[snode_id] for snode_id in sources.tolist()] 810 811 812def SelectNodesByMask(layer, anchor, mask_obj): 813 """ 814 Obtain the node IDs inside a masked area of a spatially distributed population. 815 816 The function finds and returns all the node IDs inside a given mask of a 817 `layer`. The node IDs are returned as a `NodeCollection`. The function works on both 2-dimensional and 818 3-dimensional masks and layers. All mask types are allowed, including combined masks. 819 820 Parameters 821 ---------- 822 layer : NodeCollection 823 `NodeCollection` with node IDs of the `layer` to select nodes from. 824 anchor : tuple/list of double 825 List containing center position of the layer. This is the point from 826 where we start to search. 827 mask_obj: object 828 `Mask` object specifying chosen area. 829 830 Returns 831 ------- 832 NodeCollection: 833 `NodeCollection` of nodes/elements inside the mask. 834 """ 835 836 if not isinstance(layer, NodeCollection): 837 raise TypeError("layer must be a NodeCollection.") 838 839 mask_datum = mask_obj._datum 840 841 node_id_list = sli_func('SelectNodesByMask', 842 layer, anchor, mask_datum) 843 844 # When creating a NodeCollection, the input list of nodes IDs must be sorted. 845 return NodeCollection(sorted(node_id_list)) 846 847 848def _draw_extent(ax, xctr, yctr, xext, yext): 849 """Draw extent and set aspect ration, limits""" 850 851 # import pyplot here and not at toplevel to avoid preventing users 852 # from changing matplotlib backend after importing nest 853 import matplotlib.pyplot as plt 854 855 # thin gray line indicating extent 856 llx, lly = xctr - xext / 2.0, yctr - yext / 2.0 857 urx, ury = llx + xext, lly + yext 858 ax.add_patch( 859 plt.Rectangle((llx, lly), xext, yext, fc='none', ec='0.5', lw=1, 860 zorder=1)) 861 862 # set limits slightly outside extent 863 ax.set(aspect='equal', 864 xlim=(llx - 0.05 * xext, urx + 0.05 * xext), 865 ylim=(lly - 0.05 * yext, ury + 0.05 * yext), 866 xticks=tuple(), yticks=tuple()) 867 868 869def _shifted_positions(pos, ext): 870 """Get shifted positions corresponding to boundary conditions.""" 871 return [[pos[0] + ext[0], pos[1]], 872 [pos[0] - ext[0], pos[1]], 873 [pos[0], pos[1] + ext[1]], 874 [pos[0], pos[1] - ext[1]], 875 [pos[0] + ext[0], pos[1] - ext[1]], 876 [pos[0] - ext[0], pos[1] + ext[1]], 877 [pos[0] + ext[0], pos[1] + ext[1]], 878 [pos[0] - ext[0], pos[1] - ext[1]]] 879 880 881def PlotLayer(layer, fig=None, nodecolor='b', nodesize=20): 882 """ 883 Plot all nodes in a `layer`. 884 885 Parameters 886 ---------- 887 layer : NodeCollection 888 `NodeCollection` of spatially distributed nodes 889 fig : [None | matplotlib.figure.Figure object], optional, default: None 890 Matplotlib figure to plot to. If not given, a new figure is 891 created. 892 nodecolor : [None | any matplotlib color], optional, default: 'b' 893 Color for nodes 894 nodesize : float, optional, default: 20 895 Marker size for nodes 896 897 Returns 898 ------- 899 `matplotlib.figure.Figure` object 900 901 See also 902 -------- 903 PlotProbabilityParameter: Create a plot of the connection probability and/or mask. 904 PlotTargets: Plot all targets of a given source. 905 matplotlib.figure.Figure : matplotlib Figure class 906 907 Notes 908 ----- 909 * Do **not** use this function in distributed simulations. 910 911 912 Example 913 ------- 914 :: 915 916 import nest 917 import matplotlib.pyplot as plt 918 919 # create a spatial population 920 s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[11, 11], extent=[11., 11.])) 921 922 # plot layer with all its nodes 923 nest.PlotLayer(s_nodes) 924 plt.show() 925 """ 926 927 # import pyplot here and not at toplevel to avoid preventing users 928 # from changing matplotlib backend after importing nest 929 import matplotlib.pyplot as plt 930 931 if not HAVE_MPL: 932 raise ImportError('Matplotlib could not be imported') 933 934 if not isinstance(layer, NodeCollection): 935 raise TypeError('layer must be a NodeCollection.') 936 937 # get layer extent 938 ext = layer.spatial['extent'] 939 940 if len(ext) == 2: 941 # 2D layer 942 943 # get layer extent and center, x and y 944 xext, yext = ext 945 xctr, yctr = layer.spatial['center'] 946 947 # extract position information, transpose to list of x and y pos 948 xpos, ypos = zip(*GetPosition(layer)) 949 950 if fig is None: 951 fig = plt.figure() 952 ax = fig.add_subplot(111) 953 else: 954 ax = fig.gca() 955 956 ax.scatter(xpos, ypos, s=nodesize, facecolor=nodecolor) 957 _draw_extent(ax, xctr, yctr, xext, yext) 958 959 elif len(ext) == 3: 960 # 3D layer 961 from mpl_toolkits.mplot3d import Axes3D 962 963 # extract position information, transpose to list of x,y,z pos 964 pos = zip(*GetPosition(layer)) 965 966 if fig is None: 967 fig = plt.figure() 968 ax = fig.add_subplot(111, projection='3d') 969 else: 970 ax = fig.gca() 971 972 ax.scatter(*pos, s=nodesize, c=nodecolor) 973 plt.draw_if_interactive() 974 975 else: 976 raise ValueError("unexpected dimension of layer") 977 978 return fig 979 980 981def PlotTargets(src_nrn, tgt_layer, syn_type=None, fig=None, 982 mask=None, probability_parameter=None, 983 src_color='red', src_size=50, tgt_color='blue', tgt_size=20, 984 mask_color='yellow', probability_cmap='Greens'): 985 """ 986 Plot all targets of source neuron `src_nrn` in a target layer `tgt_layer`. 987 988 Parameters 989 ---------- 990 src_nrn : NodeCollection 991 `NodeCollection` of source neuron (as single-element NodeCollection) 992 tgt_layer : NodeCollection 993 `NodeCollection` of tgt_layer 994 syn_type : [None | str], optional, default: None 995 Show only targets connected with a given synapse type 996 fig : [None | matplotlib.figure.Figure object], optional, default: None 997 Matplotlib figure to plot to. If not given, a new figure is created. 998 mask : [None | dict], optional, default: None 999 Draw mask with targets; see :py:func:`.PlotProbabilityParameter` for details. 1000 probability_parameter : [None | Parameter], optional, default: None 1001 Draw connection probability with targets; see :py:func:`.PlotProbabilityParameter` for details. 1002 src_color : [None | any matplotlib color], optional, default: 'red' 1003 Color used to mark source node position 1004 src_size : float, optional, default: 50 1005 Size of source marker (see scatter for details) 1006 tgt_color : [None | any matplotlib color], optional, default: 'blue' 1007 Color used to mark target node positions 1008 tgt_size : float, optional, default: 20 1009 Size of target markers (see scatter for details) 1010 mask_color : [None | any matplotlib color], optional, default: 'red' 1011 Color used for line marking mask 1012 probability_cmap : [None | any matplotlib cmap color], optional, default: 'Greens' 1013 Color used for lines marking probability parameter. 1014 1015 Returns 1016 ------- 1017 matplotlib.figure.Figure object 1018 1019 See also 1020 -------- 1021 GetTargetNodes: Obtain targets of a sources in a given target layer. 1022 GetTargetPositions: Obtain positions of targets of sources in a given target layer. 1023 probability_parameter: Add indication of connection probability and mask to axes. 1024 PlotLayer: Plot all nodes in a spatially distributed population. 1025 matplotlib.pyplot.scatter : matplotlib scatter plot. 1026 1027 Notes 1028 ----- 1029 * Do **not** use this function in distributed simulations. 1030 1031 **Example** 1032 :: 1033 1034 import nest 1035 import matplotlib.pyplot as plt 1036 1037 # create a spatial population 1038 s_nodes = nest.Create('iaf_psc_alpha', positions=nest.spatial.grid(shape=[11, 11], extent=[11., 11.])) 1039 1040 # connectivity specifications with a mask 1041 conndict = {'rule': 'pairwise_bernoulli', 'p': 1., 1042 'mask': {'rectangular': {'lower_left' : [-2.0, -1.0], 1043 'upper_right': [2.0, 1.0]}}} 1044 1045 # connect population s_nodes with itself according to the given 1046 # specifications 1047 nest.Connect(s_nodes, s_nodes, conndict) 1048 1049 # plot the targets of a source neuron 1050 nest.PlotTargets(s_nodes[4], s_nodes) 1051 plt.show() 1052 """ 1053 1054 # import pyplot here and not at toplevel to avoid preventing users 1055 # from changing matplotlib backend after importing nest 1056 import matplotlib.pyplot as plt 1057 1058 if not HAVE_MPL: 1059 raise ImportError("Matplotlib could not be imported") 1060 1061 if not isinstance(src_nrn, NodeCollection) or len(src_nrn) != 1: 1062 raise TypeError("src_nrn must be a single element NodeCollection.") 1063 if not isinstance(tgt_layer, NodeCollection): 1064 raise TypeError("tgt_layer must be a NodeCollection.") 1065 1066 # get position of source 1067 srcpos = GetPosition(src_nrn) 1068 1069 # get layer extent 1070 ext = tgt_layer.spatial['extent'] 1071 1072 if len(ext) == 2: 1073 # 2D layer 1074 1075 # get layer extent and center, x and y 1076 xext, yext = ext 1077 xctr, yctr = tgt_layer.spatial['center'] 1078 1079 if fig is None: 1080 fig = plt.figure() 1081 ax = fig.add_subplot(111) 1082 else: 1083 ax = fig.gca() 1084 1085 # get positions, reorganize to x and y vectors 1086 tgtpos = GetTargetPositions(src_nrn, tgt_layer, syn_type) 1087 if tgtpos: 1088 xpos, ypos = zip(*tgtpos[0]) 1089 ax.scatter(xpos, ypos, s=tgt_size, facecolor=tgt_color) 1090 1091 ax.scatter(srcpos[:1], srcpos[1:], s=src_size, facecolor=src_color, alpha=0.4, zorder=-10) 1092 1093 if mask is not None or probability_parameter is not None: 1094 edges = [xctr - xext, xctr + xext, yctr - yext, yctr + yext] 1095 PlotProbabilityParameter(src_nrn, probability_parameter, mask=mask, edges=edges, ax=ax, 1096 prob_cmap=probability_cmap, mask_color=mask_color) 1097 1098 _draw_extent(ax, xctr, yctr, xext, yext) 1099 1100 else: 1101 # 3D layer 1102 from mpl_toolkits.mplot3d import Axes3D 1103 1104 if fig is None: 1105 fig = plt.figure() 1106 ax = fig.add_subplot(111, projection='3d') 1107 else: 1108 ax = fig.gca() 1109 1110 # get positions, reorganize to x,y,z vectors 1111 tgtpos = GetTargetPositions(src_nrn, tgt_layer, syn_type) 1112 if tgtpos: 1113 xpos, ypos, zpos = zip(*tgtpos[0]) 1114 ax.scatter3D(xpos, ypos, zpos, s=tgt_size, facecolor=tgt_color) 1115 1116 ax.scatter3D(srcpos[:1], srcpos[1:2], srcpos[2:], s=src_size, facecolor=src_color, alpha=0.4, zorder=-10) 1117 1118 plt.draw_if_interactive() 1119 1120 return fig 1121 1122 1123def _create_mask_patches(mask, periodic, extent, source_pos, face_color='yellow'): 1124 """Create Matplotlib Patch objects representing the mask""" 1125 1126 # import pyplot here and not at toplevel to avoid preventing users 1127 # from changing matplotlib backend after importing nest 1128 import matplotlib.pyplot as plt 1129 import matplotlib as mtpl 1130 1131 edge_color = 'black' 1132 alpha = 0.2 1133 line_width = 2 1134 mask_patches = [] 1135 1136 if 'anchor' in mask: 1137 offs = np.array(mask['anchor']) 1138 else: 1139 offs = np.array([0., 0.]) 1140 1141 if 'circular' in mask: 1142 r = mask['circular']['radius'] 1143 1144 patch = plt.Circle(source_pos + offs, radius=r, 1145 fc=face_color, ec=edge_color, alpha=alpha, lw=line_width) 1146 mask_patches.append(patch) 1147 1148 if periodic: 1149 for pos in _shifted_positions(source_pos + offs, extent): 1150 patch = plt.Circle(pos, radius=r, 1151 fc=face_color, ec=edge_color, alpha=alpha, lw=line_width) 1152 mask_patches.append(patch) 1153 elif 'doughnut' in mask: 1154 # Mmm... doughnut 1155 def make_doughnut_patch(pos, r_out, r_in, ec, fc, alpha): 1156 def make_circle(r): 1157 t = np.arange(0, np.pi * 2.0, 0.01) 1158 t = t.reshape((len(t), 1)) 1159 x = r * np.cos(t) 1160 y = r * np.sin(t) 1161 return np.hstack((x, y)) 1162 outside_verts = make_circle(r_out)[::-1] 1163 inside_verts = make_circle(r_in) 1164 codes = np.ones(len(inside_verts), dtype=mpath.Path.code_type) * mpath.Path.LINETO 1165 codes[0] = mpath.Path.MOVETO 1166 vertices = np.concatenate([outside_verts, inside_verts]) 1167 vertices += pos 1168 all_codes = np.concatenate((codes, codes)) 1169 path = mpath.Path(vertices, all_codes) 1170 return mpatches.PathPatch(path, fc=fc, ec=ec, alpha=alpha, lw=line_width) 1171 1172 r_in = mask['doughnut']['inner_radius'] 1173 r_out = mask['doughnut']['outer_radius'] 1174 pos = source_pos + offs 1175 patch = make_doughnut_patch(pos, r_in, r_out, edge_color, face_color, alpha) 1176 mask_patches.append(patch) 1177 if periodic: 1178 for pos in _shifted_positions(source_pos + offs, extent): 1179 patch = make_doughnut_patch(pos, r_in, r_out, edge_color, face_color, alpha) 1180 mask_patches.append(patch) 1181 elif 'rectangular' in mask: 1182 ll = np.array(mask['rectangular']['lower_left']) 1183 ur = np.array(mask['rectangular']['upper_right']) 1184 width = ur[0] - ll[0] 1185 height = ur[1] - ll[1] 1186 pos = source_pos + ll + offs 1187 cntr = [pos[0] + width/2, pos[1] + height/2] 1188 1189 if 'azimuth_angle' in mask['rectangular']: 1190 angle = mask['rectangular']['azimuth_angle'] 1191 else: 1192 angle = 0.0 1193 1194 patch = plt.Rectangle(pos, width, height, 1195 fc=face_color, ec=edge_color, alpha=alpha, lw=line_width) 1196 # Need to rotate about center 1197 trnsf = mtpl.transforms.Affine2D().rotate_deg_around(cntr[0], cntr[1], angle) + plt.gca().transData 1198 patch.set_transform(trnsf) 1199 mask_patches.append(patch) 1200 1201 if periodic: 1202 for pos in _shifted_positions(source_pos + ll + offs, extent): 1203 patch = plt.Rectangle(pos, width, height, 1204 fc=face_color, ec=edge_color, alpha=alpha, lw=line_width) 1205 1206 cntr = [pos[0] + width/2, pos[1] + height/2] 1207 # Need to rotate about center 1208 trnsf = mtpl.transforms.Affine2D().rotate_deg_around(cntr[0], cntr[1], angle) + plt.gca().transData 1209 patch.set_transform(trnsf) 1210 mask_patches.append(patch) 1211 elif 'elliptical' in mask: 1212 width = mask['elliptical']['major_axis'] 1213 height = mask['elliptical']['minor_axis'] 1214 if 'azimuth_angle' in mask['elliptical']: 1215 angle = mask['elliptical']['azimuth_angle'] 1216 else: 1217 angle = 0.0 1218 if 'anchor' in mask['elliptical']: 1219 anchor = mask['elliptical']['anchor'] 1220 else: 1221 anchor = np.array([0., 0.]) 1222 patch = mpl.patches.Ellipse(source_pos + offs + anchor, width, height, 1223 angle=angle, fc=face_color, 1224 ec=edge_color, alpha=alpha, lw=line_width) 1225 mask_patches.append(patch) 1226 1227 if periodic: 1228 for pos in _shifted_positions(source_pos + offs + anchor, extent): 1229 patch = mpl.patches.Ellipse(pos, width, height, angle=angle, fc=face_color, 1230 ec=edge_color, alpha=alpha, lw=line_width) 1231 mask_patches.append(patch) 1232 else: 1233 raise ValueError('Mask type cannot be plotted with this version of PyNEST.') 1234 return mask_patches 1235 1236 1237def PlotProbabilityParameter(source, parameter=None, mask=None, edges=[-0.5, 0.5, -0.5, 0.5], shape=[100, 100], 1238 ax=None, prob_cmap='Greens', mask_color='yellow'): 1239 """ 1240 Create a plot of the connection probability and/or mask. 1241 1242 A probability plot is created based on a `Parameter` and a `source`. The 1243 `Parameter` should have a distance dependency. The `source` must be given 1244 as a `NodeCollection` with a single node ID. Optionally a `mask` can also be 1245 plotted. 1246 1247 Parameters 1248 ---------- 1249 source : NodeCollection 1250 Single node ID `NodeCollection` to use as source. 1251 parameter : Parameter 1252 `Parameter` the probability is based on. 1253 mask : Dictionary 1254 Optional specification of a connection mask. Connections will only 1255 be made to nodes inside the mask. See :py:func:`.CreateMask` for options on 1256 how to specify the mask. 1257 edges : list/tuple 1258 List of four edges of the region to plot. The values are given as 1259 [x_min, x_max, y_min, y_max]. 1260 shape : list/tuple 1261 Number of `Parameter` values to calculate in each direction. 1262 ax : matplotlib.axes.AxesSubplot, 1263 A matplotlib axes instance to plot in. If none is given, 1264 a new one is created. 1265 """ 1266 1267 # import pyplot here and not at toplevel to avoid preventing users 1268 # from changing matplotlib backend after importing nest 1269 import matplotlib.pyplot as plt 1270 1271 if not HAVE_MPL: 1272 raise ImportError('Matplotlib could not be imported') 1273 1274 if parameter is None and mask is None: 1275 raise ValueError('At least one of parameter or mask must be specified') 1276 if ax is None: 1277 fig, ax = plt.subplots() 1278 ax.set_xlim(*edges[:2]) 1279 ax.set_ylim(*edges[2:]) 1280 1281 if parameter is not None: 1282 z = np.zeros(shape[::-1]) 1283 for i, x in enumerate(np.linspace(edges[0], edges[1], shape[0])): 1284 positions = [[x, y] for y in np.linspace(edges[2], edges[3], shape[1])] 1285 values = parameter.apply(source, positions) 1286 z[:, i] = np.array(values) 1287 img = ax.imshow(np.minimum(np.maximum(z, 0.0), 1.0), extent=edges, 1288 origin='lower', cmap=prob_cmap, vmin=0., vmax=1.) 1289 plt.colorbar(img, ax=ax, fraction=0.046, pad=0.04) 1290 1291 if mask is not None: 1292 periodic = source.spatial['edge_wrap'] 1293 extent = source.spatial['extent'] 1294 source_pos = GetPosition(source) 1295 patches = _create_mask_patches(mask, periodic, extent, source_pos, face_color=mask_color) 1296 for patch in patches: 1297 patch.set_zorder(0.5) 1298 ax.add_patch(patch) 1299