1""" 2Detect and replace instances of g_malloc() and wmem_alloc() with 3g_new() wmem_new(), to improve the readability of Wireshark's code. 4 5Also detect and replace instances of 6g_malloc(sizeof(struct myobj) * foo) 7with: 8g_new(struct myobj, foo) 9to better prevent integer overflows 10 11SPDX-License-Identifier: MIT 12""" 13 14import os 15import re 16import sys 17 18print_replacement_info = True 19 20patterns = [ 21# Replace (myobj *)g_malloc(sizeof(myobj)) with g_new(myobj, 1) 22# Replace (struct myobj *)g_malloc(sizeof(struct myobj)) with g_new(struct myobj, 1) 23(re.compile(r'\(\s*([struct]{0,6}\s*[^\s\*]+)\s*\*\s*\)\s*g_malloc(0?)\s*\(\s*sizeof\s*\(\s*\1\s*\)\s*\)'), r'g_new\2(\1, 1)'), 24 25# Replace (myobj *)g_malloc(sizeof(myobj) * foo) with g_new(myobj, foo) 26# Replace (struct myobj *)g_malloc(sizeof(struct myobj) * foo) with g_new(struct myobj, foo) 27(re.compile(r'\(\s*([struct]{0,6}\s*[^\s\*]+)\s*\*\s*\)\s*g_malloc(0?)\s*\(\s*sizeof\s*\(\s*\1\s*\)\s*\*\s*([^\s]+)\s*\)'), r'g_new\2(\1, \3)'), 28 29# Replace (myobj *)g_malloc(foo * sizeof(myobj)) with g_new(myobj, foo) 30# Replace (struct myobj *)g_malloc(foo * sizeof(struct myobj)) with g_new(struct myobj, foo) 31(re.compile(r'\(\s*([struct]{0,6}\s*[^\s\*]+)\s*\*\s*\)\s*g_malloc(0?)\s*\(\s*([^\s]+)\s*\*\s*sizeof\s*\(\s*\1\s*\)\s*\)'), r'g_new\2(\1, \3)'), 32 33# Replace (myobj *)wmem_alloc(wmem_file_scope(), sizeof(myobj)) with wmem_new(wmem_file_scope(), myobj) 34# Replace (struct myobj *)wmem_alloc(wmem_file_scope(), sizeof(struct myobj)) with wmem_new(wmem_file_scope(), struct myobj) 35(re.compile(r'\(\s*([struct]{0,6}\s*[^\s\*]+)\s*\*\s*\)\s*wmem_alloc(0?)\s*\(\s*([_a-z\(\)->]+),\s*sizeof\s*\(\s*\1\s*\)\s*\)'), r'wmem_new\2(\3, \1)'), 36] 37 38def replace_file(fpath): 39 with open(fpath, 'r') as fh: 40 fdata_orig = fh.read() 41 fdata = fdata_orig 42 for pattern, replacewith in patterns: 43 fdata_out = pattern.sub(replacewith, fdata) 44 if print_replacement_info and fdata != fdata_out: 45 for match in re.finditer(pattern, fdata): 46 replacement = re.sub(pattern, replacewith, match.group(0)) 47 print("Bad malloc pattern in %s: Replace '%s' with '%s'" % (fpath, match.group(0), replacement)) 48 fdata = fdata_out 49 if fdata_out != fdata_orig: 50 with open(fpath, 'w') as fh: 51 fh.write(fdata_out) 52 return fdata_out 53 54def run_specific_files(fpaths): 55 for fpath in fpaths: 56 if not (fpath.endswith('.c') or fpath.endswith('.cpp')): 57 continue 58 replace_file(fpath) 59 60def run_recursive(root_dir): 61 for root, dirs, files in os.walk(root_dir): 62 fpaths = [] 63 for fname in files: 64 fpath = os.path.join(root, fname) 65 fpaths.append(fpath) 66 run_specific_files(fpaths) 67 68def test_replacements(): 69 test_string = """\ 70(if_info_t*) g_malloc0(sizeof(if_info_t)) 71(oui_info_t *)g_malloc(sizeof (oui_info_t)) 72(guint8 *)g_malloc(16 * sizeof(guint8)) 73(guint32 *)g_malloc(sizeof(guint32)*2) 74(struct imf_field *)g_malloc (sizeof (struct imf_field)) 75(rtspstat_t *)g_malloc( sizeof(rtspstat_t) ) 76(proto_data_t *)wmem_alloc(scope, sizeof(proto_data_t)) 77(giop_sub_handle_t *)wmem_alloc(wmem_epan_scope(), sizeof (giop_sub_handle_t)) 78(mtp3_addr_pc_t *)wmem_alloc0(pinfo->pool, sizeof(mtp3_addr_pc_t)) 79(dcerpc_bind_value *)wmem_alloc(wmem_file_scope(), sizeof (dcerpc_bind_value)) 80(dcerpc_matched_key *)wmem_alloc(wmem_file_scope(), sizeof (dcerpc_matched_key)); 81(struct smtp_session_state *)wmem_alloc0(wmem_file_scope(), sizeof(struct smtp_session_state)) 82(struct batman_packet_v5 *)wmem_alloc(pinfo->pool, sizeof(struct batman_packet_v5)) 83(struct knx_keyring_mca_keys*) wmem_alloc( wmem_epan_scope(), sizeof( struct knx_keyring_mca_keys ) ) 84""" 85 expected_output = """\ 86g_new0(if_info_t, 1) 87g_new(oui_info_t, 1) 88g_new(guint8, 16) 89g_new(guint32, 2) 90g_new(struct imf_field, 1) 91g_new(rtspstat_t, 1) 92wmem_new(scope, proto_data_t) 93wmem_new(wmem_epan_scope(), giop_sub_handle_t) 94wmem_new0(pinfo->pool, mtp3_addr_pc_t) 95wmem_new(wmem_file_scope(), dcerpc_bind_value) 96wmem_new(wmem_file_scope(), dcerpc_matched_key); 97wmem_new0(wmem_file_scope(), struct smtp_session_state) 98wmem_new(pinfo->pool, struct batman_packet_v5) 99wmem_new(wmem_epan_scope(), struct knx_keyring_mca_keys) 100""" 101 output = test_string 102 for pattern, replacewith in patterns: 103 output = pattern.sub(replacewith, output) 104 assert(output == expected_output) 105 106def main(): 107 test_replacements() 108 if len(sys.argv) == 2: 109 root_dir = sys.argv[1] 110 run_recursive(root_dir) 111 else: 112 fpaths = [] 113 for line in sys.stdin: 114 line = line.strip() 115 if line: 116 fpaths.append(line) 117 run_specific_files(fpaths) 118 119if __name__ == "__main__": 120 main() 121