1 /*
2  * Copyright (c) 2011 NLNet Labs. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
14  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
15  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16  * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
17  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
18  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
19  * GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
21  * IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
22  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
23  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  *
25  */
26 
27 /**
28  * TCP connections.
29  *
30  */
31 
32 #include "config.h"
33 #include "wire/tcpset.h"
34 
35 #include <string.h>
36 
37 static const char* tcp_str = "tcp";
38 
39 
40 /**
41  * Create a tcp connection.
42  *
43  */
44 tcp_conn_type*
tcp_conn_create()45 tcp_conn_create()
46 {
47     tcp_conn_type* tcp_conn = NULL;
48     CHECKALLOC(tcp_conn = (tcp_conn_type*) malloc(sizeof(tcp_conn_type)));
49     memset(tcp_conn, 0, sizeof(tcp_conn_type));
50     tcp_conn->packet = buffer_create(PACKET_BUFFER_SIZE);
51     if (!tcp_conn->packet) {
52         free(tcp_conn);
53         return NULL;
54     }
55     tcp_conn->msglen = 0;
56     tcp_conn->total_bytes = 0;
57     tcp_conn->fd = -1;
58     return tcp_conn;
59 }
60 
61 
62 /**
63  * Create a set of tcp connections.
64  *
65  */
66 tcp_set_type*
tcp_set_create()67 tcp_set_create()
68 {
69     size_t i = 0;
70     tcp_set_type* tcp_set = NULL;
71     CHECKALLOC(tcp_set = (tcp_set_type*) malloc(sizeof(tcp_set_type)));
72     memset(tcp_set, 0, sizeof(tcp_set_type));
73     tcp_set->tcp_count = 0;
74     for (i=0; i < TCPSET_MAX; i++) {
75         tcp_set->tcp_conn[i] = tcp_conn_create();
76     }
77     tcp_set->tcp_waiting_first = NULL;
78     tcp_set->tcp_waiting_last = NULL;
79     return tcp_set;
80 }
81 
82 
83 /**
84  * Make tcp connection ready for reading.
85  * \param[in] tcp tcp connection
86  *
87  */
88 void
tcp_conn_ready(tcp_conn_type * tcp)89 tcp_conn_ready(tcp_conn_type* tcp)
90 {
91     ods_log_assert(tcp);
92     tcp->total_bytes = 0;
93     tcp->msglen = 0;
94     buffer_clear(tcp->packet);
95 }
96 
97 
98 /*
99  * Read from a tcp connection.
100  *
101  */
102 int
tcp_conn_read(tcp_conn_type * tcp)103 tcp_conn_read(tcp_conn_type* tcp)
104 {
105     ssize_t received = 0;
106     ods_log_assert(tcp);
107     ods_log_assert(tcp->fd != -1);
108     /* receive leading packet length bytes */
109     if (tcp->total_bytes < sizeof(tcp->msglen)) {
110         received = read(tcp->fd, (char*) &tcp->msglen + tcp->total_bytes,
111             sizeof(tcp->msglen) - tcp->total_bytes);
112         if (received == -1) {
113             if (errno == EAGAIN || errno == EINTR) {
114                 /* read would block, try later */
115                 return 0;
116             } else {
117                 if (errno != ECONNRESET) {
118                     ods_log_error("[%s] error read() sz: %s", tcp_str,
119                         strerror(errno));
120                 }
121                 return -1;
122             }
123         } else if (received == 0) {
124             /* EOF */
125             return -1;
126         }
127         tcp->total_bytes += received;
128         if (tcp->total_bytes < sizeof(tcp->msglen)) {
129             /* not complete yet, try later */
130             return 0;
131         }
132         ods_log_assert(tcp->total_bytes == sizeof(tcp->msglen));
133         tcp->msglen = ntohs(tcp->msglen);
134         if (tcp->msglen > buffer_capacity(tcp->packet)) {
135             /* packet to big, drop connection */
136             ods_log_error("[%s] packet too big, dropping connection", tcp_str);
137             return 0;
138         }
139         buffer_set_limit(tcp->packet, tcp->msglen);
140     }
141     ods_log_assert(buffer_remaining(tcp->packet) > 0);
142 
143     received = read(tcp->fd, buffer_current(tcp->packet),
144         buffer_remaining(tcp->packet));
145     if (received == -1) {
146         if (errno == EAGAIN || errno == EINTR) {
147             /* read would block, try later */
148             return 0;
149         } else {
150             if (errno != ECONNRESET) {
151                 ods_log_error("[%s] error read(): %s", tcp_str,
152                     strerror(errno));
153             }
154             return -1;
155         }
156     } else if (received == 0) {
157         /* EOF */
158         return -1;
159     }
160     tcp->total_bytes += received;
161     buffer_skip(tcp->packet, received);
162     if (buffer_remaining(tcp->packet) > 0) {
163         /* not complete yet, wait for more */
164         return 0;
165     }
166     /* completed */
167     ods_log_assert(buffer_position(tcp->packet) == tcp->msglen);
168     return 1;
169 }
170 
171 
172 /*
173  * Write to a tcp connection.
174  *
175  */
176 int
tcp_conn_write(tcp_conn_type * tcp)177 tcp_conn_write(tcp_conn_type* tcp)
178 {
179     ssize_t sent = 0;
180     ods_log_assert(tcp);
181     ods_log_assert(tcp->fd != -1);
182     if (tcp->total_bytes < sizeof(tcp->msglen)) {
183         uint16_t sendlen = htons(tcp->msglen);
184         sent = write(tcp->fd, (const char*)&sendlen + tcp->total_bytes,
185             sizeof(tcp->msglen) - tcp->total_bytes);
186         if (sent == -1) {
187             if (errno == EAGAIN || errno == EINTR) {
188                 /* write would block, try later */
189                 return 0;
190             } else {
191                 return -1;
192             }
193         }
194         tcp->total_bytes += sent;
195         if (tcp->total_bytes < sizeof(tcp->msglen)) {
196             /* incomplete write, resume later */
197             return 0;
198         }
199         ods_log_assert(tcp->total_bytes == sizeof(tcp->msglen));
200     }
201     ods_log_assert(tcp->total_bytes < tcp->msglen + sizeof(tcp->msglen));
202     sent = write(tcp->fd, buffer_current(tcp->packet),
203         buffer_remaining(tcp->packet));
204     if (sent == -1) {
205         if (errno == EAGAIN || errno == EINTR) {
206             /* write would block, try later */
207             return 0;
208         } else {
209             return -1;
210         }
211     }
212     buffer_skip(tcp->packet, sent);
213     tcp->total_bytes += sent;
214     if (tcp->total_bytes < tcp->msglen + sizeof(tcp->msglen)) {
215         /* more to write when socket becomes writable again */
216         return 0;
217     }
218     ods_log_assert(tcp->total_bytes == tcp->msglen + sizeof(tcp->msglen));
219     return 1;
220 }
221 
222 
223 /**
224  * Clean up tcp connection.
225  *
226  */
227 static void
tcp_conn_cleanup(tcp_conn_type * conn)228 tcp_conn_cleanup(tcp_conn_type* conn)
229 {
230     if (!conn) {
231         return;
232     }
233     buffer_cleanup(conn->packet);
234     free(conn);
235 }
236 
237 /**
238  * Clean up set of tcp connections.
239  *
240  */
241 void
tcp_set_cleanup(tcp_set_type * set)242 tcp_set_cleanup(tcp_set_type* set)
243 {
244     size_t i = 0;
245     if (!set) {
246         return;
247     }
248     for (i=0; i < TCPSET_MAX; i++) {
249         tcp_conn_cleanup(set->tcp_conn[i]);
250     }
251     free(set);
252 }
253