1 /*
2 * Copyright 2006 Serge van den Boom <svdb@stack.nl>
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation; either version 2 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, write to the Free Software
16 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17 */
18
19 #define PORT_WANT_ERRNO
20 #include "port.h"
21
22 #define CONNECT_INTERNAL
23 #define SOCKET_INTERNAL
24 #include "connect.h"
25
26 #include "resolve.h"
27 #include "libs/alarm.h"
28 #include "../socket/socket.h"
29 #include "libs/misc.h"
30 #include "libs/log.h"
31
32 #include <assert.h>
33 #include <errno.h>
34 #include <stdlib.h>
35 #include <string.h>
36 #ifdef USE_WINSOCK
37 # include <winsock2.h>
38 # include <ws2tcpip.h>
39 # include "../wspiapiwrap.h"
40 #else
41 # include <netdb.h>
42 #endif
43
44 #define DEBUG_CONNECT_REF
45 #ifdef DEBUG_CONNECT_REF
46 # include "types.h"
47 #endif
48
49
50 static void connectHostNext(ConnectState *connectState);
51 static void doConnectCallback(ConnectState *connectState, NetDescriptor *nd,
52 const struct sockaddr *addr, socklen_t addrLen);
53 static void doConnectErrorCallback(ConnectState *connectState,
54 const ConnectError *error);
55
56
57 static ConnectState *
ConnectState_alloc(void)58 ConnectState_alloc(void) {
59 return (ConnectState *) malloc(sizeof (ConnectState));
60 }
61
62 static void
ConnectState_free(ConnectState * connectState)63 ConnectState_free(ConnectState *connectState) {
64 free(connectState);
65 }
66
67 static void
ConnectState_delete(ConnectState * connectState)68 ConnectState_delete(ConnectState *connectState) {
69 assert(connectState->nd == NULL);
70 assert(connectState->alarm == NULL);
71 assert(connectState->info == NULL);
72 assert(connectState->infoPtr == NULL);
73 ConnectState_free(connectState);
74 }
75
76 void
ConnectState_incRef(ConnectState * connectState)77 ConnectState_incRef(ConnectState *connectState) {
78 assert(connectState->refCount < REFCOUNT_MAX);
79 connectState->refCount++;
80 #ifdef DEBUG_CONNECT_REF
81 log_add(log_Debug, "ConnectState %08" PRIxPTR ": ref++ (%d)",
82 (uintptr_t) connectState, connectState->refCount);
83 #endif
84 }
85
86 bool
ConnectState_decRef(ConnectState * connectState)87 ConnectState_decRef(ConnectState *connectState) {
88 assert(connectState->refCount > 0);
89 connectState->refCount--;
90 #ifdef DEBUG_CONNECT_REF
91 log_add(log_Debug, "ConnectState %08" PRIxPTR ": ref-- (%d)",
92 (uintptr_t) connectState, connectState->refCount);
93 #endif
94 if (connectState->refCount == 0) {
95 ConnectState_delete(connectState);
96 return true;
97 }
98 return false;
99 }
100
101 // decrements ref count by 1
102 void
ConnectState_close(ConnectState * connectState)103 ConnectState_close(ConnectState *connectState) {
104 if (connectState->resolveState != NULL) {
105 Resolve_close(connectState->resolveState);
106 connectState->resolveState = NULL;
107 }
108 if (connectState->alarm != NULL) {
109 Alarm_remove(connectState->alarm);
110 connectState->alarm = NULL;
111 }
112 if (connectState->nd != NULL) {
113 NetDescriptor_close(connectState->nd);
114 connectState->nd = NULL;
115 }
116 if (connectState->info != NULL) {
117 freeaddrinfo(connectState->info);
118 connectState->info = NULL;
119 connectState->infoPtr = NULL;
120 }
121 connectState->state = Connect_closed;
122 ConnectState_decRef(connectState);
123 }
124
125 void
ConnectState_setExtra(ConnectState * connectState,void * extra)126 ConnectState_setExtra(ConnectState *connectState, void *extra) {
127 connectState->extra = extra;
128 }
129
130 void *
ConnectState_getExtra(ConnectState * connectState)131 ConnectState_getExtra(ConnectState *connectState) {
132 return connectState->extra;
133 }
134
135 static void
connectCallback(NetDescriptor * nd)136 connectCallback(NetDescriptor *nd) {
137 // Called by the NetManager when a connection has been established.
138 ConnectState *connectState =
139 (ConnectState *) NetDescriptor_getExtra(nd);
140 int err;
141
142 if (connectState->alarm != NULL) {
143 Alarm_remove(connectState->alarm);
144 connectState->alarm = NULL;
145 }
146
147 if (connectState->state == Connect_closed) {
148 // The connection attempt has been aborted.
149 #ifdef DEBUG
150 log_add(log_Debug, "Connection attempt was aborted.");
151 #endif
152 ConnectState_decRef(connectState);
153 return;
154 }
155
156 if (Socket_getError(NetDescriptor_getSocket(nd), &err) == -1) {
157 log_add(log_Fatal, "Socket_getError() failed: %s.",
158 strerror(errno));
159 explode();
160 }
161 if (err != 0) {
162 #ifdef DEBUG
163 log_add(log_Debug, "connect() failed: %s.", strerror(err));
164 #endif
165 NetDescriptor_close(nd);
166 connectState->nd = NULL;
167 connectState->infoPtr = connectState->infoPtr->ai_next;
168 connectHostNext(connectState);
169 return;
170 }
171
172 #ifdef DEBUG
173 log_add(log_Debug, "Connection established.");
174 #endif
175
176 // Notify the higher layer.
177 connectState->nd = NULL;
178 // The callback function takes over ownership of the
179 // NetDescriptor.
180 NetDescriptor_setWriteCallback(nd, NULL);
181 // Note that connectState->info and connectState->infoPtr are cleaned up
182 // when ConnectState_close() is called by the callback function.
183
184 ConnectState_incRef(connectState);
185 doConnectCallback(connectState, nd, connectState->infoPtr->ai_addr,
186 connectState->infoPtr->ai_addrlen);
187 {
188 // The callback called should release the last reference to
189 // the connectState, by calling ConnectState_close().
190 bool released = ConnectState_decRef(connectState);
191 assert(released);
192 (void) released; // In case assert() evaluates to nothing.
193 }
194 }
195
196 static void
connectTimeoutCallback(ConnectState * connectState)197 connectTimeoutCallback(ConnectState *connectState) {
198 connectState->alarm = NULL;
199
200 NetDescriptor_close(connectState->nd);
201 connectState->nd = NULL;
202
203 connectState->infoPtr = connectState->infoPtr->ai_next;
204 connectHostNext(connectState);
205 }
206
207 static void
setConnectTimeout(ConnectState * connectState)208 setConnectTimeout(ConnectState *connectState) {
209 assert(connectState->alarm == NULL);
210
211 connectState->alarm =
212 Alarm_addRelativeMs(connectState->flags.timeout,
213 (AlarmCallback) connectTimeoutCallback, connectState);
214 }
215
216 // Try connecting to the next address.
217 static Socket *
tryConnectHostNext(ConnectState * connectState)218 tryConnectHostNext(ConnectState *connectState) {
219 struct addrinfo *info;
220 Socket *sock;
221 int connectResult;
222
223 assert(connectState->nd == NULL);
224
225 info = connectState->infoPtr;
226
227 sock = Socket_openNative(info->ai_family, info->ai_socktype,
228 info->ai_protocol);
229 if (sock == Socket_noSocket) {
230 int savedErrno = errno;
231 log_add(log_Error, "socket() failed: %s.", strerror(errno));
232 errno = savedErrno;
233 return Socket_noSocket;
234 }
235
236 if (Socket_setNonBlocking(sock) == -1) {
237 int savedErrno = errno;
238 log_add(log_Error, "Could not make socket non-blocking: %s.",
239 strerror(errno));
240 errno = savedErrno;
241 return Socket_noSocket;
242 }
243
244 (void) Socket_setReuseAddr(sock);
245 // Ignore errors; it's not a big deal.
246 (void) Socket_setInlineOOB(sock);
247 // Ignore errors; it's not a big deal as the other party is not
248 // not supposed to send any OOB data.
249 (void) Socket_setKeepAlive(sock);
250 // Ignore errors; it's not a big deal.
251
252 connectResult = Socket_connect(sock, info->ai_addr, info->ai_addrlen);
253 if (connectResult == 0) {
254 // Connection has already succeeded.
255 // We just wait for the writability callback anyhow, so that
256 // we can use one code path.
257 return sock;
258 }
259
260 switch (errno) {
261 case EINPROGRESS:
262 // Connection in progress; wait for the write callback.
263 return sock;
264 }
265
266 // Connection failed immediately. This is just for one of the addresses,
267 // so this does not have to be final.
268 // Note that as the socket is non-blocking, most failed connection
269 // errors will usually not be reported immediately.
270 {
271 int savedErrno = errno;
272 Socket_close(sock);
273 #ifdef DEBUG
274 log_add(log_Debug, "connect() immediately failed for one address: "
275 "%s.", strerror(errno));
276 // TODO: add the address in the status message.
277 #endif
278 errno = savedErrno;
279 }
280 return Socket_noSocket;
281 }
282
283 static void
connectRetryCallback(ConnectState * connectState)284 connectRetryCallback(ConnectState *connectState) {
285 connectState->alarm = NULL;
286
287 connectState->infoPtr = connectState->info;
288 connectHostNext(connectState);
289 }
290
291 static void
setConnectRetryAlarm(ConnectState * connectState)292 setConnectRetryAlarm(ConnectState *connectState) {
293 assert(connectState->alarm == NULL);
294 assert(connectState->flags.retryDelayMs != Connect_noRetry);
295
296 connectState->alarm =
297 Alarm_addRelativeMs(connectState->flags.retryDelayMs,
298 (AlarmCallback) connectRetryCallback, connectState);
299 }
300
301 static void
connectHostReportAllFailed(ConnectState * connectState)302 connectHostReportAllFailed(ConnectState *connectState) {
303 // Could not connect to any host.
304 ConnectError error;
305 freeaddrinfo(connectState->info);
306 connectState->info = NULL;
307 connectState->infoPtr = NULL;
308 connectState->state = Connect_closed;
309 error.state = Connect_connecting;
310 error.err = ETIMEDOUT;
311 // No errno code is exactly suitable. We have been unable
312 // to connect to any host, but the reasons may vary
313 // (unreachable, refused, ...).
314 // ETIMEDOUT is the least specific portable errno code that
315 // seems appropriate.
316 doConnectErrorCallback(connectState, &error);
317 }
318
319 static void
connectHostNext(ConnectState * connectState)320 connectHostNext(ConnectState *connectState) {
321 Socket *sock;
322
323 while (connectState->infoPtr != NULL) {
324 sock = tryConnectHostNext(connectState);
325
326 if (sock != Socket_noSocket) {
327 // Connection succeeded or connection in progress
328 connectState->nd =
329 NetDescriptor_new(sock, (void *) connectState);
330 if (connectState->nd == NULL) {
331 ConnectError error;
332 int savedErrno = errno;
333
334 log_add(log_Error, "NetDescriptor_new() failed: %s.",
335 strerror(errno));
336 Socket_close(sock);
337 freeaddrinfo(connectState->info);
338 connectState->info = NULL;
339 connectState->infoPtr = NULL;
340 connectState->state = Connect_closed;
341 error.state = Connect_connecting;
342 error.err = savedErrno;
343 doConnectErrorCallback(connectState, &error);
344 return;
345 }
346
347 NetDescriptor_setWriteCallback(connectState->nd, connectCallback);
348 setConnectTimeout(connectState);
349 return;
350 }
351
352 connectState->infoPtr = connectState->infoPtr->ai_next;
353 }
354
355 // Connect failed to all addresses.
356
357 if (connectState->flags.retryDelayMs == Connect_noRetry) {
358 connectHostReportAllFailed(connectState);
359 return;
360 }
361
362 setConnectRetryAlarm(connectState);
363 }
364
365 static void
connectHostResolveCallback(ResolveState * resolveState,struct addrinfo * info)366 connectHostResolveCallback(ResolveState *resolveState,
367 struct addrinfo *info) {
368 ConnectState *connectState =
369 (ConnectState *) ResolveState_getExtra(resolveState);
370
371 connectState->state = Connect_connecting;
372
373 Resolve_close(resolveState);
374 connectState->resolveState = NULL;
375
376 if (connectState->flags.familyPrefer != PF_UNSPEC) {
377 // Reorganise the 'info' list to put the structures of the
378 // prefered family in front.
379 struct addrinfo *preferred;
380 struct addrinfo **preferredEnd;
381 struct addrinfo *rest;
382 struct addrinfo **restEnd;
383 splitAddrInfoOnFamily(info, connectState->flags.familyPrefer,
384 &preferred, &preferredEnd, &rest, &restEnd);
385 info = preferred;
386 *preferredEnd = rest;
387 }
388
389 connectState->info = info;
390 connectState->infoPtr = info;
391
392 connectHostNext(connectState);
393 }
394
395 static void
connectHostResolveErrorCallback(ResolveState * resolveState,const ResolveError * resolveError)396 connectHostResolveErrorCallback(ResolveState *resolveState,
397 const ResolveError *resolveError) {
398 ConnectState *connectState =
399 (ConnectState *) ResolveState_getExtra(resolveState);
400 ConnectError connectError;
401
402 assert(resolveError->gaiRes != 0);
403
404 Resolve_close(resolveState);
405 connectState->resolveState = NULL;
406
407 connectError.state = Connect_resolving;
408 connectError.resolveError = resolveError;
409 connectError.err = resolveError->err;
410 doConnectErrorCallback(connectState, &connectError);
411 }
412
413 ConnectState *
connectHostByName(const char * host,const char * service,Protocol proto,const ConnectFlags * flags,ConnectConnectCallback connectCallback,ConnectErrorCallback errorCallback,void * extra)414 connectHostByName(const char *host, const char *service, Protocol proto,
415 const ConnectFlags *flags, ConnectConnectCallback connectCallback,
416 ConnectErrorCallback errorCallback, void *extra) {
417 struct addrinfo hints;
418 ConnectState *connectState;
419 ResolveFlags resolveFlags;
420 // Structure is empty (for now).
421
422 assert(flags->familyDemand == PF_inet ||
423 flags->familyDemand == PF_inet6 ||
424 flags->familyDemand == PF_unspec);
425 assert(flags->familyPrefer == PF_inet ||
426 flags->familyPrefer == PF_inet6 ||
427 flags->familyPrefer == PF_unspec);
428 assert(proto == IPProto_tcp || proto == IPProto_udp);
429
430 memset(&hints, '\0', sizeof hints);
431 hints.ai_family = protocolFamilyTranslation[flags->familyDemand];
432 hints.ai_protocol = protocolTranslation[proto];
433
434 if (proto == IPProto_tcp) {
435 hints.ai_socktype = SOCK_STREAM;
436 } else {
437 assert(proto == IPProto_udp);
438 hints.ai_socktype = SOCK_DGRAM;
439 }
440 hints.ai_flags = 0;
441
442 connectState = ConnectState_alloc();
443 connectState->refCount = 1;
444 #ifdef DEBUG_CONNECT_REF
445 log_add(log_Debug, "ConnectState %08" PRIxPTR ": ref=1 (%d)",
446 (uintptr_t) connectState, connectState->refCount);
447 #endif
448 connectState->state = Connect_resolving;
449 connectState->flags = *flags;
450 connectState->connectCallback = connectCallback;
451 connectState->errorCallback = errorCallback;
452 connectState->extra = extra;
453 connectState->info = NULL;
454 connectState->infoPtr = NULL;
455 connectState->nd = NULL;
456 connectState->alarm = NULL;
457
458 connectState->resolveState = getaddrinfoAsync(
459 host, service, &hints, &resolveFlags,
460 (ResolveCallback) connectHostResolveCallback,
461 (ResolveErrorCallback) connectHostResolveErrorCallback,
462 (ResolveCallbackArg) connectState);
463
464 return connectState;
465 }
466
467 // NB: The callback function becomes the owner of nd
468 static void
doConnectCallback(ConnectState * connectState,NetDescriptor * nd,const struct sockaddr * addr,socklen_t addrLen)469 doConnectCallback(ConnectState *connectState, NetDescriptor *nd,
470 const struct sockaddr *addr, socklen_t addrLen) {
471 assert(connectState->connectCallback != NULL);
472
473 ConnectState_incRef(connectState);
474 // No need to increment nd as the callback function takes over ownership.
475 (*connectState->connectCallback)(connectState, nd, addr, addrLen);
476 ConnectState_decRef(connectState);
477 }
478
479 static void
doConnectErrorCallback(ConnectState * connectState,const ConnectError * error)480 doConnectErrorCallback(ConnectState *connectState,
481 const ConnectError *error) {
482 assert(connectState->errorCallback != NULL);
483
484 ConnectState_incRef(connectState);
485 (*connectState->errorCallback)(connectState, error);
486 ConnectState_decRef(connectState);
487 }
488
489
490
491