1 /*
2  *  source_table.cpp
3  *
4  *  This file is part of NEST.
5  *
6  *  Copyright (C) 2004 The NEST Initiative
7  *
8  *  NEST is free software: you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation, either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  NEST is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with NEST.  If not, see <http://www.gnu.org/licenses/>.
20  *
21  */
22 
23 // C++ includes:
24 #include <iostream>
25 
26 // Includes from nestkernel:
27 #include "connection_manager.h"
28 #include "connection_manager_impl.h"
29 #include "kernel_manager.h"
30 #include "mpi_manager_impl.h"
31 #include "source_table.h"
32 #include "vp_manager_impl.h"
33 
SourceTable()34 nest::SourceTable::SourceTable()
35 {
36 }
37 
~SourceTable()38 nest::SourceTable::~SourceTable()
39 {
40 }
41 
42 void
initialize()43 nest::SourceTable::initialize()
44 {
45   assert( sizeof( Source ) == 8 );
46   const thread num_threads = kernel().vp_manager.get_num_threads();
47   sources_.resize( num_threads );
48   is_cleared_.initialize( num_threads, false );
49   saved_entry_point_.initialize( num_threads, false );
50   current_positions_.resize( num_threads );
51   saved_positions_.resize( num_threads );
52   compressible_sources_.resize( num_threads );
53   compressed_spike_data_map_.resize( num_threads );
54 
55 #pragma omp parallel
56   {
57     const thread tid = kernel().vp_manager.get_thread_id();
58     sources_[ tid ].resize( 0 );
59     resize_sources( tid );
60     compressible_sources_[ tid ].resize( 0 );
61     compressed_spike_data_map_[ tid ].resize( 0 );
62   } // of omp parallel
63 }
64 
65 void
finalize()66 nest::SourceTable::finalize()
67 {
68   for ( thread tid = 0; tid < static_cast< thread >( sources_.size() ); ++tid )
69   {
70     if ( is_cleared_[ tid ].is_false() )
71     {
72       clear( tid );
73       compressible_sources_[ tid ].clear();
74       compressed_spike_data_map_[ tid ].clear();
75     }
76   }
77 
78   sources_.clear();
79   current_positions_.clear();
80   saved_positions_.clear();
81   compressible_sources_.clear();
82   compressed_spike_data_map_.clear();
83 }
84 
85 bool
is_cleared() const86 nest::SourceTable::is_cleared() const
87 {
88   return is_cleared_.all_true();
89 }
90 
91 std::vector< BlockVector< nest::Source > >&
get_thread_local_sources(const thread tid)92 nest::SourceTable::get_thread_local_sources( const thread tid )
93 {
94   return sources_[ tid ];
95 }
96 
97 nest::SourceTablePosition
find_maximal_position() const98 nest::SourceTable::find_maximal_position() const
99 {
100   SourceTablePosition max_position( -1, -1, -1 );
101   for ( thread tid = 0; tid < kernel().vp_manager.get_num_threads(); ++tid )
102   {
103     if ( max_position < saved_positions_[ tid ] )
104     {
105       max_position = saved_positions_[ tid ];
106     }
107   }
108   return max_position;
109 }
110 
111 void
clean(const thread tid)112 nest::SourceTable::clean( const thread tid )
113 {
114   // Find maximal position in source table among threads to make sure
115   // unprocessed entries are not removed. Given this maximal position,
116   // we can safely delete all larger entries since they will not be
117   // touched any more.
118   const SourceTablePosition max_position = find_maximal_position();
119 
120   // If this thread corresponds to max_position's thread, we can only
121   // delete part of the sources table, with indices larger than those
122   // in max_position; if this thread is larger than max_positions's
123   // thread, we can delete all sources; otherwise we do nothing.
124   if ( max_position.tid == tid )
125   {
126     for ( synindex syn_id = max_position.syn_id; syn_id < sources_[ tid ].size(); ++syn_id )
127     {
128       BlockVector< Source >& sources = sources_[ tid ][ syn_id ];
129       if ( max_position.syn_id == syn_id )
130       {
131         // we need to add 2 to max_position.lcid since
132         // max_position.lcid + 1 can contain a valid entry which we
133         // do not want to delete.
134         if ( max_position.lcid + 2 < static_cast< long >( sources.size() ) )
135         {
136           sources.erase( sources.begin() + max_position.lcid + 2, sources.end() );
137         }
138       }
139       else
140       {
141         assert( max_position.syn_id < syn_id );
142         sources.clear();
143       }
144     }
145   }
146   else if ( max_position.tid < tid )
147   {
148     sources_[ tid ].clear();
149   }
150   else
151   {
152     // do nothing
153     assert( tid < max_position.tid );
154   }
155 }
156 
157 nest::index
get_node_id(const thread tid,const synindex syn_id,const index lcid) const158 nest::SourceTable::get_node_id( const thread tid, const synindex syn_id, const index lcid ) const
159 {
160   if ( not kernel().connection_manager.get_keep_source_table() )
161   {
162     throw KernelException( "Cannot use SourceTable::get_node_id when get_keep_source_table is false" );
163   }
164   return sources_[ tid ][ syn_id ][ lcid ].get_node_id();
165 }
166 
167 nest::index
remove_disabled_sources(const thread tid,const synindex syn_id)168 nest::SourceTable::remove_disabled_sources( const thread tid, const synindex syn_id )
169 {
170   if ( sources_[ tid ].size() <= syn_id )
171   {
172     return invalid_index;
173   }
174 
175   BlockVector< Source >& mysources = sources_[ tid ][ syn_id ];
176   const index max_size = mysources.size();
177   if ( max_size == 0 )
178   {
179     return invalid_index;
180   }
181 
182   // lcid needs to be signed, to allow lcid >= 0 check in while loop
183   // to fail; afterwards we can be certain that it is non-negative and
184   // we can static_cast it to index
185   long lcid = max_size - 1;
186   while ( lcid >= 0 and mysources[ lcid ].is_disabled() )
187   {
188     --lcid;
189   }
190   ++lcid; // lcid marks first disabled source, but the while loop only
191           // exits if lcid points at a not disabled element, hence we
192           // need to increase it by one again
193   mysources.erase( mysources.begin() + lcid, mysources.end() );
194   if ( static_cast< index >( lcid ) == max_size )
195   {
196     return invalid_index;
197   }
198   return static_cast< index >( lcid );
199 }
200 
201 void
compute_buffer_pos_for_unique_secondary_sources(const thread tid,std::map<index,size_t> & buffer_pos_of_source_node_id_syn_id)202 nest::SourceTable::compute_buffer_pos_for_unique_secondary_sources( const thread tid,
203   std::map< index, size_t >& buffer_pos_of_source_node_id_syn_id )
204 {
205   // set of unique sources & synapse types, required to determine
206   // secondary events MPI buffer positions
207   // initialized and deleted by thread 0 in this method
208   static std::set< std::pair< index, size_t > >* unique_secondary_source_node_id_syn_id;
209 #pragma omp single
210   {
211     unique_secondary_source_node_id_syn_id = new std::set< std::pair< index, size_t > >();
212   }
213 
214   // collect all unique pairs of source node ID and synapse-type id
215   // corresponding to continuous-data connections on this MPI rank;
216   // using a set makes sure secondary events are not duplicated for
217   // targets on the same process, but different threads
218   for ( size_t syn_id = 0; syn_id < sources_[ tid ].size(); ++syn_id )
219   {
220     if ( not kernel().model_manager.get_synapse_prototype( syn_id, tid ).is_primary() )
221     {
222       for ( BlockVector< Source >::const_iterator source_cit = sources_[ tid ][ syn_id ].begin();
223             source_cit != sources_[ tid ][ syn_id ].end();
224             ++source_cit )
225       {
226 #pragma omp critical
227         {
228           ( *unique_secondary_source_node_id_syn_id ).insert( std::make_pair( source_cit->get_node_id(), syn_id ) );
229         }
230       }
231     }
232   }
233 #pragma omp barrier
234 
235 #pragma omp single
236   {
237     // compute receive buffer positions for all unique pairs of source
238     // node ID and synapse-type id on this MPI rank
239     std::vector< int > recv_counts_secondary_events_in_int_per_rank( kernel().mpi_manager.get_num_processes(), 0 );
240 
241     for (
242       std::set< std::pair< index, size_t > >::const_iterator cit = ( *unique_secondary_source_node_id_syn_id ).begin();
243       cit != ( *unique_secondary_source_node_id_syn_id ).end();
244       ++cit )
245     {
246       const thread source_rank = kernel().mpi_manager.get_process_id_of_node_id( cit->first );
247       const size_t event_size = kernel().model_manager.get_secondary_event_prototype( cit->second, tid ).size();
248 
249       buffer_pos_of_source_node_id_syn_id.insert(
250         std::make_pair( pack_source_node_id_and_syn_id( cit->first, cit->second ),
251           recv_counts_secondary_events_in_int_per_rank[ source_rank ] ) );
252 
253       recv_counts_secondary_events_in_int_per_rank[ source_rank ] += event_size;
254     }
255 
256     // each chunk needs to contain one additional int that can be used
257     // to communicate whether waveform relaxation has converged
258     for ( auto& recv_count : recv_counts_secondary_events_in_int_per_rank )
259     {
260       ++recv_count;
261     }
262 
263     kernel().mpi_manager.set_recv_counts_secondary_events_in_int_per_rank(
264       recv_counts_secondary_events_in_int_per_rank );
265     delete unique_secondary_source_node_id_syn_id;
266   } // of omp single
267 }
268 
269 void
resize_sources(const thread tid)270 nest::SourceTable::resize_sources( const thread tid )
271 {
272   sources_[ tid ].resize( kernel().model_manager.get_num_synapse_prototypes() );
273 }
274 
275 bool
source_should_be_processed_(const thread rank_start,const thread rank_end,const Source & source) const276 nest::SourceTable::source_should_be_processed_( const thread rank_start,
277   const thread rank_end,
278   const Source& source ) const
279 {
280   const thread source_rank = kernel().mpi_manager.get_process_id_of_node_id( source.get_node_id() );
281 
282   return not( source.is_processed() or source.is_disabled()
283            // is this thread responsible for this part of the MPI
284            // buffer?
285            or source_rank < rank_start
286            or rank_end <= source_rank );
287 }
288 
289 bool
next_entry_has_same_source_(const SourceTablePosition & current_position,const Source & current_source) const290 nest::SourceTable::next_entry_has_same_source_( const SourceTablePosition& current_position,
291   const Source& current_source ) const
292 {
293   assert( not current_position.is_invalid() );
294 
295   const auto& local_sources = sources_[ current_position.tid ][ current_position.syn_id ];
296   const size_t next_lcid = current_position.lcid + 1;
297 
298   return (
299     next_lcid < local_sources.size() and local_sources[ next_lcid ].get_node_id() == current_source.get_node_id() );
300 }
301 
302 bool
previous_entry_has_same_source_(const SourceTablePosition & current_position,const Source & current_source) const303 nest::SourceTable::previous_entry_has_same_source_( const SourceTablePosition& current_position,
304   const Source& current_source ) const
305 {
306   assert( not current_position.is_invalid() );
307 
308   const auto& local_sources = sources_[ current_position.tid ][ current_position.syn_id ];
309   const long previous_lcid = current_position.lcid - 1; // needs to be a signed type such that negative
310                                                         // values can signal invalid indices
311 
312   return ( previous_lcid >= 0 and not local_sources[ previous_lcid ].is_processed()
313     and local_sources[ previous_lcid ].get_node_id() == current_source.get_node_id() );
314 }
315 
316 bool
populate_target_data_fields_(const SourceTablePosition & current_position,const Source & current_source,const thread source_rank,TargetData & next_target_data) const317 nest::SourceTable::populate_target_data_fields_( const SourceTablePosition& current_position,
318   const Source& current_source,
319   const thread source_rank,
320   TargetData& next_target_data ) const
321 {
322   const auto node_id = current_source.get_node_id();
323 
324   // set values of next_target_data
325   next_target_data.set_source_lid( kernel().vp_manager.node_id_to_lid( node_id ) );
326   next_target_data.set_source_tid( kernel().vp_manager.vp_to_thread( kernel().vp_manager.node_id_to_vp( node_id ) ) );
327   next_target_data.reset_marker();
328 
329   if ( current_source.is_primary() ) // primary connection, i.e., chemical synapses
330   {
331     next_target_data.set_is_primary( true );
332 
333     TargetDataFields& target_fields = next_target_data.target_data;
334     target_fields.set_syn_id( current_position.syn_id );
335     if ( kernel().connection_manager.use_compressed_spikes() )
336     {
337       // WARNING: we set the tid field here to zero just to make sure
338       // it has a defined value; however, this value is _not_ used
339       // anywhere when using compressed spikes
340       target_fields.set_tid( 0 );
341       auto it_idx = compressed_spike_data_map_.at( current_position.tid )
342                       .at( current_position.syn_id )
343                       .find( current_source.get_node_id() );
344       if ( it_idx != compressed_spike_data_map_.at( current_position.tid ).at( current_position.syn_id ).end() )
345       {
346         // WARNING: no matter how tempting, do not try to remove this
347         // entry from the compressed_spike_data_map_; if the MPI buffer
348         // is already full, this entry will need to be communicated the
349         // next MPI comm round, which, naturally, is not possible if it
350         // has been removed
351         target_fields.set_lcid( it_idx->second );
352       }
353       else // another thread is responsible for communicating this compressed source
354       {
355         return false;
356       }
357     }
358     else
359     {
360       // we store the thread index of the source table, not our own tid!
361       target_fields.set_tid( current_position.tid );
362       target_fields.set_lcid( current_position.lcid );
363     }
364   }
365   else // secondary connection, e.g., gap junctions
366   {
367     next_target_data.set_is_primary( false );
368 
369     // the source rank will write to the buffer position relative to
370     // the first position from the absolute position in the receive
371     // buffer
372     const size_t relative_recv_buffer_pos = kernel().connection_manager.get_secondary_recv_buffer_position(
373                                               current_position.tid, current_position.syn_id, current_position.lcid )
374       - kernel().mpi_manager.get_recv_displacement_secondary_events_in_int( source_rank );
375 
376     SecondaryTargetDataFields& secondary_fields = next_target_data.secondary_data;
377     secondary_fields.set_recv_buffer_pos( relative_recv_buffer_pos );
378     secondary_fields.set_syn_id( current_position.syn_id );
379   }
380 
381   return true;
382 }
383 
384 bool
get_next_target_data(const thread tid,const thread rank_start,const thread rank_end,thread & source_rank,TargetData & next_target_data)385 nest::SourceTable::get_next_target_data( const thread tid,
386   const thread rank_start,
387   const thread rank_end,
388   thread& source_rank,
389   TargetData& next_target_data )
390 {
391   SourceTablePosition& current_position = current_positions_[ tid ];
392 
393   if ( current_position.is_invalid() )
394   {
395     return false; // nothing to do here
396   }
397 
398   // we stay in this loop either until we can return a valid
399   // TargetData object or we have reached the end of the sources table
400   while ( true )
401   {
402     current_position.seek_to_next_valid_index( sources_ );
403     if ( current_position.is_invalid() )
404     {
405       return false; // reached the end of the sources table
406     }
407 
408     // the current position contains an entry, so we retrieve it
409     Source& current_source = sources_[ current_position.tid ][ current_position.syn_id ][ current_position.lcid ];
410 
411     if ( not source_should_be_processed_( rank_start, rank_end, current_source ) )
412     {
413       current_position.decrease();
414       continue;
415     }
416 
417     // we need to set a marker stating whether the entry following this
418     // entry, if existent, has the same source
419     kernel().connection_manager.set_source_has_more_targets( current_position.tid,
420       current_position.syn_id,
421       current_position.lcid,
422       next_entry_has_same_source_( current_position, current_source ) );
423 
424     // no need to communicate this entry if the previous entry has the same source
425     if ( previous_entry_has_same_source_( current_position, current_source ) )
426     {
427       current_source.set_processed( true ); // no need to look at this entry again
428       current_position.decrease();
429       continue;
430     }
431 
432     // reaching this means we found an entry that should be
433     // communicated via MPI, so we prepare to return the relevant data
434 
435     // set the source rank
436     source_rank = kernel().mpi_manager.get_process_id_of_node_id( current_source.get_node_id() );
437 
438     if ( not populate_target_data_fields_( current_position, current_source, source_rank, next_target_data ) )
439     {
440       current_position.decrease();
441       continue;
442     }
443 
444     // we are about to return a valid entry, so mark it as processed
445     current_source.set_processed( true );
446 
447     current_position.decrease();
448     return true; // found a valid entry
449   }
450 }
451 
452 void
resize_compressible_sources()453 nest::SourceTable::resize_compressible_sources()
454 {
455   for ( thread tid = 0; tid < static_cast< thread >( compressible_sources_.size() ); ++tid )
456   {
457     compressible_sources_[ tid ].clear();
458     compressible_sources_[ tid ].resize(
459       kernel().model_manager.get_num_synapse_prototypes(), std::map< index, SpikeData >() );
460   }
461 }
462 
463 void
collect_compressible_sources(const thread tid)464 nest::SourceTable::collect_compressible_sources( const thread tid )
465 {
466   for ( synindex syn_id = 0; syn_id < sources_[ tid ].size(); ++syn_id )
467   {
468     index lcid = 0;
469     auto& syn_sources = sources_[ tid ][ syn_id ];
470     while ( lcid < syn_sources.size() )
471     {
472       const index old_source_node_id = syn_sources[ lcid ].get_node_id();
473       const std::pair< index, SpikeData > source_node_id_to_spike_data =
474         std::make_pair( old_source_node_id, SpikeData( tid, syn_id, lcid, 0 ) );
475       compressible_sources_[ tid ][ syn_id ].insert( source_node_id_to_spike_data );
476 
477       // find next source with different node_id (assumes sorted sources)
478       ++lcid;
479       while ( ( lcid < syn_sources.size() ) and ( syn_sources[ lcid ].get_node_id() == old_source_node_id ) )
480       {
481         ++lcid;
482       }
483     }
484   }
485 }
486 
487 void
fill_compressed_spike_data(std::vector<std::vector<std::vector<SpikeData>>> & compressed_spike_data)488 nest::SourceTable::fill_compressed_spike_data(
489   std::vector< std::vector< std::vector< SpikeData > > >& compressed_spike_data )
490 {
491   compressed_spike_data.clear();
492   compressed_spike_data.resize( kernel().model_manager.get_num_synapse_prototypes() );
493 
494   for ( thread tid = 0; tid < static_cast< thread >( compressible_sources_.size() ); ++tid )
495   {
496     compressed_spike_data_map_[ tid ].clear();
497     compressed_spike_data_map_[ tid ].resize(
498       kernel().model_manager.get_num_synapse_prototypes(), std::map< index, size_t >() );
499   }
500 
501   // pseudo-random thread selector to balance memory usage across
502   // threads of compressed_spike_data_map_
503   size_t thread_idx = 0;
504 
505   // for each local thread and each synapse type we will populate this
506   // vector with spike data containing information about all process
507   // local targets
508   std::vector< SpikeData > spike_data;
509 
510   for ( thread tid = 0; tid < static_cast< thread >( compressible_sources_.size() ); ++tid )
511   {
512     for ( synindex syn_id = 0; syn_id < compressible_sources_[ tid ].size(); ++syn_id )
513     {
514       for ( auto it = compressible_sources_[ tid ][ syn_id ].begin();
515             it != compressible_sources_[ tid ][ syn_id ].end(); )
516       {
517         spike_data.clear();
518 
519         // add target position on this thread
520         spike_data.push_back( it->second );
521 
522         // add target positions on all other threads
523         for ( thread other_tid = tid + 1; other_tid < static_cast< thread >( compressible_sources_.size() );
524               ++other_tid )
525         {
526           auto other_it = compressible_sources_[ other_tid ][ syn_id ].find( it->first );
527           if ( other_it != compressible_sources_[ other_tid ][ syn_id ].end() )
528           {
529             spike_data.push_back( other_it->second );
530             compressible_sources_[ other_tid ][ syn_id ].erase( other_it );
531           }
532         }
533 
534         // WARNING: store source-node-id -> process-global-synapse
535         // association in compressed_spike_data_map on a
536         // pseudo-randomly selected thread which houses targets for
537         // this source; this tries to balance memory usage of this
538         // data structure across threads
539         const thread responsible_tid = spike_data[ thread_idx % spike_data.size() ].get_tid();
540         ++thread_idx;
541 
542         compressed_spike_data_map_[ responsible_tid ][ syn_id ].insert(
543           std::make_pair( it->first, compressed_spike_data[ syn_id ].size() ) );
544         compressed_spike_data[ syn_id ].push_back( spike_data );
545 
546         it = compressible_sources_[ tid ][ syn_id ].erase( it );
547       }
548       compressible_sources_[ tid ][ syn_id ].clear();
549     }
550   }
551 }
552