1 //--------------------------------------------------------------------------
2 // Copyright (C) 2014-2021 Cisco and/or its affiliates. All rights reserved.
3 // Copyright (C) 2005-2013 Sourcefire, Inc.
4 //
5 // This program is free software; you can redistribute it and/or modify it
6 // under the terms of the GNU General Public License Version 2 as published
7 // by the Free Software Foundation.  You may not use, modify or distribute
8 // this program under any other version of the GNU General Public License.
9 //
10 // This program is distributed in the hope that it will be useful, but
11 // WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 // General Public License for more details.
14 //
15 // You should have received a copy of the GNU General Public License along
16 // with this program; if not, write to the Free Software Foundation, Inc.,
17 // 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
18 //--------------------------------------------------------------------------
19 
20 // service_mdns.cc author Sourcefire Inc.
21 
22 #ifdef HAVE_CONFIG_H
23 #include "config.h"
24 #endif
25 
26 #include "service_mdns.h"
27 
28 #include "app_info_table.h"
29 #include "appid_module.h"
30 #include "protocols/packet.h"
31 #include "search_engines/search_tool.h"
32 
33 using namespace snort;
34 
35 #define MDNS_PORT   5353
36 #define PATTERN_REFERENCE_PTR   3
37 #define PATTERN_STR_LOCAL_1           "\005local"
38 #define PATTERN_STR_LOCAL_2           "\005LOCAL"
39 #define PATTERN_STR_ARPA_1           "\004arpa"
40 #define PATTERN_STR_ARPA_2           "\004ARPA"
41 #define PATTERN_USERNAME_1           '@'
42 #define MDNS_PATTERN1 "\x00\x00\x84\x00\x00\x00"
43 #define MDNS_PATTERN2 "\x00\x00\x08\x00\x00\x00"
44 #define MDNS_PATTERN3 "\x00\x00\x04\x00\x00\x00"
45 #define MDNS_PATTERN4 "\x00\x00\x00\x00"
46 #define SRV_RECORD "\x00\x21"
47 #define SRV_RECORD_OFFSET  6
48 #define LENGTH_OFFSET 8
49 #define NEXT_MESSAGE_OFFSET 10
50 #define QUERY_OFFSET 4
51 #define ANSWER_OFFSET 6
52 #define RECORD_OFFSET 12
53 #define SHIFT_BITS 8
54 #define SHIFT_BITS_REFERENCE_PTR  6
55 #define REFERENCE_PTR_LENGTH  2
56 #define MAX_LENGTH_SERVICE_NAME 256
57 
58 enum MDNSState
59 {
60     MDNS_STATE_CONNECTION,
61     MDNS_STATE_CONNECTION_ERROR
62 };
63 
64 struct ServiceMDNSData
65 {
66     MDNSState state;
67 };
68 
69 struct MdnsPattern
70 {
71     const uint8_t* pattern;
72     unsigned length;
73 };
74 
75 struct MatchedPatterns
76 {
77     MdnsPattern* mpattern;
78     int match_start_pos;
79     MatchedPatterns* next;
80 };
81 
82 static MdnsPattern patterns[] =
83 {
84     { (const uint8_t*)PATTERN_STR_LOCAL_1, sizeof(PATTERN_STR_LOCAL_1) - 1 },
85     { (const uint8_t*)PATTERN_STR_LOCAL_2, sizeof(PATTERN_STR_LOCAL_2) - 1 },
86     { (const uint8_t*)PATTERN_STR_ARPA_1, sizeof(PATTERN_STR_ARPA_1) - 1 },
87     { (const uint8_t*)PATTERN_STR_ARPA_2, sizeof(PATTERN_STR_ARPA_2) - 1 },
88 };
89 
MdnsServiceDetector(ServiceDiscovery * sd)90 MdnsServiceDetector::MdnsServiceDetector(ServiceDiscovery* sd)
91 {
92     handler = sd;
93     name = "MDNS";
94     proto = IpProtocol::UDP;
95     detectorType = DETECTOR_TYPE_DECODER;
96 
97     appid_registry =
98     {
99         { APP_ID_MDNS, APPINFO_FLAG_SERVICE_ADDITIONAL }
100     };
101 
102     service_ports =
103     {
104         { 5353, IpProtocol::UDP, false },
105     };
106 
107     for (unsigned i = 0; i < sizeof(patterns) / sizeof(*patterns); i++)
108         matcher.add((const char*)patterns[i].pattern, patterns[i].length, &patterns[i]);
109     matcher.prep();
110 
111     handler->register_detector(name, this, proto);
112 }
113 
do_custom_reload()114 void MdnsServiceDetector::do_custom_reload()
115 {
116     matcher.reload();
117 }
118 
validate(AppIdDiscoveryArgs & args)119 int MdnsServiceDetector::validate(AppIdDiscoveryArgs& args)
120 {
121     int ret_val;
122 
123     ServiceMDNSData* fd = (ServiceMDNSData*)data_get(args.asd);
124     if (!fd)
125     {
126         fd = (ServiceMDNSData*)snort_calloc(sizeof(ServiceMDNSData));
127         data_add(args.asd, fd, &snort_free);
128         fd->state = MDNS_STATE_CONNECTION;
129     }
130 
131     if (args.pkt->ptrs.dp == MDNS_PORT || args.pkt->ptrs.sp == MDNS_PORT )
132     {
133         ret_val = validate_reply(args.data, args.size);
134         if (ret_val == 1)
135         {
136             if (args.asd.get_odp_ctxt().mdns_user_reporting)
137             {
138                 MatchedPatterns* pattern_list = nullptr;
139                 analyze_user(args.asd, args.pkt, args.size, args.change_bits, pattern_list);
140                 destroy_match_list(pattern_list);
141                 goto success;
142             }
143             goto success;
144         }
145         else
146             goto fail;
147     }
148     else
149         goto fail;
150 
151 success:
152     return add_service(args.change_bits, args.asd, args.pkt, args.dir, APP_ID_MDNS);
153 
154 fail:
155     fail_service(args.asd, args.pkt, args.dir);
156     return APPID_NOMATCH;
157 }
158 
validate_reply(const uint8_t * data,uint16_t size)159 int MdnsServiceDetector::validate_reply(const uint8_t* data, uint16_t size)
160 {
161     int ret_val;
162 
163     /* Check for the pattern match*/
164     if (size >= 6 && memcmp(data, MDNS_PATTERN1, sizeof(MDNS_PATTERN1)-1) == 0)
165         ret_val = 1;
166     else if (size >= 6 && memcmp(data, MDNS_PATTERN2,  sizeof(MDNS_PATTERN2)-1) == 0)
167         ret_val = 1;
168     else if (size >= 6 && memcmp(data,MDNS_PATTERN3, sizeof(MDNS_PATTERN3)-1) == 0)
169         ret_val = 1;
170     else if (size >= 4 && memcmp(data,MDNS_PATTERN4, sizeof(MDNS_PATTERN4)-1) == 0)
171         ret_val = 1;
172     else
173         ret_val = 0;
174 
175     return ret_val;
176 }
177 
178 /* Input to this function is start_ptr and data_size.
179    Output is resp_endptr, start_index and user_name_len
180    Returns 0 or 1 for successful/unsuccessful hit for pattern '@'
181    Returns -1 for invalid address pointer or past the data_size */
reference_pointer(const char * start_ptr,const char ** resp_endptr,int * start_index,uint16_t data_size,uint8_t * user_name_len,unsigned size,MatchedPatterns * & pattern_list)182 int MdnsServiceDetector::reference_pointer(const char* start_ptr, const char** resp_endptr,
183     int* start_index, uint16_t data_size, uint8_t* user_name_len, unsigned size,
184     MatchedPatterns*& pattern_list)
185 {
186     int index = 0;
187     int pattern_length = 0;
188 
189     while (index< data_size &&  (start_ptr[index] == ' ' ))
190         index++;
191 
192     if (index >= data_size)
193         return -1;
194     *start_index = index;
195 
196     const char* temp_start_ptr;
197     temp_start_ptr  = start_ptr+index;
198 
199     // FIXIT-M - This code needs review to ensure it works correctly with the new semantics of the
200     //           index returned by the SearchTool find_all pattern matching function
201     scan_matched_patterns(start_ptr, size - data_size + index, resp_endptr, &pattern_length, pattern_list);
202     /* Contains reference pointer */
203     while ((index < data_size) && !(*resp_endptr) && ((uint8_t )temp_start_ptr[index]  >>
204         SHIFT_BITS_REFERENCE_PTR  != PATTERN_REFERENCE_PTR))
205     {
206         if (temp_start_ptr[index] == PATTERN_USERNAME_1)
207         {
208             *user_name_len = index - *start_index;
209             index++;
210             break;
211         }
212         index++;
213         scan_matched_patterns(start_ptr, size - data_size + index, resp_endptr, &pattern_length, pattern_list);
214     }
215     if (index >= data_size)
216         *user_name_len = 0;
217     else if ((uint8_t )temp_start_ptr[index]  >> SHIFT_BITS_REFERENCE_PTR == PATTERN_REFERENCE_PTR)
218         pattern_length = REFERENCE_PTR_LENGTH;
219     else if (!(*resp_endptr) && ((uint8_t )temp_start_ptr[index]  >>SHIFT_BITS_REFERENCE_PTR !=
220         PATTERN_REFERENCE_PTR ))
221     {
222         while ((index < data_size) && !(*resp_endptr) && ((uint8_t )temp_start_ptr[index]  >>
223             SHIFT_BITS_REFERENCE_PTR != PATTERN_REFERENCE_PTR))
224         {
225             index++;
226             scan_matched_patterns(start_ptr,  size - data_size + index, resp_endptr,
227                 &pattern_length, pattern_list);
228         }
229         if (index >= data_size)
230             *user_name_len = 0;
231         else if ((uint8_t )temp_start_ptr[index]  >> SHIFT_BITS_REFERENCE_PTR ==
232             PATTERN_REFERENCE_PTR)
233             pattern_length = REFERENCE_PTR_LENGTH;
234     }
235 
236     /* Add reference pointer bytes */
237     if ( index+ pattern_length < data_size)
238         *resp_endptr = start_ptr + index+ pattern_length;
239     else
240         return -1;
241 
242     if (*user_name_len > 0)
243         return 1;
244     else
245         return 0;
246 }
247 
248 /* Input to this Function is pkt and size
249    Processing: 1. Parses Multiple MDNS response packet
250                2. Calls the function which scans for pattern to identify the user
251                3. Calls the function which does the Username reporting along with the host
252   MDNS User Analysis*/
analyze_user(AppIdSession & asd,const Packet * pkt,uint16_t size,AppidChangeBits & change_bits,MatchedPatterns * & pattern_list)253 int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint16_t size,
254     AppidChangeBits& change_bits, MatchedPatterns*& pattern_list)
255 {
256     int start_index = 0;
257     uint16_t data_size = size;
258 
259     /* Scan for MDNS response, decided on Query value */
260     const char* query_val = (const char*)pkt->data + QUERY_OFFSET;
261     int query_val_int = (short)(query_val[0]<<SHIFT_BITS  | query_val[1]);
262     const char* answers = (const char*)pkt->data + ANSWER_OFFSET;
263     int ans_count =  (short)(answers[0]<< SHIFT_BITS | (answers[1] ));
264 
265     if ( query_val_int == 0)
266     {
267         const char* resp_endptr;
268         const char* user_original;
269 
270         const char* srv_original  = (const char*)pkt->data + RECORD_OFFSET;
271         pattern_list = create_match_list(srv_original, size - RECORD_OFFSET);
272         const char* end_srv_original  = (const char*)pkt->data + RECORD_OFFSET + data_size;
273         for (int processed_ans = 0; processed_ans < ans_count && data_size <= size && size > 0;
274             processed_ans++ )
275         {
276             // Call Decode Reference pointer function if referenced value instead of direct value
277             uint8_t user_name_len = 0;
278             int ret_value = reference_pointer(srv_original, &resp_endptr,  &start_index, data_size,
279                 &user_name_len, size, pattern_list);
280             int user_index =0;
281 
282             if (ret_value == -1)
283                 return -1;
284             else if (ret_value)
285             {
286                 while (start_index < data_size && (!isprint(srv_original[start_index])  ||
287                     srv_original[start_index] == '"' || srv_original[start_index] =='\''))
288                 {
289                     start_index++;
290                     user_index++;
291                 }
292                 user_name_len -=user_index;
293 
294                 char user_name[MAX_LENGTH_SERVICE_NAME] = "";
295                 memcpy(user_name, srv_original + start_index, user_name_len);
296                 user_name[user_name_len] = '\0';
297 
298                 user_index =0;
299                 while (user_index < user_name_len)
300                 {
301                     if (!isprint(user_name[user_index]))
302                         return 1;
303 
304                     user_index++;
305                 }
306 
307                 add_user(asd, user_name, APP_ID_MDNS, true, change_bits);
308                 break;
309             }
310 
311             // Find the  length to Jump to the next response
312             if ((resp_endptr  + NEXT_MESSAGE_OFFSET  ) < (srv_original + data_size))
313             {
314                 const uint8_t* data_len_str = (const uint8_t*)(resp_endptr+ LENGTH_OFFSET);
315                 uint16_t data_len =  (short)( data_len_str[0]<< SHIFT_BITS | ( data_len_str[1] ));
316                 data_size = data_size - (resp_endptr  + NEXT_MESSAGE_OFFSET + data_len -
317                     srv_original);
318                 /* Check if user name is available in the Domain Name field */
319                 if (data_size < size)
320                 {
321                     if (memcmp(resp_endptr, SRV_RECORD, sizeof(SRV_RECORD)-1)==0)
322                         start_index = SRV_RECORD_OFFSET;
323                     else
324                         start_index =0;
325 
326                     srv_original = resp_endptr  + NEXT_MESSAGE_OFFSET;
327                     user_original = (const char*)memchr((const uint8_t*)srv_original, PATTERN_USERNAME_1,
328                         data_len);
329 
330                     if (user_original )
331                     {
332                         user_name_len = user_original - srv_original - start_index;
333                         const char* user_name_bkp = srv_original + start_index;
334                         /* Non-Printable characters in the beginning */
335 
336                         while (user_index < user_name_len)
337                         {
338                             if (isprint(user_name_bkp[user_index]))
339                                 break;
340 
341                             user_index++;
342                         }
343 
344                         int user_printable_index = user_index;
345                         /* Non-Printable characters in the between  */
346 
347                         while (user_printable_index < user_name_len)
348                         {
349                             if (!isprint(user_name_bkp [user_printable_index ]))
350                                 return 0;
351 
352                             user_printable_index++;
353                         }
354                         /* Copy  the user name if available */
355                         if (( user_name_len - user_index ) < MAX_LENGTH_SERVICE_NAME )
356                         {
357                             char user_name[MAX_LENGTH_SERVICE_NAME];
358                             memcpy(user_name, user_name_bkp + user_index,
359                                 user_name_len - user_index);
360                             user_name[ user_name_len - user_index ] = '\0';
361                             add_user(asd, user_name, APP_ID_MDNS, true, change_bits);
362                             return 1;
363                         }
364                         else
365                             return 0;
366                     }
367 
368                     srv_original = srv_original +  data_len;
369                     if (srv_original > end_srv_original)
370                         return 0;
371                 }
372                 else
373                     return 0;
374             }
375             else
376                 return 0;
377         }
378     }
379     else
380         return 0;
381 
382     return 1;
383 }
384 
mdns_pattern_match(void * id,void *,int match_end_pos,void * data,void *)385 static int mdns_pattern_match(void* id, void*, int match_end_pos, void* data, void*)
386 {
387     MatchedPatterns* cm;
388     MatchedPatterns** matches = (MatchedPatterns**)data;
389     MdnsPattern* target = (MdnsPattern*)id;
390     MatchedPatterns* element;
391     MatchedPatterns* prevElement;
392 
393     cm = (MatchedPatterns*)snort_calloc(sizeof(MatchedPatterns));
394 
395     cm->mpattern = target;
396     cm->match_start_pos = match_end_pos - target->length;
397     for (prevElement = nullptr, element = *matches;
398         element;
399         prevElement = element, element = element->next)
400     {
401         if (element->match_start_pos > cm->match_start_pos)
402             break;
403     }
404 
405     if (prevElement)
406     {
407         cm->next = prevElement->next;
408         prevElement->next = cm;
409     }
410     else
411     {
412         cm->next = *matches;
413         *matches = cm;
414     }
415 
416     return 0;
417 }
418 
create_match_list(const char * data,uint16_t dataSize)419 MatchedPatterns* MdnsServiceDetector::create_match_list(const char* data, uint16_t dataSize)
420 {
421     MatchedPatterns* pattern_list = nullptr;
422     matcher.find_all((const char*)data, dataSize, mdns_pattern_match, false, (void*)&pattern_list);
423 
424     return pattern_list;
425 }
426 
scan_matched_patterns(const char * dataPtr,uint16_t index,const char ** resp_endptr,int * pattern_length,MatchedPatterns * & pattern_list)427 void MdnsServiceDetector::scan_matched_patterns(const char* dataPtr, uint16_t index, const
428     char** resp_endptr, int* pattern_length, MatchedPatterns*& pattern_list)
429 {
430     while (pattern_list)
431     {
432         if (pattern_list->match_start_pos == index)
433         {
434             *resp_endptr = dataPtr;
435             *pattern_length = pattern_list->mpattern->length;
436             return;
437         }
438 
439         if (pattern_list->match_start_pos > index)
440             break;
441 
442         MatchedPatterns* element = pattern_list;
443         pattern_list = pattern_list->next;
444         snort_free(element);
445     }
446     *resp_endptr = nullptr;
447     *pattern_length = 0;
448 }
449 
destroy_match_list(MatchedPatterns * & pattern_list)450 void MdnsServiceDetector::destroy_match_list(MatchedPatterns*& pattern_list)
451 {
452     while (pattern_list)
453     {
454         MatchedPatterns* element = pattern_list;
455         pattern_list = pattern_list->next;
456 
457         snort_free(element);
458     }
459 }
460