1#!/usr/bin/env python3
2
3import os, sys, fnmatch
4import xml.etree.ElementTree as ET
5import argparse
6
7"""
8A static protocol code generator.
9"""
10
11wltype_to_ctypes = {
12    "uint": "uint32_t ",
13    "fixed": "uint32_t ",
14    "int": "int32_t ",
15    "object": "struct wp_object *",
16    "new_id": "struct wp_object *",
17    "string": "const char *",
18    "fd": "int ",
19}
20
21
22def superstring(a, b):
23    na, nb = len(a), len(b)
24    if nb > na:
25        b, a, nb, na = a, b, na, nb
26    # A contains B
27    for i in range(na - nb + 1):
28        if a[i : nb + i] == b:
29            return a
30
31    # suffix of B is prefix of A
32    ba_overlap = 0
33    for i in range(1, nb):
34        if b[-i:] == a[:i]:
35            ba_overlap = i
36
37    # suffix of A is prefix of B
38    ab_overlap = 0
39    for i in range(1, nb):
40        if a[-i:] == b[:i]:
41            ab_overlap = i
42
43    if ba_overlap > ab_overlap:
44        return b + a[ba_overlap:]
45    else:
46        return a + b[ab_overlap:]
47
48
49def get_offset(haystack, needle):
50    for i in range(len(haystack) - len(needle) + 1):
51        if haystack[i : i + len(needle)] == needle:
52            return i
53    return None
54
55
56def shortest_superstring(strings):
57    """
58    Given strings L_1,...L_n over domain U, report an approximation
59    of the shortest superstring of the lists, and offsets of the
60    L_i into this string. Has O(n^3) runtime; O(n^2 polylog) is possible.
61    """
62    if not len(strings):
63        return None
64
65    pool = []
66    for s in strings:
67        if s not in pool:
68            pool.append(s)
69
70    while len(pool) > 1:
71        max_overlap = 0
72        best = None
73        for i in range(len(pool)):
74            for j in range(i):
75                d = len(pool[i]) + len(pool[j]) - len(superstring(pool[i], pool[j]))
76                if d >= max_overlap:
77                    max_overlap = d
78                    best = (j, i)
79
80        s = superstring(pool[best[0]], pool[best[1]])
81        del pool[best[1]]
82        del pool[best[0]]
83        pool.append(s)
84
85    sstring = pool[0]
86    for s in strings:
87        assert get_offset(sstring, s) != None, ("substring property", sstring, s)
88
89    return sstring
90
91
92def write_enum(is_header, ostream, iface_name, enum):
93    if not is_header:
94        return
95
96    enum_name = enum.attrib["name"]
97    is_bitfield = "bitfield" in enum.attrib and enum.attrib["bitfield"] == "true"
98
99    long_name = iface_name + "_" + enum_name
100    print("enum " + long_name + " {", file=ostream)
101    for entry in enum:
102        if entry.tag != "entry":
103            continue
104        entry_name = entry.attrib["name"]
105        entry_value = entry.attrib["value"]
106
107        full_name = long_name.upper() + "_" + entry_name.upper()
108        print("\t" + full_name + " = " + entry_value + ",", file=ostream)
109    print("};", file=ostream)
110
111
112def is_exportable(func_name, export_list):
113    for e in export_list:
114        if fnmatch.fnmatchcase(func_name, e):
115            return True
116    return False
117
118
119def write_func(is_header, ostream, func_name, func):
120    c_sig = ["struct context *ctx"]
121    w_args = []
122
123    num_fd_args = 0
124    num_reg_args = 0
125    num_obj_args = 0
126    num_new_args = 0
127    num_stretch_args = 0
128    for arg in func:
129        if arg.tag != "arg":
130            continue
131
132        arg_name = arg.attrib["name"]
133        arg_type = arg.attrib["type"]
134        arg_interface = arg.attrib["interface"] if "interface" in arg.attrib else None
135        if arg_type == "new_id" and arg_interface is None:
136            # Special case, for wl_registry_bind
137            c_sig.append("const char *interface")
138            c_sig.append("uint32_t version")
139            c_sig.append("struct wp_object *id")
140            w_args.append(("interface", "string", None))
141            w_args.append(("version", "uint", None))
142            w_args.append((arg_name, "new_id", None))
143            num_obj_args += 1
144            num_new_args += 1
145            num_reg_args += 3
146            num_stretch_args += 1
147            continue
148
149        if arg_type == "array":
150            c_sig.append("int " + arg_name + "_count")
151            c_sig.append("const uint8_t *" + arg_name + "_val")
152        else:
153            c_sig.append(wltype_to_ctypes[arg_type] + arg_name)
154        w_args.append((arg_name, arg_type, arg_interface))
155        if arg_type == "fd":
156            num_fd_args += 1
157        else:
158            num_reg_args += 1
159        if arg_type == "object" or arg_type == "new_id":
160            num_obj_args += 1
161        if arg_type == "new_id":
162            num_new_args += 1
163        if arg_type in ("array", "string"):
164            num_stretch_args += 1
165
166    do_signature = "void do_{}({});".format(func_name, ", ".join(c_sig))
167    handle_signature = "static void call_{}(struct context *ctx, const uint32_t *payload, const int *fds, struct message_tracker *mt)".format(
168        func_name
169    )
170
171    W = lambda *x: print(*x, file=ostream)
172    if is_header:
173        W(do_signature)
174    if not is_header:
175        # Write function definition
176        W(do_signature)
177        W(handle_signature + " {")
178        if num_reg_args > 0:
179            W("\tunsigned int i = 0;")
180        if num_fd_args > 0:
181            W("\tunsigned int k = 0;")
182
183        tmp_names = ["ctx"]
184        n_fds_left = num_fd_args
185        n_reg_left = num_reg_args
186        for i, (arg_name, arg_type, arg_interface) in enumerate(w_args):
187            if arg_type == "array":
188                n_reg_left -= 1
189                W(
190                    "\tconst uint8_t *arg{}_b = (const uint8_t *)&payload[i + 1];".format(
191                        i
192                    )
193                )
194                W("\tint arg{}_a = (int)payload[i];".format(i))
195                if n_reg_left > 0:
196                    W("\ti += 1 + (unsigned int)((arg{}_a + 0x3) >> 2);".format(i))
197
198                tmp_names.append("arg{}_a".format(i))
199                tmp_names.append("arg{}_b".format(i))
200                continue
201
202            tmp_names.append("arg{}".format(i))
203
204            if arg_type == "fd":
205                n_fds_left -= 1
206                W("\tint arg{} = fds[{}];".format(i, "k++" if n_fds_left > 0 else "k"))
207                continue
208
209            n_reg_left -= 1
210            if arg_type == "string":
211                W("\tconst char *arg{} = (const char *)&payload[i + 1];".format(i))
212                W("\tif (!payload[i]) arg{} = NULL;".format(i))
213                if n_reg_left > 0:
214                    W("\ti += 1 + ((payload[i] + 0x3) >> 2);")
215                continue
216
217            i_incr = "i++" if n_reg_left > 0 else "i"
218
219            if arg_type == "object" or arg_type == "new_id":
220                if arg_interface is None:
221                    intf_str = "NULL"
222                else:
223                    intf_str = "&intf_" + arg_interface
224                W(
225                    "\tstruct wp_object *arg{} = get_object(mt, payload[{}], {});".format(
226                        i, i_incr, intf_str
227                    )
228                )
229            elif arg_type == "int":
230                W("\tint32_t arg{} = (int32_t)payload[{}];".format(i, i_incr))
231            elif arg_type == "uint" or arg_type == "fixed":
232                W("\tuint32_t arg{} = payload[{}];".format(i, i_incr))
233
234        W("\tdo_{}({});".format(func_name, ", ".join(tmp_names)))
235        if num_obj_args == 0:
236            W("\t(void)mt;")
237        if num_fd_args == 0:
238            W("\t(void)fds;")
239        if num_reg_args == 0:
240            W("\t(void)payload;")
241
242        W("}")
243
244
245def load_msg_data(func_name, func, for_export):
246    w_args = []
247    for arg in func:
248        if arg.tag != "arg":
249            continue
250        arg_name = arg.attrib["name"]
251        arg_type = arg.attrib["type"]
252        arg_interface = arg.attrib["interface"] if "interface" in arg.attrib else None
253        if arg_type == "new_id" and arg_interface is None:
254            w_args.append(("interface", "string", None))
255            w_args.append(("version", "uint", None))
256            w_args.append((arg_name, "new_id", None))
257        else:
258            w_args.append((arg_name, arg_type, arg_interface))
259
260    new_objs = []
261    for arg_name, arg_type, arg_interface in w_args:
262        if arg_type == "new_id":
263            new_objs.append(
264                "&intf_" + arg_interface if arg_interface is not None else "NULL"
265            )
266
267    # gap coding: 0=end,1=new_obj,2=array,3=string
268    num_fd_args = 0
269    gaps = [0]
270    gap_ends = []
271    for arg_name, arg_type, arg_interface in w_args:
272        if arg_type == "fd":
273            num_fd_args += 1
274            continue
275
276        gaps[-1] += 1
277        if arg_type in ("new_id", "string", "array"):
278            gap_ends.append({"new_id": 1, "string": 3, "array": 2}[arg_type])
279            gaps.append(0)
280    gap_ends.append(0)
281    gap_codes = [str(g * 4 + e) for g, e in zip(gaps, gap_ends)]
282
283    is_destructor = "type" in func.attrib and func.attrib["type"] == "destructor"
284    is_request = item.tag == "request"
285    short_name = func.attrib["name"]
286
287    return (
288        is_request,
289        func_name,
290        short_name,
291        new_objs,
292        gap_codes,
293        is_destructor,
294        num_fd_args,
295        for_export,
296    )
297
298
299def write_interface(
300    ostream, iface_name, func_data, gap_code_array, new_obj_array, dest_name
301):
302    reqs, evts = [], []
303    for x in func_data:
304        if x[0]:
305            reqs.append(x)
306        else:
307            evts.append(x)
308
309    W = lambda *x: print(*x, file=ostream)
310
311    if len(reqs) > 0 or len(evts) > 0:
312        W("static const struct msg_data msgs_" + iface_name + "[] = {")
313
314    msg_names = []
315    for x in reqs + evts:
316        (
317            is_request,
318            func_name,
319            short_name,
320            new_objs,
321            gap_codes,
322            is_destructor,
323            num_fd_args,
324            for_export,
325        ) = x
326        msg_names.append(short_name)
327
328        mda = []
329        mda.append(
330            "gaps_{} + {}".format(dest_name, get_offset(gap_code_array, gap_codes))
331        )
332        if len(new_objs) > 0:
333            mda.append(
334                "objt_{} + {}".format(dest_name, get_offset(new_obj_array, new_objs))
335            )
336        else:
337            mda.append("NULL")
338
339        mda.append(("call_" + func_name) if for_export else "NULL")
340        mda.append(str(num_fd_args))
341        mda.append("true" if is_destructor else "false")
342
343        W("\t{" + ", ".join(mda) + "},")
344
345    mcn = "NULL"
346    if len(reqs) > 0 or len(evts) > 0:
347        W("};")
348        mcn = "msgs_" + iface_name
349
350    W("const struct wp_interface intf_" + iface_name + " = {")
351    W("\t" + mcn + ",")
352    W("\t" + str(len(reqs)) + ",")
353    W("\t" + str(len(evts)) + ",")
354    W('\t"{}",'.format(iface_name))
355    W('\t"{}",'.format("\\0".join(msg_names)))
356    W("};")
357
358
359if __name__ == "__main__":
360    parser = argparse.ArgumentParser()
361    parser.add_argument("mode", help="Either 'header' or 'data'.")
362    parser.add_argument(
363        "export_list", help="List of events/requests which need parsing."
364    )
365    parser.add_argument("output_file", help="C file to create.")
366    parser.add_argument("protocols", nargs="+", help="XML protocol files to use.")
367    args = parser.parse_args()
368
369    is_header = {"data": False, "header": True}[args.mode]
370    if is_header:
371        assert args.output_file[-2:] == ".h"
372    else:
373        assert args.output_file[-2:] == ".c"
374    dest_name = os.path.basename(args.output_file)[:-2].replace("-", "_")
375
376    export_list = open(args.export_list).read().split("\n")
377
378    intfset = set()
379    for source in args.protocols:
380        tree = ET.parse(source)
381        root = tree.getroot()
382        for intf in root:
383            if intf.tag == "interface":
384                intfset.add(intf.attrib["name"])
385                for msg in intf:
386                    for arg in msg:
387                        if "interface" in arg.attrib:
388                            intfset.add(arg.attrib["interface"])
389    interfaces = sorted(intfset)
390
391    header_guard = "{}_H".format(dest_name.upper())
392    with open(args.output_file, "w") as ostream:
393        W = lambda *x: print(*x, file=ostream)
394
395        if is_header:
396            W("#ifndef {}".format(header_guard))
397            W("#define {}".format(header_guard))
398            W()
399        W('#include "symgen_types.h"')
400        if not is_header:
401            W("#include <stddef.h>")
402
403        for intf in interfaces:
404            W("extern const struct wp_interface intf_{};".format(intf))
405
406        gap_code_list = []
407        new_obj_list = []
408
409        interface_data = []
410
411        for source in sorted(args.protocols):
412            tree = ET.parse(source)
413            root = tree.getroot()
414            for interface in root:
415                if interface.tag != "interface":
416                    continue
417                iface_name = interface.attrib["name"]
418
419                func_data = []
420                for item in interface:
421                    if item.tag == "enum":
422                        write_enum(is_header, ostream, iface_name, item)
423                    elif item.tag == "request" or item.tag == "event":
424                        is_req = item.tag == "request"
425                        func_name = (
426                            iface_name
427                            + "_"
428                            + ("req" if is_req else "evt")
429                            + "_"
430                            + item.attrib["name"]
431                        )
432
433                        for_export = is_exportable(func_name, export_list)
434                        if for_export:
435                            write_func(is_header, ostream, func_name, item)
436                        if not is_header:
437                            func_data.append(load_msg_data(func_name, item, for_export))
438
439                    elif item.tag == "description":
440                        pass
441                    else:
442                        raise Exception(item.tag)
443
444                for x in func_data:
445                    gap_code_list.append(x[4])
446                    new_obj_list.append(x[3])
447
448                interface_data.append((iface_name, func_data))
449
450        if not is_header:
451            gap_code_array = shortest_superstring(gap_code_list)
452            new_obj_array = shortest_superstring(new_obj_list)
453
454            if new_obj_array is not None:
455                W("static const struct wp_interface *objt_" + dest_name + "[] = {")
456                W("\t" + ",\n\t".join(new_obj_array))
457                W("};")
458
459            if gap_code_array is not None:
460                W("static const uint16_t gaps_" + dest_name + "[] = {")
461                W("\t" + ",\n\t".join(gap_code_array))
462                W("};")
463
464            for iface_name, func_data in interface_data:
465                write_interface(
466                    ostream,
467                    iface_name,
468                    func_data,
469                    gap_code_array,
470                    new_obj_array,
471                    dest_name,
472                )
473
474        if is_header:
475            W()
476            W("#endif /* {} */".format(header_guard))
477