1 #include <string.h>
2 #include "tlsm.h"
3 #include "socketpuller.h"
4 #include "utils.h"
5 #include "logs.h"
6 #include "udpfrontend.h"
7 #include "timedtask.h"
8 #include "dnscache.h"
9 #include "ipmisc.h"
10 #include "domainstatistic.h"
11 #include "ptimer.h"
12 
SwepWorks(IHeader * h,int Number,TlsM * Module)13 static void SwepWorks(IHeader *h, int Number, TlsM *Module)
14 {
15     ShowTimeOutMessage(h, 'S');
16     DomainStatistic_Add(h, STATISTIC_TYPE_REFUSED);
17 
18     if( Number == 1  )
19     {
20         /** TODO */
21     }
22 }
23 
TlsM_SendWrapper(TlsM * m,SOCKET s,char * Start,int Count)24 static int TlsM_SendWrapper(TlsM *m, SOCKET s, char *Start, int Count)
25 {
26     size_t nsentall = 0;
27 
28     while( TRUE )
29     {
30         CURLcode State;
31         size_t ns = 0;
32 
33         State = curl_easy_send(m->Departure,
34                                Start + nsentall,
35                                Count - nsentall,
36                                &ns
37                                );
38 
39         nsentall += ns;
40 
41         if( State == CURLE_AGAIN )
42         {
43             if( !SocketIsWritable(s, 2000) )
44             {
45                 return -45;
46             }
47         } else if( State == CURLE_OK )
48         {
49             break;
50         } else {
51             return -51;
52         }
53 
54     }
55 
56     if( nsentall > INT_MAX )
57     {
58         return -58;
59     } else {
60         return (int)nsentall;
61     }
62 }
63 
TlsM_RecvWrapper(TlsM * m,SOCKET s,char * Buffer,int BufferLength)64 static int TlsM_RecvWrapper(TlsM *m, SOCKET s, char *Buffer, int BufferLength)
65 {
66     size_t nr = 0;
67     CURLcode State;
68 
69     do {
70         State = curl_easy_recv(m->Departure,
71                                Buffer,
72                                BufferLength,
73                                &nr
74                                );
75 
76         if( State == CURLE_AGAIN && !SocketIsStillReadable(s, 2000) )
77         {
78             return -81;
79         }
80 
81     } while( State == CURLE_AGAIN );
82 
83     if( nr > INT_MAX )
84     {
85         return -86;
86     } else {
87         return (int)nr;
88     }
89 }
90 
TlsM_Send_Actual(TlsM * m,IHeader * h)91 static int TlsM_Send_Actual(TlsM *m, IHeader *h /* Entity followed */)
92 {
93     uint16_t TCPLength;
94     curl_socket_t s;
95 
96     if( m->Context.Add(&(m->Context), h) != 0 )
97     {
98         return -11;
99     }
100 
101     /* Set up connection */
102     if( m->Departure == NULL )
103     {
104         m->Departure = curl_easy_init();
105         if(  m->Departure == NULL )
106         {
107             ERRORMSG("Fatal error 40.\n");
108             return -41;
109         }
110 
111         curl_easy_setopt(m->Departure, CURLOPT_URL, m->Services[0]);
112         curl_easy_setopt(m->Departure, CURLOPT_CONNECT_ONLY, 1L);
113         curl_easy_setopt(m->Departure,
114                          CURLOPT_SSLVERSION,
115                          CURL_SSLVERSION_TLSv1_2
116                          );
117 
118         if( curl_easy_perform(m->Departure) != CURLE_OK )
119         {
120             return -52;
121         }
122 
123         if( curl_easy_getinfo(m->Departure, CURLINFO_ACTIVESOCKET, &s)
124            != CURLE_OK
125            )
126         {
127             return -58;
128         }
129 
130         m->Puller.Add(&(m->Puller), s, NULL, 0);
131     } else {
132         if( curl_easy_getinfo(m->Departure, CURLINFO_ACTIVESOCKET, &s)
133            != CURLE_OK
134            )
135         {
136             return -68;
137         }
138     }
139 
140     /* Preparing content */
141     TCPLength = htons(h->EntityLength);
142     memcpy((char *)(IHEADER_TAIL(h)) - 2, &TCPLength, 2);
143 
144     /* Sending content */
145     if( TlsM_SendWrapper(m,
146                          s,
147                          (char *)(IHEADER_TAIL(h)) - 2,
148                          h->EntityLength + 2
149                          )
150         != h->EntityLength + 2 )
151     {
152         return -120;
153     }
154 
155     return 0;
156 }
157 
TlsM_CloseConnection(TlsM * m)158 static void TlsM_CloseConnection(TlsM *m)
159 {
160     curl_socket_t s;
161 
162     if( curl_easy_getinfo(m->Departure, CURLINFO_ACTIVESOCKET, &s)
163        != CURLE_OK
164        )
165     {
166         return;
167     }
168 
169     m->Puller.Del(&(m->Puller), s);
170 
171     curl_easy_cleanup(m->Departure);
172     m->Departure = NULL;
173 }
174 
TlsM_Send(TlsM * m,IHeader * h,int BufferLength)175 PUBFUNC int TlsM_Send(TlsM *m,
176                       IHeader *h, /* Entity followed */
177                       int BufferLength
178                       )
179 {
180     int State;
181 
182     State = sendto(m->Incoming,
183                    (const char *)h,
184                    sizeof(IHeader) + h->EntityLength,
185                    MSG_NOSIGNAL,
186                    (const struct sockaddr *)&(m->IncomingAddr.Addr),
187                    GetAddressLength(m->IncomingAddr.family)
188                    );
189 
190     return !(State > 0);
191 }
192 
TlsM_Works(TlsM * m)193 static int TlsM_Works(TlsM *m)
194 {
195     SOCKET  s;
196 
197     #define BUF_LENGTH  2048
198     char *ReceiveBuffer;
199     IHeader *Header;
200 
201     #define LEFT_LENGTH  (BUF_LENGTH - sizeof(IHeader))
202     char *Entity;
203 
204     static const struct timeval TimeLimit = {5, 0};
205     struct timeval TimeOut;
206 
207     time_t  LastRecvFromServer = 0;
208 
209     BOOL Retried = FALSE;
210 
211     int NumberOfCumulated = 0;
212 
213     ReceiveBuffer = SafeMalloc(BUF_LENGTH);
214     if( ReceiveBuffer == NULL )
215     {
216         ERRORMSG("Fatal error 127.\n");
217         return -128;
218     }
219 
220     Header = (IHeader *)ReceiveBuffer;
221     Entity = ReceiveBuffer + sizeof(IHeader);
222 
223     while( TRUE )
224     {
225         TimeOut = TimeLimit;
226         s = m->Puller.Select(&(m->Puller), &TimeOut, NULL, TRUE, FALSE);
227 
228         if( s == INVALID_SOCKET )
229         {
230             m->Context.Swep(&(m->Context), (SwepCallback)SwepWorks, m);
231             NumberOfCumulated = 0;
232         } else if( s == m->Incoming )
233         {
234             int State;
235 
236             if( NumberOfCumulated > 1024 )
237             {
238                 m->Context.Swep(&(m->Context), (SwepCallback)SwepWorks, m);
239                 NumberOfCumulated = 0;
240             }
241 
242             State = recvfrom(s,
243                              ReceiveBuffer, /* Receiving a header */
244                              BUF_LENGTH,
245                              0,
246                              NULL,
247                              NULL
248                              );
249 
250             if( State <= 0 )
251             {
252                 Retried = TRUE;
253                 continue;
254             }
255 
256             ++NumberOfCumulated;
257 
258             Retried = FALSE;
259 
260             if( m->Departure != NULL &&
261                 time(NULL) - LastRecvFromServer > 5 )
262             {
263                 TlsM_CloseConnection(m);
264             }
265 
266             if( TlsM_Send_Actual(m, Header) != 0 )
267             {
268                 TlsM_CloseConnection(m);
269 
270                 /* Try again */
271                 if( TlsM_Send_Actual(m, Header) != 0 )
272                 {
273                     TlsM_CloseConnection(m);
274                 }
275 
276                 Retried = TRUE;
277             }
278 
279         } else /* Departure socket */ {
280             int State;
281             uint16_t TcpLength;
282 
283             if( TlsM_RecvWrapper(m, s, (char *)&TcpLength, 2) != 2 )
284             {
285                 TlsM_CloseConnection(m);
286                 INFO("TLS server closed the connection.\n");
287 
288                 if( !Retried )
289                 {
290                     INFO("TLS query retrying...\n");
291 
292                     if( TlsM_Send_Actual(m, Header) != 0 )
293                     {
294                         TlsM_CloseConnection(m);
295                     }
296 
297                     Retried = TRUE;
298                 }
299 
300                 continue;
301             }
302 
303             TcpLength = ntohs(TcpLength);
304 
305             if( TcpLength > LEFT_LENGTH )
306             {
307                 WARNING("TLS segment is too large, discarded.\n");
308                 TlsM_CloseConnection(m);
309                 continue;
310             }
311 
312             if( TlsM_RecvWrapper(m,
313                                  s,
314                                  Entity,
315                                  TcpLength
316                                  )
317                 != TcpLength )
318             {
319                 TlsM_CloseConnection(m);
320                 continue;
321             }
322 
323             if( m->Context.FindAndRemove(&(m->Context), Header, Header) != 0 )
324             {
325                 continue;
326             }
327 
328             switch( IPMiscSingleton_Process(Header) )
329             {
330             case IP_MISC_ACTION_NOTHING:
331                 break;
332 
333             case IP_MISC_ACTION_BLOCK:
334                 ShowBlockedMessage(Header, "Bad package, discarded");
335                 continue;
336                 break;
337 
338             default:
339                 ERRORMSG("Fatal error 298.\n");
340                 continue;
341                 break;
342             }
343 
344             State = IHeader_SendBack(Header);
345 
346             if( State != 0 )
347             {
348                 ShowErrorMessage(Header, 'S');
349                 continue;
350             }
351 
352             ShowNormalMessage(Header, 'S');
353             DNSCache_AddItemsToCache(Header);
354             DomainStatistic_Add(Header, STATISTIC_TYPE_TCP);
355         }
356     }
357 }
358 
TlsM_Init(TlsM * m,const char * Services)359 int TlsM_Init(TlsM *m, const char *Services)
360 {
361     if( m == NULL || Services == NULL )
362     {
363         return -7;
364     }
365 
366     if( ModuleContext_Init(&(m->Context)) != 0 )
367     {
368         return -12;
369     }
370 
371     if( SocketPuller_Init(&(m->Puller)) != 0 )
372     {
373         return -389;
374     }
375 
376     m->Incoming = TryBindLocal(Ipv6_Aviliable(), 10500, &(m->IncomingAddr));
377     if( m->Incoming == INVALID_SOCKET )
378     {
379         return -357;
380     }
381 
382     m->Puller.Add(&(m->Puller), m->Incoming, NULL, 0);
383 
384     m->Departure = NULL;
385 
386     if( StringList_Init(&(m->ServiceList), NULL, NULL) != 0 )
387     {
388         StringList l;
389         StringListIterator i;
390         const char *one;
391 
392         if( StringList_Init(&l, Services, ",") != 0 )
393         {
394             return -170;
395         }
396 
397         if( StringListIterator_Init(&i, &l) != 0 )
398         {
399             return -175;
400         }
401 
402         while( (one = i.Next(&i)) != NULL )
403         {
404             char n[512];
405 
406             snprintf(n, sizeof(n), "https://%s", one);
407             n[sizeof(n) - 1] = '\0';
408 
409             m->ServiceList.Add(&(m->ServiceList), n, NULL);
410         }
411 
412         l.Free(&l);
413     }
414 
415     m->Services = m->ServiceList.ToCharPtrArray(&(m->ServiceList));
416     if( m->Services == NULL )
417     {
418         return -316;
419     }
420 
421     m->Send = TlsM_Send;
422 
423     CREATE_THREAD(TlsM_Works, m, m->WorkThread);
424 
425     return 0;
426 }
427