1 // Copyright (C) 2012 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 #ifndef DLIB_BSP_CPph_ 4 #define DLIB_BSP_CPph_ 5 6 #include "bsp.h" 7 #include <memory> 8 #include <stack> 9 10 // ---------------------------------------------------------------------------------------- 11 // ---------------------------------------------------------------------------------------- 12 13 namespace dlib 14 { 15 16 namespace impl1 17 { 18 connect_all(map_id_to_con & cons,const std::vector<network_address> & hosts,unsigned long node_id)19 void connect_all ( 20 map_id_to_con& cons, 21 const std::vector<network_address>& hosts, 22 unsigned long node_id 23 ) 24 { 25 cons.clear(); 26 for (unsigned long i = 0; i < hosts.size(); ++i) 27 { 28 std::unique_ptr<bsp_con> con(new bsp_con(hosts[i])); 29 dlib::serialize(node_id, con->stream); // tell the other end our node_id 30 unsigned long id = i+1; 31 cons.add(id, con); 32 } 33 } 34 connect_all_hostinfo(map_id_to_con & cons,const std::vector<hostinfo> & hosts,unsigned long node_id,std::string & error_string)35 void connect_all_hostinfo ( 36 map_id_to_con& cons, 37 const std::vector<hostinfo>& hosts, 38 unsigned long node_id, 39 std::string& error_string 40 ) 41 { 42 cons.clear(); 43 for (unsigned long i = 0; i < hosts.size(); ++i) 44 { 45 try 46 { 47 std::unique_ptr<bsp_con> con(new bsp_con(hosts[i].addr)); 48 dlib::serialize(node_id, con->stream); // tell the other end our node_id 49 con->stream.flush(); 50 unsigned long id = hosts[i].node_id; 51 cons.add(id, con); 52 } 53 catch (std::exception&) 54 { 55 std::ostringstream sout; 56 sout << "Could not connect to " << hosts[i].addr; 57 error_string = sout.str(); 58 break; 59 } 60 } 61 } 62 63 send_out_connection_orders(map_id_to_con & cons,const std::vector<network_address> & hosts)64 void send_out_connection_orders ( 65 map_id_to_con& cons, 66 const std::vector<network_address>& hosts 67 ) 68 { 69 // tell everyone their node ids 70 cons.reset(); 71 while (cons.move_next()) 72 { 73 dlib::serialize(cons.element().key(), cons.element().value()->stream); 74 } 75 76 // now tell them who to connect to 77 std::vector<hostinfo> targets; 78 for (unsigned long i = 0; i < hosts.size(); ++i) 79 { 80 hostinfo info(hosts[i], i+1); 81 82 dlib::serialize(targets, cons[info.node_id]->stream); 83 targets.push_back(info); 84 85 // let the other host know how many incoming connections to expect 86 const unsigned long num = hosts.size()-targets.size(); 87 dlib::serialize(num, cons[info.node_id]->stream); 88 cons[info.node_id]->stream.flush(); 89 } 90 } 91 92 // ------------------------------------------------------------------------------------ 93 94 95 } 96 97 // ---------------------------------------------------------------------------------------- 98 // ---------------------------------------------------------------------------------------- 99 // ---------------------------------------------------------------------------------------- 100 101 namespace impl2 102 { 103 // These control bytes are sent before each message between nodes. Note that many 104 // of these are only sent between the control node (node 0) and the other nodes. 105 // This is because the controller node is responsible for handling the 106 // synchronization that needs to happen when all nodes block on calls to 107 // receive_data() 108 // at the same time. 109 110 // denotes a normal content message. 111 const static char MESSAGE_HEADER = 0; 112 113 // sent to the controller node when someone receives a message via receive_data(). 114 const static char GOT_MESSAGE = 1; 115 116 // sent to the controller node when someone sends a message via send(). 117 const static char SENT_MESSAGE = 2; 118 119 // sent to the controller node when someone enters a call to receive_data() 120 const static char IN_WAITING_STATE = 3; 121 122 // broadcast when a node terminates itself. 123 const static char NODE_TERMINATE = 5; 124 125 // broadcast by the controller node when it determines that all nodes are blocked 126 // on calls to receive_data() and there aren't any messages in flight. This is also 127 // what makes us go to the next epoch. 128 const static char SEE_ALL_IN_WAITING_STATE = 6; 129 130 // This isn't ever transmitted between nodes. It is used internally to indicate 131 // that an error occurred. 132 const static char READ_ERROR = 7; 133 134 // ------------------------------------------------------------------------------------ 135 read_thread(impl1::bsp_con * con,unsigned long node_id,unsigned long sender_id,impl1::thread_safe_message_queue & msg_buffer)136 void read_thread ( 137 impl1::bsp_con* con, 138 unsigned long node_id, 139 unsigned long sender_id, 140 impl1::thread_safe_message_queue& msg_buffer 141 ) 142 { 143 try 144 { 145 while(true) 146 { 147 impl1::msg_data msg; 148 deserialize(msg.msg_type, con->stream); 149 msg.sender_id = sender_id; 150 151 if (msg.msg_type == MESSAGE_HEADER) 152 { 153 msg.data.reset(new std::vector<char>); 154 deserialize(msg.epoch, con->stream); 155 deserialize(*msg.data, con->stream); 156 } 157 158 msg_buffer.push_and_consume(msg); 159 160 if (msg.msg_type == NODE_TERMINATE) 161 break; 162 } 163 } 164 catch (std::exception& e) 165 { 166 impl1::msg_data msg; 167 msg.data.reset(new std::vector<char>); 168 vectorstream sout(*msg.data); 169 sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; 170 sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; 171 sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; 172 sout << " Receiving processing node id: " << node_id << std::endl; 173 sout << " Error message in the exception: " << e.what() << std::endl; 174 175 msg.sender_id = sender_id; 176 msg.msg_type = READ_ERROR; 177 178 msg_buffer.push_and_consume(msg); 179 } 180 catch (...) 181 { 182 impl1::msg_data msg; 183 msg.data.reset(new std::vector<char>); 184 vectorstream sout(*msg.data); 185 sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; 186 sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; 187 sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; 188 sout << " Receiving processing node id: " << node_id << std::endl; 189 190 msg.sender_id = sender_id; 191 msg.msg_type = READ_ERROR; 192 193 msg_buffer.push_and_consume(msg); 194 } 195 } 196 197 // ------------------------------------------------------------------------------------ 198 199 } 200 201 // ---------------------------------------------------------------------------------------- 202 // ---------------------------------------------------------------------------------------- 203 // IMPLEMENTATION OF bsp_context OBJECT MEMBERS 204 // ---------------------------------------------------------------------------------------- 205 // ---------------------------------------------------------------------------------------- 206 207 void bsp_context:: close_all_connections_gracefully()208 close_all_connections_gracefully( 209 ) 210 { 211 if (node_id() != 0) 212 { 213 _cons.reset(); 214 while (_cons.move_next()) 215 { 216 // tell the other end that we are intentionally dropping the connection 217 serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); 218 _cons.element().value()->stream.flush(); 219 } 220 } 221 222 impl1::msg_data msg; 223 // now wait for all the other nodes to terminate 224 while (num_terminated_nodes < _cons.size() ) 225 { 226 if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0) 227 { 228 num_waiting_nodes = 0; 229 broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); 230 ++current_epoch; 231 } 232 233 if (!msg_buffer.pop(msg)) 234 throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); 235 236 if (msg.msg_type == impl2::NODE_TERMINATE) 237 { 238 ++num_terminated_nodes; 239 _cons[msg.sender_id]->terminated = true; 240 } 241 else if (msg.msg_type == impl2::READ_ERROR) 242 { 243 throw dlib::socket_error(msg.data_to_string()); 244 } 245 else if (msg.msg_type == impl2::MESSAGE_HEADER) 246 { 247 throw dlib::socket_error("A BSP node received a message after it has terminated."); 248 } 249 else if (msg.msg_type == impl2::GOT_MESSAGE) 250 { 251 --num_waiting_nodes; 252 --outstanding_messages; 253 } 254 else if (msg.msg_type == impl2::SENT_MESSAGE) 255 { 256 ++outstanding_messages; 257 } 258 else if (msg.msg_type == impl2::IN_WAITING_STATE) 259 { 260 ++num_waiting_nodes; 261 } 262 } 263 264 if (node_id() == 0) 265 { 266 _cons.reset(); 267 while (_cons.move_next()) 268 { 269 // tell the other end that we are intentionally dropping the connection 270 serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); 271 _cons.element().value()->stream.flush(); 272 } 273 274 if (outstanding_messages != 0) 275 { 276 std::ostringstream sout; 277 sout << "A BSP job was allowed to terminate before all sent messages have been received.\n"; 278 sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n"; 279 sout << "have a corresponding call to receive()."; 280 throw dlib::socket_error(sout.str()); 281 } 282 } 283 } 284 285 // ---------------------------------------------------------------------------------------- 286 287 bsp_context:: ~bsp_context()288 ~bsp_context() 289 { 290 _cons.reset(); 291 while (_cons.move_next()) 292 { 293 _cons.element().value()->con->shutdown(); 294 } 295 296 msg_buffer.disable(); 297 298 // this will wait for all the threads to terminate 299 threads.clear(); 300 } 301 302 // ---------------------------------------------------------------------------------------- 303 304 bsp_context:: bsp_context(unsigned long node_id_,impl1::map_id_to_con & cons_)305 bsp_context( 306 unsigned long node_id_, 307 impl1::map_id_to_con& cons_ 308 ) : 309 outstanding_messages(0), 310 num_waiting_nodes(0), 311 num_terminated_nodes(0), 312 current_epoch(1), 313 _cons(cons_), 314 _node_id(node_id_) 315 { 316 // spawn a bunch of read threads, one for each connection 317 _cons.reset(); 318 while (_cons.move_next()) 319 { 320 std::unique_ptr<thread_function> ptr(new thread_function(&impl2::read_thread, 321 _cons.element().value().get(), 322 _node_id, 323 _cons.element().key(), 324 ref(msg_buffer))); 325 threads.push_back(ptr); 326 } 327 328 } 329 330 // ---------------------------------------------------------------------------------------- 331 332 bool bsp_context:: receive_data(std::shared_ptr<std::vector<char>> & item,unsigned long & sending_node_id)333 receive_data ( 334 std::shared_ptr<std::vector<char> >& item, 335 unsigned long& sending_node_id 336 ) 337 { 338 notify_control_node(impl2::IN_WAITING_STATE); 339 340 while (true) 341 { 342 // If there aren't any nodes left to give us messages then return right now. 343 // We need to check the msg_buffer size to make sure there aren't any 344 // unprocessed message there. Recall that this can happen because status 345 // messages always jump to the front of the message buffer. So we might have 346 // learned about the node terminations before processing their messages for us. 347 if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0) 348 { 349 return false; 350 } 351 352 // if all running nodes are currently blocking forever on receive_data() 353 if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size()) 354 { 355 num_waiting_nodes = 0; 356 broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); 357 358 // Note that the reason we have this epoch counter is so we can tell if a 359 // sent message is from before or after one of these "all nodes waiting" 360 // synchronization events. If we didn't have the epoch count we would have 361 // a race condition where one node gets the SEE_ALL_IN_WAITING_STATE 362 // message before others and then sends out a message to another node 363 // before that node got the SEE_ALL_IN_WAITING_STATE message. Then that 364 // node would think the normal message came before SEE_ALL_IN_WAITING_STATE 365 // which would be bad. 366 ++current_epoch; 367 return false; 368 } 369 370 impl1::msg_data data; 371 if (!msg_buffer.pop(data, current_epoch)) 372 throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); 373 374 375 switch(data.msg_type) 376 { 377 case impl2::MESSAGE_HEADER: { 378 item = data.data; 379 sending_node_id = data.sender_id; 380 notify_control_node(impl2::GOT_MESSAGE); 381 return true; 382 } break; 383 384 case impl2::IN_WAITING_STATE: { 385 ++num_waiting_nodes; 386 } break; 387 388 case impl2::GOT_MESSAGE: { 389 --outstanding_messages; 390 --num_waiting_nodes; 391 } break; 392 393 case impl2::SENT_MESSAGE: { 394 ++outstanding_messages; 395 } break; 396 397 case impl2::NODE_TERMINATE: { 398 ++num_terminated_nodes; 399 _cons[data.sender_id]->terminated = true; 400 } break; 401 402 case impl2::SEE_ALL_IN_WAITING_STATE: { 403 ++current_epoch; 404 return false; 405 } break; 406 407 case impl2::READ_ERROR: { 408 throw dlib::socket_error(data.data_to_string()); 409 } break; 410 411 default: { 412 throw dlib::socket_error("Unknown message received by dlib::bsp_context"); 413 } break; 414 } // end switch() 415 } // end while (true) 416 } 417 418 // ---------------------------------------------------------------------------------------- 419 420 void bsp_context:: notify_control_node(char val)421 notify_control_node ( 422 char val 423 ) 424 { 425 if (node_id() == 0) 426 { 427 using namespace impl2; 428 switch(val) 429 { 430 case SENT_MESSAGE: { 431 ++outstanding_messages; 432 } break; 433 434 case GOT_MESSAGE: { 435 --outstanding_messages; 436 } break; 437 438 case IN_WAITING_STATE: { 439 // nothing to do in this case 440 } break; 441 442 default: 443 DLIB_CASSERT(false,"This should never happen"); 444 } 445 } 446 else 447 { 448 serialize(val, _cons[0]->stream); 449 _cons[0]->stream.flush(); 450 } 451 } 452 453 // ---------------------------------------------------------------------------------------- 454 455 void bsp_context:: broadcast_byte(char val)456 broadcast_byte ( 457 char val 458 ) 459 { 460 for (unsigned long i = 0; i < number_of_nodes(); ++i) 461 { 462 // don't send to yourself or to terminated nodes 463 if (i == node_id() || _cons[i]->terminated) 464 continue; 465 466 serialize(val, _cons[i]->stream); 467 _cons[i]->stream.flush(); 468 } 469 } 470 471 // ---------------------------------------------------------------------------------------- 472 473 void bsp_context:: send_data(const std::vector<char> & item,unsigned long target_node_id)474 send_data( 475 const std::vector<char>& item, 476 unsigned long target_node_id 477 ) 478 { 479 using namespace impl2; 480 if (_cons[target_node_id]->terminated) 481 throw socket_error("Attempt to send a message to a node that has terminated."); 482 483 serialize(MESSAGE_HEADER, _cons[target_node_id]->stream); 484 serialize(current_epoch, _cons[target_node_id]->stream); 485 serialize(item, _cons[target_node_id]->stream); 486 _cons[target_node_id]->stream.flush(); 487 488 notify_control_node(SENT_MESSAGE); 489 } 490 491 // ---------------------------------------------------------------------------------------- 492 493 } 494 495 #endif // DLIB_BSP_CPph_ 496 497