1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved.  Released under a BSD
4 license as described in the file LICENSE.
5  */
6 // This implements the allreduce function of MPI.
7 
8 #pragma once
9 #include <string>
10 #ifdef _WIN32
11 #include <WinSock2.h>
12 #include <WS2tcpip.h>
13 typedef unsigned int uint32_t;
14 typedef unsigned short uint16_t;
15 typedef int socklen_t;
16 typedef SOCKET socket_t;
17 #define CLOSESOCK closesocket
18 #else
19 #include <sys/socket.h>
20 #include <sys/socket.h>
21 #include <netinet/in.h>
22 #include <netinet/tcp.h>
23 #include <netdb.h>
24 #include <stdlib.h>
25 #include <stdio.h>
26 typedef int socket_t;
27 #define CLOSESOCK close
28 #endif
29 
30 using namespace std;
31 
32 const size_t ar_buf_size = 1<<16;
33 
34 struct node_socks {
35   std::string current_master;
36   socket_t parent;
37   socket_t children[2];
~node_socksnode_socks38   ~node_socks()
39   {
40     if(current_master != "") {
41       if(parent != -1)
42 	CLOSESOCK(this->parent);
43       if(children[0] != -1)
44 	CLOSESOCK(this->children[0]);
45       if(children[1] != -1)
46 	CLOSESOCK(this->children[1]);
47     }
48   }
node_socksnode_socks49   node_socks ()
50   {
51     current_master = "";
52   }
53 };
54 
55 
addbufs(T * buf1,const T * buf2,const size_t n)56 template <class T, void (*f)(T&, const T&)> void addbufs(T* buf1, const T* buf2, const size_t n) {
57   for(size_t i = 0;i < n;i++)
58     f(buf1[i], buf2[i]);
59 }
60 
61 void all_reduce_init(const string master_location, const size_t unique_id, const size_t total, const size_t node, node_socks& socks);
62 
pass_up(char * buffer,size_t left_read_pos,size_t right_read_pos,size_t & parent_sent_pos,socket_t parent_sock,size_t n)63 template <class T> void pass_up(char* buffer, size_t left_read_pos, size_t right_read_pos, size_t& parent_sent_pos, socket_t parent_sock, size_t n) {
64   size_t my_bufsize = min(ar_buf_size, min(left_read_pos, right_read_pos) / sizeof(T) * sizeof(T) - parent_sent_pos);
65 
66   if(my_bufsize > 0) {
67     //going to pass up this chunk of data to the parent
68     int write_size = send(parent_sock, buffer+parent_sent_pos, (int)my_bufsize, 0);
69     if(write_size < 0) {
70       cerr<<"Write to parent failed "<<my_bufsize<<" "<<write_size<<" "<<parent_sent_pos<<" "<<left_read_pos<<" "<<right_read_pos<<endl ;
71       throw exception();
72     }
73     parent_sent_pos += write_size;
74   }
75 
76 }
77 
reduce(char * buffer,const size_t n,const socket_t parent_sock,const socket_t * child_sockets)78 template <class T, void (*f)(T&, const T&)>void reduce(char* buffer, const size_t n, const socket_t parent_sock, const socket_t* child_sockets) {
79 
80   fd_set fds;
81   FD_ZERO(&fds);
82   if(child_sockets[0] != -1)
83     FD_SET(child_sockets[0],&fds);
84   if(child_sockets[1] != -1)
85     FD_SET(child_sockets[1],&fds);
86 
87   socket_t max_fd = max(child_sockets[0],child_sockets[1])+1;
88   size_t child_read_pos[2] = {0,0}; //First unread float from left and right children
89   int child_unprocessed[2] = {0,0}; //The number of bytes sent by the child but not yet added to the buffer
90   char child_read_buf[2][ar_buf_size+sizeof(T)-1];
91   size_t parent_sent_pos = 0; //First unsent float to parent
92   //parent_sent_pos <= left_read_pos
93   //parent_sent_pos <= right_read_pos
94 
95   if(child_sockets[0] == -1) {
96     child_read_pos[0] = n;
97   }
98   if(child_sockets[1] == -1) {
99     child_read_pos[1] = n;
100   }
101 
102   while (parent_sent_pos < n || child_read_pos[0] < n || child_read_pos[1] < n)
103     {
104       if(parent_sock != -1)
105 	pass_up<T>(buffer, child_read_pos[0], child_read_pos[1], parent_sent_pos, parent_sock, n);
106 
107       if(parent_sent_pos >= n && child_read_pos[0] >= n && child_read_pos[1] >= n) break;
108 
109       if(child_read_pos[0] < n || child_read_pos[1] < n) {
110 	if (max_fd > 0 && select((int)max_fd,&fds,NULL, NULL, NULL) == -1)
111 	  {
112 	    cerr << "select: " << strerror(errno) << endl;
113 	    throw exception();
114 	  }
115 
116 	for(int i = 0;i < 2;i++) {
117 	  if(child_sockets[i] != -1 && FD_ISSET(child_sockets[i],&fds)) {
118 	    //there is data to be left from left child
119 	    if(child_read_pos[i] == n) {
120 	      cerr<<"I think child has no data to send but he thinks he has "<<FD_ISSET(child_sockets[0],&fds)<<" "<<FD_ISSET(child_sockets[1],&fds)<<endl;
121 	      throw exception();
122 	    }
123 
124 
125 	    size_t count = min(ar_buf_size,n - child_read_pos[i]);
126 	    int read_size = recv(child_sockets[i], child_read_buf[i] + child_unprocessed[i], (int)count, 0);
127 	    if(read_size == -1) {
128 	      cerr << "recv from child: " << strerror(errno) << endl;
129 	      throw exception();
130 	    }
131 
132 	    addbufs<T, f>((T*)buffer + child_read_pos[i]/sizeof(T), (T*)child_read_buf[i], (child_read_pos[i] + read_size)/sizeof(T) - child_read_pos[i]/sizeof(T));
133 
134 	    child_read_pos[i] += read_size;
135 	    int old_unprocessed = child_unprocessed[i];
136 	    child_unprocessed[i] = child_read_pos[i] % (int)sizeof(T);
137 	    for(int j = 0;j < child_unprocessed[i];j++) {
138 	      child_read_buf[i][j] = child_read_buf[i][((old_unprocessed + read_size)/(int)sizeof(T))*sizeof(T)+j];
139 	    }
140 
141 	    if(child_read_pos[i] == n) //Done reading parent
142 	      FD_CLR(child_sockets[i],&fds);
143 	  }
144 	  else if(child_sockets[i] != -1 && child_read_pos[i] != n)
145 	    FD_SET(child_sockets[i],&fds);
146 	}
147       }
148       if(parent_sock == -1 && child_read_pos[0] == n && child_read_pos[1] == n)
149 	parent_sent_pos = n;
150 
151     }
152 
153 }
154 
155 void broadcast(char* buffer, const size_t n, const socket_t parent_sock, const socket_t * child_sockets);
156 
157 
all_reduce(T * buffer,const size_t n,const std::string master_location,const size_t unique_id,const size_t total,const size_t node,node_socks & socks)158 template <class T, void (*f)(T&, const T&)> void all_reduce(T* buffer, const size_t n, const std::string master_location, const size_t unique_id, const size_t total, const size_t node, node_socks& socks)
159 {
160   if(master_location != socks.current_master)
161     all_reduce_init(master_location, unique_id, total, node, socks);
162   reduce<T, f>((char*)buffer, n*sizeof(T), socks.parent, socks.children);
163   broadcast((char*)buffer, n*sizeof(T), socks.parent, socks.children);
164 }
165