1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include "plasma/io.h"
19 
20 #include <cstdint>
21 #include <memory>
22 #include <sstream>
23 
24 #include "arrow/status.h"
25 #include "arrow/util/endian.h"
26 #include "arrow/util/logging.h"
27 
28 #include "plasma/common.h"
29 #include "plasma/plasma_generated.h"
30 
31 using arrow::Status;
32 
33 /// Number of times we try connecting to a socket.
34 constexpr int64_t kNumConnectAttempts = 80;
35 /// Time to wait between connection attempts to a socket.
36 constexpr int64_t kConnectTimeoutMs = 100;
37 
38 namespace plasma {
39 
40 using flatbuf::MessageType;
41 
WriteBytes(int fd,uint8_t * cursor,size_t length)42 Status WriteBytes(int fd, uint8_t* cursor, size_t length) {
43   ssize_t nbytes = 0;
44   size_t bytesleft = length;
45   size_t offset = 0;
46   while (bytesleft > 0) {
47     // While we haven't written the whole message, write to the file descriptor,
48     // advance the cursor, and decrease the amount left to write.
49     nbytes = write(fd, cursor + offset, bytesleft);
50     if (nbytes < 0) {
51       if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
52         continue;
53       }
54       return Status::IOError(strerror(errno));
55     } else if (nbytes == 0) {
56       return Status::IOError("Encountered unexpected EOF");
57     }
58     ARROW_CHECK(nbytes > 0);
59     bytesleft -= nbytes;
60     offset += nbytes;
61   }
62 
63   return Status::OK();
64 }
65 
WriteMessage(int fd,MessageType type,int64_t length,uint8_t * bytes)66 Status WriteMessage(int fd, MessageType type, int64_t length, uint8_t* bytes) {
67   int64_t version = arrow::BitUtil::ToLittleEndian(kPlasmaProtocolVersion);
68   assert(sizeof(MessageType) == sizeof(int64_t));
69   type = static_cast<MessageType>(
70       arrow::BitUtil::ToLittleEndian(static_cast<int64_t>(type)));
71   int64_t length_le = arrow::BitUtil::ToLittleEndian(length);
72   RETURN_NOT_OK(WriteBytes(fd, reinterpret_cast<uint8_t*>(&version), sizeof(version)));
73   RETURN_NOT_OK(WriteBytes(fd, reinterpret_cast<uint8_t*>(&type), sizeof(type)));
74   RETURN_NOT_OK(WriteBytes(fd, reinterpret_cast<uint8_t*>(&length_le), sizeof(length)));
75   return WriteBytes(fd, bytes, length * sizeof(char));
76 }
77 
ReadBytes(int fd,uint8_t * cursor,size_t length)78 Status ReadBytes(int fd, uint8_t* cursor, size_t length) {
79   ssize_t nbytes = 0;
80   // Termination condition: EOF or read 'length' bytes total.
81   size_t bytesleft = length;
82   size_t offset = 0;
83   while (bytesleft > 0) {
84     nbytes = read(fd, cursor + offset, bytesleft);
85     if (nbytes < 0) {
86       if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
87         continue;
88       }
89       return Status::IOError(strerror(errno));
90     } else if (0 == nbytes) {
91       return Status::IOError("Encountered unexpected EOF");
92     }
93     ARROW_CHECK(nbytes > 0);
94     bytesleft -= nbytes;
95     offset += nbytes;
96   }
97 
98   return Status::OK();
99 }
100 
ReadMessage(int fd,MessageType * type,std::vector<uint8_t> * buffer)101 Status ReadMessage(int fd, MessageType* type, std::vector<uint8_t>* buffer) {
102   int64_t version;
103   RETURN_NOT_OK_ELSE(ReadBytes(fd, reinterpret_cast<uint8_t*>(&version), sizeof(version)),
104                      *type = MessageType::PlasmaDisconnectClient);
105   version = arrow::BitUtil::FromLittleEndian(version);
106   ARROW_CHECK(version == kPlasmaProtocolVersion) << "version = " << version;
107   RETURN_NOT_OK_ELSE(ReadBytes(fd, reinterpret_cast<uint8_t*>(type), sizeof(*type)),
108                      *type = MessageType::PlasmaDisconnectClient);
109   assert(sizeof(MessageType) == sizeof(int64_t));
110   *type = static_cast<MessageType>(
111       arrow::BitUtil::FromLittleEndian(static_cast<int64_t>(*type)));
112   int64_t length_temp;
113   RETURN_NOT_OK_ELSE(
114       ReadBytes(fd, reinterpret_cast<uint8_t*>(&length_temp), sizeof(length_temp)),
115       *type = MessageType::PlasmaDisconnectClient);
116   // The length must be read as an int64_t, but it should be used as a size_t.
117   size_t length = static_cast<size_t>(arrow::BitUtil::FromLittleEndian(length_temp));
118   if (length > buffer->size()) {
119     buffer->resize(length);
120   }
121   RETURN_NOT_OK_ELSE(ReadBytes(fd, buffer->data(), length),
122                      *type = MessageType::PlasmaDisconnectClient);
123   return Status::OK();
124 }
125 
BindIpcSock(const std::string & pathname,bool shall_listen)126 int BindIpcSock(const std::string& pathname, bool shall_listen) {
127   struct sockaddr_un socket_address;
128   int socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
129   if (socket_fd < 0) {
130     ARROW_LOG(ERROR) << "socket() failed for pathname " << pathname;
131     return -1;
132   }
133   // Tell the system to allow the port to be reused.
134   int on = 1;
135   if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&on),
136                  sizeof(on)) < 0) {
137     ARROW_LOG(ERROR) << "setsockopt failed for pathname " << pathname;
138     close(socket_fd);
139     return -1;
140   }
141 
142   unlink(pathname.c_str());
143   memset(&socket_address, 0, sizeof(socket_address));
144   socket_address.sun_family = AF_UNIX;
145   if (pathname.size() + 1 > sizeof(socket_address.sun_path)) {
146     ARROW_LOG(ERROR) << "Socket pathname is too long.";
147     close(socket_fd);
148     return -1;
149   }
150   strncpy(socket_address.sun_path, pathname.c_str(), pathname.size() + 1);
151 
152   if (bind(socket_fd, reinterpret_cast<struct sockaddr*>(&socket_address),
153            sizeof(socket_address)) != 0) {
154     ARROW_LOG(ERROR) << "Bind failed for pathname " << pathname;
155     close(socket_fd);
156     return -1;
157   }
158   if (shall_listen && listen(socket_fd, 128) == -1) {
159     ARROW_LOG(ERROR) << "Could not listen to socket " << pathname;
160     close(socket_fd);
161     return -1;
162   }
163   return socket_fd;
164 }
165 
ConnectIpcSocketRetry(const std::string & pathname,int num_retries,int64_t timeout,int * fd)166 Status ConnectIpcSocketRetry(const std::string& pathname, int num_retries,
167                              int64_t timeout, int* fd) {
168   // Pick the default values if the user did not specify.
169   if (num_retries < 0) {
170     num_retries = kNumConnectAttempts;
171   }
172   if (timeout < 0) {
173     timeout = kConnectTimeoutMs;
174   }
175   *fd = ConnectIpcSock(pathname);
176   while (*fd < 0 && num_retries > 0) {
177     ARROW_LOG(ERROR) << "Connection to IPC socket failed for pathname " << pathname
178                      << ", retrying " << num_retries << " more times";
179     // Sleep for timeout milliseconds.
180     usleep(static_cast<int>(timeout * 1000));
181     *fd = ConnectIpcSock(pathname);
182     --num_retries;
183   }
184 
185   // If we could not connect to the socket, exit.
186   if (*fd == -1) {
187     return Status::IOError("Could not connect to socket ", pathname);
188   }
189 
190   return Status::OK();
191 }
192 
ConnectIpcSock(const std::string & pathname)193 int ConnectIpcSock(const std::string& pathname) {
194   struct sockaddr_un socket_address;
195   int socket_fd;
196 
197   socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
198   if (socket_fd < 0) {
199     ARROW_LOG(ERROR) << "socket() failed for pathname " << pathname;
200     return -1;
201   }
202 
203   memset(&socket_address, 0, sizeof(socket_address));
204   socket_address.sun_family = AF_UNIX;
205   if (pathname.size() + 1 > sizeof(socket_address.sun_path)) {
206     ARROW_LOG(ERROR) << "Socket pathname is too long.";
207     close(socket_fd);
208     return -1;
209   }
210   strncpy(socket_address.sun_path, pathname.c_str(), pathname.size() + 1);
211 
212   if (connect(socket_fd, reinterpret_cast<struct sockaddr*>(&socket_address),
213               sizeof(socket_address)) != 0) {
214     close(socket_fd);
215     return -1;
216   }
217 
218   return socket_fd;
219 }
220 
AcceptClient(int socket_fd)221 int AcceptClient(int socket_fd) {
222   int client_fd = accept(socket_fd, NULL, NULL);
223   if (client_fd < 0) {
224     ARROW_LOG(ERROR) << "Error reading from socket.";
225     return -1;
226   }
227   return client_fd;
228 }
229 
ReadMessageAsync(int sock)230 std::unique_ptr<uint8_t[]> ReadMessageAsync(int sock) {
231   int64_t size;
232   Status s = ReadBytes(sock, reinterpret_cast<uint8_t*>(&size), sizeof(int64_t));
233   if (!s.ok()) {
234     // The other side has closed the socket.
235     ARROW_LOG(DEBUG) << "Socket has been closed, or some other error has occurred.";
236     close(sock);
237     return NULL;
238   }
239   auto message = std::unique_ptr<uint8_t[]>(new uint8_t[size]);
240   s = ReadBytes(sock, message.get(), size);
241   if (!s.ok()) {
242     // The other side has closed the socket.
243     ARROW_LOG(DEBUG) << "Socket has been closed, or some other error has occurred.";
244     close(sock);
245     return NULL;
246   }
247   return message;
248 }
249 
250 }  // namespace plasma
251