1 /*
2 * source_table.cpp
3 *
4 * This file is part of NEST.
5 *
6 * Copyright (C) 2004 The NEST Initiative
7 *
8 * NEST is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 2 of the License, or
11 * (at your option) any later version.
12 *
13 * NEST is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with NEST. If not, see <http://www.gnu.org/licenses/>.
20 *
21 */
22
23 // C++ includes:
24 #include <iostream>
25
26 // Includes from nestkernel:
27 #include "connection_manager.h"
28 #include "connection_manager_impl.h"
29 #include "kernel_manager.h"
30 #include "mpi_manager_impl.h"
31 #include "source_table.h"
32 #include "vp_manager_impl.h"
33
SourceTable()34 nest::SourceTable::SourceTable()
35 {
36 }
37
~SourceTable()38 nest::SourceTable::~SourceTable()
39 {
40 }
41
42 void
initialize()43 nest::SourceTable::initialize()
44 {
45 assert( sizeof( Source ) == 8 );
46 const thread num_threads = kernel().vp_manager.get_num_threads();
47 sources_.resize( num_threads );
48 is_cleared_.initialize( num_threads, false );
49 saved_entry_point_.initialize( num_threads, false );
50 current_positions_.resize( num_threads );
51 saved_positions_.resize( num_threads );
52 compressible_sources_.resize( num_threads );
53 compressed_spike_data_map_.resize( num_threads );
54
55 #pragma omp parallel
56 {
57 const thread tid = kernel().vp_manager.get_thread_id();
58 sources_[ tid ].resize( 0 );
59 resize_sources( tid );
60 compressible_sources_[ tid ].resize( 0 );
61 compressed_spike_data_map_[ tid ].resize( 0 );
62 } // of omp parallel
63 }
64
65 void
finalize()66 nest::SourceTable::finalize()
67 {
68 for ( thread tid = 0; tid < static_cast< thread >( sources_.size() ); ++tid )
69 {
70 if ( is_cleared_[ tid ].is_false() )
71 {
72 clear( tid );
73 compressible_sources_[ tid ].clear();
74 compressed_spike_data_map_[ tid ].clear();
75 }
76 }
77
78 sources_.clear();
79 current_positions_.clear();
80 saved_positions_.clear();
81 compressible_sources_.clear();
82 compressed_spike_data_map_.clear();
83 }
84
85 bool
is_cleared() const86 nest::SourceTable::is_cleared() const
87 {
88 return is_cleared_.all_true();
89 }
90
91 std::vector< BlockVector< nest::Source > >&
get_thread_local_sources(const thread tid)92 nest::SourceTable::get_thread_local_sources( const thread tid )
93 {
94 return sources_[ tid ];
95 }
96
97 nest::SourceTablePosition
find_maximal_position() const98 nest::SourceTable::find_maximal_position() const
99 {
100 SourceTablePosition max_position( -1, -1, -1 );
101 for ( thread tid = 0; tid < kernel().vp_manager.get_num_threads(); ++tid )
102 {
103 if ( max_position < saved_positions_[ tid ] )
104 {
105 max_position = saved_positions_[ tid ];
106 }
107 }
108 return max_position;
109 }
110
111 void
clean(const thread tid)112 nest::SourceTable::clean( const thread tid )
113 {
114 // Find maximal position in source table among threads to make sure
115 // unprocessed entries are not removed. Given this maximal position,
116 // we can safely delete all larger entries since they will not be
117 // touched any more.
118 const SourceTablePosition max_position = find_maximal_position();
119
120 // If this thread corresponds to max_position's thread, we can only
121 // delete part of the sources table, with indices larger than those
122 // in max_position; if this thread is larger than max_positions's
123 // thread, we can delete all sources; otherwise we do nothing.
124 if ( max_position.tid == tid )
125 {
126 for ( synindex syn_id = max_position.syn_id; syn_id < sources_[ tid ].size(); ++syn_id )
127 {
128 BlockVector< Source >& sources = sources_[ tid ][ syn_id ];
129 if ( max_position.syn_id == syn_id )
130 {
131 // we need to add 2 to max_position.lcid since
132 // max_position.lcid + 1 can contain a valid entry which we
133 // do not want to delete.
134 if ( max_position.lcid + 2 < static_cast< long >( sources.size() ) )
135 {
136 sources.erase( sources.begin() + max_position.lcid + 2, sources.end() );
137 }
138 }
139 else
140 {
141 assert( max_position.syn_id < syn_id );
142 sources.clear();
143 }
144 }
145 }
146 else if ( max_position.tid < tid )
147 {
148 sources_[ tid ].clear();
149 }
150 else
151 {
152 // do nothing
153 assert( tid < max_position.tid );
154 }
155 }
156
157 nest::index
get_node_id(const thread tid,const synindex syn_id,const index lcid) const158 nest::SourceTable::get_node_id( const thread tid, const synindex syn_id, const index lcid ) const
159 {
160 if ( not kernel().connection_manager.get_keep_source_table() )
161 {
162 throw KernelException( "Cannot use SourceTable::get_node_id when get_keep_source_table is false" );
163 }
164 return sources_[ tid ][ syn_id ][ lcid ].get_node_id();
165 }
166
167 nest::index
remove_disabled_sources(const thread tid,const synindex syn_id)168 nest::SourceTable::remove_disabled_sources( const thread tid, const synindex syn_id )
169 {
170 if ( sources_[ tid ].size() <= syn_id )
171 {
172 return invalid_index;
173 }
174
175 BlockVector< Source >& mysources = sources_[ tid ][ syn_id ];
176 const index max_size = mysources.size();
177 if ( max_size == 0 )
178 {
179 return invalid_index;
180 }
181
182 // lcid needs to be signed, to allow lcid >= 0 check in while loop
183 // to fail; afterwards we can be certain that it is non-negative and
184 // we can static_cast it to index
185 long lcid = max_size - 1;
186 while ( lcid >= 0 and mysources[ lcid ].is_disabled() )
187 {
188 --lcid;
189 }
190 ++lcid; // lcid marks first disabled source, but the while loop only
191 // exits if lcid points at a not disabled element, hence we
192 // need to increase it by one again
193 mysources.erase( mysources.begin() + lcid, mysources.end() );
194 if ( static_cast< index >( lcid ) == max_size )
195 {
196 return invalid_index;
197 }
198 return static_cast< index >( lcid );
199 }
200
201 void
compute_buffer_pos_for_unique_secondary_sources(const thread tid,std::map<index,size_t> & buffer_pos_of_source_node_id_syn_id)202 nest::SourceTable::compute_buffer_pos_for_unique_secondary_sources( const thread tid,
203 std::map< index, size_t >& buffer_pos_of_source_node_id_syn_id )
204 {
205 // set of unique sources & synapse types, required to determine
206 // secondary events MPI buffer positions
207 // initialized and deleted by thread 0 in this method
208 static std::set< std::pair< index, size_t > >* unique_secondary_source_node_id_syn_id;
209 #pragma omp single
210 {
211 unique_secondary_source_node_id_syn_id = new std::set< std::pair< index, size_t > >();
212 }
213
214 // collect all unique pairs of source node ID and synapse-type id
215 // corresponding to continuous-data connections on this MPI rank;
216 // using a set makes sure secondary events are not duplicated for
217 // targets on the same process, but different threads
218 for ( size_t syn_id = 0; syn_id < sources_[ tid ].size(); ++syn_id )
219 {
220 if ( not kernel().model_manager.get_synapse_prototype( syn_id, tid ).is_primary() )
221 {
222 for ( BlockVector< Source >::const_iterator source_cit = sources_[ tid ][ syn_id ].begin();
223 source_cit != sources_[ tid ][ syn_id ].end();
224 ++source_cit )
225 {
226 #pragma omp critical
227 {
228 ( *unique_secondary_source_node_id_syn_id ).insert( std::make_pair( source_cit->get_node_id(), syn_id ) );
229 }
230 }
231 }
232 }
233 #pragma omp barrier
234
235 #pragma omp single
236 {
237 // compute receive buffer positions for all unique pairs of source
238 // node ID and synapse-type id on this MPI rank
239 std::vector< int > recv_counts_secondary_events_in_int_per_rank( kernel().mpi_manager.get_num_processes(), 0 );
240
241 for (
242 std::set< std::pair< index, size_t > >::const_iterator cit = ( *unique_secondary_source_node_id_syn_id ).begin();
243 cit != ( *unique_secondary_source_node_id_syn_id ).end();
244 ++cit )
245 {
246 const thread source_rank = kernel().mpi_manager.get_process_id_of_node_id( cit->first );
247 const size_t event_size = kernel().model_manager.get_secondary_event_prototype( cit->second, tid ).size();
248
249 buffer_pos_of_source_node_id_syn_id.insert(
250 std::make_pair( pack_source_node_id_and_syn_id( cit->first, cit->second ),
251 recv_counts_secondary_events_in_int_per_rank[ source_rank ] ) );
252
253 recv_counts_secondary_events_in_int_per_rank[ source_rank ] += event_size;
254 }
255
256 // each chunk needs to contain one additional int that can be used
257 // to communicate whether waveform relaxation has converged
258 for ( auto& recv_count : recv_counts_secondary_events_in_int_per_rank )
259 {
260 ++recv_count;
261 }
262
263 kernel().mpi_manager.set_recv_counts_secondary_events_in_int_per_rank(
264 recv_counts_secondary_events_in_int_per_rank );
265 delete unique_secondary_source_node_id_syn_id;
266 } // of omp single
267 }
268
269 void
resize_sources(const thread tid)270 nest::SourceTable::resize_sources( const thread tid )
271 {
272 sources_[ tid ].resize( kernel().model_manager.get_num_synapse_prototypes() );
273 }
274
275 bool
source_should_be_processed_(const thread rank_start,const thread rank_end,const Source & source) const276 nest::SourceTable::source_should_be_processed_( const thread rank_start,
277 const thread rank_end,
278 const Source& source ) const
279 {
280 const thread source_rank = kernel().mpi_manager.get_process_id_of_node_id( source.get_node_id() );
281
282 return not( source.is_processed() or source.is_disabled()
283 // is this thread responsible for this part of the MPI
284 // buffer?
285 or source_rank < rank_start
286 or rank_end <= source_rank );
287 }
288
289 bool
next_entry_has_same_source_(const SourceTablePosition & current_position,const Source & current_source) const290 nest::SourceTable::next_entry_has_same_source_( const SourceTablePosition& current_position,
291 const Source& current_source ) const
292 {
293 assert( not current_position.is_invalid() );
294
295 const auto& local_sources = sources_[ current_position.tid ][ current_position.syn_id ];
296 const size_t next_lcid = current_position.lcid + 1;
297
298 return (
299 next_lcid < local_sources.size() and local_sources[ next_lcid ].get_node_id() == current_source.get_node_id() );
300 }
301
302 bool
previous_entry_has_same_source_(const SourceTablePosition & current_position,const Source & current_source) const303 nest::SourceTable::previous_entry_has_same_source_( const SourceTablePosition& current_position,
304 const Source& current_source ) const
305 {
306 assert( not current_position.is_invalid() );
307
308 const auto& local_sources = sources_[ current_position.tid ][ current_position.syn_id ];
309 const long previous_lcid = current_position.lcid - 1; // needs to be a signed type such that negative
310 // values can signal invalid indices
311
312 return ( previous_lcid >= 0 and not local_sources[ previous_lcid ].is_processed()
313 and local_sources[ previous_lcid ].get_node_id() == current_source.get_node_id() );
314 }
315
316 bool
populate_target_data_fields_(const SourceTablePosition & current_position,const Source & current_source,const thread source_rank,TargetData & next_target_data) const317 nest::SourceTable::populate_target_data_fields_( const SourceTablePosition& current_position,
318 const Source& current_source,
319 const thread source_rank,
320 TargetData& next_target_data ) const
321 {
322 const auto node_id = current_source.get_node_id();
323
324 // set values of next_target_data
325 next_target_data.set_source_lid( kernel().vp_manager.node_id_to_lid( node_id ) );
326 next_target_data.set_source_tid( kernel().vp_manager.vp_to_thread( kernel().vp_manager.node_id_to_vp( node_id ) ) );
327 next_target_data.reset_marker();
328
329 if ( current_source.is_primary() ) // primary connection, i.e., chemical synapses
330 {
331 next_target_data.set_is_primary( true );
332
333 TargetDataFields& target_fields = next_target_data.target_data;
334 target_fields.set_syn_id( current_position.syn_id );
335 if ( kernel().connection_manager.use_compressed_spikes() )
336 {
337 // WARNING: we set the tid field here to zero just to make sure
338 // it has a defined value; however, this value is _not_ used
339 // anywhere when using compressed spikes
340 target_fields.set_tid( 0 );
341 auto it_idx = compressed_spike_data_map_.at( current_position.tid )
342 .at( current_position.syn_id )
343 .find( current_source.get_node_id() );
344 if ( it_idx != compressed_spike_data_map_.at( current_position.tid ).at( current_position.syn_id ).end() )
345 {
346 // WARNING: no matter how tempting, do not try to remove this
347 // entry from the compressed_spike_data_map_; if the MPI buffer
348 // is already full, this entry will need to be communicated the
349 // next MPI comm round, which, naturally, is not possible if it
350 // has been removed
351 target_fields.set_lcid( it_idx->second );
352 }
353 else // another thread is responsible for communicating this compressed source
354 {
355 return false;
356 }
357 }
358 else
359 {
360 // we store the thread index of the source table, not our own tid!
361 target_fields.set_tid( current_position.tid );
362 target_fields.set_lcid( current_position.lcid );
363 }
364 }
365 else // secondary connection, e.g., gap junctions
366 {
367 next_target_data.set_is_primary( false );
368
369 // the source rank will write to the buffer position relative to
370 // the first position from the absolute position in the receive
371 // buffer
372 const size_t relative_recv_buffer_pos = kernel().connection_manager.get_secondary_recv_buffer_position(
373 current_position.tid, current_position.syn_id, current_position.lcid )
374 - kernel().mpi_manager.get_recv_displacement_secondary_events_in_int( source_rank );
375
376 SecondaryTargetDataFields& secondary_fields = next_target_data.secondary_data;
377 secondary_fields.set_recv_buffer_pos( relative_recv_buffer_pos );
378 secondary_fields.set_syn_id( current_position.syn_id );
379 }
380
381 return true;
382 }
383
384 bool
get_next_target_data(const thread tid,const thread rank_start,const thread rank_end,thread & source_rank,TargetData & next_target_data)385 nest::SourceTable::get_next_target_data( const thread tid,
386 const thread rank_start,
387 const thread rank_end,
388 thread& source_rank,
389 TargetData& next_target_data )
390 {
391 SourceTablePosition& current_position = current_positions_[ tid ];
392
393 if ( current_position.is_invalid() )
394 {
395 return false; // nothing to do here
396 }
397
398 // we stay in this loop either until we can return a valid
399 // TargetData object or we have reached the end of the sources table
400 while ( true )
401 {
402 current_position.seek_to_next_valid_index( sources_ );
403 if ( current_position.is_invalid() )
404 {
405 return false; // reached the end of the sources table
406 }
407
408 // the current position contains an entry, so we retrieve it
409 Source& current_source = sources_[ current_position.tid ][ current_position.syn_id ][ current_position.lcid ];
410
411 if ( not source_should_be_processed_( rank_start, rank_end, current_source ) )
412 {
413 current_position.decrease();
414 continue;
415 }
416
417 // we need to set a marker stating whether the entry following this
418 // entry, if existent, has the same source
419 kernel().connection_manager.set_source_has_more_targets( current_position.tid,
420 current_position.syn_id,
421 current_position.lcid,
422 next_entry_has_same_source_( current_position, current_source ) );
423
424 // no need to communicate this entry if the previous entry has the same source
425 if ( previous_entry_has_same_source_( current_position, current_source ) )
426 {
427 current_source.set_processed( true ); // no need to look at this entry again
428 current_position.decrease();
429 continue;
430 }
431
432 // reaching this means we found an entry that should be
433 // communicated via MPI, so we prepare to return the relevant data
434
435 // set the source rank
436 source_rank = kernel().mpi_manager.get_process_id_of_node_id( current_source.get_node_id() );
437
438 if ( not populate_target_data_fields_( current_position, current_source, source_rank, next_target_data ) )
439 {
440 current_position.decrease();
441 continue;
442 }
443
444 // we are about to return a valid entry, so mark it as processed
445 current_source.set_processed( true );
446
447 current_position.decrease();
448 return true; // found a valid entry
449 }
450 }
451
452 void
resize_compressible_sources()453 nest::SourceTable::resize_compressible_sources()
454 {
455 for ( thread tid = 0; tid < static_cast< thread >( compressible_sources_.size() ); ++tid )
456 {
457 compressible_sources_[ tid ].clear();
458 compressible_sources_[ tid ].resize(
459 kernel().model_manager.get_num_synapse_prototypes(), std::map< index, SpikeData >() );
460 }
461 }
462
463 void
collect_compressible_sources(const thread tid)464 nest::SourceTable::collect_compressible_sources( const thread tid )
465 {
466 for ( synindex syn_id = 0; syn_id < sources_[ tid ].size(); ++syn_id )
467 {
468 index lcid = 0;
469 auto& syn_sources = sources_[ tid ][ syn_id ];
470 while ( lcid < syn_sources.size() )
471 {
472 const index old_source_node_id = syn_sources[ lcid ].get_node_id();
473 const std::pair< index, SpikeData > source_node_id_to_spike_data =
474 std::make_pair( old_source_node_id, SpikeData( tid, syn_id, lcid, 0 ) );
475 compressible_sources_[ tid ][ syn_id ].insert( source_node_id_to_spike_data );
476
477 // find next source with different node_id (assumes sorted sources)
478 ++lcid;
479 while ( ( lcid < syn_sources.size() ) and ( syn_sources[ lcid ].get_node_id() == old_source_node_id ) )
480 {
481 ++lcid;
482 }
483 }
484 }
485 }
486
487 void
fill_compressed_spike_data(std::vector<std::vector<std::vector<SpikeData>>> & compressed_spike_data)488 nest::SourceTable::fill_compressed_spike_data(
489 std::vector< std::vector< std::vector< SpikeData > > >& compressed_spike_data )
490 {
491 compressed_spike_data.clear();
492 compressed_spike_data.resize( kernel().model_manager.get_num_synapse_prototypes() );
493
494 for ( thread tid = 0; tid < static_cast< thread >( compressible_sources_.size() ); ++tid )
495 {
496 compressed_spike_data_map_[ tid ].clear();
497 compressed_spike_data_map_[ tid ].resize(
498 kernel().model_manager.get_num_synapse_prototypes(), std::map< index, size_t >() );
499 }
500
501 // pseudo-random thread selector to balance memory usage across
502 // threads of compressed_spike_data_map_
503 size_t thread_idx = 0;
504
505 // for each local thread and each synapse type we will populate this
506 // vector with spike data containing information about all process
507 // local targets
508 std::vector< SpikeData > spike_data;
509
510 for ( thread tid = 0; tid < static_cast< thread >( compressible_sources_.size() ); ++tid )
511 {
512 for ( synindex syn_id = 0; syn_id < compressible_sources_[ tid ].size(); ++syn_id )
513 {
514 for ( auto it = compressible_sources_[ tid ][ syn_id ].begin();
515 it != compressible_sources_[ tid ][ syn_id ].end(); )
516 {
517 spike_data.clear();
518
519 // add target position on this thread
520 spike_data.push_back( it->second );
521
522 // add target positions on all other threads
523 for ( thread other_tid = tid + 1; other_tid < static_cast< thread >( compressible_sources_.size() );
524 ++other_tid )
525 {
526 auto other_it = compressible_sources_[ other_tid ][ syn_id ].find( it->first );
527 if ( other_it != compressible_sources_[ other_tid ][ syn_id ].end() )
528 {
529 spike_data.push_back( other_it->second );
530 compressible_sources_[ other_tid ][ syn_id ].erase( other_it );
531 }
532 }
533
534 // WARNING: store source-node-id -> process-global-synapse
535 // association in compressed_spike_data_map on a
536 // pseudo-randomly selected thread which houses targets for
537 // this source; this tries to balance memory usage of this
538 // data structure across threads
539 const thread responsible_tid = spike_data[ thread_idx % spike_data.size() ].get_tid();
540 ++thread_idx;
541
542 compressed_spike_data_map_[ responsible_tid ][ syn_id ].insert(
543 std::make_pair( it->first, compressed_spike_data[ syn_id ].size() ) );
544 compressed_spike_data[ syn_id ].push_back( spike_data );
545
546 it = compressible_sources_[ tid ][ syn_id ].erase( it );
547 }
548 compressible_sources_[ tid ][ syn_id ].clear();
549 }
550 }
551 }
552