1 /** 2 * This program is free software; you can redistribute it and/or modify 3 * it under the terms of the GNU General Public License as published by 4 * the Free Software Foundation; either version 2 of the License, or 5 * (at your option) any later version. 6 * 7 * net/socket.h 8 * (c) 2005-2008 Murat Deligonul 9 * 10 * Encapsulates a TCP stream socket, supports IPv6 & SSL 11 */ 12 13 #ifndef __NET_SOCKET_H 14 # define __NET_SOCKET_H 15 16 #include <string> 17 #include <iosfwd> 18 #include <cstdarg> 19 #include <cerrno> 20 #include <sys/types.h> 21 #include <sys/socket.h> 22 #include <unistd.h> 23 #include <fcntl.h> 24 25 #ifdef HAVE_SSL 26 # define OPENSSL_NO_KRB5 /* prevent compilation error on RH Linux 9, et al */ 27 # include <openssl/ssl.h> 28 # include <openssl/err.h> 29 # include <openssl/rand.h> 30 #endif 31 32 #include "util/exception.h" 33 #include "util/strings.h" 34 #include "io/engine.h" 35 #include "io/pollable.h" 36 #include "io/buffer.h" 37 #include "net/types.h" 38 #include "net/resolver.h" 39 #include "net/radaptor.h" 40 #include "debug.h" 41 42 namespace net { 43 44 struct socket_exception_tag { }; 45 typedef util::exception<socket_exception_tag, std::runtime_error> socket_exception; 46 47 class socket : public io::pollable { 48 private: 49 /** Common to all sockets **/ 50 static class io::engine * engine; 51 static class resolver * resolver; 52 static class radaptor * radaptor; 53 54 #ifdef HAVE_SSL 55 static SSL_CTX *ssl_ctx; 56 static char *tls_rand_file; 57 static int seed_PRNG(); 58 SSL *ssl; 59 #endif 60 public: 61 static const size_t BUFFER_MINIMUM = io::buffer::MINIMUM; 62 63 static const size_t BUFFER_MAXIMUM = io::buffer::MAXIMUM; 64 65 /** Static member functions **/ assign_resolver(class resolver * r)66 static void assign_resolver(class resolver *r) { 67 socket::resolver = r; 68 } assign_engine(class io::engine * e)69 static void assign_engine(class io::engine * e) { 70 delete radaptor; 71 socket::engine = e; 72 socket::radaptor = new class radaptor(*resolver, *engine); 73 } set_nonblocking(int f)74 static int set_nonblocking(int f) { 75 return fcntl(f, F_SETFL, fcntl(f, F_GETFL,0) | O_NONBLOCK); 76 } get_resolver()77 static class resolver * get_resolver() { 78 return socket::resolver; 79 } compress()80 static void compress() { 81 engine->compress(); 82 } 83 84 /** Socket state **/ 85 enum state { 86 NEW = 0, 87 ACCEPTING, 88 OPEN, 89 CONNECT_RESOLVING, 90 CONNECTING, 91 CONNECTED, 92 LISTEN_RESOLVING, 93 LISTENING, 94 CLOSED 95 }; 96 97 enum sock_option { 98 SOCK_DEFAULT = 0, 99 SOCK_SSL = 2 100 }; 101 102 private: 103 /** resolver_callback functions **/ 104 int async_lookup_finished(const resolver::result *); 105 int async_lookup_failed(const resolver::result *); 106 107 private: 108 /** basic properties **/ 109 int family; 110 enum state state; 111 int options; 112 int lookup_id; /* id number of current async. name resolution */ 113 114 /** resolver callback wrapper **/ 115 const net::resolver_callback_wrapper<class socket, 116 &socket::async_lookup_finished, 117 &socket::async_lookup_failed> r_callback; 118 119 /** I/O buffers **/ 120 io::buffer ibuff; 121 io::buffer obuff; 122 123 /** address information **/ 124 ap_pair local, peer; 125 126 /** resolver helper data **/ 127 resolver::request * tmp_req; 128 std::pair<unsigned char *, size_t> * interface_data; 129 130 public: 131 /** 132 * Create a socket of protocol with socket() 133 */ 134 socket(int family, int options, size_t, size_t); 135 136 /** 137 * accept() a connection from another socket` 138 */ 139 socket(socket * source, size_t, size_t); 140 ~socket()141 virtual ~socket() { 142 close(); 143 } 144 145 /** re-create a socket **/ 146 int open(int, int = IPPROTO_TCP); 147 148 /** accessors **/ get_state()149 int get_state() const { return state; } get_family()150 int get_family() const { return family; } get_options()151 int get_options() const { return options; } 152 153 /** peer/local address & port information **/ get_peer()154 const ap_pair& get_peer() const { return peer; } get_local()155 const ap_pair& get_local() const { return local; } peer_addr()156 const char * peer_addr() const { return peer.first.c_str(); } local_addr()157 const char * local_addr() const { return local.first.c_str(); } peer_port()158 unsigned short peer_port() const { return peer.second; } local_port()159 unsigned short local_port() const { return local.second; } 160 161 /** buffer information **/ get_ibuff()162 io::buffer& get_ibuff() { return ibuff; } get_obuff()163 io::buffer& get_obuff() { return obuff; } ibuff_size()164 size_t ibuff_size() const { return ibuff.size(); } obuff_size()165 size_t obuff_size() const { return obuff.size(); } 166 167 /** make it do stuff **/ 168 int async_connect(const char *, const char *, unsigned short, int = SOCK_DEFAULT); 169 int listen(const char *, unsigned short, int, int = SOCK_DEFAULT); 170 int listen(const char *, const char *, int, int = SOCK_DEFAULT); 171 int close(); 172 173 private: 174 /** 'pollable' interface **/ 175 virtual int event_callback(int); 176 177 int update_addr_info(bool); 178 int update_addr_info(bool, const struct sockaddr *, size_t); 179 int check_for_close() const; 180 181 static int recv_test(int); 182 183 public: 184 /** 185 * Event handlers: 186 * NOTE: Any handler that causes the closing of the socket MUST return < 0. 187 * This is to halt processing of additional events and avoid invoking additional callbacks 188 * on potentially deleted objects. 189 * The error related callbacks (disconnects, connect fail, error) are always assumed to have 190 * closed the socket. 191 */ 192 virtual int on_readable() = 0; 193 virtual int on_writeable() = 0; 194 virtual void on_disconnect(int); 195 virtual void on_connect(); 196 virtual void on_connect_fail(int); 197 virtual void on_connecting(); 198 199 /** 200 * SSL functions... 201 */ 202 #ifdef HAVE_SSL 203 private: 204 int switch_to_ssl(); 205 int accept_to_ssl(); 206 207 public: 208 static int init_ssl(const char *); 209 static int shutdown_ssl(); 210 #endif 211 set_events(int events)212 void set_events(int events) { 213 if (get_fd() > -1) { 214 engine->set_events(this, events); 215 } 216 } 217 events()218 int events() const { 219 return get_events(); 220 } 221 222 /** I/O functions ** 223 224 *** (buffered) reads and writes *** 225 int read(int = 0); 226 int queue(const void *, int); 227 int write(const void * , int); 228 229 *** Flush input buffer ... *** 230 int flush(socket *); ** ... to another socket (output buff) ** 231 int flush(int, size_t); ** ... to a raw file descriptor ** 232 233 *** Flush output buffer ... *** 234 int flushO(socket *); ** ... to another socket (input buff) ** 235 int flushO(size_t ); ** ... out our own fd 236 237 *** To clear buffers *** 238 void clear(); 239 void clearO(); 240 **/ 241 242 int printf(const char *, ...); 243 int printfQ(const char *, ...); 244 int printf_raw(const char *, va_list); 245 246 /* 247 * read() from file descriptor & add to buffer 248 * return: > 0 = success (bytes added) 249 * 0 = nothing to add 250 * < 0 = error condition 251 * [... error codes same as io::buffer error codes ...] 252 */ 253 int read(size_t len = 0) { 254 #ifdef HAVE_SSL 255 if (ssl) { 256 return ibuff.insert(ssl, len); 257 } 258 #endif 259 return ibuff.insert(get_fd(), len); 260 } 261 262 /** 263 * Buffer size bytes, and attempt to send as much as possible 264 * return # bytes sent 265 */ write(const void * data,size_t size)266 int write(const void * data , size_t size) { 267 int r = queue(data, size); 268 if (r < 0) { 269 return r; 270 } 271 return obuff.flush(get_fd()); 272 } 273 274 /** 275 * Buffer 'size' bytes, but do not send. 276 * Return # bytes queued. 277 */ queue(const void * data,size_t size)278 int queue(const void * data, size_t size) { 279 return obuff.insert(data, size); 280 } 281 282 /** 283 * Flush input buffer to output buffer of socket. 284 */ flush(socket * target)285 int flush(socket * target) { 286 return ibuff.flush(&target->obuff); 287 } 288 289 /** 290 * Flush input buffer to a file descriptor. 291 */ flush(int fd,size_t size)292 int flush(int fd, size_t size) { 293 return ibuff.flush(fd, size); 294 } 295 296 /** 297 * Flush input buffer to a ostream. 298 */ flush(std::ostream & out,size_t size)299 int flush(std::ostream& out, size_t size) { 300 return ibuff.flush(out, size); 301 } 302 303 /** 304 * Flush output buffer to input buffer of another socket 305 */ flushO(socket * target)306 int flushO(socket * target) { 307 return obuff.flush(&target->ibuff); 308 } 309 310 /** 311 * Flush output buffer out our own fd. 312 */ flushO(size_t size)313 int flushO(size_t size) { 314 #ifdef HAVE_SSL 315 if (ssl) { 316 return obuff.flush(ssl, size); 317 } 318 #endif 319 return obuff.flush(get_fd(), size); 320 } flushO()321 int flushO() { 322 return flushO(obuff_size()); 323 } 324 325 /* determine input/output buffer capacity */ buffer_available()326 size_t buffer_available() { 327 return ibuff.available(); 328 } 329 buffer_availableO()330 size_t buffer_availableO() { 331 return obuff.available(); 332 } 333 334 /** 335 * Clear buffers. 336 */ clear()337 void clear() { 338 ibuff.clear(); 339 } clearO()340 void clearO() { 341 obuff.clear(); 342 } 343 optimize_buffers()344 void optimize_buffers() { 345 obuff.optimize(); 346 ibuff.optimize(); 347 } 348 349 private: 350 // non-copyable 351 socket(const socket &); 352 socket & operator=(socket &); 353 }; 354 355 } /* namespace net */ 356 #endif /* __NET_SOCKET_H */ 357 358