1 // Copyright (C) 2012  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_BSP_CPph_
4 #define DLIB_BSP_CPph_
5 
6 #include "bsp.h"
7 #include <memory>
8 #include <stack>
9 
10 // ----------------------------------------------------------------------------------------
11 // ----------------------------------------------------------------------------------------
12 
13 namespace dlib
14 {
15 
16     namespace impl1
17     {
18 
connect_all(map_id_to_con & cons,const std::vector<network_address> & hosts,unsigned long node_id)19         void connect_all (
20             map_id_to_con& cons,
21             const std::vector<network_address>& hosts,
22             unsigned long node_id
23         )
24         {
25             cons.clear();
26             for (unsigned long i = 0; i < hosts.size(); ++i)
27             {
28                 std::unique_ptr<bsp_con> con(new bsp_con(hosts[i]));
29                 dlib::serialize(node_id, con->stream); // tell the other end our node_id
30                 unsigned long id = i+1;
31                 cons.add(id, con);
32             }
33         }
34 
connect_all_hostinfo(map_id_to_con & cons,const std::vector<hostinfo> & hosts,unsigned long node_id,std::string & error_string)35         void connect_all_hostinfo (
36             map_id_to_con& cons,
37             const std::vector<hostinfo>& hosts,
38             unsigned long node_id,
39             std::string& error_string
40         )
41         {
42             cons.clear();
43             for (unsigned long i = 0; i < hosts.size(); ++i)
44             {
45                 try
46                 {
47                     std::unique_ptr<bsp_con> con(new bsp_con(hosts[i].addr));
48                     dlib::serialize(node_id, con->stream); // tell the other end our node_id
49                     con->stream.flush();
50                     unsigned long id = hosts[i].node_id;
51                     cons.add(id, con);
52                 }
53                 catch (std::exception&)
54                 {
55                     std::ostringstream sout;
56                     sout << "Could not connect to " << hosts[i].addr;
57                     error_string = sout.str();
58                     break;
59                 }
60             }
61         }
62 
63 
send_out_connection_orders(map_id_to_con & cons,const std::vector<network_address> & hosts)64         void send_out_connection_orders (
65             map_id_to_con& cons,
66             const std::vector<network_address>& hosts
67         )
68         {
69             // tell everyone their node ids
70             cons.reset();
71             while (cons.move_next())
72             {
73                 dlib::serialize(cons.element().key(), cons.element().value()->stream);
74             }
75 
76             // now tell them who to connect to
77             std::vector<hostinfo> targets;
78             for (unsigned long i = 0; i < hosts.size(); ++i)
79             {
80                 hostinfo info(hosts[i], i+1);
81 
82                 dlib::serialize(targets, cons[info.node_id]->stream);
83                 targets.push_back(info);
84 
85                 // let the other host know how many incoming connections to expect
86                 const unsigned long num = hosts.size()-targets.size();
87                 dlib::serialize(num, cons[info.node_id]->stream);
88                 cons[info.node_id]->stream.flush();
89             }
90         }
91 
92     // ------------------------------------------------------------------------------------
93 
94 
95     }
96 
97 // ----------------------------------------------------------------------------------------
98 // ----------------------------------------------------------------------------------------
99 // ----------------------------------------------------------------------------------------
100 
101     namespace impl2
102     {
103         // These control bytes are sent before each message between nodes.  Note that many
104         // of these are only sent between the control node (node 0) and the other nodes.
105         // This is because the controller node is responsible for handling the
106         // synchronization that needs to happen when all nodes block on calls to
107         // receive_data()
108         // at the same time.
109 
110         // denotes a normal content message.
111         const static char MESSAGE_HEADER            = 0;
112 
113         // sent to the controller node when someone receives a message via receive_data().
114         const static char GOT_MESSAGE               = 1;
115 
116         // sent to the controller node when someone sends a message via send().
117         const static char SENT_MESSAGE              = 2;
118 
119         // sent to the controller node when someone enters a call to receive_data()
120         const static char IN_WAITING_STATE          = 3;
121 
122         // broadcast when a node terminates itself.
123         const static char NODE_TERMINATE            = 5;
124 
125         // broadcast by the controller node when it determines that all nodes are blocked
126         // on calls to receive_data() and there aren't any messages in flight.  This is also
127         // what makes us go to the next epoch.
128         const static char SEE_ALL_IN_WAITING_STATE  = 6;
129 
130         // This isn't ever transmitted between nodes.  It is used internally to indicate
131         // that an error occurred.
132         const static char READ_ERROR                = 7;
133 
134     // ------------------------------------------------------------------------------------
135 
read_thread(impl1::bsp_con * con,unsigned long node_id,unsigned long sender_id,impl1::thread_safe_message_queue & msg_buffer)136         void read_thread (
137             impl1::bsp_con* con,
138             unsigned long node_id,
139             unsigned long sender_id,
140             impl1::thread_safe_message_queue& msg_buffer
141         )
142         {
143             try
144             {
145                 while(true)
146                 {
147                     impl1::msg_data msg;
148                     deserialize(msg.msg_type, con->stream);
149                     msg.sender_id = sender_id;
150 
151                     if (msg.msg_type == MESSAGE_HEADER)
152                     {
153                         msg.data.reset(new std::vector<char>);
154                         deserialize(msg.epoch, con->stream);
155                         deserialize(*msg.data, con->stream);
156                     }
157 
158                     msg_buffer.push_and_consume(msg);
159 
160                     if (msg.msg_type == NODE_TERMINATE)
161                         break;
162                 }
163             }
164             catch (std::exception& e)
165             {
166                 impl1::msg_data msg;
167                 msg.data.reset(new std::vector<char>);
168                 vectorstream sout(*msg.data);
169                 sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n";
170                 sout << "  Sending processing node address:   " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl;
171                 sout << "  Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl;
172                 sout << "  Receiving processing node id:      " << node_id << std::endl;
173                 sout << "  Error message in the exception:    " << e.what() << std::endl;
174 
175                 msg.sender_id = sender_id;
176                 msg.msg_type = READ_ERROR;
177 
178                 msg_buffer.push_and_consume(msg);
179             }
180             catch (...)
181             {
182                 impl1::msg_data msg;
183                 msg.data.reset(new std::vector<char>);
184                 vectorstream sout(*msg.data);
185                 sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n";
186                 sout << "  Sending processing node address:   " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl;
187                 sout << "  Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl;
188                 sout << "  Receiving processing node id:      " << node_id << std::endl;
189 
190                 msg.sender_id = sender_id;
191                 msg.msg_type = READ_ERROR;
192 
193                 msg_buffer.push_and_consume(msg);
194             }
195         }
196 
197     // ------------------------------------------------------------------------------------
198 
199     }
200 
201 // ----------------------------------------------------------------------------------------
202 // ----------------------------------------------------------------------------------------
203 //                          IMPLEMENTATION OF bsp_context OBJECT MEMBERS
204 // ----------------------------------------------------------------------------------------
205 // ----------------------------------------------------------------------------------------
206 
207     void bsp_context::
close_all_connections_gracefully()208     close_all_connections_gracefully(
209     )
210     {
211         if (node_id() != 0)
212         {
213             _cons.reset();
214             while (_cons.move_next())
215             {
216                 // tell the other end that we are intentionally dropping the connection
217                 serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
218                 _cons.element().value()->stream.flush();
219             }
220         }
221 
222         impl1::msg_data msg;
223         // now wait for all the other nodes to terminate
224         while (num_terminated_nodes < _cons.size() )
225         {
226             if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0)
227             {
228                 num_waiting_nodes = 0;
229                 broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
230                 ++current_epoch;
231             }
232 
233             if (!msg_buffer.pop(msg))
234                 throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
235 
236             if (msg.msg_type == impl2::NODE_TERMINATE)
237             {
238                 ++num_terminated_nodes;
239                 _cons[msg.sender_id]->terminated = true;
240             }
241             else if (msg.msg_type == impl2::READ_ERROR)
242             {
243                 throw dlib::socket_error(msg.data_to_string());
244             }
245             else if (msg.msg_type == impl2::MESSAGE_HEADER)
246             {
247                 throw dlib::socket_error("A BSP node received a message after it has terminated.");
248             }
249             else if (msg.msg_type == impl2::GOT_MESSAGE)
250             {
251                 --num_waiting_nodes;
252                 --outstanding_messages;
253             }
254             else if (msg.msg_type == impl2::SENT_MESSAGE)
255             {
256                 ++outstanding_messages;
257             }
258             else if (msg.msg_type == impl2::IN_WAITING_STATE)
259             {
260                 ++num_waiting_nodes;
261             }
262         }
263 
264         if (node_id() == 0)
265         {
266             _cons.reset();
267             while (_cons.move_next())
268             {
269                 // tell the other end that we are intentionally dropping the connection
270                 serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
271                 _cons.element().value()->stream.flush();
272             }
273 
274             if (outstanding_messages != 0)
275             {
276                 std::ostringstream sout;
277                 sout << "A BSP job was allowed to terminate before all sent messages have been received.\n";
278                 sout << "There are at least " << outstanding_messages << " messages still in flight.   Make sure all sent messages\n";
279                 sout << "have a corresponding call to receive().";
280                 throw dlib::socket_error(sout.str());
281             }
282         }
283     }
284 
285 // ----------------------------------------------------------------------------------------
286 
287     bsp_context::
~bsp_context()288     ~bsp_context()
289     {
290         _cons.reset();
291         while (_cons.move_next())
292         {
293             _cons.element().value()->con->shutdown();
294         }
295 
296         msg_buffer.disable();
297 
298         // this will wait for all the threads to terminate
299         threads.clear();
300     }
301 
302 // ----------------------------------------------------------------------------------------
303 
304     bsp_context::
bsp_context(unsigned long node_id_,impl1::map_id_to_con & cons_)305     bsp_context(
306         unsigned long node_id_,
307         impl1::map_id_to_con& cons_
308     ) :
309         outstanding_messages(0),
310         num_waiting_nodes(0),
311         num_terminated_nodes(0),
312         current_epoch(1),
313         _cons(cons_),
314         _node_id(node_id_)
315     {
316         // spawn a bunch of read threads, one for each connection
317         _cons.reset();
318         while (_cons.move_next())
319         {
320             std::unique_ptr<thread_function> ptr(new thread_function(&impl2::read_thread,
321                                                                 _cons.element().value().get(),
322                                                                 _node_id,
323                                                                 _cons.element().key(),
324                                                                 ref(msg_buffer)));
325             threads.push_back(ptr);
326         }
327 
328     }
329 
330 // ----------------------------------------------------------------------------------------
331 
332     bool bsp_context::
receive_data(std::shared_ptr<std::vector<char>> & item,unsigned long & sending_node_id)333     receive_data (
334         std::shared_ptr<std::vector<char> >& item,
335         unsigned long& sending_node_id
336     )
337     {
338         notify_control_node(impl2::IN_WAITING_STATE);
339 
340         while (true)
341         {
342             // If there aren't any nodes left to give us messages then return right now.
343             // We need to check the msg_buffer size to make sure there aren't any
344             // unprocessed message there.  Recall that this can happen because status
345             // messages always jump to the front of the message buffer.  So we might have
346             // learned about the node terminations before processing their messages for us.
347             if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0)
348             {
349                 return false;
350             }
351 
352             // if all running nodes are currently blocking forever on receive_data()
353             if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size())
354             {
355                 num_waiting_nodes = 0;
356                 broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
357 
358                 // Note that the reason we have this epoch counter is so we can tell if a
359                 // sent message is from before or after one of these "all nodes waiting"
360                 // synchronization events.  If we didn't have the epoch count we would have
361                 // a race condition where one node gets the SEE_ALL_IN_WAITING_STATE
362                 // message before others and then sends out a message to another node
363                 // before that node got the SEE_ALL_IN_WAITING_STATE message.  Then that
364                 // node would think the normal message came before SEE_ALL_IN_WAITING_STATE
365                 // which would be bad.
366                 ++current_epoch;
367                 return false;
368             }
369 
370             impl1::msg_data data;
371             if (!msg_buffer.pop(data, current_epoch))
372                 throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
373 
374 
375             switch(data.msg_type)
376             {
377                 case impl2::MESSAGE_HEADER: {
378                     item = data.data;
379                     sending_node_id = data.sender_id;
380                     notify_control_node(impl2::GOT_MESSAGE);
381                     return true;
382                 } break;
383 
384                 case impl2::IN_WAITING_STATE: {
385                     ++num_waiting_nodes;
386                 } break;
387 
388                 case impl2::GOT_MESSAGE: {
389                     --outstanding_messages;
390                     --num_waiting_nodes;
391                 } break;
392 
393                 case impl2::SENT_MESSAGE: {
394                     ++outstanding_messages;
395                 } break;
396 
397                 case impl2::NODE_TERMINATE: {
398                     ++num_terminated_nodes;
399                     _cons[data.sender_id]->terminated = true;
400                 } break;
401 
402                 case impl2::SEE_ALL_IN_WAITING_STATE: {
403                     ++current_epoch;
404                     return false;
405                 } break;
406 
407                 case impl2::READ_ERROR: {
408                     throw dlib::socket_error(data.data_to_string());
409                 } break;
410 
411                 default: {
412                     throw dlib::socket_error("Unknown message received by dlib::bsp_context");
413                 } break;
414             } // end switch()
415         } // end while (true)
416     }
417 
418 // ----------------------------------------------------------------------------------------
419 
420     void bsp_context::
notify_control_node(char val)421     notify_control_node (
422         char val
423     )
424     {
425         if (node_id() == 0)
426         {
427             using namespace impl2;
428             switch(val)
429             {
430                 case SENT_MESSAGE: {
431                     ++outstanding_messages;
432                 } break;
433 
434                 case GOT_MESSAGE: {
435                     --outstanding_messages;
436                 } break;
437 
438                 case IN_WAITING_STATE: {
439                     // nothing to do in this case
440                 } break;
441 
442                 default:
443                     DLIB_CASSERT(false,"This should never happen");
444             }
445         }
446         else
447         {
448             serialize(val, _cons[0]->stream);
449             _cons[0]->stream.flush();
450         }
451     }
452 
453 // ----------------------------------------------------------------------------------------
454 
455     void bsp_context::
broadcast_byte(char val)456     broadcast_byte (
457         char val
458     )
459     {
460         for (unsigned long i = 0; i < number_of_nodes(); ++i)
461         {
462             // don't send to yourself or to terminated nodes
463             if (i == node_id() || _cons[i]->terminated)
464                 continue;
465 
466             serialize(val, _cons[i]->stream);
467             _cons[i]->stream.flush();
468         }
469     }
470 
471 // ----------------------------------------------------------------------------------------
472 
473     void bsp_context::
send_data(const std::vector<char> & item,unsigned long target_node_id)474     send_data(
475         const std::vector<char>& item,
476         unsigned long target_node_id
477     )
478     {
479         using namespace impl2;
480         if (_cons[target_node_id]->terminated)
481             throw socket_error("Attempt to send a message to a node that has terminated.");
482 
483         serialize(MESSAGE_HEADER, _cons[target_node_id]->stream);
484         serialize(current_epoch, _cons[target_node_id]->stream);
485         serialize(item, _cons[target_node_id]->stream);
486         _cons[target_node_id]->stream.flush();
487 
488         notify_control_node(SENT_MESSAGE);
489     }
490 
491 // ----------------------------------------------------------------------------------------
492 
493 }
494 
495 #endif // DLIB_BSP_CPph_
496 
497