1 #include "hosts.h"
2 #include "addresslist.h"
3 #include "udpfrontend.h"
4 #include "hcontext.h"
5 #include "socketpuller.h"
6 #include "goodiplist.h"
7 #include "logs.h"
8 #include "domainstatistic.h"
9 
10 static BOOL BlockIpv6WhenIpv4Exists = FALSE;
11 
12 static SOCKET	IncomeSocket;
13 static Address_Type	IncomeAddress;
14 
Hosts_TypeExisting(const char * Domain,HostsRecordType Type)15 BOOL Hosts_TypeExisting(const char *Domain, HostsRecordType Type)
16 {
17     return StaticHosts_TypeExisting(Domain, Type) ||
18            DynamicHosts_TypeExisting(Domain, Type);
19 }
20 
Hosts_Try_Inner(IHeader * Header,int BufferLength)21 static HostsUtilsTryResult Hosts_Try_Inner(IHeader *Header, int BufferLength)
22 {
23     HostsUtilsTryResult ret;
24 
25     ret = StaticHosts_Try(Header, BufferLength);
26     if( ret != HOSTSUTILS_TRY_NONE )
27     {
28         return ret;
29     }
30 
31     return DynamicHosts_Try(Header, BufferLength);
32 }
33 
Hosts_GetCName(const char * Domain,char * Buffer)34 static int Hosts_GetCName(const char *Domain, char *Buffer)
35 {
36     return !(StaticHosts_GetCName(Domain, Buffer) == 0 ||
37            DynamicHosts_GetCName(Domain, Buffer) == 0);
38 }
39 
Hosts_Try(IHeader * Header,int BufferLength)40 HostsUtilsTryResult Hosts_Try(IHeader *Header, int BufferLength)
41 {
42     HostsUtilsTryResult ret;
43 
44     if( BlockIpv6WhenIpv4Exists )
45     {
46         if( Header->Type == DNS_TYPE_AAAA &&
47             (Hosts_TypeExisting(Header->Domain, HOSTS_TYPE_A) ||
48              Hosts_TypeExisting(Header->Domain, HOSTS_TYPE_GOOD_IP_LIST)
49              )
50             )
51         {
52             /** TODO: Show blocked message */
53             return HOSTSUTILS_TRY_BLOCKED;
54         }
55     }
56 
57     if( Hosts_TypeExisting(Header->Domain, HOSTS_TYPE_EXCLUEDE) )
58     {
59         return HOSTSUTILS_TRY_NONE;
60     }
61 
62     ret = Hosts_Try_Inner(Header, BufferLength);
63 
64     if( ret == HOSTSUTILS_TRY_RECURSED )
65     {
66         if( sendto(IncomeSocket,
67                    (const char *)Header, /* Only send header and identifier */
68                    sizeof(IHeader) + sizeof(uint16_t), /* Only send header and identifier */
69                    MSG_NOSIGNAL,
70                    (const struct sockaddr *)&(IncomeAddress.Addr),
71                    GetAddressLength(IncomeAddress.family)
72                    )
73             < 0 )
74         {
75             return HOSTSUTILS_TRY_NONE;
76         }
77     }
78 
79     return ret;
80 }
81 
Hosts_Get(IHeader * Header,int BufferLength)82 int Hosts_Get(IHeader *Header, int BufferLength)
83 {
84     switch( Hosts_Try(Header, BufferLength) )
85     {
86     case HOSTSUTILS_TRY_BLOCKED:
87         IHeader_SendBackRefusedMessage(Header);
88         ShowRefusingMessage(Header, "Disabled because of existing IPv4 host");
89         DomainStatistic_Add(Header, STATISTIC_TYPE_REFUSED);
90         return 0;
91         break;
92 
93     case HOSTSUTILS_TRY_NONE:
94         return -126;
95         break;
96 
97     case HOSTSUTILS_TRY_RECURSED:
98         /** TODO: Show hosts message */
99         return 0;
100 
101     case HOSTSUTILS_TRY_OK:
102         ShowNormalMessage(Header, 'H');
103         DomainStatistic_Add(Header, STATISTIC_TYPE_HOSTS);
104         return 0;
105         break;
106 
107     default:
108         return -139;
109         break;
110     }
111 }
112 
Hosts_SocketLoop(void * Unused)113 static int Hosts_SocketLoop(void *Unused)
114 {
115 	static HostsContext	Context;
116 	static SocketPuller Puller;
117 
118     static SOCKET	OutcomeSocket;
119     static Address_Type	OutcomeAddress;
120 
121 	static const struct timeval	LongTime = {3600, 0};
122 	static const struct timeval	ShortTime = {10, 0};
123 
124 	struct timeval	TimeLimit = LongTime;
125 
126 	#define LEFT_LENGTH_SL (sizeof(RequestBuffer) - sizeof(IHeader))
127 	static char		RequestBuffer[2048];
128 	IHeader         *Header = (IHeader *)RequestBuffer;
129 	char		    *RequestEntity = RequestBuffer + sizeof(IHeader);
130 
131 	OutcomeSocket = TryBindLocal(Ipv6_Aviliable(), 10300, &OutcomeAddress);
132 
133 	if( OutcomeSocket == INVALID_SOCKET )
134 	{
135 		return -416;
136 	}
137 
138     if( SocketPuller_Init(&Puller) != 0 )
139     {
140         return -423;
141     }
142 
143     Puller.Add(&Puller, IncomeSocket, NULL, 0);
144     Puller.Add(&Puller, OutcomeSocket, NULL, 0);
145 
146     if( HostsContext_Init(&Context) != 0 )
147     {
148         return -431;
149     }
150 
151     srand(time(NULL));
152 
153 	while( TRUE )
154 	{
155 	    SOCKET  Pulled;
156 
157 	    Pulled = Puller.Select(&Puller, &TimeLimit, NULL, TRUE, FALSE);
158 	    if( Pulled == INVALID_SOCKET )
159         {
160             TimeLimit = LongTime;
161             Context.Swep(&Context);
162         } else if( Pulled == IncomeSocket )
163         {
164             /* Recursive query */
165             int State;
166             char RecursedDomain[DOMAIN_NAME_LENGTH_MAX + 1];
167             uint16_t NewIdentifier;
168 
169             TimeLimit = ShortTime;
170 
171             State = recvfrom(IncomeSocket,
172                              RequestBuffer, /* Receiving a header */
173                              sizeof(RequestBuffer),
174                              0,
175                              NULL,
176                              NULL
177                              );
178 
179             if( State < 1 )
180             {
181                 continue;
182             }
183 
184             if( Hosts_GetCName(Header->Domain, RecursedDomain) != 0 )
185             {
186                 ERRORMSG("Fatal error 221.\n");
187                 continue;
188             }
189 
190             NewIdentifier = rand();
191 
192             if( Context.Add(&Context, Header, RecursedDomain, NewIdentifier)
193                 != 0 )
194             {
195                 ERRORMSG("Fatal error 230.\n");
196                 continue;
197             }
198 
199             if( HostsUtils_Query(OutcomeSocket,
200                                  &OutcomeAddress,
201                                  NewIdentifier,
202                                  RecursedDomain,
203                                  Header->Type
204                                  )
205                 != 0 )
206             {
207                 /** TODO: Show an error */
208                 continue;
209             }
210 
211         } else if( Pulled == OutcomeSocket )
212         {
213             int State;
214 
215             #define LEFT_LENGTH_SL_N (sizeof(NewRequest) - sizeof(IHeader));
216             static char NewRequest[2048];
217             IHeader *NewHeader = (IHeader *)NewRequest;
218 
219             TimeLimit = ShortTime;
220 
221             State = recvfrom(OutcomeSocket,
222                              RequestBuffer, /* Receiving a header */
223                              sizeof(RequestBuffer),
224                              0,
225                              NULL,
226                              NULL
227                              );
228 
229             if( State < 1 )
230             {
231                 continue;
232             }
233 
234             if( Context.FindAndRemove(&Context, Header, NewHeader) != 0 )
235             {
236                 ERRORMSG("Fatal error 267.\n");
237                 continue;
238             }
239 
240             if( HostsUtils_CombineRecursedResponse(NewRequest,
241                                                    sizeof(NewRequest),
242                                                    RequestEntity,
243                                                    State,
244                                                    Header->Domain
245                                                    )
246                 != 0 )
247             {
248                 ERRORMSG("Fatal error 279.\n");
249                 continue;
250             }
251 
252             if( IHeader_SendBack(NewHeader) != 0 )
253             {
254                 ERRORMSG("Fatal error 285.\n");
255                 continue;
256             }
257 
258             ShowNormalMessage(NewHeader, 'H');
259         } else {}
260 	}
261 
262 	return 0;
263 }
264 
Hosts_Init(ConfigFileInfo * ConfigInfo)265 int Hosts_Init(ConfigFileInfo *ConfigInfo)
266 {
267     ThreadHandle t;
268 
269     StaticHosts_Init(ConfigInfo);
270     DynamicHosts_Init(ConfigInfo);
271 
272     GoodIpList_Init(ConfigInfo);
273 
274     BlockIpv6WhenIpv4Exists = ConfigGetBoolean(ConfigInfo,
275                                                  "BlockIpv6WhenIpv4Exists"
276                                                  );
277 
278     IncomeSocket = TryBindLocal(Ipv6_Aviliable(), 10200, &IncomeAddress);
279     if( IncomeSocket == INVALID_SOCKET )
280     {
281         return -25;
282     }
283 
284     CREATE_THREAD(Hosts_SocketLoop, NULL, t);
285     DETACH_THREAD(t);
286 
287     return 0;
288 }
289