1"""
2Detect and replace instances of g_malloc() and wmem_alloc() with
3g_new() wmem_new(), to improve the readability of Wireshark's code.
4
5Also detect and replace instances of
6g_malloc(sizeof(struct myobj) * foo)
7with:
8g_new(struct myobj, foo)
9to better prevent integer overflows
10
11SPDX-License-Identifier: MIT
12"""
13
14import os
15import re
16import sys
17
18print_replacement_info = True
19
20patterns = [
21# Replace (myobj *)g_malloc(sizeof(myobj)) with g_new(myobj, 1)
22# Replace (struct myobj *)g_malloc(sizeof(struct myobj)) with g_new(struct myobj, 1)
23(re.compile(r'\(\s*([struct]{0,6}\s*[^\s\*]+)\s*\*\s*\)\s*g_malloc(0?)\s*\(\s*sizeof\s*\(\s*\1\s*\)\s*\)'), r'g_new\2(\1, 1)'),
24
25# Replace (myobj *)g_malloc(sizeof(myobj) * foo) with g_new(myobj, foo)
26# Replace (struct myobj *)g_malloc(sizeof(struct myobj) * foo) with g_new(struct myobj, foo)
27(re.compile(r'\(\s*([struct]{0,6}\s*[^\s\*]+)\s*\*\s*\)\s*g_malloc(0?)\s*\(\s*sizeof\s*\(\s*\1\s*\)\s*\*\s*([^\s]+)\s*\)'), r'g_new\2(\1, \3)'),
28
29# Replace (myobj *)g_malloc(foo * sizeof(myobj)) with g_new(myobj, foo)
30# Replace (struct myobj *)g_malloc(foo * sizeof(struct myobj)) with g_new(struct myobj, foo)
31(re.compile(r'\(\s*([struct]{0,6}\s*[^\s\*]+)\s*\*\s*\)\s*g_malloc(0?)\s*\(\s*([^\s]+)\s*\*\s*sizeof\s*\(\s*\1\s*\)\s*\)'), r'g_new\2(\1, \3)'),
32
33# Replace (myobj *)wmem_alloc(wmem_file_scope(), sizeof(myobj)) with wmem_new(wmem_file_scope(), myobj)
34# Replace (struct myobj *)wmem_alloc(wmem_file_scope(), sizeof(struct myobj)) with wmem_new(wmem_file_scope(), struct myobj)
35(re.compile(r'\(\s*([struct]{0,6}\s*[^\s\*]+)\s*\*\s*\)\s*wmem_alloc(0?)\s*\(\s*([_a-z\(\)->]+),\s*sizeof\s*\(\s*\1\s*\)\s*\)'), r'wmem_new\2(\3, \1)'),
36]
37
38def replace_file(fpath):
39    with open(fpath, 'r') as fh:
40        fdata_orig = fh.read()
41    fdata = fdata_orig
42    for pattern, replacewith in patterns:
43        fdata_out = pattern.sub(replacewith, fdata)
44        if print_replacement_info and fdata != fdata_out:
45            for match in re.finditer(pattern, fdata):
46                replacement = re.sub(pattern, replacewith, match.group(0))
47                print("Bad malloc pattern in %s: Replace '%s' with '%s'" % (fpath, match.group(0), replacement))
48        fdata = fdata_out
49    if fdata_out != fdata_orig:
50        with open(fpath, 'w') as fh:
51            fh.write(fdata_out)
52    return fdata_out
53
54def run_specific_files(fpaths):
55    for fpath in fpaths:
56        if not (fpath.endswith('.c') or fpath.endswith('.cpp')):
57            continue
58        replace_file(fpath)
59
60def run_recursive(root_dir):
61    for root, dirs, files in os.walk(root_dir):
62        fpaths = []
63        for fname in files:
64            fpath = os.path.join(root, fname)
65            fpaths.append(fpath)
66        run_specific_files(fpaths)
67
68def test_replacements():
69    test_string = """\
70(if_info_t*) g_malloc0(sizeof(if_info_t))
71(oui_info_t *)g_malloc(sizeof (oui_info_t))
72(guint8 *)g_malloc(16 * sizeof(guint8))
73(guint32 *)g_malloc(sizeof(guint32)*2)
74(struct imf_field *)g_malloc (sizeof (struct imf_field))
75(rtspstat_t *)g_malloc( sizeof(rtspstat_t) )
76(proto_data_t *)wmem_alloc(scope, sizeof(proto_data_t))
77(giop_sub_handle_t *)wmem_alloc(wmem_epan_scope(), sizeof (giop_sub_handle_t))
78(mtp3_addr_pc_t *)wmem_alloc0(pinfo->pool, sizeof(mtp3_addr_pc_t))
79(dcerpc_bind_value *)wmem_alloc(wmem_file_scope(), sizeof (dcerpc_bind_value))
80(dcerpc_matched_key *)wmem_alloc(wmem_file_scope(), sizeof (dcerpc_matched_key));
81(struct smtp_session_state *)wmem_alloc0(wmem_file_scope(), sizeof(struct smtp_session_state))
82(struct batman_packet_v5 *)wmem_alloc(pinfo->pool, sizeof(struct batman_packet_v5))
83(struct knx_keyring_mca_keys*) wmem_alloc( wmem_epan_scope(), sizeof( struct knx_keyring_mca_keys ) )
84"""
85    expected_output = """\
86g_new0(if_info_t, 1)
87g_new(oui_info_t, 1)
88g_new(guint8, 16)
89g_new(guint32, 2)
90g_new(struct imf_field, 1)
91g_new(rtspstat_t, 1)
92wmem_new(scope, proto_data_t)
93wmem_new(wmem_epan_scope(), giop_sub_handle_t)
94wmem_new0(pinfo->pool, mtp3_addr_pc_t)
95wmem_new(wmem_file_scope(), dcerpc_bind_value)
96wmem_new(wmem_file_scope(), dcerpc_matched_key);
97wmem_new0(wmem_file_scope(), struct smtp_session_state)
98wmem_new(pinfo->pool, struct batman_packet_v5)
99wmem_new(wmem_epan_scope(), struct knx_keyring_mca_keys)
100"""
101    output = test_string
102    for pattern, replacewith in patterns:
103        output = pattern.sub(replacewith, output)
104    assert(output == expected_output)
105
106def main():
107    test_replacements()
108    if len(sys.argv) == 2:
109        root_dir = sys.argv[1]
110        run_recursive(root_dir)
111    else:
112        fpaths = []
113        for line in sys.stdin:
114            line = line.strip()
115            if line:
116                fpaths.append(line)
117        run_specific_files(fpaths)
118
119if __name__ == "__main__":
120    main()
121