1# -*- coding: utf-8 -*- 2# 3# store_restore_network.py 4# 5# This file is part of NEST. 6# 7# Copyright (C) 2004 The NEST Initiative 8# 9# NEST is free software: you can redistribute it and/or modify 10# it under the terms of the GNU General Public License as published by 11# the Free Software Foundation, either version 2 of the License, or 12# (at your option) any later version. 13# 14# NEST is distributed in the hope that it will be useful, 15# but WITHOUT ANY WARRANTY; without even the implied warranty of 16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17# GNU General Public License for more details. 18# 19# You should have received a copy of the GNU General Public License 20# along with NEST. If not, see <http://www.gnu.org/licenses/>. 21 22""" 23Store and restore a network simulation 24-------------------------------------- 25 26This example shows how to store user-specified aspects of a network 27to file and how to later restore the network for further simulation. 28This may be used, e.g., to train weights in a network up to a certain 29point, store those weights and later perform diverse experiments on 30the same network using the stored weights. 31 32.. admonition:: Only user-specified aspects are stored 33 34 NEST does not support storing the complete state of a simulation 35 in a way that would allow one to continue a simulation as if one had 36 made a new ``Simulate()`` call on an existing network. Such complete 37 checkpointing would be very difficult to implement. 38 39 NEST's explicit approach to storing and restoring network state makes 40 clear to all which aspects of a network are carried from one simulation 41 to another and thus contributes to good scientific practice. 42 43 Storing and restoring is currently not supported for MPI-parallel simulations. 44 45""" 46 47############################################################################### 48# Import necessary modules. 49 50import nest 51import pickle 52 53############################################################################### 54# These modules are only needed for illustrative plotting. 55 56import matplotlib.pyplot as plt 57from matplotlib import gridspec 58import numpy as np 59import pandas as pd 60import textwrap 61 62############################################################################### 63# Implement network as class. 64# 65# Implementing the network as a class makes network properties available to 66# the initial network builder, the storer and the restorer, thus reducing the 67# amount of data that needs to be stored. 68 69 70class EINetwork: 71 """ 72 A simple balanced random network with plastic excitatory synapses. 73 74 This simple Brunel-style balanced random network has an excitatory 75 and inhibitory population, both driven by external excitatory poisson 76 input. Excitatory connections are plastic (STDP). Spike activity of 77 the excitatory population is recorded. 78 79 The model is provided as a non-trivial example for storing and restoring. 80 """ 81 82 def __init__(self): 83 self.nI = 500 84 self.nE = 4 * self.nI 85 self.n = self.nE + self.nI 86 87 self.JE = 1.0 88 self.JI = -4 * self.JE 89 self.indeg_e = 200 90 self.indeg_i = 50 91 92 self.neuron_model = "iaf_psc_delta" 93 94 # Create synapse models so we can extract specific connection information 95 nest.CopyModel("stdp_synapse_hom", "e_syn", {"Wmax": 2 * self.JE}) 96 nest.CopyModel("static_synapse", "i_syn") 97 98 self.nrn_params = {"V_m": nest.random.normal(-65., 5.)} 99 self.poisson_rate = 800. 100 101 def build(self): 102 """ 103 Construct network from scratch, including instrumentation. 104 """ 105 106 self.e_neurons = nest.Create(self.neuron_model, n=self.nE, params=self.nrn_params) 107 self.i_neurons = nest.Create(self.neuron_model, n=self.nI, params=self.nrn_params) 108 self.neurons = self.e_neurons + self.i_neurons 109 110 self.pg = nest.Create("poisson_generator", {"rate": self.poisson_rate}) 111 self.sr = nest.Create("spike_recorder") 112 113 nest.Connect(self.e_neurons, self.neurons, 114 {"rule": "fixed_indegree", "indegree": self.indeg_e}, 115 {"synapse_model": "e_syn", "weight": self.JE}) 116 nest.Connect(self.i_neurons, self.neurons, 117 {"rule": "fixed_indegree", "indegree": self.indeg_i}, 118 {"synapse_model": "i_syn", "weight": self.JI}) 119 nest.Connect(self.pg, self.neurons, "all_to_all", {"weight": self.JE}) 120 nest.Connect(self.e_neurons, self.sr) 121 122 def store(self, dump_filename): 123 """ 124 Store neuron membrane potential and synaptic weights to given file. 125 """ 126 127 assert nest.NumProcesses() == 1, "Cannot dump MPI parallel" 128 129 ############################################################################### 130 # Build dictionary with relevant network information: 131 # - membrane potential for all neurons in each population 132 # - source, target and weight of all connections 133 # Dictionary entries are Pandas Dataframes. 134 # 135 # Strictly speaking, we would not need to store the weight of the inhibitory 136 # synapses since they are fixed, but we do so out of symmetry and to make it 137 # easier to add plasticity for inhibitory connections later. 138 139 network = {} 140 network["n_vp"] = nest.total_num_virtual_procs 141 network["e_nrns"] = self.neurons.get(["V_m"], output="pandas") 142 network["i_nrns"] = self.neurons.get(["V_m"], output="pandas") 143 144 network["e_syns"] = nest.GetConnections(synapse_model="e_syn").get( 145 ("source", "target", "weight"), output="pandas") 146 network["i_syns"] = nest.GetConnections(synapse_model="i_syn").get( 147 ("source", "target", "weight"), output="pandas") 148 149 with open(dump_filename, "wb") as f: 150 pickle.dump(network, f, pickle.HIGHEST_PROTOCOL) 151 152 def restore(self, dump_filename): 153 """ 154 Restore network from data in file combined with base information in the class. 155 """ 156 157 assert nest.NumProcesses() == 1, "Cannot load MPI parallel" 158 159 with open(dump_filename, "rb") as f: 160 network = pickle.load(f) 161 162 assert network["n_vp"] == nest.total_num_virtual_procs, "N_VP must match" 163 164 ############################################################################### 165 # Reconstruct neurons 166 # Since NEST does not understand Pandas Series, we must pass the values as 167 # NumPy arrays 168 self.e_neurons = nest.Create(self.neuron_model, n=self.nE, 169 params={"V_m": network["e_nrns"].V_m.values}) 170 self.i_neurons = nest.Create(self.neuron_model, n=self.nI, 171 params={"V_m": network["i_nrns"].V_m.values}) 172 self.neurons = self.e_neurons + self.i_neurons 173 174 ############################################################################### 175 # Reconstruct instrumentation 176 self.pg = nest.Create("poisson_generator", {"rate": self.poisson_rate}) 177 self.sr = nest.Create("spike_recorder") 178 179 ############################################################################### 180 # Reconstruct connectivity 181 nest.Connect(network["e_syns"].source.values, network["e_syns"].target.values, 182 "one_to_one", 183 {"synapse_model": "e_syn", "weight": network["e_syns"].weight.values}) 184 185 nest.Connect(network["i_syns"].source.values, network["i_syns"].target.values, 186 "one_to_one", 187 {"synapse_model": "i_syn", "weight": network["i_syns"].weight.values}) 188 189 ############################################################################### 190 # Reconnect instruments 191 nest.Connect(self.pg, self.neurons, "all_to_all", {"weight": self.JE}) 192 nest.Connect(self.e_neurons, self.sr) 193 194 195class DemoPlot: 196 """ 197 Create demonstration figure for effect of storing and restoring a network. 198 199 The figure shows raster plots for five different runs, a PSTH for the 200 initial 1 s simulation and PSTHs for all 1 s continuations, and weight 201 histograms. 202 """ 203 204 def __init__(self): 205 self._colors = [c["color"] for c in plt.rcParams["axes.prop_cycle"]] 206 self._next_line = 0 207 208 plt.rcParams.update({'font.size': 10}) 209 self.fig = plt.figure(figsize=(10, 7), constrained_layout=False) 210 211 gs = gridspec.GridSpec(4, 2, bottom=0.08, top=0.9, left=0.07, right=0.98, wspace=0.2, hspace=0.4) 212 self.rasters = ([self.fig.add_subplot(gs[0, 0])] + 213 [self.fig.add_subplot(gs[n, 1]) for n in range(4)]) 214 self.weights = self.fig.add_subplot(gs[1, 0]) 215 self.comment = self.fig.add_subplot(gs[2:, 0]) 216 217 self.fig.suptitle("Storing and reloading a network simulation") 218 self.comment.set_axis_off() 219 self.comment.text(0, 1, textwrap.dedent(""" 220 Storing, loading and continuing a simulation of a balanced E-I network 221 with STDP in excitatory synapses. 222 223 Top left: Raster plot of initial simulation for 1000ms (blue). Network state 224 (connections, membrane potential, synaptic weights) is stored at the end of 225 the initial simulation. 226 227 Top right: Immediate continuation of the initial simulation from t=1000ms 228 to t=2000ms (orange) by calling Simulate(1000) again after storing the network. 229 This continues based on the full network state, including spikes in transit. 230 231 Second row, right: Simulating for 1000ms after loading the stored network 232 into a clean kernel (green). Time runs from 0ms and only connectivity, V_m and 233 synaptic weights are restored. Dynamics differ somewhat from continuation. 234 235 Third row, right: Same as in second row with identical random seed (red), 236 resulting in identical spike patterns. 237 238 Fourth row, right: Simulating for 1000ms from same stored network state as 239 above but with different random seed yields different spike patterns (purple). 240 241 Above: Distribution of excitatory synaptic weights at end of each sample 242 simulation. Green and red curves are identical and overlay to form brown curve."""), 243 transform=self.comment.transAxes, fontsize=8, 244 verticalalignment='top') 245 246 def add_to_plot(self, net, n_max=100, t_min=0, t_max=1000, lbl=""): 247 spks = pd.DataFrame.from_dict(net.sr.get("events")) 248 spks = spks.loc[(spks.senders < n_max) & (t_min < spks.times) & (spks.times < t_max)] 249 250 self.rasters[self._next_line].plot(spks.times, spks.senders, ".", 251 color=self._colors[self._next_line]) 252 self.rasters[self._next_line].set_xlim(t_min, t_max) 253 self.rasters[self._next_line].set_title(lbl) 254 if 1 < self._next_line < 4: 255 self.rasters[self._next_line].set_xticklabels([]) 256 elif self._next_line == 4: 257 self.rasters[self._next_line].set_xlabel('Time [ms]') 258 259 # To save time while plotting, we extract only a subset of connections. 260 # For simplicity, we just use a prime-number stepping. 261 w = nest.GetConnections(source=net.e_neurons[::41], synapse_model="e_syn").weight 262 wbins = np.arange(0.7, 1.4, 0.01) 263 self.weights.hist(w, bins=wbins, 264 histtype="step", density=True, label=lbl, 265 color=self._colors[self._next_line], 266 alpha=0.7, lw=3) 267 268 if self._next_line == 0: 269 self.rasters[0].set_ylabel("neuron id") 270 self.weights.set_ylabel("p(w)") 271 self.weights.set_xlabel("Weight w [mV]") 272 273 plt.draw() 274 plt.pause(1e-3) # allow figure window to draw figure 275 276 self._next_line += 1 277 278 279if __name__ == "__main__": 280 281 plt.ion() 282 283 T_sim = 1000 284 285 dplot = DemoPlot() 286 287 ############################################################################### 288 # Ensure clean slate and make NEST less chatty 289 nest.set_verbosity("M_WARNING") 290 nest.ResetKernel() 291 292 ############################################################################### 293 # Create network from scratch and simulate 1s. 294 nest.local_num_threads = 4 295 nest.print_time = True 296 ein = EINetwork() 297 print("*** Initial simulation ***") 298 ein.build() 299 nest.Simulate(T_sim) 300 dplot.add_to_plot(ein, lbl="Initial simulation") 301 302 ############################################################################### 303 # Store network state to file with state after 1s. 304 print("\n*** Storing simulation ...", end="", flush=True) 305 ein.store("ein_1000.pkl") 306 print(" done ***\n") 307 308 ############################################################################### 309 # Continue simulation by another 1s. 310 print("\n*** Continuing simulation ***") 311 nest.Simulate(T_sim) 312 dplot.add_to_plot(ein, lbl="Continued simulation", t_min=T_sim, t_max=2 * T_sim) 313 314 ############################################################################### 315 # Clear kernel, restore network from file and simulate for 1s. 316 print("\n*** Reloading and resuming simulation ***") 317 nest.ResetKernel() 318 nest.local_num_threads = 4 319 ein2 = EINetwork() 320 ein2.restore("ein_1000.pkl") 321 nest.Simulate(T_sim) 322 dplot.add_to_plot(ein2, lbl="Reloaded simulation") 323 324 ############################################################################### 325 # Repeat previous step. This shall result in *exactly* the same results as 326 # the previous run because we use the same random seed. 327 print("\n*** Reloading and resuming simulation (same seed) ***") 328 nest.ResetKernel() 329 nest.local_num_threads = 4 330 ein2 = EINetwork() 331 ein2.restore("ein_1000.pkl") 332 nest.Simulate(T_sim) 333 dplot.add_to_plot(ein2, lbl="Reloaded simulation (same seed)") 334 335 ############################################################################### 336 # Clear, restore and simulate again, but now with different random seed. 337 # Details in results shall differ from previous run. 338 print("\n*** Reloading and resuming simulation (different seed) ***") 339 nest.ResetKernel() 340 nest.local_num_threads = 4 341 nest.rng_seed = 987654321 342 ein2 = EINetwork() 343 ein2.restore("ein_1000.pkl") 344 nest.Simulate(T_sim) 345 dplot.add_to_plot(ein2, lbl="Reloaded simulation (different seed)") 346 347 dplot.fig.savefig("store_restore_network.png") 348 349 input("Press ENTER to close figure!") 350