1#!/usr/bin/env python3
2# Wireshark - Network traffic analyzer
3# By Gerald Combs <gerald@wireshark.org>
4# Copyright 1998 Gerald Combs
5#
6# SPDX-License-Identifier: GPL-2.0-or-later
7
8import os
9import re
10import argparse
11import signal
12import subprocess
13
14# This utility scans the dissector code for proto_tree_add_...() calls that constrain the type
15# or length of the item added, and checks that the used item is acceptable.
16#
17# Note that this can only work where the hf_item variable or length is passed in directly - where it
18# is assigned to a different variable or a macro is used, it isn't tracked.
19
20# TODO:
21# Currently assuming we'll find call + first 2 args in same line...
22# Attempt to check for allowed encoding types (most likely will be literal values |'d)?
23
24
25# Try to exit soon after Ctrl-C is pressed.
26should_exit = False
27
28def signal_handler(sig, frame):
29    global should_exit
30    should_exit = True
31    print('You pressed Ctrl+C - exiting')
32
33signal.signal(signal.SIGINT, signal_handler)
34
35
36warnings_found = 0
37errors_found = 0
38
39# A call is an individual call to an API we are interested in.
40# Internal to APICheck below.
41class Call:
42    def __init__(self, hf_name, line_number=None, length=None):
43       self.hf_name = hf_name
44       self.line_number = line_number
45       self.length = None
46       if length:
47           try:
48               self.length = int(length)
49           except:
50               pass
51
52
53# A check for a particular API function.
54class APICheck:
55    def __init__(self, fun_name, allowed_types):
56        self.fun_name = fun_name
57        self.allowed_types = allowed_types
58        self.calls = []
59
60        if fun_name.startswith('ptvcursor'):
61            # RE captures function name + 1st 2 args (always ptvc + hfindex)
62            self.p = re.compile('.*' +  self.fun_name + '\(([a-zA-Z0-9_]+),\s*([a-zA-Z0-9_]+)')
63        elif fun_name.find('add_bitmask') == -1:
64            # RE captures function name + 1st 2 args (always tree + hfindex)
65            self.p = re.compile('.*' +  self.fun_name + '\(([a-zA-Z0-9_]+),\s*([a-zA-Z0-9_]+)')
66        else:
67            # RE captures function name + 1st + 4th args (always tree + hfindex)
68            self.p = re.compile('.*' +  self.fun_name + '\(([a-zA-Z0-9_]+),\s*[a-zA-Z0-9_]+,\s*[a-zA-Z0-9_]+,\s*([a-zA-Z0-9_]+)')
69
70        self.file = None
71
72    def find_calls(self, file):
73        self.file = file
74        self.calls = []
75        with open(file, 'r') as f:
76            for line_number, line in enumerate(f, start=1):
77                m = self.p.match(line)
78                if m:
79                    self.calls.append(Call(m.group(2), line_number=line_number))
80
81    def check_against_items(self, items):
82        for call in self.calls:
83            if call.hf_name in items:
84                if not items[call.hf_name].item_type in self.allowed_types:
85                    # Report this issue.
86                    print('Error: ' +  self.fun_name + '(.., ' + call.hf_name + ', ...) called at ' +
87                          self.file + ':' + str(call.line_number) +
88                          ' with type ' + items[call.hf_name].item_type)
89                    print('    (allowed types are', self.allowed_types, ')\n')
90                    # Inc global count of issues found.
91                    global errors_found
92                    errors_found += 1
93
94
95class ProtoTreeAddItemCheck(APICheck):
96    def __init__(self, ptv=None):
97
98        # RE will capture whole call.  N.B. only looking at calls with literal numerical length field.
99
100        if not ptv:
101            # proto_item *
102            # proto_tree_add_item(proto_tree *tree, int hfindex, tvbuff_t *tvb,
103            #                     const gint start, gint length, const guint encoding)
104            self.fun_name = 'proto_tree_add_item'
105            self.p = re.compile('.*' + self.fun_name + '\([a-zA-Z0-9_]+,\s*([a-zA-Z0-9_]+),\s*[a-zA-Z0-9_]+,\s*[a-zA-Z0-9_]+,\s*([0-9]+),\s*([a-zA-Z0-9_]+)')
106        else:
107            # proto_item *
108            # ptvcursor_add(ptvcursor_t *ptvc, int hfindex, gint length,
109            #               const guint encoding)
110            self.fun_name = 'ptvcursor_add'
111            self.p = re.compile('.*' + self.fun_name + '\([a-zA-Z0-9_]+,\s*([a-zA-Z0-9_]+),\s*([0-9]+),\s*([a-zA-Z0-9_]+)')
112
113
114        self.lengths = {}
115        self.lengths['FT_CHAR']  = 1
116        self.lengths['FT_UINT8']  = 1
117        self.lengths['FT_INT8']   = 1
118        self.lengths['FT_UINT16'] = 2
119        self.lengths['FT_INT16']  = 2
120        self.lengths['FT_UINT24'] = 3
121        self.lengths['FT_INT24']  = 3
122        self.lengths['FT_UINT32'] = 4
123        self.lengths['FT_INT32']  = 4
124        self.lengths['FT_UINT40'] = 5
125        self.lengths['FT_INT40']  = 5
126        self.lengths['FT_UINT48'] = 6
127        self.lengths['FT_INT48']  = 6
128        self.lengths['FT_UINT56'] = 7
129        self.lengths['FT_INT56']  = 7
130        self.lengths['FT_UINT64'] = 8
131        self.lengths['FT_INT64']  = 8
132        # TODO: for FT_BOOLEAN, could take length from 2nd arg (which is in bits...)
133        self.lengths['FT_ETHER']  = 6
134        # TODO: other types...
135
136    def find_calls(self, file):
137        self.file = file
138        self.calls = []
139        with open(file, 'r') as f:
140            # TODO: would be better to just iterate over those found in whole file,
141            # but extra effort would be needed to still know line number.
142            for line_number, line in enumerate(f, start=1):
143                m = self.p.match(line)
144                if m:
145                    self.calls.append(Call(m.group(1), line_number=line_number, length=m.group(2)))
146
147    def check_against_items(self, items):
148        # For now, only complaining if length if call is longer than the item type implies.
149        #
150        # Could also be bugs where the length is always less than the type allows.
151        # Would involve keeping track (in the item) of whether any call had used the full length.
152
153        for call in self.calls:
154            if call.hf_name in items:
155                if call.length and items[call.hf_name].item_type in self.lengths:
156                    if self.lengths[items[call.hf_name].item_type] < call.length:
157                        print('Warning:', self.file + ':' + str(call.line_number),
158                              self.fun_name + ' called for', call.hf_name, ' - ',
159                              'item type is', items[call.hf_name].item_type, 'but call has len', call.length)
160
161                        global warnings_found
162                        warnings_found += 1
163
164
165
166##################################################################################################
167# This is a set of items (by filter name) where we know that the bitmask is non-contiguous,
168# but is still believed to be correct.
169known_non_contiguous_fields = { 'wlan.fixed.capabilities.cfpoll.sta',
170                                'wlan.wfa.ie.wme.qos_info.sta.reserved',
171                                'btrfcomm.frame_type',   # https://os.itec.kit.edu/downloads/sa_2006_roehricht-martin_flow-control-in-bluez.pdf
172                                'capwap.control.message_element.ac_descriptor.dtls_policy.r', # RFC 5415
173                                'couchbase.extras.subdoc.flags.reserved',
174                                'wlan.fixed.capabilities.cfpoll.ap',   # These are 3 separate bits...
175                                'wlan.wfa.ie.wme.tspec.ts_info.reserved', # matches other fields in same sequence
176                                'zbee_zcl_se.pp.attr.payment_control_configuration.reserved', # matches other fields in same sequence
177                                'zbee_zcl_se.pp.snapshot_payload_cause.reserved',  # matches other fields in same sequence
178                                'ebhscr.eth.rsv',  # matches other fields in same sequence
179                                'v120.lli',  # non-contiguous field (http://www.acacia-net.com/wwwcla/protocol/v120_l2.htm)
180                                'stun.type.class',
181                                'bssgp.csg_id'
182
183                              }
184##################################################################################################
185
186
187field_widths = {
188    'FT_BOOLEAN' : 64,   # TODO: Width depends upon 'display' field, not checked.
189    'FT_CHAR'    : 8,
190    'FT_UINT8'   : 8,
191    'FT_INT8'    : 8,
192    'FT_UINT16'  : 16,
193    'FT_INT16'   : 16,
194    'FT_UINT24'  : 24,
195    'FT_INT24'   : 24,
196    'FT_UINT32'  : 32,
197    'FT_INT32'   : 32,
198    'FT_UINT40'  : 40,
199    'FT_INT40'   : 40,
200    'FT_UINT48'  : 48,
201    'FT_INT48'   : 48,
202    'FT_UINT56'  : 56,
203    'FT_INT56'   : 56,
204    'FT_UINT64'  : 64,
205    'FT_INT64'   : 64
206}
207
208
209# The relevant parts of an hf item.  Used as value in dict where hf variable name is key.
210class Item:
211
212    previousItem = None
213
214    def __init__(self, filename, filter, label, item_type, type_modifier, mask=None, check_mask=False, check_label=False, check_consecutive=False):
215        self.filename = filename
216        self.filter = filter
217        self.label = label
218
219        self.mask = mask
220        if check_mask or check_consecutive:
221            self.set_mask_value()
222
223        if check_consecutive:
224            if Item.previousItem and Item.previousItem.filter == filter:
225                if label != Item.previousItem.label:
226                    print('Warning: ' + filename + ': - filter "' + filter +
227                          '" appears consecutively - labels are "' + Item.previousItem.label + '" and "' + label + '"')
228            if Item.previousItem and self.mask_value and (Item.previousItem.mask_value == self.mask_value):
229                if label != Item.previousItem.label:
230                    print('Warning: ' + filename + ': - mask ' + self.mask +
231                          ' appears consecutively - labels are "' + Item.previousItem.label + '" and "' + label + '"')
232
233            Item.previousItem = self
234
235
236        # Optionally check label.
237        if check_label:
238            if label.startswith(' ') or label.endswith(' '):
239                print('Warning: ' + filename + ' filter "' + filter +  '" label' + label + '" begins or ends with a space')
240            if (label.count('(') != label.count(')') or
241                label.count('[') != label.count(']') or
242                label.count('{') != label.count('}')):
243                print('Warning: ' + filename + ': - filter "' + filter + '" label', '"' + label + '"', 'has unbalanced parens/braces/brackets')
244            if item_type != 'FT_NONE' and label.endswith(':'):
245                print('Warning: ' + filename + ': - filter "' + filter + '" label', '"' + label + '"', 'ends with an unnecessary colon')
246
247        self.item_type = item_type
248        self.type_modifier = type_modifier
249
250        # Optionally check that mask bits are contiguous
251        if check_mask:
252            if not mask in { 'NULL', '0x0', '0', '0x00'}:
253                self.check_contiguous_bits(mask)
254                self.check_mask_too_long(mask)
255                self.check_num_digits(mask)
256                self.check_digits_all_zeros(mask)
257
258
259    def set_mask_value(self):
260        try:
261            # Read according to the appropriate base.
262            if self.mask.startswith('0x'):
263                self.mask_value = int(self.mask, 16)
264            elif self.mask.startswith('0'):
265                self.mask_value = int(self.mask, 8)
266            else:
267                self.mask_value = int(self.mask, 10)
268        except:
269            self.mask_value = 0
270
271
272    # Return true if bit position n is set in value.
273    def check_bit(self, value, n):
274        return (value & (0x1 << n)) != 0
275
276    # Output a warning if non-contigous bits are found in the the mask (guint64).
277    # Note that this legimately happens in several dissectors where multiple reserved/unassigned
278    # bits are conflated into one field.
279    # TODO: there is probably a cool/efficient way to check this?
280    def check_contiguous_bits(self, mask):
281        if not self.mask_value:
282            return
283
284        # Walk past any l.s. 0 bits
285        n = 0
286        while not self.check_bit(self.mask_value, n) and n <= 63:
287            n += 1
288        if n==63:
289            return
290
291        mask_start = n
292        # Walk through any bits that are set
293        while self.check_bit(self.mask_value, n) and n <= 63:
294            n += 1
295        n += 1
296
297        if n >= 63:
298            return
299
300        # Look up the field width
301        field_width = 0
302        if not self.item_type in field_widths:
303            print('unexpected item_type is ', self.item_type)
304            field_width = 64
305        else:
306            field_width = self.get_field_width_in_bits()
307
308
309        # Its a problem is the mask_width is > field_width - some of the bits won't get looked at!?
310        mask_width = n-1-mask_start
311        if mask_width > field_width:
312            # N.B. No call, so no line number.
313            print(self.filename + ':', 'filter=', self.filter, self.item_type, 'so field_width=', field_width,
314                  'but mask is', mask, 'which is', mask_width, 'bits wide!')
315            global warnings_found
316            warnings_found += 1
317
318        # Now, any more zero set bits are an error!
319        if self.filter in known_non_contiguous_fields or self.filter.startswith('rtpmidi'):
320            # Don't report if we know this one is Ok.
321            return
322        while n <= 63:
323            if self.check_bit(self.mask_value, n):
324                print('Warning:', self.filename, 'filter=', self.filter, ' - mask with non-contiguous bits', mask)
325                return
326            n += 1
327
328    def get_field_width_in_bits(self):
329        if self.item_type == 'FT_BOOLEAN':
330            if self.type_modifier == 'NULL':
331                return 8  # i.e. 1 byte
332            elif self.type_modifier == 'BASE_NONE':
333                return 8
334            elif self.type_modifier == 'SEP_DOT':
335                return 64
336            else:
337                # Round up to next nibble.
338                return int(self.type_modifier)+3
339        else:
340            return field_widths[self.item_type]
341
342    def check_mask_too_long(self, mask):
343        if not self.mask_value:
344            return
345        if mask.startswith('0x00') or mask.endswith('00'):
346            # There may be good reasons for having a wider field/mask, e.g. if there are 32 related flags, showing them
347            # all lined up as part of the same word may make it clearer.  But some cases have been found
348            # where the grouping does not seem to be natural..
349            print('Warning:', self.filename, 'filter=', self.filter, ' - mask with leading or trailing 0 bytes suggests field', self.item_type, 'may be wider than necessary?', mask)
350            global warnings_found
351            warnings_found += 1
352
353    def check_num_digits(self, mask):
354        if mask.startswith('0x') and len(mask) > 3:
355            global warnings_found
356            global errors_found
357            if len(mask) % 2:
358                print('Warning:', self.filename, 'filter=', self.filter, ' - mask has odd number of digits', mask,
359                      'expected max for', self.item_type, 'is', int((self.get_field_width_in_bits())/4))
360                warnings_found += 1
361
362            if self.item_type in field_widths:
363                if len(mask)-2 > self.get_field_width_in_bits()/4:
364                    extra_digits = mask[2:2+(len(mask)-2 - int(self.get_field_width_in_bits()/4))]
365                    # Its an error if any of these are non-zero, as they won't have any effect!
366                    if extra_digits != '0'*len(extra_digits):
367                        print('Error:', self.filename, 'filter=', self.filter, self.mask, "with len is", len(mask)-2,
368                              "but type", self.item_type, " indicates max of", int(self.get_field_width_in_bits()/4),
369                              "and extra digits are non-zero (" + extra_digits + ")")
370                        errors_found += 1
371                    else:
372                        # If has leading zeros, still confusing, so warn.
373                        print('Warning:', self.filename, 'filter=', self.filter, self.mask, "with len is", len(mask)-2,
374                              "but type", self.item_type, " indicates max of", int(self.get_field_width_in_bits()/4))
375                        warnings_found += 1
376
377            else:
378                print('Warning:', self.filename, 'filter=', self.filter, ' - item has type', self.item_type, 'but mask set:', mask)
379                warnings_found += 1
380
381    def check_digits_all_zeros(self, mask):
382        if mask.startswith('0x') and len(mask) > 3:
383            if mask[2:] == '0'*(len(mask)-2):
384                print('Warning: ', self.filename, 'filter=', self.filter, ' - item has all zeros - this is confusing! :', mask)
385                global warnings_found
386                warnings_found += 1
387
388
389# These are APIs in proto.c that check a set of types at runtime and can print '.. is not of type ..' to the console
390# if the type is not suitable.
391apiChecks = []
392apiChecks.append(APICheck('proto_tree_add_item_ret_uint', { 'FT_CHAR', 'FT_UINT8', 'FT_UINT16', 'FT_UINT24', 'FT_UINT32'}))
393apiChecks.append(APICheck('proto_tree_add_item_ret_int', { 'FT_INT8', 'FT_INT16', 'FT_INT24', 'FT_INT32'}))
394apiChecks.append(APICheck('ptvcursor_add_ret_uint', { 'FT_CHAR', 'FT_UINT8', 'FT_UINT16', 'FT_UINT24', 'FT_UINT32'}))
395apiChecks.append(APICheck('ptvcursor_add_ret_int', { 'FT_INT8', 'FT_INT16', 'FT_INT24', 'FT_INT32'}))
396apiChecks.append(APICheck('ptvcursor_add_ret_string', { 'FT_STRING', 'FT_STRINGZ', 'FT_UINT_STRING', 'FT_STRINGZPAD', 'FT_STRINGZTRUNC'}))
397apiChecks.append(APICheck('ptvcursor_add_ret_boolean', { 'FT_BOOLEAN'}))
398apiChecks.append(APICheck('proto_tree_add_item_ret_uint64', { 'FT_UINT40', 'FT_UINT48', 'FT_UINT56', 'FT_UINT64'}))
399apiChecks.append(APICheck('proto_tree_add_item_ret_int64', { 'FT_INT40', 'FT_INT48', 'FT_INT56', 'FT_INT64'}))
400apiChecks.append(APICheck('proto_tree_add_item_ret_boolean', { 'FT_BOOLEAN'}))
401apiChecks.append(APICheck('proto_tree_add_item_ret_string_and_length', { 'FT_STRING', 'FT_STRINGZ', 'FT_UINT_STRING', 'FT_STRINGZPAD', 'FT_STRINGZTRUNC'}))
402apiChecks.append(APICheck('proto_tree_add_item_ret_display_string_and_length', { 'FT_STRING', 'FT_STRINGZ', 'FT_UINT_STRING',
403                                                                                 'FT_STRINGZPAD', 'FT_STRINGZTRUNC', 'FT_BYTES', 'FT_UINT_BYTES'}))
404apiChecks.append(APICheck('proto_tree_add_item_ret_time_string', { 'FT_ABSOLUTE_TIME', 'FT_RELATIVE_TIME'}))
405apiChecks.append(APICheck('proto_tree_add_uint', {  'FT_CHAR', 'FT_UINT8', 'FT_UINT16', 'FT_UINT24', 'FT_UINT32', 'FT_FRAMENUM'}))
406apiChecks.append(APICheck('proto_tree_add_uint_format_value', {  'FT_CHAR', 'FT_UINT8', 'FT_UINT16', 'FT_UINT24', 'FT_UINT32', 'FT_FRAMENUM'}))
407apiChecks.append(APICheck('proto_tree_add_uint_format', {  'FT_CHAR', 'FT_UINT8', 'FT_UINT16', 'FT_UINT24', 'FT_UINT32', 'FT_FRAMENUM'}))
408apiChecks.append(APICheck('proto_tree_add_uint64', { 'FT_UINT40', 'FT_UINT48', 'FT_UINT56', 'FT_UINT64', 'FT_FRAMENUM'}))
409apiChecks.append(APICheck('proto_tree_add_int64', { 'FT_INT40', 'FT_INT48', 'FT_INT56', 'FT_INT64'}))
410apiChecks.append(APICheck('proto_tree_add_int64_format_value', { 'FT_INT40', 'FT_INT48', 'FT_INT56', 'FT_INT64'}))
411apiChecks.append(APICheck('proto_tree_add_int64_format', { 'FT_INT40', 'FT_INT48', 'FT_INT56', 'FT_INT64'}))
412apiChecks.append(APICheck('proto_tree_add_int', { 'FT_INT8', 'FT_INT16', 'FT_INT24', 'FT_INT32'}))
413apiChecks.append(APICheck('proto_tree_add_int_format_value', { 'FT_INT8', 'FT_INT16', 'FT_INT24', 'FT_INT32'}))
414apiChecks.append(APICheck('proto_tree_add_int_format', { 'FT_INT8', 'FT_INT16', 'FT_INT24', 'FT_INT32'}))
415apiChecks.append(APICheck('proto_tree_add_boolean', { 'FT_BOOLEAN'}))
416apiChecks.append(APICheck('proto_tree_add_boolean64', { 'FT_BOOLEAN'}))
417apiChecks.append(APICheck('proto_tree_add_float', { 'FT_FLOAT'}))
418apiChecks.append(APICheck('proto_tree_add_float_format', { 'FT_FLOAT'}))
419apiChecks.append(APICheck('proto_tree_add_float_format_value', { 'FT_FLOAT'}))
420apiChecks.append(APICheck('proto_tree_add_double', { 'FT_DOUBLE'}))
421apiChecks.append(APICheck('proto_tree_add_double_format', { 'FT_DOUBLE'}))
422apiChecks.append(APICheck('proto_tree_add_double_format_value', { 'FT_DOUBLE'}))
423apiChecks.append(APICheck('proto_tree_add_string', { 'FT_STRING', 'FT_STRINGZ', 'FT_STRINGZPAD', 'FT_STRINGZTRUNC'}))
424apiChecks.append(APICheck('proto_tree_add_string_format', { 'FT_STRING', 'FT_STRINGZ', 'FT_STRINGZPAD', 'FT_STRINGZTRUNC'}))
425apiChecks.append(APICheck('proto_tree_add_string_format_value', { 'FT_STRING', 'FT_STRINGZ', 'FT_STRINGZPAD', 'FT_STRINGZTRUNC'}))
426apiChecks.append(APICheck('proto_tree_add_guid', { 'FT_GUID'}))
427apiChecks.append(APICheck('proto_tree_add_oid', { 'FT_OID'}))
428apiChecks.append(APICheck('proto_tree_add_none_format', { 'FT_NONE'}))
429apiChecks.append(APICheck('proto_tree_add_item_ret_varint', { 'FT_INT8', 'FT_INT16', 'FT_INT24', 'FT_INT32', 'FT_INT40', 'FT_INT48', 'FT_INT56', 'FT_INT64',
430                                                              'FT_CHAR', 'FT_UINT8', 'FT_UINT16', 'FT_UINT24', 'FT_UINT32', 'FT_FRAMENUM',
431                                                              'FT_UINT40', 'FT_UINT48', 'FT_UINT56', 'FT_UINT64',}))
432apiChecks.append(APICheck('proto_tree_add_boolean_bits_format_value', { 'FT_BOOLEAN'}))
433apiChecks.append(APICheck('proto_tree_add_boolean_bits_format_value64', { 'FT_BOOLEAN'}))
434apiChecks.append(APICheck('proto_tree_add_ascii_7bits_item', { 'FT_STRING'}))
435apiChecks.append(APICheck('proto_tree_add_checksum', { 'FT_UINT8', 'FT_UINT16', 'FT_UINT24', 'FT_UINT32'}))
436apiChecks.append(APICheck('proto_tree_add_int64_bits_format_value', { 'FT_INT40', 'FT_INT48', 'FT_INT56', 'FT_INT64'}))
437
438
439bitmask_types = { 'FT_CHAR', 'FT_UINT8', 'FT_UINT16', 'FT_UINT24', 'FT_UINT32',
440                  'FT_INT8', 'FT_INT16', 'FT_INT24', 'FT_INT32',
441                  'FT_UINT40', 'FT_UINT48', 'FT_UINT56', 'FT_UINT64',
442                  'FT_INT40', 'FT_INT48', 'FT_INT56', 'FT_INT64',
443                   'FT_BOOLEAN'}
444apiChecks.append(APICheck('proto_tree_add_bitmask', bitmask_types))
445apiChecks.append(APICheck('proto_tree_add_bitmask_tree', bitmask_types))
446apiChecks.append(APICheck('proto_tree_add_bitmask_ret_uint64', bitmask_types))
447apiChecks.append(APICheck('proto_tree_add_bitmask_with_flags', bitmask_types))
448apiChecks.append(APICheck('proto_tree_add_bitmask_with_flags_ret_uint64', bitmask_types))
449apiChecks.append(APICheck('proto_tree_add_bitmask_value', bitmask_types))
450apiChecks.append(APICheck('proto_tree_add_bitmask_value_with_flags', bitmask_types))
451apiChecks.append(APICheck('proto_tree_add_bitmask_len', bitmask_types))
452apiChecks.append(APICheck('proto_tree_add_bitmask_text', bitmask_types))
453
454# Check some ptvcuror calls too.
455apiChecks.append(APICheck('ptvcursor_add_ret_uint', { 'FT_CHAR', 'FT_UINT8', 'FT_UINT16', 'FT_UINT24', 'FT_UINT32'}))
456apiChecks.append(APICheck('ptvcursor_add_ret_int', { 'FT_INT8', 'FT_INT16', 'FT_INT24', 'FT_INT32'}))
457apiChecks.append(APICheck('ptvcursor_add_ret_boolean', { 'FT_BOOLEAN'}))
458
459
460# Also try to check proto_tree_add_item() calls (for length)
461apiChecks.append(ProtoTreeAddItemCheck())
462apiChecks.append(ProtoTreeAddItemCheck(True)) # for ptvcursor_add()
463
464
465
466def removeComments(code_string):
467    code_string = re.sub(re.compile(r"/\*.*?\*/",re.DOTALL ) ,"" , code_string) # C-style comment
468    code_string = re.sub(re.compile(r"//.*?\n" ) ,"" , code_string)             # C++-style comment
469    return code_string
470
471# Test for whether the given file was automatically generated.
472def isGeneratedFile(filename):
473    # Open file
474    f_read = open(os.path.join(filename), 'r')
475    lines_tested = 0
476    for line in f_read:
477        # The comment to say that its generated is near the top, so give up once
478        # get a few lines down.
479        if lines_tested > 10:
480            f_read.close()
481            return False
482        if (line.find('Generated automatically') != -1 or
483            line.find('Autogenerated from') != -1 or
484            line.find('is autogenerated') != -1 or
485            line.find('automatically generated by Pidl') != -1 or
486            line.find('Created by: The Qt Meta Object Compiler') != -1 or
487            line.find('This file was generated') != -1 or
488            line.find('This filter was automatically generated') != -1):
489
490
491            f_read.close()
492            return True
493        lines_tested = lines_tested + 1
494
495    # OK, looks like a hand-written file!
496    f_read.close()
497    return False
498
499# Look for hf items in a dissector file.
500def find_items(filename, check_mask=False, check_label=False, check_consecutive=False):
501    is_generated = isGeneratedFile(filename)
502    items = {}
503    with open(filename, 'r') as f:
504        contents = f.read()
505        # Remove comments so as not to trip up RE.
506        contents = removeComments(contents)
507        matches = re.finditer(r'.*\{\s*\&(hf_.*),\s*{\s*\"(.+)\",\s*\"([a-zA-Z0-9_\-\.]+)\",\s*([A-Z0-9_]*),\s*(.*),\s*([&A-Za-z0-9x_\(\)]*),\s*([a-z0-9x_]*),', contents)
508        for m in matches:
509            # Store this item.
510            hf = m.group(1)
511            items[hf] = Item(filename, filter=m.group(3), label=m.group(2), item_type=m.group(4), mask=m.group(7),
512                             type_modifier=m.group(5),
513                             check_mask=check_mask,
514                             check_label=check_label,
515                             check_consecutive=(not is_generated and check_consecutive))
516    return items
517
518
519
520def is_dissector_file(filename):
521    p = re.compile(r'.*packet-.*\.c')
522    return p.match(filename)
523
524
525def findDissectorFilesInFolder(folder, dissector_files=None, recursive=False):
526    if dissector_files is None:
527        dissector_files = []
528    if recursive:
529        for root, subfolders, files in os.walk(folder):
530            for f in files:
531                if should_exit:
532                    return
533                f = os.path.join(root, f)
534                dissector_files.append(f)
535    else:
536        for f in sorted(os.listdir(folder)):
537            if should_exit:
538                return
539            filename = os.path.join(folder, f)
540            dissector_files.append(filename)
541
542    return [x for x in filter(is_dissector_file, dissector_files)]
543
544
545
546# Run checks on the given dissector file.
547def checkFile(filename, check_mask=False, check_label=False, check_consecutive=False):
548    # Check file exists - e.g. may have been deleted in a recent commit.
549    if not os.path.exists(filename):
550        print(filename, 'does not exist!')
551        return
552
553    # Find important parts of items.
554    items = find_items(filename, check_mask, check_label, check_consecutive)
555
556    # Check each API
557    for c in apiChecks:
558        c.find_calls(filename)
559        c.check_against_items(items)
560
561
562
563#################################################################
564# Main logic.
565
566# command-line args.  Controls which dissector files should be checked.
567# If no args given, will just scan epan/dissectors folder.
568parser = argparse.ArgumentParser(description='Check calls in dissectors')
569parser.add_argument('--file', action='store', default='',
570                    help='specify individual dissector file to test')
571parser.add_argument('--folder', action='store', default='',
572                    help='specify folder to test')
573parser.add_argument('--commits', action='store',
574                    help='last N commits to check')
575parser.add_argument('--open', action='store_true',
576                    help='check open files')
577parser.add_argument('--mask', action='store_true',
578                   help='when set, check mask field too')
579parser.add_argument('--label', action='store_true',
580                   help='when set, check label field too')
581parser.add_argument('--consecutive', action='store_true',
582                    help='when set, copy copy/paste errors between consecutive items')
583
584
585args = parser.parse_args()
586
587
588# Get files from wherever command-line args indicate.
589files = []
590if args.file:
591    # Add single specified file
592    if not args.file.startswith('epan') and not os.path.exists(args.file):
593        files.append(os.path.join('epan', 'dissectors', args.file))
594    else:
595        files.append(args.file)
596elif args.folder:
597    # Add all files from a given folder.
598    folder = args.folder
599    if not os.path.isdir(folder):
600        print('Folder', folder, 'not found!')
601        exit(1)
602    # Find files from folder.
603    print('Looking for files in', folder)
604    files = findDissectorFilesInFolder(folder, recursive=True)
605elif args.commits:
606    # Get files affected by specified number of commits.
607    command = ['git', 'diff', '--name-only', '--diff-filter=d', 'HEAD~' + args.commits]
608    files = [f.decode('utf-8')
609             for f in subprocess.check_output(command).splitlines()]
610    # Will examine dissector files only
611    files = list(filter(lambda f : is_dissector_file(f), files))
612elif args.open:
613    # Unstaged changes.
614    command = ['git', 'diff', '--name-only', '--diff-filter=d']
615    files = [f.decode('utf-8')
616             for f in subprocess.check_output(command).splitlines()]
617    # Only interested in dissector files.
618    files = list(filter(lambda f : is_dissector_file(f), files))
619    # Staged changes.
620    command = ['git', 'diff', '--staged', '--name-only', '--diff-filter=d']
621    files_staged = [f.decode('utf-8')
622                    for f in subprocess.check_output(command).splitlines()]
623    # Only interested in dissector files.
624    files_staged = list(filter(lambda f : is_dissector_file(f), files_staged))
625    for f in files_staged:
626        if not f in files:
627            files.append(f)
628else:
629    # Find all dissector files.
630    files = findDissectorFilesInFolder(os.path.join('epan', 'dissectors'))
631    files = findDissectorFilesInFolder(os.path.join('plugins', 'epan'), recursive=True, dissector_files=files)
632
633
634# If scanning a subset of files, list them here.
635print('Examining:')
636if args.file or args.commits or args.open:
637    if files:
638        print(' '.join(files), '\n')
639    else:
640        print('No files to check.\n')
641else:
642    print('All dissector modules\n')
643
644
645# Now check the files.
646for f in files:
647    if should_exit:
648        exit(1)
649    checkFile(f, check_mask=args.mask, check_label=args.label, check_consecutive=args.consecutive)
650
651# Show summary.
652print(warnings_found, 'warnings')
653if errors_found:
654    print(errors_found, 'errors')
655    exit(1)
656