1#!/usr/bin/env python3 2 3import os, sys, fnmatch 4import xml.etree.ElementTree as ET 5import argparse 6 7""" 8A static protocol code generator. 9""" 10 11wltype_to_ctypes = { 12 "uint": "uint32_t ", 13 "fixed": "uint32_t ", 14 "int": "int32_t ", 15 "object": "struct wp_object *", 16 "new_id": "struct wp_object *", 17 "string": "const char *", 18 "fd": "int ", 19} 20 21 22def superstring(a, b): 23 na, nb = len(a), len(b) 24 if nb > na: 25 b, a, nb, na = a, b, na, nb 26 # A contains B 27 for i in range(na - nb + 1): 28 if a[i : nb + i] == b: 29 return a 30 31 # suffix of B is prefix of A 32 ba_overlap = 0 33 for i in range(1, nb): 34 if b[-i:] == a[:i]: 35 ba_overlap = i 36 37 # suffix of A is prefix of B 38 ab_overlap = 0 39 for i in range(1, nb): 40 if a[-i:] == b[:i]: 41 ab_overlap = i 42 43 if ba_overlap > ab_overlap: 44 return b + a[ba_overlap:] 45 else: 46 return a + b[ab_overlap:] 47 48 49def get_offset(haystack, needle): 50 for i in range(len(haystack) - len(needle) + 1): 51 if haystack[i : i + len(needle)] == needle: 52 return i 53 return None 54 55 56def shortest_superstring(strings): 57 """ 58 Given strings L_1,...L_n over domain U, report an approximation 59 of the shortest superstring of the lists, and offsets of the 60 L_i into this string. Has O(n^3) runtime; O(n^2 polylog) is possible. 61 """ 62 if not len(strings): 63 return None 64 65 pool = [] 66 for s in strings: 67 if s not in pool: 68 pool.append(s) 69 70 while len(pool) > 1: 71 max_overlap = 0 72 best = None 73 for i in range(len(pool)): 74 for j in range(i): 75 d = len(pool[i]) + len(pool[j]) - len(superstring(pool[i], pool[j])) 76 if d >= max_overlap: 77 max_overlap = d 78 best = (j, i) 79 80 s = superstring(pool[best[0]], pool[best[1]]) 81 del pool[best[1]] 82 del pool[best[0]] 83 pool.append(s) 84 85 sstring = pool[0] 86 for s in strings: 87 assert get_offset(sstring, s) != None, ("substring property", sstring, s) 88 89 return sstring 90 91 92def write_enum(is_header, ostream, iface_name, enum): 93 if not is_header: 94 return 95 96 enum_name = enum.attrib["name"] 97 is_bitfield = "bitfield" in enum.attrib and enum.attrib["bitfield"] == "true" 98 99 long_name = iface_name + "_" + enum_name 100 print("enum " + long_name + " {", file=ostream) 101 for entry in enum: 102 if entry.tag != "entry": 103 continue 104 entry_name = entry.attrib["name"] 105 entry_value = entry.attrib["value"] 106 107 full_name = long_name.upper() + "_" + entry_name.upper() 108 print("\t" + full_name + " = " + entry_value + ",", file=ostream) 109 print("};", file=ostream) 110 111 112def is_exportable(func_name, export_list): 113 for e in export_list: 114 if fnmatch.fnmatchcase(func_name, e): 115 return True 116 return False 117 118 119def write_func(is_header, ostream, func_name, func): 120 c_sig = ["struct context *ctx"] 121 w_args = [] 122 123 num_fd_args = 0 124 num_reg_args = 0 125 num_obj_args = 0 126 num_new_args = 0 127 num_stretch_args = 0 128 for arg in func: 129 if arg.tag != "arg": 130 continue 131 132 arg_name = arg.attrib["name"] 133 arg_type = arg.attrib["type"] 134 arg_interface = arg.attrib["interface"] if "interface" in arg.attrib else None 135 if arg_type == "new_id" and arg_interface is None: 136 # Special case, for wl_registry_bind 137 c_sig.append("const char *interface") 138 c_sig.append("uint32_t version") 139 c_sig.append("struct wp_object *id") 140 w_args.append(("interface", "string", None)) 141 w_args.append(("version", "uint", None)) 142 w_args.append((arg_name, "new_id", None)) 143 num_obj_args += 1 144 num_new_args += 1 145 num_reg_args += 3 146 num_stretch_args += 1 147 continue 148 149 if arg_type == "array": 150 c_sig.append("int " + arg_name + "_count") 151 c_sig.append("const uint8_t *" + arg_name + "_val") 152 else: 153 c_sig.append(wltype_to_ctypes[arg_type] + arg_name) 154 w_args.append((arg_name, arg_type, arg_interface)) 155 if arg_type == "fd": 156 num_fd_args += 1 157 else: 158 num_reg_args += 1 159 if arg_type == "object" or arg_type == "new_id": 160 num_obj_args += 1 161 if arg_type == "new_id": 162 num_new_args += 1 163 if arg_type in ("array", "string"): 164 num_stretch_args += 1 165 166 do_signature = "void do_{}({});".format(func_name, ", ".join(c_sig)) 167 handle_signature = "static void call_{}(struct context *ctx, const uint32_t *payload, const int *fds, struct message_tracker *mt)".format( 168 func_name 169 ) 170 171 W = lambda *x: print(*x, file=ostream) 172 if is_header: 173 W(do_signature) 174 if not is_header: 175 # Write function definition 176 W(do_signature) 177 W(handle_signature + " {") 178 if num_reg_args > 0: 179 W("\tunsigned int i = 0;") 180 if num_fd_args > 0: 181 W("\tunsigned int k = 0;") 182 183 tmp_names = ["ctx"] 184 n_fds_left = num_fd_args 185 n_reg_left = num_reg_args 186 for i, (arg_name, arg_type, arg_interface) in enumerate(w_args): 187 if arg_type == "array": 188 n_reg_left -= 1 189 W( 190 "\tconst uint8_t *arg{}_b = (const uint8_t *)&payload[i + 1];".format( 191 i 192 ) 193 ) 194 W("\tint arg{}_a = (int)payload[i];".format(i)) 195 if n_reg_left > 0: 196 W("\ti += 1 + (unsigned int)((arg{}_a + 0x3) >> 2);".format(i)) 197 198 tmp_names.append("arg{}_a".format(i)) 199 tmp_names.append("arg{}_b".format(i)) 200 continue 201 202 tmp_names.append("arg{}".format(i)) 203 204 if arg_type == "fd": 205 n_fds_left -= 1 206 W("\tint arg{} = fds[{}];".format(i, "k++" if n_fds_left > 0 else "k")) 207 continue 208 209 n_reg_left -= 1 210 if arg_type == "string": 211 W("\tconst char *arg{} = (const char *)&payload[i + 1];".format(i)) 212 W("\tif (!payload[i]) arg{} = NULL;".format(i)) 213 if n_reg_left > 0: 214 W("\ti += 1 + ((payload[i] + 0x3) >> 2);") 215 continue 216 217 i_incr = "i++" if n_reg_left > 0 else "i" 218 219 if arg_type == "object" or arg_type == "new_id": 220 if arg_interface is None: 221 intf_str = "NULL" 222 else: 223 intf_str = "&intf_" + arg_interface 224 W( 225 "\tstruct wp_object *arg{} = get_object(mt, payload[{}], {});".format( 226 i, i_incr, intf_str 227 ) 228 ) 229 elif arg_type == "int": 230 W("\tint32_t arg{} = (int32_t)payload[{}];".format(i, i_incr)) 231 elif arg_type == "uint" or arg_type == "fixed": 232 W("\tuint32_t arg{} = payload[{}];".format(i, i_incr)) 233 234 W("\tdo_{}({});".format(func_name, ", ".join(tmp_names))) 235 if num_obj_args == 0: 236 W("\t(void)mt;") 237 if num_fd_args == 0: 238 W("\t(void)fds;") 239 if num_reg_args == 0: 240 W("\t(void)payload;") 241 242 W("}") 243 244 245def load_msg_data(func_name, func, for_export): 246 w_args = [] 247 for arg in func: 248 if arg.tag != "arg": 249 continue 250 arg_name = arg.attrib["name"] 251 arg_type = arg.attrib["type"] 252 arg_interface = arg.attrib["interface"] if "interface" in arg.attrib else None 253 if arg_type == "new_id" and arg_interface is None: 254 w_args.append(("interface", "string", None)) 255 w_args.append(("version", "uint", None)) 256 w_args.append((arg_name, "new_id", None)) 257 else: 258 w_args.append((arg_name, arg_type, arg_interface)) 259 260 new_objs = [] 261 for arg_name, arg_type, arg_interface in w_args: 262 if arg_type == "new_id": 263 new_objs.append( 264 "&intf_" + arg_interface if arg_interface is not None else "NULL" 265 ) 266 267 # gap coding: 0=end,1=new_obj,2=array,3=string 268 num_fd_args = 0 269 gaps = [0] 270 gap_ends = [] 271 for arg_name, arg_type, arg_interface in w_args: 272 if arg_type == "fd": 273 num_fd_args += 1 274 continue 275 276 gaps[-1] += 1 277 if arg_type in ("new_id", "string", "array"): 278 gap_ends.append({"new_id": 1, "string": 3, "array": 2}[arg_type]) 279 gaps.append(0) 280 gap_ends.append(0) 281 gap_codes = [str(g * 4 + e) for g, e in zip(gaps, gap_ends)] 282 283 is_destructor = "type" in func.attrib and func.attrib["type"] == "destructor" 284 is_request = item.tag == "request" 285 short_name = func.attrib["name"] 286 287 return ( 288 is_request, 289 func_name, 290 short_name, 291 new_objs, 292 gap_codes, 293 is_destructor, 294 num_fd_args, 295 for_export, 296 ) 297 298 299def write_interface( 300 ostream, iface_name, func_data, gap_code_array, new_obj_array, dest_name 301): 302 reqs, evts = [], [] 303 for x in func_data: 304 if x[0]: 305 reqs.append(x) 306 else: 307 evts.append(x) 308 309 W = lambda *x: print(*x, file=ostream) 310 311 if len(reqs) > 0 or len(evts) > 0: 312 W("static const struct msg_data msgs_" + iface_name + "[] = {") 313 314 msg_names = [] 315 for x in reqs + evts: 316 ( 317 is_request, 318 func_name, 319 short_name, 320 new_objs, 321 gap_codes, 322 is_destructor, 323 num_fd_args, 324 for_export, 325 ) = x 326 msg_names.append(short_name) 327 328 mda = [] 329 mda.append( 330 "gaps_{} + {}".format(dest_name, get_offset(gap_code_array, gap_codes)) 331 ) 332 if len(new_objs) > 0: 333 mda.append( 334 "objt_{} + {}".format(dest_name, get_offset(new_obj_array, new_objs)) 335 ) 336 else: 337 mda.append("NULL") 338 339 mda.append(("call_" + func_name) if for_export else "NULL") 340 mda.append(str(num_fd_args)) 341 mda.append("true" if is_destructor else "false") 342 343 W("\t{" + ", ".join(mda) + "},") 344 345 mcn = "NULL" 346 if len(reqs) > 0 or len(evts) > 0: 347 W("};") 348 mcn = "msgs_" + iface_name 349 350 W("const struct wp_interface intf_" + iface_name + " = {") 351 W("\t" + mcn + ",") 352 W("\t" + str(len(reqs)) + ",") 353 W("\t" + str(len(evts)) + ",") 354 W('\t"{}",'.format(iface_name)) 355 W('\t"{}",'.format("\\0".join(msg_names))) 356 W("};") 357 358 359if __name__ == "__main__": 360 parser = argparse.ArgumentParser() 361 parser.add_argument("mode", help="Either 'header' or 'data'.") 362 parser.add_argument( 363 "export_list", help="List of events/requests which need parsing." 364 ) 365 parser.add_argument("output_file", help="C file to create.") 366 parser.add_argument("protocols", nargs="+", help="XML protocol files to use.") 367 args = parser.parse_args() 368 369 is_header = {"data": False, "header": True}[args.mode] 370 if is_header: 371 assert args.output_file[-2:] == ".h" 372 else: 373 assert args.output_file[-2:] == ".c" 374 dest_name = os.path.basename(args.output_file)[:-2].replace("-", "_") 375 376 export_list = open(args.export_list).read().split("\n") 377 378 intfset = set() 379 for source in args.protocols: 380 tree = ET.parse(source) 381 root = tree.getroot() 382 for intf in root: 383 if intf.tag == "interface": 384 intfset.add(intf.attrib["name"]) 385 for msg in intf: 386 for arg in msg: 387 if "interface" in arg.attrib: 388 intfset.add(arg.attrib["interface"]) 389 interfaces = sorted(intfset) 390 391 header_guard = "{}_H".format(dest_name.upper()) 392 with open(args.output_file, "w") as ostream: 393 W = lambda *x: print(*x, file=ostream) 394 395 if is_header: 396 W("#ifndef {}".format(header_guard)) 397 W("#define {}".format(header_guard)) 398 W() 399 W('#include "symgen_types.h"') 400 if not is_header: 401 W("#include <stddef.h>") 402 403 for intf in interfaces: 404 W("extern const struct wp_interface intf_{};".format(intf)) 405 406 gap_code_list = [] 407 new_obj_list = [] 408 409 interface_data = [] 410 411 for source in sorted(args.protocols): 412 tree = ET.parse(source) 413 root = tree.getroot() 414 for interface in root: 415 if interface.tag != "interface": 416 continue 417 iface_name = interface.attrib["name"] 418 419 func_data = [] 420 for item in interface: 421 if item.tag == "enum": 422 write_enum(is_header, ostream, iface_name, item) 423 elif item.tag == "request" or item.tag == "event": 424 is_req = item.tag == "request" 425 func_name = ( 426 iface_name 427 + "_" 428 + ("req" if is_req else "evt") 429 + "_" 430 + item.attrib["name"] 431 ) 432 433 for_export = is_exportable(func_name, export_list) 434 if for_export: 435 write_func(is_header, ostream, func_name, item) 436 if not is_header: 437 func_data.append(load_msg_data(func_name, item, for_export)) 438 439 elif item.tag == "description": 440 pass 441 else: 442 raise Exception(item.tag) 443 444 for x in func_data: 445 gap_code_list.append(x[4]) 446 new_obj_list.append(x[3]) 447 448 interface_data.append((iface_name, func_data)) 449 450 if not is_header: 451 gap_code_array = shortest_superstring(gap_code_list) 452 new_obj_array = shortest_superstring(new_obj_list) 453 454 if new_obj_array is not None: 455 W("static const struct wp_interface *objt_" + dest_name + "[] = {") 456 W("\t" + ",\n\t".join(new_obj_array)) 457 W("};") 458 459 if gap_code_array is not None: 460 W("static const uint16_t gaps_" + dest_name + "[] = {") 461 W("\t" + ",\n\t".join(gap_code_array)) 462 W("};") 463 464 for iface_name, func_data in interface_data: 465 write_interface( 466 ostream, 467 iface_name, 468 func_data, 469 gap_code_array, 470 new_obj_array, 471 dest_name, 472 ) 473 474 if is_header: 475 W() 476 W("#endif /* {} */".format(header_guard)) 477