1# -*- coding: utf-8 -*-
2#
3# store_restore_network.py
4#
5# This file is part of NEST.
6#
7# Copyright (C) 2004 The NEST Initiative
8#
9# NEST is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 2 of the License, or
12# (at your option) any later version.
13#
14# NEST is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with NEST.  If not, see <http://www.gnu.org/licenses/>.
21
22"""
23Store and restore a network simulation
24--------------------------------------
25
26This example shows how to store user-specified aspects of a network
27to file and how to later restore the network for further simulation.
28This may be used, e.g., to train weights in a network up to a certain
29point, store those weights and later perform diverse experiments on
30the same network using the stored weights.
31
32.. admonition:: Only user-specified aspects are stored
33
34   NEST does not support storing the complete state of a simulation
35   in a way that would allow one to continue a simulation as if one had
36   made a new ``Simulate()`` call on an existing network. Such complete
37   checkpointing would be very difficult to implement.
38
39   NEST's explicit approach to storing and restoring network state makes
40   clear to all which aspects of a network are carried from one simulation
41   to another and thus contributes to good scientific practice.
42
43   Storing and restoring is currently not supported for MPI-parallel simulations.
44
45"""
46
47###############################################################################
48# Import necessary modules.
49
50import nest
51import pickle
52
53###############################################################################
54# These modules are only needed for illustrative plotting.
55
56import matplotlib.pyplot as plt
57from matplotlib import gridspec
58import numpy as np
59import pandas as pd
60import textwrap
61
62###############################################################################
63# Implement network as class.
64#
65# Implementing the network as a class makes network properties available to
66# the initial network builder, the storer and the restorer, thus reducing the
67# amount of data that needs to be stored.
68
69
70class EINetwork:
71    """
72    A simple balanced random network with plastic excitatory synapses.
73
74    This simple Brunel-style balanced random network has an excitatory
75    and inhibitory population, both driven by external excitatory poisson
76    input. Excitatory connections are plastic (STDP). Spike activity of
77    the excitatory population is recorded.
78
79    The model is provided as a non-trivial example for storing and restoring.
80    """
81
82    def __init__(self):
83        self.nI = 500
84        self.nE = 4 * self.nI
85        self.n = self.nE + self.nI
86
87        self.JE = 1.0
88        self.JI = -4 * self.JE
89        self.indeg_e = 200
90        self.indeg_i = 50
91
92        self.neuron_model = "iaf_psc_delta"
93
94        # Create synapse models so we can extract specific connection information
95        nest.CopyModel("stdp_synapse_hom", "e_syn", {"Wmax": 2 * self.JE})
96        nest.CopyModel("static_synapse", "i_syn")
97
98        self.nrn_params = {"V_m": nest.random.normal(-65., 5.)}
99        self.poisson_rate = 800.
100
101    def build(self):
102        """
103        Construct network from scratch, including instrumentation.
104        """
105
106        self.e_neurons = nest.Create(self.neuron_model, n=self.nE, params=self.nrn_params)
107        self.i_neurons = nest.Create(self.neuron_model, n=self.nI, params=self.nrn_params)
108        self.neurons = self.e_neurons + self.i_neurons
109
110        self.pg = nest.Create("poisson_generator", {"rate": self.poisson_rate})
111        self.sr = nest.Create("spike_recorder")
112
113        nest.Connect(self.e_neurons, self.neurons,
114                     {"rule": "fixed_indegree", "indegree": self.indeg_e},
115                     {"synapse_model": "e_syn", "weight": self.JE})
116        nest.Connect(self.i_neurons, self.neurons,
117                     {"rule": "fixed_indegree", "indegree": self.indeg_i},
118                     {"synapse_model": "i_syn", "weight": self.JI})
119        nest.Connect(self.pg, self.neurons, "all_to_all", {"weight": self.JE})
120        nest.Connect(self.e_neurons, self.sr)
121
122    def store(self, dump_filename):
123        """
124        Store neuron membrane potential and synaptic weights to given file.
125        """
126
127        assert nest.NumProcesses() == 1, "Cannot dump MPI parallel"
128
129        ###############################################################################
130        # Build dictionary with relevant network information:
131        #   - membrane potential for all neurons in each population
132        #   - source, target and weight of all connections
133        # Dictionary entries are Pandas Dataframes.
134        #
135        # Strictly speaking, we would not need to store the weight of the inhibitory
136        # synapses since they are fixed, but we do so out of symmetry and to make it
137        # easier to add plasticity for inhibitory connections later.
138
139        network = {}
140        network["n_vp"] = nest.total_num_virtual_procs
141        network["e_nrns"] = self.neurons.get(["V_m"], output="pandas")
142        network["i_nrns"] = self.neurons.get(["V_m"], output="pandas")
143
144        network["e_syns"] = nest.GetConnections(synapse_model="e_syn").get(
145            ("source", "target", "weight"), output="pandas")
146        network["i_syns"] = nest.GetConnections(synapse_model="i_syn").get(
147            ("source", "target", "weight"), output="pandas")
148
149        with open(dump_filename, "wb") as f:
150            pickle.dump(network, f, pickle.HIGHEST_PROTOCOL)
151
152    def restore(self, dump_filename):
153        """
154        Restore network from data in file combined with base information in the class.
155        """
156
157        assert nest.NumProcesses() == 1, "Cannot load MPI parallel"
158
159        with open(dump_filename, "rb") as f:
160            network = pickle.load(f)
161
162        assert network["n_vp"] == nest.total_num_virtual_procs, "N_VP must match"
163
164        ###############################################################################
165        # Reconstruct neurons
166        # Since NEST does not understand Pandas Series, we must pass the values as
167        # NumPy arrays
168        self.e_neurons = nest.Create(self.neuron_model, n=self.nE,
169                                     params={"V_m": network["e_nrns"].V_m.values})
170        self.i_neurons = nest.Create(self.neuron_model, n=self.nI,
171                                     params={"V_m": network["i_nrns"].V_m.values})
172        self.neurons = self.e_neurons + self.i_neurons
173
174        ###############################################################################
175        # Reconstruct instrumentation
176        self.pg = nest.Create("poisson_generator", {"rate": self.poisson_rate})
177        self.sr = nest.Create("spike_recorder")
178
179        ###############################################################################
180        # Reconstruct connectivity
181        nest.Connect(network["e_syns"].source.values, network["e_syns"].target.values,
182                     "one_to_one",
183                     {"synapse_model": "e_syn", "weight": network["e_syns"].weight.values})
184
185        nest.Connect(network["i_syns"].source.values, network["i_syns"].target.values,
186                     "one_to_one",
187                     {"synapse_model": "i_syn", "weight": network["i_syns"].weight.values})
188
189        ###############################################################################
190        # Reconnect instruments
191        nest.Connect(self.pg, self.neurons, "all_to_all", {"weight": self.JE})
192        nest.Connect(self.e_neurons, self.sr)
193
194
195class DemoPlot:
196    """
197    Create demonstration figure for effect of storing and restoring a network.
198
199    The figure shows raster plots for five different runs, a PSTH for the
200    initial 1 s simulation and PSTHs for all 1 s continuations, and weight
201    histograms.
202    """
203
204    def __init__(self):
205        self._colors = [c["color"] for c in plt.rcParams["axes.prop_cycle"]]
206        self._next_line = 0
207
208        plt.rcParams.update({'font.size': 10})
209        self.fig = plt.figure(figsize=(10, 7), constrained_layout=False)
210
211        gs = gridspec.GridSpec(4, 2, bottom=0.08, top=0.9, left=0.07, right=0.98, wspace=0.2, hspace=0.4)
212        self.rasters = ([self.fig.add_subplot(gs[0, 0])] +
213                        [self.fig.add_subplot(gs[n, 1]) for n in range(4)])
214        self.weights = self.fig.add_subplot(gs[1, 0])
215        self.comment = self.fig.add_subplot(gs[2:, 0])
216
217        self.fig.suptitle("Storing and reloading a network simulation")
218        self.comment.set_axis_off()
219        self.comment.text(0, 1, textwrap.dedent("""
220            Storing, loading and continuing a simulation of a balanced E-I network
221            with STDP in excitatory synapses.
222
223            Top left: Raster plot of initial simulation for 1000ms (blue). Network state
224            (connections, membrane potential, synaptic weights) is stored at the end of
225            the initial simulation.
226
227            Top right: Immediate continuation of the initial simulation from t=1000ms
228            to t=2000ms (orange) by calling Simulate(1000) again after storing the network.
229            This continues based on the full network state, including spikes in transit.
230
231            Second row, right: Simulating for 1000ms after loading the stored network
232            into a clean kernel (green). Time runs from 0ms and only connectivity, V_m and
233            synaptic weights are restored. Dynamics differ somewhat from continuation.
234
235            Third row, right: Same as in second row with identical random seed (red),
236            resulting in identical spike patterns.
237
238            Fourth row, right: Simulating for 1000ms from same stored network state as
239            above but with different random seed yields different spike patterns (purple).
240
241            Above: Distribution of excitatory synaptic weights at end of each sample
242            simulation. Green and red curves are identical and overlay to form brown curve."""),
243                          transform=self.comment.transAxes, fontsize=8,
244                          verticalalignment='top')
245
246    def add_to_plot(self, net, n_max=100, t_min=0, t_max=1000, lbl=""):
247        spks = pd.DataFrame.from_dict(net.sr.get("events"))
248        spks = spks.loc[(spks.senders < n_max) & (t_min < spks.times) & (spks.times < t_max)]
249
250        self.rasters[self._next_line].plot(spks.times, spks.senders, ".",
251                                           color=self._colors[self._next_line])
252        self.rasters[self._next_line].set_xlim(t_min, t_max)
253        self.rasters[self._next_line].set_title(lbl)
254        if 1 < self._next_line < 4:
255            self.rasters[self._next_line].set_xticklabels([])
256        elif self._next_line == 4:
257            self.rasters[self._next_line].set_xlabel('Time [ms]')
258
259        # To save time while plotting, we extract only a subset of connections.
260        # For simplicity, we just use a prime-number stepping.
261        w = nest.GetConnections(source=net.e_neurons[::41], synapse_model="e_syn").weight
262        wbins = np.arange(0.7, 1.4, 0.01)
263        self.weights.hist(w, bins=wbins,
264                          histtype="step", density=True, label=lbl,
265                          color=self._colors[self._next_line],
266                          alpha=0.7, lw=3)
267
268        if self._next_line == 0:
269            self.rasters[0].set_ylabel("neuron id")
270            self.weights.set_ylabel("p(w)")
271            self.weights.set_xlabel("Weight w [mV]")
272
273        plt.draw()
274        plt.pause(1e-3)  # allow figure window to draw figure
275
276        self._next_line += 1
277
278
279if __name__ == "__main__":
280
281    plt.ion()
282
283    T_sim = 1000
284
285    dplot = DemoPlot()
286
287    ###############################################################################
288    # Ensure clean slate and make NEST less chatty
289    nest.set_verbosity("M_WARNING")
290    nest.ResetKernel()
291
292    ###############################################################################
293    # Create network from scratch and simulate 1s.
294    nest.local_num_threads = 4
295    nest.print_time = True
296    ein = EINetwork()
297    print("*** Initial simulation ***")
298    ein.build()
299    nest.Simulate(T_sim)
300    dplot.add_to_plot(ein, lbl="Initial simulation")
301
302    ###############################################################################
303    # Store network state to file with state after 1s.
304    print("\n*** Storing simulation ...", end="", flush=True)
305    ein.store("ein_1000.pkl")
306    print(" done ***\n")
307
308    ###############################################################################
309    # Continue simulation by another 1s.
310    print("\n*** Continuing simulation ***")
311    nest.Simulate(T_sim)
312    dplot.add_to_plot(ein, lbl="Continued simulation", t_min=T_sim, t_max=2 * T_sim)
313
314    ###############################################################################
315    # Clear kernel, restore network from file and simulate for 1s.
316    print("\n*** Reloading and resuming simulation ***")
317    nest.ResetKernel()
318    nest.local_num_threads = 4
319    ein2 = EINetwork()
320    ein2.restore("ein_1000.pkl")
321    nest.Simulate(T_sim)
322    dplot.add_to_plot(ein2, lbl="Reloaded simulation")
323
324    ###############################################################################
325    # Repeat previous step. This shall result in *exactly* the same results as
326    # the previous run because we use the same random seed.
327    print("\n*** Reloading and resuming simulation (same seed) ***")
328    nest.ResetKernel()
329    nest.local_num_threads = 4
330    ein2 = EINetwork()
331    ein2.restore("ein_1000.pkl")
332    nest.Simulate(T_sim)
333    dplot.add_to_plot(ein2, lbl="Reloaded simulation (same seed)")
334
335    ###############################################################################
336    # Clear, restore and simulate again, but now with different random seed.
337    # Details in results shall differ from previous run.
338    print("\n*** Reloading and resuming simulation (different seed) ***")
339    nest.ResetKernel()
340    nest.local_num_threads = 4
341    nest.rng_seed = 987654321
342    ein2 = EINetwork()
343    ein2.restore("ein_1000.pkl")
344    nest.Simulate(T_sim)
345    dplot.add_to_plot(ein2, lbl="Reloaded simulation (different seed)")
346
347    dplot.fig.savefig("store_restore_network.png")
348
349    input("Press ENTER to close figure!")
350