1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD
4 license as described in the file LICENSE.
5 */
6 // This implements the allreduce function of MPI.
7
8 #pragma once
9 #include <string>
10 #ifdef _WIN32
11 #include <WinSock2.h>
12 #include <WS2tcpip.h>
13 typedef unsigned int uint32_t;
14 typedef unsigned short uint16_t;
15 typedef int socklen_t;
16 typedef SOCKET socket_t;
17 #define CLOSESOCK closesocket
18 #else
19 #include <sys/socket.h>
20 #include <sys/socket.h>
21 #include <netinet/in.h>
22 #include <netinet/tcp.h>
23 #include <netdb.h>
24 #include <stdlib.h>
25 #include <stdio.h>
26 typedef int socket_t;
27 #define CLOSESOCK close
28 #endif
29
30 using namespace std;
31
32 const size_t ar_buf_size = 1<<16;
33
34 struct node_socks {
35 std::string current_master;
36 socket_t parent;
37 socket_t children[2];
~node_socksnode_socks38 ~node_socks()
39 {
40 if(current_master != "") {
41 if(parent != -1)
42 CLOSESOCK(this->parent);
43 if(children[0] != -1)
44 CLOSESOCK(this->children[0]);
45 if(children[1] != -1)
46 CLOSESOCK(this->children[1]);
47 }
48 }
node_socksnode_socks49 node_socks ()
50 {
51 current_master = "";
52 }
53 };
54
55
addbufs(T * buf1,const T * buf2,const size_t n)56 template <class T, void (*f)(T&, const T&)> void addbufs(T* buf1, const T* buf2, const size_t n) {
57 for(size_t i = 0;i < n;i++)
58 f(buf1[i], buf2[i]);
59 }
60
61 void all_reduce_init(const string master_location, const size_t unique_id, const size_t total, const size_t node, node_socks& socks);
62
pass_up(char * buffer,size_t left_read_pos,size_t right_read_pos,size_t & parent_sent_pos,socket_t parent_sock,size_t n)63 template <class T> void pass_up(char* buffer, size_t left_read_pos, size_t right_read_pos, size_t& parent_sent_pos, socket_t parent_sock, size_t n) {
64 size_t my_bufsize = min(ar_buf_size, min(left_read_pos, right_read_pos) / sizeof(T) * sizeof(T) - parent_sent_pos);
65
66 if(my_bufsize > 0) {
67 //going to pass up this chunk of data to the parent
68 int write_size = send(parent_sock, buffer+parent_sent_pos, (int)my_bufsize, 0);
69 if(write_size < 0) {
70 cerr<<"Write to parent failed "<<my_bufsize<<" "<<write_size<<" "<<parent_sent_pos<<" "<<left_read_pos<<" "<<right_read_pos<<endl ;
71 throw exception();
72 }
73 parent_sent_pos += write_size;
74 }
75
76 }
77
reduce(char * buffer,const size_t n,const socket_t parent_sock,const socket_t * child_sockets)78 template <class T, void (*f)(T&, const T&)>void reduce(char* buffer, const size_t n, const socket_t parent_sock, const socket_t* child_sockets) {
79
80 fd_set fds;
81 FD_ZERO(&fds);
82 if(child_sockets[0] != -1)
83 FD_SET(child_sockets[0],&fds);
84 if(child_sockets[1] != -1)
85 FD_SET(child_sockets[1],&fds);
86
87 socket_t max_fd = max(child_sockets[0],child_sockets[1])+1;
88 size_t child_read_pos[2] = {0,0}; //First unread float from left and right children
89 int child_unprocessed[2] = {0,0}; //The number of bytes sent by the child but not yet added to the buffer
90 char child_read_buf[2][ar_buf_size+sizeof(T)-1];
91 size_t parent_sent_pos = 0; //First unsent float to parent
92 //parent_sent_pos <= left_read_pos
93 //parent_sent_pos <= right_read_pos
94
95 if(child_sockets[0] == -1) {
96 child_read_pos[0] = n;
97 }
98 if(child_sockets[1] == -1) {
99 child_read_pos[1] = n;
100 }
101
102 while (parent_sent_pos < n || child_read_pos[0] < n || child_read_pos[1] < n)
103 {
104 if(parent_sock != -1)
105 pass_up<T>(buffer, child_read_pos[0], child_read_pos[1], parent_sent_pos, parent_sock, n);
106
107 if(parent_sent_pos >= n && child_read_pos[0] >= n && child_read_pos[1] >= n) break;
108
109 if(child_read_pos[0] < n || child_read_pos[1] < n) {
110 if (max_fd > 0 && select((int)max_fd,&fds,NULL, NULL, NULL) == -1)
111 {
112 cerr << "select: " << strerror(errno) << endl;
113 throw exception();
114 }
115
116 for(int i = 0;i < 2;i++) {
117 if(child_sockets[i] != -1 && FD_ISSET(child_sockets[i],&fds)) {
118 //there is data to be left from left child
119 if(child_read_pos[i] == n) {
120 cerr<<"I think child has no data to send but he thinks he has "<<FD_ISSET(child_sockets[0],&fds)<<" "<<FD_ISSET(child_sockets[1],&fds)<<endl;
121 throw exception();
122 }
123
124
125 size_t count = min(ar_buf_size,n - child_read_pos[i]);
126 int read_size = recv(child_sockets[i], child_read_buf[i] + child_unprocessed[i], (int)count, 0);
127 if(read_size == -1) {
128 cerr << "recv from child: " << strerror(errno) << endl;
129 throw exception();
130 }
131
132 addbufs<T, f>((T*)buffer + child_read_pos[i]/sizeof(T), (T*)child_read_buf[i], (child_read_pos[i] + read_size)/sizeof(T) - child_read_pos[i]/sizeof(T));
133
134 child_read_pos[i] += read_size;
135 int old_unprocessed = child_unprocessed[i];
136 child_unprocessed[i] = child_read_pos[i] % (int)sizeof(T);
137 for(int j = 0;j < child_unprocessed[i];j++) {
138 child_read_buf[i][j] = child_read_buf[i][((old_unprocessed + read_size)/(int)sizeof(T))*sizeof(T)+j];
139 }
140
141 if(child_read_pos[i] == n) //Done reading parent
142 FD_CLR(child_sockets[i],&fds);
143 }
144 else if(child_sockets[i] != -1 && child_read_pos[i] != n)
145 FD_SET(child_sockets[i],&fds);
146 }
147 }
148 if(parent_sock == -1 && child_read_pos[0] == n && child_read_pos[1] == n)
149 parent_sent_pos = n;
150
151 }
152
153 }
154
155 void broadcast(char* buffer, const size_t n, const socket_t parent_sock, const socket_t * child_sockets);
156
157
all_reduce(T * buffer,const size_t n,const std::string master_location,const size_t unique_id,const size_t total,const size_t node,node_socks & socks)158 template <class T, void (*f)(T&, const T&)> void all_reduce(T* buffer, const size_t n, const std::string master_location, const size_t unique_id, const size_t total, const size_t node, node_socks& socks)
159 {
160 if(master_location != socks.current_master)
161 all_reduce_init(master_location, unique_id, total, node, socks);
162 reduce<T, f>((char*)buffer, n*sizeof(T), socks.parent, socks.children);
163 broadcast((char*)buffer, n*sizeof(T), socks.parent, socks.children);
164 }
165