1# -*- coding: utf-8 -*-
2#
3# brunel_alpha_evolution_strategies.py
4#
5# This file is part of NEST.
6#
7# Copyright (C) 2004 The NEST Initiative
8#
9# NEST is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 2 of the License, or
12# (at your option) any later version.
13#
14# NEST is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with NEST.  If not, see <http://www.gnu.org/licenses/>.
21
22"""
23Use evolution strategies to find parameters for a random balanced network (alpha synapses)
24------------------------------------------------------------------------------------------
25
26This script uses an optimization algorithm to find the appropriate
27parameter values for the external drive "eta" and the relative ratio
28of excitation and inhibition "g" for a balanced random network that
29lead to particular population-averaged rates, coefficients of
30variation and correlations.
31
32From an initial Gaussian search distribution parameterized with mean
33and standard deviation network parameters are sampled. Network
34realizations of these parameters are simulated and evaluated according
35to an objective function that measures how close the activity
36statistics are to their desired values (~fitness). From these fitness
37values the approximate natural gradient of the fitness landscape is
38computed and used to update the parameters of the search
39distribution. This procedure is repeated until the maximal number of
40function evaluations is reached or the width of the search
41distribution becomes extremely small.  We use the following fitness
42function:
43
44.. math::
45
46    f = - alpha(r - r*)^2 - beta(cv - cv*)^2 - gamma(corr - corr*)^2
47
48where `alpha`, `beta` and `gamma` are weighting factors, and stars indicate
49target values.
50
51The network contains an excitatory and an inhibitory population on
52the basis of the network used in [1]_.
53
54The optimization algorithm (evolution strategies) is described in
55Wierstra et al. [2]_.
56
57
58References
59~~~~~~~~~~~~
60
61.. [1] Brunel N (2000). Dynamics of Sparsely Connected Networks of
62       Excitatory and Inhibitory Spiking Neurons. Journal of Computational
63       Neuroscience 8, 183-208.
64
65.. [2] Wierstra et al. (2014). Natural evolution strategies. Journal of
66       Machine Learning Research, 15(1), 949-980.
67
68See Also
69~~~~~~~~~~
70
71:doc:`brunel_alpha_nest`
72
73Authors
74~~~~~~~
75
76Jakob Jordan
77"""
78
79import matplotlib.pyplot as plt
80from matplotlib.patches import Ellipse
81import numpy as np
82import scipy.special as sp
83import nest
84
85###############################################################################
86# Analysis
87
88
89def cut_warmup_time(spikes, warmup_time):
90    # Removes initial warmup time from recorded spikes
91    spikes['senders'] = spikes['senders'][
92        spikes['times'] > warmup_time]
93    spikes['times'] = spikes['times'][
94        spikes['times'] > warmup_time]
95
96    return spikes
97
98
99def compute_rate(spikes, N_rec, sim_time):
100    # Computes average rate from recorded spikes
101    return (1. * len(spikes['times']) / N_rec / sim_time * 1e3)
102
103
104def sort_spikes(spikes):
105    # Sorts recorded spikes by node ID
106    unique_node_ids = sorted(np.unique(spikes['senders']))
107    spiketrains = []
108    for node_id in unique_node_ids:
109        spiketrains.append(spikes['times'][spikes['senders'] == node_id])
110    return unique_node_ids, spiketrains
111
112
113def compute_cv(spiketrains):
114    # Computes coefficient of variation from sorted spikes
115    if spiketrains:
116        isis = np.hstack([np.diff(st) for st in spiketrains])
117        if len(isis) > 1:
118            return np.std(isis) / np.mean(isis)
119        else:
120            return 0.
121    else:
122        return 0.
123
124
125def bin_spiketrains(spiketrains, t_min, t_max, t_bin):
126    # Bins sorted spikes
127    bins = np.arange(t_min, t_max, t_bin)
128    return bins, [np.histogram(s, bins=bins)[0] for s in spiketrains]
129
130
131def compute_correlations(binned_spiketrains):
132    # Computes correlations from binned spiketrains
133    n = len(binned_spiketrains)
134    if n > 1:
135        cc = np.corrcoef(binned_spiketrains)
136        return 1. / (n * (n - 1.)) * (np.sum(cc) - n)
137    else:
138        return 0.
139
140
141def compute_statistics(parameters, espikes, ispikes):
142    # Computes population-averaged rates coefficients of variation and
143    # correlations from recorded spikes of excitatory and inhibitory
144    # populations
145
146    espikes = cut_warmup_time(espikes, parameters['warmup_time'])
147    ispikes = cut_warmup_time(ispikes, parameters['warmup_time'])
148
149    erate = compute_rate(espikes, parameters['N_rec'], parameters['sim_time'])
150    irate = compute_rate(espikes, parameters['N_rec'], parameters['sim_time'])
151
152    enode_ids, espiketrains = sort_spikes(espikes)
153    inode_ids, ispiketrains = sort_spikes(ispikes)
154
155    ecv = compute_cv(espiketrains)
156    icv = compute_cv(ispiketrains)
157
158    ecorr = compute_correlations(
159        bin_spiketrains(espiketrains, 0., parameters['sim_time'], 1.)[1])
160    icorr = compute_correlations(
161        bin_spiketrains(ispiketrains, 0., parameters['sim_time'], 1.)[1])
162
163    return (np.mean([erate, irate]),
164            np.mean([ecv, icv]),
165            np.mean([ecorr, icorr]))
166
167
168###############################################################################
169# Network simulation
170
171
172def simulate(parameters):
173    # Simulates the network and returns recorded spikes for excitatory
174    # and inhibitory population
175
176    # Code taken from brunel_alpha_nest.py
177
178    def LambertWm1(x):
179        # Using scipy to mimic the gsl_sf_lambert_Wm1 function.
180        return sp.lambertw(x, k=-1 if x < 0 else 0).real
181
182    def ComputePSPnorm(tauMem, CMem, tauSyn):
183        a = (tauMem / tauSyn)
184        b = (1.0 / tauSyn - 1.0 / tauMem)
185
186        # time of maximum
187        t_max = 1.0 / b * (-LambertWm1(-np.exp(-1.0 / a) / a) - 1.0 / a)
188
189        # maximum of PSP for current of unit amplitude
190        return (np.exp(1.0) / (tauSyn * CMem * b) *
191                ((np.exp(-t_max / tauMem) - np.exp(-t_max / tauSyn)) / b -
192                 t_max * np.exp(-t_max / tauSyn)))
193
194    # number of excitatory neurons
195    NE = int(parameters['gamma'] * parameters['N'])
196    # number of inhibitory neurons
197    NI = parameters['N'] - NE
198
199    # number of excitatory synapses per neuron
200    CE = int(parameters['epsilon'] * NE)
201    # number of inhibitory synapses per neuron
202    CI = int(parameters['epsilon'] * NI)
203
204    tauSyn = 0.5  # synaptic time constant in ms
205    tauMem = 20.0  # time constant of membrane potential in ms
206    CMem = 250.0  # capacitance of membrane in in pF
207    theta = 20.0  # membrane threshold potential in mV
208    neuron_parameters = {
209        'C_m': CMem,
210        'tau_m': tauMem,
211        'tau_syn_ex': tauSyn,
212        'tau_syn_in': tauSyn,
213        't_ref': 2.0,
214        'E_L': 0.0,
215        'V_reset': 0.0,
216        'V_m': 0.0,
217        'V_th': theta
218    }
219    J = 0.1        # postsynaptic amplitude in mV
220    J_unit = ComputePSPnorm(tauMem, CMem, tauSyn)
221    J_ex = J / J_unit  # amplitude of excitatory postsynaptic current
222    # amplitude of inhibitory postsynaptic current
223    J_in = -parameters['g'] * J_ex
224
225    nu_th = (theta * CMem) / (J_ex * CE * np.exp(1) * tauMem * tauSyn)
226    nu_ex = parameters['eta'] * nu_th
227    p_rate = 1000.0 * nu_ex * CE
228
229    nest.ResetKernel()
230    nest.set_verbosity('M_FATAL')
231
232    nest.rng_seed = parameters['seed']
233    nest.resolution = parameters['dt']
234
235    nodes_ex = nest.Create('iaf_psc_alpha', NE, params=neuron_parameters)
236    nodes_in = nest.Create('iaf_psc_alpha', NI, params=neuron_parameters)
237    noise = nest.Create('poisson_generator', params={'rate': p_rate})
238    espikes = nest.Create('spike_recorder', params={'label': 'brunel-py-ex'})
239    ispikes = nest.Create('spike_recorder', params={'label': 'brunel-py-in'})
240
241    nest.CopyModel('static_synapse', 'excitatory',
242                   {'weight': J_ex, 'delay': parameters['delay']})
243    nest.CopyModel('static_synapse', 'inhibitory',
244                   {'weight': J_in, 'delay': parameters['delay']})
245
246    nest.Connect(noise, nodes_ex, syn_spec='excitatory')
247    nest.Connect(noise, nodes_in, syn_spec='excitatory')
248
249    if parameters['N_rec'] > NE:
250        raise ValueError(
251            f'Requested recording from {parameters["N_rec"]} neurons, but only {NE} in excitatory population')
252    if parameters['N_rec'] > NI:
253        raise ValueError(
254            f'Requested recording from {parameters["N_rec"]} neurons, but only {NI} in inhibitory population')
255    nest.Connect(nodes_ex[:parameters['N_rec']], espikes)
256    nest.Connect(nodes_in[:parameters['N_rec']], ispikes)
257
258    conn_parameters_ex = {'rule': 'fixed_indegree', 'indegree': CE}
259    nest.Connect(nodes_ex, nodes_ex + nodes_in, conn_parameters_ex, 'excitatory')
260
261    conn_parameters_in = {'rule': 'fixed_indegree', 'indegree': CI}
262    nest.Connect(nodes_in, nodes_ex + nodes_in, conn_parameters_in, 'inhibitory')
263
264    nest.Simulate(parameters['sim_time'])
265
266    return (espikes.events,
267            ispikes.events)
268
269
270###############################################################################
271# Optimization
272
273
274def default_population_size(dimensions):
275    # Returns a population size suited for the given number of dimensions
276    # See Wierstra et al. (2014)
277
278    return 4 + int(np.floor(3 * np.log(dimensions)))
279
280
281def default_learning_rate_mu():
282    # Returns a default learning rate for the mean of the search distribution
283    # See Wierstra et al. (2014)
284
285    return 1
286
287
288def default_learning_rate_sigma(dimensions):
289    # Returns a default learning rate for the standard deviation of the
290    # search distribution for the given number of dimensions
291    # See Wierstra et al. (2014)
292
293    return (3 + np.log(dimensions)) / (12. * np.sqrt(dimensions))
294
295
296def compute_utility(fitness):
297    # Computes utility and order used for fitness shaping
298    # See Wierstra et al. (2014)
299
300    n = len(fitness)
301    order = np.argsort(fitness)[::-1]
302    fitness = fitness[order]
303
304    utility = [
305        np.max([0, np.log((n / 2) + 1)]) - np.log(k + 1) for k in range(n)]
306    utility = utility / np.sum(utility) - 1. / n
307
308    return order, utility
309
310
311def optimize(func, mu, sigma, learning_rate_mu=None, learning_rate_sigma=None,
312             population_size=None, fitness_shaping=True,
313             mirrored_sampling=True, record_history=False,
314             max_generations=2000, min_sigma=1e-8, verbosity=0):
315
316    ###########################################################################
317    # Optimizes an objective function via evolution strategies using the
318    # natural gradient of multinormal search distributions in natural
319    # coordinates.  Does not consider covariances between parameters (
320    # "Separable natural evolution strategies").
321    # See Wierstra et al. (2014)
322    #
323    # Parameters
324    # ----------
325    # func: function
326    #     The function to be maximized.
327    # mu: float
328    #     Initial mean of the search distribution.
329    # sigma: float
330    #     Initial standard deviation of the search distribution.
331    # learning_rate_mu: float
332    #     Learning rate of mu.
333    # learning_rate_sigma: float
334    #     Learning rate of sigma.
335    # population_size: int
336    #     Number of individuals sampled in each generation.
337    # fitness_shaping: bool
338    #     Whether to use fitness shaping, compensating for large
339    #     deviations in fitness, see Wierstra et al. (2014).
340    # mirrored_sampling: bool
341    #     Whether to use mirrored sampling, i.e., evaluating a mirrored
342    #     sample for each sample, see Wierstra et al. (2014).
343    # record_history: bool
344    #     Whether to record history of search distribution parameters,
345    #     fitness values and individuals.
346    # max_generations: int
347    #     Maximal number of generations.
348    # min_sigma: float
349    #     Minimal value for standard deviation of search
350    #     distribution. If any dimension has a value smaller than this,
351    #     the search is stopped.
352    # verbosity: bool
353    #     Whether to continuously print progress information.
354    #
355    # Returns
356    # -------
357    # dict
358    #     Dictionary of final parameters of search distribution and
359    #     history.
360
361    if not isinstance(mu, np.ndarray):
362        raise TypeError('mu needs to be of type np.ndarray')
363    if not isinstance(sigma, np.ndarray):
364        raise TypeError('sigma needs to be of type np.ndarray')
365
366    if learning_rate_mu is None:
367        learning_rate_mu = default_learning_rate_mu()
368    if learning_rate_sigma is None:
369        learning_rate_sigma = default_learning_rate_sigma(mu.size)
370    if population_size is None:
371        population_size = default_population_size(mu.size)
372
373    generation = 0
374    mu_history = []
375    sigma_history = []
376    pop_history = []
377    fitness_history = []
378
379    while True:
380
381        # create new population using the search distribution
382        s = np.random.normal(0, 1, size=(population_size,) + np.shape(mu))
383        z = mu + sigma * s
384
385        # add mirrored perturbations if enabled
386        if mirrored_sampling:
387            z = np.vstack([z, mu - sigma * s])
388            s = np.vstack([s, -s])
389
390        # evaluate fitness for every individual in population
391        fitness = np.fromiter((func(*zi) for zi in z), np.float)
392
393        # print status if enabled
394        if verbosity > 0:
395            print(
396                f'# Generation {generation:d} | fitness {np.mean(fitness):.3f} | '
397                f'mu {", ".join(str(np.round(mu_i, 3)) for mu_i in mu)} | '
398                f'sigma {", ".join(str(np.round(sigma_i, 3)) for sigma_i in sigma)}'
399            )
400
401        # apply fitness shaping if enabled
402        if fitness_shaping:
403            order, utility = compute_utility(fitness)
404            s = s[order]
405            z = z[order]
406        else:
407            utility = fitness
408
409        # bookkeeping
410        if record_history:
411            mu_history.append(mu.copy())
412            sigma_history.append(sigma.copy())
413            pop_history.append(z.copy())
414            fitness_history.append(fitness)
415
416        # exit if max generations reached or search distributions are
417        # very narrow
418        if generation == max_generations or np.all(sigma < min_sigma):
419            break
420
421        # update parameter of search distribution via natural gradient
422        # descent in natural coordinates
423        mu += learning_rate_mu * sigma * np.dot(utility, s)
424        sigma *= np.exp(learning_rate_sigma / 2. * np.dot(utility, s**2 - 1))
425
426        generation += 1
427
428    return {
429        'mu': mu,
430        'sigma': sigma,
431        'fitness_history': np.array(fitness_history),
432        'mu_history': np.array(mu_history),
433        'sigma_history': np.array(sigma_history),
434        'pop_history': np.array(pop_history)
435    }
436
437
438def optimize_network(optimization_parameters, simulation_parameters):
439    # Searches for suitable network parameters to fulfill defined constraints
440
441    np.random.seed(simulation_parameters['seed'])
442
443    def objective_function(g, eta):
444        # Returns the fitness of a specific network parametrization
445
446        # create local copy of parameters that uses parameters given
447        # by optimization algorithm
448        simulation_parameters_local = simulation_parameters.copy()
449        simulation_parameters_local['g'] = g
450        simulation_parameters_local['eta'] = eta
451
452        # perform the network simulation
453        espikes, ispikes = simulate(simulation_parameters_local)
454
455        # analyse the result and compute fitness
456        rate, cv, corr = compute_statistics(
457            simulation_parameters, espikes, ispikes)
458        fitness = (
459            -optimization_parameters['fitness_weight_rate'] * (rate - optimization_parameters['target_rate'])**2 -
460            optimization_parameters['fitness_weight_cv'] * (cv - optimization_parameters['target_cv'])**2 -
461            optimization_parameters['fitness_weight_corr'] * (corr - optimization_parameters['target_corr'])**2
462        )
463
464        return fitness
465
466    return optimize(
467        objective_function,
468        np.array(optimization_parameters['mu']),
469        np.array(optimization_parameters['sigma']),
470        max_generations=optimization_parameters['max_generations'],
471        record_history=True,
472        verbosity=optimization_parameters['verbosity']
473    )
474
475###############################################################################
476# Main
477
478
479if __name__ == '__main__':
480    simulation_parameters = {
481        'seed': 123,
482        'dt': 0.1,            # (ms) simulation resolution
483        'sim_time': 1000.,    # (ms) simulation duration
484        'warmup_time': 300.,  # (ms) duration ignored during analysis
485        'delay': 1.5,         # (ms) synaptic delay
486        'g': None,            # relative ratio of excitation and inhibition
487        'eta': None,          # relative strength of external drive
488        'epsilon': 0.1,       # average connectivity of network
489        'N': 400,             # number of neurons in network
490        'gamma': 0.8,         # relative size of excitatory and
491                              # inhibitory population
492        'N_rec': 40,          # number of neurons to record activity from
493    }
494
495    optimization_parameters = {
496        'verbosity': 1,             # print progress over generations
497        'max_generations': 20,      # maximal number of generations
498        'target_rate': 1.89,        # (spikes/s) target rate
499        'target_corr': 0.0,         # target correlation
500        'target_cv': 1.,            # target coefficient of variation
501        'mu': [1., 3.],             # initial mean for search distribution
502                                    # (mu(g), mu(eta))
503        'sigma': [0.15, 0.05],      # initial sigma for search
504                                    # distribution (sigma(g), sigma(eta))
505
506        # hyperparameters of the fitness function; these are used to
507        # compensate for the different typical scales of the
508        # individual measures, rate ~ O(1), cv ~ (0.1), corr ~ O(0.01)
509        'fitness_weight_rate': 1.,    # relative weight of rate deviation
510        'fitness_weight_cv': 10.,     # relative weight of cv deviation
511        'fitness_weight_corr': 100.,  # relative weight of corr deviation
512    }
513
514    # optimize network parameters
515    optimization_result = optimize_network(optimization_parameters,
516                                           simulation_parameters)
517
518    simulation_parameters['g'] = optimization_result['mu'][0]
519    simulation_parameters['eta'] = optimization_result['mu'][1]
520
521    espikes, ispikes = simulate(simulation_parameters)
522
523    rate, cv, corr = compute_statistics(
524        simulation_parameters, espikes, ispikes)
525    print('Statistics after optimization:', end=' ')
526    print('Rate: {:.3f}, cv: {:.3f}, correlation: {:.3f}'.format(
527        rate, cv, corr))
528
529    # plot results
530    fig = plt.figure(figsize=(10, 4))
531    ax1 = fig.add_axes([0.06, 0.12, 0.25, 0.8])
532    ax2 = fig.add_axes([0.4, 0.12, 0.25, 0.8])
533    ax3 = fig.add_axes([0.74, 0.12, 0.25, 0.8])
534
535    ax1.set_xlabel('Time (ms)')
536    ax1.set_ylabel('Neuron id')
537
538    ax2.set_xlabel(r'Relative strength of inhibition $g$')
539    ax2.set_ylabel(r'Relative strength of external drive $\eta$')
540
541    ax3.set_xlabel('Generation')
542    ax3.set_ylabel('Fitness')
543
544    # raster plot
545    ax1.plot(espikes['times'], espikes['senders'], ls='', marker='.')
546
547    # search distributions and individuals
548    for mu, sigma in zip(optimization_result['mu_history'],
549                         optimization_result['sigma_history']):
550        ellipse = Ellipse(
551            xy=mu, width=2 * sigma[0], height=2 * sigma[1], alpha=0.5, fc='k')
552        ellipse.set_clip_box(ax2.bbox)
553        ax2.add_artist(ellipse)
554    ax2.plot(optimization_result['mu_history'][:, 0],
555             optimization_result['mu_history'][:, 1],
556             marker='.', color='k', alpha=0.5)
557    for generation in optimization_result['pop_history']:
558        ax2.scatter(generation[:, 0], generation[:, 1])
559
560    # fitness over generations
561    ax3.errorbar(np.arange(len(optimization_result['fitness_history'])),
562                 np.mean(optimization_result['fitness_history'], axis=1),
563                 yerr=np.std(optimization_result['fitness_history'], axis=1))
564
565    fig.savefig('brunel_alpha_evolution_strategies.pdf')
566