1 /**
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 2 of the License, or
5  * (at your option) any later version.
6  *
7  * net/socket.h
8  * (c) 2005-2008 Murat Deligonul
9  *
10  * Encapsulates a TCP stream socket, supports IPv6 & SSL
11  */
12 
13 #ifndef __NET_SOCKET_H
14 #   define __NET_SOCKET_H
15 
16 #include <string>
17 #include <iosfwd>
18 #include <cstdarg>
19 #include <cerrno>
20 #include <sys/types.h>
21 #include <sys/socket.h>
22 #include <unistd.h>
23 #include <fcntl.h>
24 
25 #ifdef HAVE_SSL
26 #	define OPENSSL_NO_KRB5		/* prevent compilation error on RH Linux 9, et al */
27 #	include <openssl/ssl.h>
28 #	include <openssl/err.h>
29 #	include <openssl/rand.h>
30 #endif
31 
32 #include "util/exception.h"
33 #include "util/strings.h"
34 #include "io/engine.h"
35 #include "io/pollable.h"
36 #include "io/buffer.h"
37 #include "net/types.h"
38 #include "net/resolver.h"
39 #include "net/radaptor.h"
40 #include "debug.h"
41 
42 namespace net {
43 
44 struct socket_exception_tag { };
45 typedef util::exception<socket_exception_tag, std::runtime_error>   socket_exception;
46 
47 class socket : public io::pollable {
48 private:
49 	/** Common to all sockets **/
50 	static class io::engine * engine;
51 	static class resolver * resolver;
52 	static class radaptor * radaptor;
53 
54 #ifdef HAVE_SSL
55 	static SSL_CTX *ssl_ctx;
56 	static char *tls_rand_file;
57 	static int seed_PRNG();
58 	SSL *ssl;
59 #endif
60 public:
61 	static const size_t BUFFER_MINIMUM = io::buffer::MINIMUM;
62 
63 	static const size_t BUFFER_MAXIMUM = io::buffer::MAXIMUM;
64 
65 	/** Static member functions **/
assign_resolver(class resolver * r)66 	static  void assign_resolver(class resolver *r)	{
67 		socket::resolver = r;
68 	}
assign_engine(class io::engine * e)69 	static  void assign_engine(class io::engine * e) {
70 		delete radaptor;
71 		socket::engine = e;
72 		socket::radaptor = new class radaptor(*resolver, *engine);
73 	}
set_nonblocking(int f)74 	static	int set_nonblocking(int f) {
75 		return fcntl(f, F_SETFL, fcntl(f, F_GETFL,0) | O_NONBLOCK);
76 	}
get_resolver()77 	static class resolver * get_resolver() {
78 		return socket::resolver;
79 	}
compress()80 	static void compress() {
81 		engine->compress();
82 	}
83 
84 	/** Socket state **/
85 	enum state {
86 		NEW		= 0,
87 		ACCEPTING,
88 		OPEN,
89 		CONNECT_RESOLVING,
90 		CONNECTING,
91 		CONNECTED,
92 		LISTEN_RESOLVING,
93 		LISTENING,
94 		CLOSED
95 	};
96 
97 	enum sock_option {
98 		SOCK_DEFAULT = 0,
99 		SOCK_SSL    = 2
100 	};
101 
102 private:
103 	/** resolver_callback functions **/
104 	int async_lookup_finished(const resolver::result *);
105 	int async_lookup_failed(const resolver::result *);
106 
107 private:
108 	/** basic properties **/
109 	int 		family;
110 	enum state	state;
111 	int 		options;
112 	int		lookup_id;	/* id number of current async. name resolution */
113 
114 	/** resolver callback wrapper **/
115 	const net::resolver_callback_wrapper<class socket,
116 			&socket::async_lookup_finished,
117 			&socket::async_lookup_failed>		r_callback;
118 
119 	/** I/O buffers **/
120 	io::buffer ibuff;
121 	io::buffer obuff;
122 
123 	/** address information **/
124 	ap_pair local, peer;
125 
126 	/** resolver helper data **/
127 	resolver::request * tmp_req;
128 	std::pair<unsigned char *, size_t>	* interface_data;
129 
130 public:
131 	/**
132 	 * Create a socket of protocol with socket()
133 	 */
134 	socket(int family, int options, size_t, size_t);
135 
136 	/**
137 	 * accept() a connection from another socket`
138 	 */
139 	socket(socket * source, size_t, size_t);
140 
~socket()141 	virtual ~socket() {
142 		close();
143 	}
144 
145 	/** re-create a socket **/
146 	int open(int, int = IPPROTO_TCP);
147 
148 	/** accessors **/
get_state()149 	int     get_state() const { return state; }
get_family()150 	int     get_family() const { return family; }
get_options()151 	int	get_options() const { return options; }
152 
153 	/** peer/local address & port information **/
get_peer()154 	const ap_pair& get_peer() const { return peer; }
get_local()155 	const ap_pair& get_local() const { return local; }
peer_addr()156 	const char * peer_addr() const { return peer.first.c_str(); }
local_addr()157 	const char * local_addr() const { return local.first.c_str(); }
peer_port()158 	unsigned short peer_port() const { return peer.second; }
local_port()159 	unsigned short local_port() const { return local.second; }
160 
161 	/** buffer information **/
get_ibuff()162 	io::buffer& get_ibuff()  { return ibuff; }
get_obuff()163 	io::buffer& get_obuff()  { return obuff; }
ibuff_size()164 	size_t ibuff_size() const { return ibuff.size(); }
obuff_size()165 	size_t obuff_size() const { return obuff.size(); }
166 
167 	/** make it do stuff **/
168 	int async_connect(const char *, const char *, unsigned short, int = SOCK_DEFAULT);
169 	int listen(const char *, unsigned short, int, int = SOCK_DEFAULT);
170 	int listen(const char *, const char *, int, int = SOCK_DEFAULT);
171 	int close();
172 
173 private:
174 	/** 'pollable' interface **/
175 	virtual int event_callback(int);
176 
177 	int update_addr_info(bool);
178 	int update_addr_info(bool, const struct sockaddr *, size_t);
179 	int check_for_close() const;
180 
181 	static 	int recv_test(int);
182 
183 public:
184 	/**
185 	  * Event handlers:
186 	  * NOTE: Any handler that causes the closing of the socket MUST return < 0.
187 	  *       This is to halt processing of additional events and avoid invoking additional callbacks
188 	  *	  on potentially deleted objects.
189 	  *       The error related callbacks (disconnects, connect fail, error) are always assumed to have
190 	  * 	  closed the socket.
191 	  */
192 	virtual int on_readable() = 0;
193 	virtual int on_writeable() = 0;
194 	virtual void on_disconnect(int);
195 	virtual void on_connect();
196 	virtual void on_connect_fail(int);
197 	virtual void on_connecting();
198 
199 	/**
200 	 * SSL functions...
201 	 */
202 #ifdef HAVE_SSL
203 private:
204 	int switch_to_ssl();
205 	int accept_to_ssl();
206 
207 public:
208 	static int init_ssl(const char *);
209 	static int shutdown_ssl();
210 #endif
211 
set_events(int events)212 	void set_events(int events) {
213 		if (get_fd() > -1) {
214 			engine->set_events(this, events);
215 		}
216 	}
217 
events()218 	int events() const {
219 		return get_events();
220 	}
221 
222     /** I/O functions **
223 
224     *** (buffered) reads and writes ***
225     int read(int = 0);
226     int queue(const void *,  int);
227     int write(const void * , int);
228 
229     *** Flush input buffer ... ***
230     int flush(socket *);             ** ... to another socket (output buff) **
231     int flush(int, size_t);                 ** ... to a raw file descriptor **
232 
233     *** Flush output buffer ... ***
234     int flushO(socket *);            ** ... to another socket (input buff) **
235     int flushO(size_t );                 ** ... out our own fd
236 
237     *** To clear buffers ***
238     void clear();
239     void clearO();
240     **/
241 
242 	int printf(const char *, ...);
243 	int printfQ(const char *, ...);
244 	int printf_raw(const char *, va_list);
245 
246 	/*
247 	* read() from file descriptor & add to buffer
248 	* return:  > 0 = success (bytes added)
249 	*            0 = nothing to add
250 	*          < 0 = error condition
251 	*     [... error codes same as io::buffer error codes ...]
252 	*/
253 	int read(size_t len = 0) {
254 	#ifdef HAVE_SSL
255 		if (ssl) {
256 			return ibuff.insert(ssl, len);
257 		}
258 	#endif
259 		return ibuff.insert(get_fd(), len);
260 	}
261 
262 	/**
263 	 * Buffer size bytes, and attempt to send as much as possible
264 	 * return # bytes sent
265 	 */
write(const void * data,size_t size)266 	int write(const void * data , size_t size) {
267 		int r = queue(data, size);
268 		if (r < 0) {
269 			return r;
270 		}
271 		return obuff.flush(get_fd());
272 	}
273 
274 	/**
275 	 * Buffer 'size' bytes, but do not send.
276 	 * Return # bytes queued.
277 	 */
queue(const void * data,size_t size)278 	int queue(const void * data, size_t size) {
279 		return obuff.insert(data, size);
280 	}
281 
282 	/**
283 	 * Flush input buffer to output buffer of socket.
284 	 */
flush(socket * target)285 	int flush(socket * target) {
286 		return ibuff.flush(&target->obuff);
287 	}
288 
289 	/**
290 	 * Flush input buffer to a file descriptor.
291 	 */
flush(int fd,size_t size)292 	int flush(int fd, size_t size) {
293 		return ibuff.flush(fd, size);
294 	}
295 
296 	/**
297 	 * Flush input buffer to a ostream.
298 	 */
flush(std::ostream & out,size_t size)299 	int flush(std::ostream& out, size_t size) {
300 		return ibuff.flush(out, size);
301 	}
302 
303 	/**
304 	 * Flush output buffer to input buffer of another socket
305 	 */
flushO(socket * target)306 	int flushO(socket * target) {
307 		return obuff.flush(&target->ibuff);
308 	}
309 
310 	/**
311 	 * Flush output buffer out our own fd.
312 	 */
flushO(size_t size)313 	int flushO(size_t size) {
314 	#ifdef HAVE_SSL
315 		if (ssl) {
316 			return obuff.flush(ssl, size);
317 		}
318 	#endif
319 		return obuff.flush(get_fd(), size);
320 	}
flushO()321 	int flushO() {
322 		return flushO(obuff_size());
323 	}
324 
325 	/* determine input/output buffer capacity */
buffer_available()326 	size_t buffer_available() {
327 		return ibuff.available();
328 	}
329 
buffer_availableO()330 	size_t buffer_availableO() {
331 		return obuff.available();
332 	}
333 
334 	/**
335 	 * Clear buffers.
336 	 */
clear()337 	void clear() {
338 		ibuff.clear();
339 	}
clearO()340 	void clearO() {
341 		obuff.clear();
342 	}
343 
optimize_buffers()344 	void optimize_buffers() {
345 		obuff.optimize();
346 		ibuff.optimize();
347 	}
348 
349 private:
350 	// non-copyable
351 	socket(const socket &);
352 	socket & operator=(socket &);
353 };
354 
355 } /* namespace net */
356 #endif /* __NET_SOCKET_H */
357 
358