1 /*
2  *  Copyright 2006  Serge van den Boom <svdb@stack.nl>
3  *
4  *  This program is free software; you can redistribute it and/or modify
5  *  it under the terms of the GNU General Public License as published by
6  *  the Free Software Foundation; either version 2 of the License, or
7  *  (at your option) any later version.
8  *
9  *  This program is distributed in the hope that it will be useful,
10  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
11  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  *  GNU General Public License for more details.
13  *
14  *  You should have received a copy of the GNU General Public License
15  *  along with this program; if not, write to the Free Software
16  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17  */
18 
19 #define PORT_WANT_ERRNO
20 #include "port.h"
21 
22 #define CONNECT_INTERNAL
23 #define SOCKET_INTERNAL
24 #include "connect.h"
25 
26 #include "resolve.h"
27 #include "libs/alarm.h"
28 #include "../socket/socket.h"
29 #include "libs/misc.h"
30 #include "libs/log.h"
31 
32 #include <assert.h>
33 #include <errno.h>
34 #include <stdlib.h>
35 #include <string.h>
36 #ifdef USE_WINSOCK
37 #	include <winsock2.h>
38 #	include <ws2tcpip.h>
39 #	include "../wspiapiwrap.h"
40 #else
41 #	include <netdb.h>
42 #endif
43 
44 #define DEBUG_CONNECT_REF
45 #ifdef DEBUG_CONNECT_REF
46 #	include "types.h"
47 #endif
48 
49 
50 static void connectHostNext(ConnectState *connectState);
51 static void doConnectCallback(ConnectState *connectState, NetDescriptor *nd,
52 		const struct sockaddr *addr, socklen_t addrLen);
53 static void doConnectErrorCallback(ConnectState *connectState,
54 		const ConnectError *error);
55 
56 
57 static ConnectState *
ConnectState_alloc(void)58 ConnectState_alloc(void) {
59 	return (ConnectState *) malloc(sizeof (ConnectState));
60 }
61 
62 static void
ConnectState_free(ConnectState * connectState)63 ConnectState_free(ConnectState *connectState) {
64 	free(connectState);
65 }
66 
67 static void
ConnectState_delete(ConnectState * connectState)68 ConnectState_delete(ConnectState *connectState) {
69 	assert(connectState->nd == NULL);
70 	assert(connectState->alarm == NULL);
71 	assert(connectState->info == NULL);
72 	assert(connectState->infoPtr == NULL);
73 	ConnectState_free(connectState);
74 }
75 
76 void
ConnectState_incRef(ConnectState * connectState)77 ConnectState_incRef(ConnectState *connectState) {
78 	assert(connectState->refCount < REFCOUNT_MAX);
79 	connectState->refCount++;
80 #ifdef DEBUG_CONNECT_REF
81 	log_add(log_Debug, "ConnectState %08" PRIxPTR ": ref++ (%d)",
82 			(uintptr_t) connectState, connectState->refCount);
83 #endif
84 }
85 
86 bool
ConnectState_decRef(ConnectState * connectState)87 ConnectState_decRef(ConnectState *connectState) {
88 	assert(connectState->refCount > 0);
89 	connectState->refCount--;
90 #ifdef DEBUG_CONNECT_REF
91 	log_add(log_Debug, "ConnectState %08" PRIxPTR ": ref-- (%d)",
92 			(uintptr_t) connectState, connectState->refCount);
93 #endif
94 	if (connectState->refCount == 0) {
95 		ConnectState_delete(connectState);
96 		return true;
97 	}
98 	return false;
99 }
100 
101 // decrements ref count by 1
102 void
ConnectState_close(ConnectState * connectState)103 ConnectState_close(ConnectState *connectState) {
104 	if (connectState->resolveState != NULL) {
105 		Resolve_close(connectState->resolveState);
106 		connectState->resolveState = NULL;
107 	}
108 	if (connectState->alarm != NULL) {
109 		Alarm_remove(connectState->alarm);
110 		connectState->alarm = NULL;
111 	}
112 	if (connectState->nd != NULL) {
113 		NetDescriptor_close(connectState->nd);
114 		connectState->nd = NULL;
115 	}
116 	if (connectState->info != NULL) {
117 		freeaddrinfo(connectState->info);
118 		connectState->info = NULL;
119 		connectState->infoPtr = NULL;
120 	}
121 	connectState->state = Connect_closed;
122 	ConnectState_decRef(connectState);
123 }
124 
125 void
ConnectState_setExtra(ConnectState * connectState,void * extra)126 ConnectState_setExtra(ConnectState *connectState, void *extra) {
127 	connectState->extra = extra;
128 }
129 
130 void *
ConnectState_getExtra(ConnectState * connectState)131 ConnectState_getExtra(ConnectState *connectState) {
132 	return connectState->extra;
133 }
134 
135 static void
connectCallback(NetDescriptor * nd)136 connectCallback(NetDescriptor *nd) {
137 	// Called by the NetManager when a connection has been established.
138 	ConnectState *connectState =
139 			(ConnectState *) NetDescriptor_getExtra(nd);
140 	int err;
141 
142 	if (connectState->alarm != NULL) {
143 		Alarm_remove(connectState->alarm);
144 		connectState->alarm = NULL;
145 	}
146 
147 	if (connectState->state == Connect_closed) {
148 		// The connection attempt has been aborted.
149 #ifdef DEBUG
150 		log_add(log_Debug, "Connection attempt was aborted.");
151 #endif
152 		ConnectState_decRef(connectState);
153 		return;
154 	}
155 
156 	if (Socket_getError(NetDescriptor_getSocket(nd), &err) == -1) {
157 		log_add(log_Fatal, "Socket_getError() failed: %s.",
158 				strerror(errno));
159 		explode();
160 	}
161 	if (err != 0) {
162 #ifdef DEBUG
163 		log_add(log_Debug, "connect() failed: %s.", strerror(err));
164 #endif
165 		NetDescriptor_close(nd);
166 		connectState->nd = NULL;
167 		connectState->infoPtr = connectState->infoPtr->ai_next;
168 		connectHostNext(connectState);
169 		return;
170 	}
171 
172 #ifdef DEBUG
173 	log_add(log_Debug, "Connection established.");
174 #endif
175 
176 	// Notify the higher layer.
177 	connectState->nd = NULL;
178 			// The callback function takes over ownership of the
179 			// NetDescriptor.
180 	NetDescriptor_setWriteCallback(nd, NULL);
181 	// Note that connectState->info and connectState->infoPtr are cleaned up
182 	// when ConnectState_close() is called by the callback function.
183 
184 	ConnectState_incRef(connectState);
185 	doConnectCallback(connectState, nd, connectState->infoPtr->ai_addr,
186 			connectState->infoPtr->ai_addrlen);
187 	{
188 		// The callback called should release the last reference to
189 		// the connectState, by calling ConnectState_close().
190 		bool released = ConnectState_decRef(connectState);
191 		assert(released);
192 		(void) released;  // In case assert() evaluates to nothing.
193 	}
194 }
195 
196 static void
connectTimeoutCallback(ConnectState * connectState)197 connectTimeoutCallback(ConnectState *connectState) {
198 	connectState->alarm = NULL;
199 
200 	NetDescriptor_close(connectState->nd);
201 	connectState->nd = NULL;
202 
203 	connectState->infoPtr = connectState->infoPtr->ai_next;
204 	connectHostNext(connectState);
205 }
206 
207 static void
setConnectTimeout(ConnectState * connectState)208 setConnectTimeout(ConnectState *connectState) {
209 	assert(connectState->alarm == NULL);
210 
211 	connectState->alarm =
212 			Alarm_addRelativeMs(connectState->flags.timeout,
213 			(AlarmCallback) connectTimeoutCallback, connectState);
214 }
215 
216 // Try connecting to the next address.
217 static Socket *
tryConnectHostNext(ConnectState * connectState)218 tryConnectHostNext(ConnectState *connectState) {
219 	struct addrinfo *info;
220 	Socket *sock;
221 	int connectResult;
222 
223 	assert(connectState->nd == NULL);
224 
225 	info = connectState->infoPtr;
226 
227 	sock = Socket_openNative(info->ai_family, info->ai_socktype,
228 			info->ai_protocol);
229 	if (sock == Socket_noSocket) {
230 		int savedErrno = errno;
231 		log_add(log_Error, "socket() failed: %s.", strerror(errno));
232 		errno = savedErrno;
233 		return Socket_noSocket;
234 	}
235 
236 	if (Socket_setNonBlocking(sock) == -1) {
237 		int savedErrno = errno;
238 		log_add(log_Error, "Could not make socket non-blocking: %s.",
239 				strerror(errno));
240 		errno = savedErrno;
241 		return Socket_noSocket;
242 	}
243 
244 	(void) Socket_setReuseAddr(sock);
245 			// Ignore errors; it's not a big deal.
246 	(void) Socket_setInlineOOB(sock);
247 			// Ignore errors; it's not a big deal as the other party is not
248 			// not supposed to send any OOB data.
249 	(void) Socket_setKeepAlive(sock);
250 			// Ignore errors; it's not a big deal.
251 
252 	connectResult = Socket_connect(sock, info->ai_addr, info->ai_addrlen);
253 	if (connectResult == 0) {
254 		// Connection has already succeeded.
255 		// We just wait for the writability callback anyhow, so that
256 		// we can use one code path.
257 		return sock;
258 	}
259 
260 	switch (errno) {
261 		case EINPROGRESS:
262 			// Connection in progress; wait for the write callback.
263 			return sock;
264 	}
265 
266 	// Connection failed immediately. This is just for one of the addresses,
267 	// so this does not have to be final.
268 	// Note that as the socket is non-blocking, most failed connection
269 	// errors will usually not be reported immediately.
270 	{
271 		int savedErrno = errno;
272 		Socket_close(sock);
273 #ifdef DEBUG
274 		log_add(log_Debug, "connect() immediately failed for one address: "
275 				"%s.", strerror(errno));
276 				// TODO: add the address in the status message.
277 #endif
278 		errno = savedErrno;
279 	}
280 	return Socket_noSocket;
281 }
282 
283 static void
connectRetryCallback(ConnectState * connectState)284 connectRetryCallback(ConnectState *connectState) {
285 	connectState->alarm = NULL;
286 
287 	connectState->infoPtr = connectState->info;
288 	connectHostNext(connectState);
289 }
290 
291 static void
setConnectRetryAlarm(ConnectState * connectState)292 setConnectRetryAlarm(ConnectState *connectState) {
293 	assert(connectState->alarm == NULL);
294 	assert(connectState->flags.retryDelayMs != Connect_noRetry);
295 
296 	connectState->alarm =
297 			Alarm_addRelativeMs(connectState->flags.retryDelayMs,
298 			(AlarmCallback) connectRetryCallback, connectState);
299 }
300 
301 static void
connectHostReportAllFailed(ConnectState * connectState)302 connectHostReportAllFailed(ConnectState *connectState) {
303 	// Could not connect to any host.
304 	ConnectError error;
305 	freeaddrinfo(connectState->info);
306 	connectState->info = NULL;
307 	connectState->infoPtr = NULL;
308 	connectState->state = Connect_closed;
309 	error.state = Connect_connecting;
310 	error.err = ETIMEDOUT;
311 			// No errno code is exactly suitable. We have been unable
312 			// to connect to any host, but the reasons may vary
313 			// (unreachable, refused, ...).
314 			// ETIMEDOUT is the least specific portable errno code that
315 			// seems appropriate.
316 	doConnectErrorCallback(connectState, &error);
317 }
318 
319 static void
connectHostNext(ConnectState * connectState)320 connectHostNext(ConnectState *connectState) {
321 	Socket *sock;
322 
323 	while (connectState->infoPtr != NULL) {
324 		sock = tryConnectHostNext(connectState);
325 
326 		if (sock != Socket_noSocket) {
327 			// Connection succeeded or connection in progress
328 			connectState->nd =
329 					NetDescriptor_new(sock, (void *) connectState);
330 			if (connectState->nd == NULL) {
331 				ConnectError error;
332 				int savedErrno = errno;
333 
334 				log_add(log_Error, "NetDescriptor_new() failed: %s.",
335 						strerror(errno));
336 				Socket_close(sock);
337 				freeaddrinfo(connectState->info);
338 				connectState->info = NULL;
339 				connectState->infoPtr = NULL;
340 				connectState->state = Connect_closed;
341 				error.state = Connect_connecting;
342 				error.err = savedErrno;
343 				doConnectErrorCallback(connectState, &error);
344 				return;
345 			}
346 
347 			NetDescriptor_setWriteCallback(connectState->nd, connectCallback);
348 			setConnectTimeout(connectState);
349 			return;
350 		}
351 
352 		connectState->infoPtr = connectState->infoPtr->ai_next;
353 	}
354 
355 	// Connect failed to all addresses.
356 
357 	if (connectState->flags.retryDelayMs == Connect_noRetry) {
358 		connectHostReportAllFailed(connectState);
359 		return;
360 	}
361 
362 	setConnectRetryAlarm(connectState);
363 }
364 
365 static void
connectHostResolveCallback(ResolveState * resolveState,struct addrinfo * info)366 connectHostResolveCallback(ResolveState *resolveState,
367 		struct addrinfo *info) {
368 	ConnectState *connectState =
369 			(ConnectState *) ResolveState_getExtra(resolveState);
370 
371 	connectState->state = Connect_connecting;
372 
373 	Resolve_close(resolveState);
374 	connectState->resolveState = NULL;
375 
376 	if (connectState->flags.familyPrefer != PF_UNSPEC) {
377 		// Reorganise the 'info' list to put the structures of the
378 		// prefered family in front.
379 		struct addrinfo *preferred;
380 		struct addrinfo **preferredEnd;
381 		struct addrinfo *rest;
382 		struct addrinfo **restEnd;
383 		splitAddrInfoOnFamily(info, connectState->flags.familyPrefer,
384 				&preferred, &preferredEnd, &rest, &restEnd);
385 		info = preferred;
386 		*preferredEnd = rest;
387 	}
388 
389 	connectState->info = info;
390 	connectState->infoPtr = info;
391 
392 	connectHostNext(connectState);
393 }
394 
395 static void
connectHostResolveErrorCallback(ResolveState * resolveState,const ResolveError * resolveError)396 connectHostResolveErrorCallback(ResolveState *resolveState,
397 		const ResolveError *resolveError) {
398 	ConnectState *connectState =
399 			(ConnectState *) ResolveState_getExtra(resolveState);
400 	ConnectError connectError;
401 
402 	assert(resolveError->gaiRes != 0);
403 
404 	Resolve_close(resolveState);
405 	connectState->resolveState = NULL;
406 
407 	connectError.state = Connect_resolving;
408 	connectError.resolveError = resolveError;
409 	connectError.err = resolveError->err;
410 	doConnectErrorCallback(connectState, &connectError);
411 }
412 
413 ConnectState *
connectHostByName(const char * host,const char * service,Protocol proto,const ConnectFlags * flags,ConnectConnectCallback connectCallback,ConnectErrorCallback errorCallback,void * extra)414 connectHostByName(const char *host, const char *service, Protocol proto,
415 		const ConnectFlags *flags, ConnectConnectCallback connectCallback,
416 		ConnectErrorCallback errorCallback, void *extra) {
417 	struct addrinfo	hints;
418 	ConnectState *connectState;
419 	ResolveFlags resolveFlags;
420 			// Structure is empty (for now).
421 
422 	assert(flags->familyDemand == PF_inet ||
423 			flags->familyDemand == PF_inet6 ||
424 			flags->familyDemand == PF_unspec);
425 	assert(flags->familyPrefer == PF_inet ||
426 			flags->familyPrefer == PF_inet6 ||
427 			flags->familyPrefer == PF_unspec);
428 	assert(proto == IPProto_tcp || proto == IPProto_udp);
429 
430 	memset(&hints, '\0', sizeof hints);
431 	hints.ai_family = protocolFamilyTranslation[flags->familyDemand];
432 	hints.ai_protocol = protocolTranslation[proto];
433 
434 	if (proto == IPProto_tcp) {
435 		hints.ai_socktype = SOCK_STREAM;
436 	} else {
437 		assert(proto == IPProto_udp);
438 		hints.ai_socktype = SOCK_DGRAM;
439 	}
440 	hints.ai_flags = 0;
441 
442 	connectState = ConnectState_alloc();
443 	connectState->refCount = 1;
444 #ifdef DEBUG_CONNECT_REF
445 	log_add(log_Debug, "ConnectState %08" PRIxPTR ": ref=1 (%d)",
446 			(uintptr_t) connectState, connectState->refCount);
447 #endif
448 	connectState->state = Connect_resolving;
449 	connectState->flags = *flags;
450 	connectState->connectCallback = connectCallback;
451 	connectState->errorCallback = errorCallback;
452 	connectState->extra = extra;
453 	connectState->info = NULL;
454 	connectState->infoPtr = NULL;
455 	connectState->nd = NULL;
456 	connectState->alarm = NULL;
457 
458 	connectState->resolveState = getaddrinfoAsync(
459 			host, service, &hints, &resolveFlags,
460 			(ResolveCallback) connectHostResolveCallback,
461 			(ResolveErrorCallback) connectHostResolveErrorCallback,
462 			(ResolveCallbackArg) connectState);
463 
464 	return connectState;
465 }
466 
467 // NB: The callback function becomes the owner of nd
468 static void
doConnectCallback(ConnectState * connectState,NetDescriptor * nd,const struct sockaddr * addr,socklen_t addrLen)469 doConnectCallback(ConnectState *connectState, NetDescriptor *nd,
470 		const struct sockaddr *addr, socklen_t addrLen) {
471 	assert(connectState->connectCallback != NULL);
472 
473 	ConnectState_incRef(connectState);
474 	// No need to increment nd as the callback function takes over ownership.
475 	(*connectState->connectCallback)(connectState, nd, addr, addrLen);
476 	ConnectState_decRef(connectState);
477 }
478 
479 static void
doConnectErrorCallback(ConnectState * connectState,const ConnectError * error)480 doConnectErrorCallback(ConnectState *connectState,
481 		const ConnectError *error) {
482 	assert(connectState->errorCallback != NULL);
483 
484 	ConnectState_incRef(connectState);
485 	(*connectState->errorCallback)(connectState, error);
486 	ConnectState_decRef(connectState);
487 }
488 
489 
490 
491