xref: /linux/tools/net/ynl/ynl-gen-c.py (revision d62c5d48)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import os
8import re
9import shutil
10import tempfile
11import yaml
12
13from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
14
15
16def c_upper(name):
17    return name.upper().replace('-', '_')
18
19
20def c_lower(name):
21    return name.lower().replace('-', '_')
22
23
24def limit_to_number(name):
25    """
26    Turn a string limit like u32-max or s64-min into its numerical value
27    """
28    if name[0] == 'u' and name.endswith('-min'):
29        return 0
30    width = int(name[1:-4])
31    if name[0] == 's':
32        width -= 1
33    value = (1 << width) - 1
34    if name[0] == 's' and name.endswith('-min'):
35        value = -value - 1
36    return value
37
38
39class BaseNlLib:
40    def get_family_id(self):
41        return 'ys->family_id'
42
43    def parse_cb_run(self, cb, data, is_dump=False, indent=1):
44        ind = '\n\t\t' + '\t' * indent + ' '
45        if is_dump:
46            return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
47        else:
48            return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
49                   "ynl_cb_array, NLMSG_MIN_TYPE)"
50
51
52class Type(SpecAttr):
53    def __init__(self, family, attr_set, attr, value):
54        super().__init__(family, attr_set, attr, value)
55
56        self.attr = attr
57        self.attr_set = attr_set
58        self.type = attr['type']
59        self.checks = attr.get('checks', {})
60
61        self.request = False
62        self.reply = False
63
64        if 'len' in attr:
65            self.len = attr['len']
66
67        if 'nested-attributes' in attr:
68            self.nested_attrs = attr['nested-attributes']
69            if self.nested_attrs == family.name:
70                self.nested_render_name = c_lower(f"{family.name}")
71            else:
72                self.nested_render_name = c_lower(f"{family.name}_{self.nested_attrs}")
73
74            if self.nested_attrs in self.family.consts:
75                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
76            else:
77                self.nested_struct_type = 'struct ' + self.nested_render_name
78
79        self.c_name = c_lower(self.name)
80        if self.c_name in _C_KW:
81            self.c_name += '_'
82
83        # Added by resolve():
84        self.enum_name = None
85        delattr(self, "enum_name")
86
87    def get_limit(self, limit, default=None):
88        value = self.checks.get(limit, default)
89        if value is None:
90            return value
91        if not isinstance(value, int):
92            value = limit_to_number(value)
93        return value
94
95    def resolve(self):
96        if 'name-prefix' in self.attr:
97            enum_name = f"{self.attr['name-prefix']}{self.name}"
98        else:
99            enum_name = f"{self.attr_set.name_prefix}{self.name}"
100        self.enum_name = c_upper(enum_name)
101
102    def is_multi_val(self):
103        return None
104
105    def is_scalar(self):
106        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
107
108    def is_recursive(self):
109        return False
110
111    def is_recursive_for_op(self, ri):
112        return self.is_recursive() and not ri.op
113
114    def presence_type(self):
115        return 'bit'
116
117    def presence_member(self, space, type_filter):
118        if self.presence_type() != type_filter:
119            return
120
121        if self.presence_type() == 'bit':
122            pfx = '__' if space == 'user' else ''
123            return f"{pfx}u32 {self.c_name}:1;"
124
125        if self.presence_type() == 'len':
126            pfx = '__' if space == 'user' else ''
127            return f"{pfx}u32 {self.c_name}_len;"
128
129    def _complex_member_type(self, ri):
130        return None
131
132    def free_needs_iter(self):
133        return False
134
135    def free(self, ri, var, ref):
136        if self.is_multi_val() or self.presence_type() == 'len':
137            ri.cw.p(f'free({var}->{ref}{self.c_name});')
138
139    def arg_member(self, ri):
140        member = self._complex_member_type(ri)
141        if member:
142            arg = [member + ' *' + self.c_name]
143            if self.presence_type() == 'count':
144                arg += ['unsigned int n_' + self.c_name]
145            return arg
146        raise Exception(f"Struct member not implemented for class type {self.type}")
147
148    def struct_member(self, ri):
149        if self.is_multi_val():
150            ri.cw.p(f"unsigned int n_{self.c_name};")
151        member = self._complex_member_type(ri)
152        if member:
153            ptr = '*' if self.is_multi_val() else ''
154            if self.is_recursive_for_op(ri):
155                ptr = '*'
156            ri.cw.p(f"{member} {ptr}{self.c_name};")
157            return
158        members = self.arg_member(ri)
159        for one in members:
160            ri.cw.p(one + ';')
161
162    def _attr_policy(self, policy):
163        return '{ .type = ' + policy + ', }'
164
165    def attr_policy(self, cw):
166        policy = c_upper('nla-' + self.attr['type'])
167
168        spec = self._attr_policy(policy)
169        cw.p(f"\t[{self.enum_name}] = {spec},")
170
171    def _attr_typol(self):
172        raise Exception(f"Type policy not implemented for class type {self.type}")
173
174    def attr_typol(self, cw):
175        typol = self._attr_typol()
176        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
177
178    def _attr_put_line(self, ri, var, line):
179        if self.presence_type() == 'bit':
180            ri.cw.p(f"if ({var}->_present.{self.c_name})")
181        elif self.presence_type() == 'len':
182            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
183        ri.cw.p(f"{line};")
184
185    def _attr_put_simple(self, ri, var, put_type):
186        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
187        self._attr_put_line(ri, var, line)
188
189    def attr_put(self, ri, var):
190        raise Exception(f"Put not implemented for class type {self.type}")
191
192    def _attr_get(self, ri, var):
193        raise Exception(f"Attr get not implemented for class type {self.type}")
194
195    def attr_get(self, ri, var, first):
196        lines, init_lines, local_vars = self._attr_get(ri, var)
197        if type(lines) is str:
198            lines = [lines]
199        if type(init_lines) is str:
200            init_lines = [init_lines]
201
202        kw = 'if' if first else 'else if'
203        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
204        if local_vars:
205            for local in local_vars:
206                ri.cw.p(local)
207            ri.cw.nl()
208
209        if not self.is_multi_val():
210            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
211            ri.cw.p("return MNL_CB_ERROR;")
212            if self.presence_type() == 'bit':
213                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
214
215        if init_lines:
216            ri.cw.nl()
217            for line in init_lines:
218                ri.cw.p(line)
219
220        for line in lines:
221            ri.cw.p(line)
222        ri.cw.block_end()
223        return True
224
225    def _setter_lines(self, ri, member, presence):
226        raise Exception(f"Setter not implemented for class type {self.type}")
227
228    def setter(self, ri, space, direction, deref=False, ref=None):
229        ref = (ref if ref else []) + [self.c_name]
230        var = "req"
231        member = f"{var}->{'.'.join(ref)}"
232
233        code = []
234        presence = ''
235        for i in range(0, len(ref)):
236            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
237            if self.presence_type() == 'bit':
238                code.append(presence + ' = 1;')
239        code += self._setter_lines(ri, member, presence)
240
241        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
242        free = bool([x for x in code if 'free(' in x])
243        alloc = bool([x for x in code if 'alloc(' in x])
244        if free and not alloc:
245            func_name = '__' + func_name
246        ri.cw.write_func('static inline void', func_name, body=code,
247                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
248
249
250class TypeUnused(Type):
251    def presence_type(self):
252        return ''
253
254    def arg_member(self, ri):
255        return []
256
257    def _attr_get(self, ri, var):
258        return ['return MNL_CB_ERROR;'], None, None
259
260    def _attr_typol(self):
261        return '.type = YNL_PT_REJECT, '
262
263    def attr_policy(self, cw):
264        pass
265
266    def attr_put(self, ri, var):
267        pass
268
269    def attr_get(self, ri, var, first):
270        pass
271
272    def setter(self, ri, space, direction, deref=False, ref=None):
273        pass
274
275
276class TypePad(Type):
277    def presence_type(self):
278        return ''
279
280    def arg_member(self, ri):
281        return []
282
283    def _attr_typol(self):
284        return '.type = YNL_PT_IGNORE, '
285
286    def attr_put(self, ri, var):
287        pass
288
289    def attr_get(self, ri, var, first):
290        pass
291
292    def attr_policy(self, cw):
293        pass
294
295    def setter(self, ri, space, direction, deref=False, ref=None):
296        pass
297
298
299class TypeScalar(Type):
300    def __init__(self, family, attr_set, attr, value):
301        super().__init__(family, attr_set, attr, value)
302
303        self.byte_order_comment = ''
304        if 'byte-order' in attr:
305            self.byte_order_comment = f" /* {attr['byte-order']} */"
306
307        if 'enum' in self.attr:
308            enum = self.family.consts[self.attr['enum']]
309            low, high = enum.value_range()
310            if 'min' not in self.checks:
311                if low != 0 or self.type[0] == 's':
312                    self.checks['min'] = low
313            if 'max' not in self.checks:
314                self.checks['max'] = high
315
316        if 'min' in self.checks and 'max' in self.checks:
317            if self.get_limit('min') > self.get_limit('max'):
318                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
319            self.checks['range'] = True
320
321        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
322        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
323        if low < 0 and self.type[0] == 'u':
324            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
325        if low < -32768 or high > 32767:
326            self.checks['full-range'] = True
327
328        # Added by resolve():
329        self.is_bitfield = None
330        delattr(self, "is_bitfield")
331        self.type_name = None
332        delattr(self, "type_name")
333
334    def resolve(self):
335        self.resolve_up(super())
336
337        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
338            self.is_bitfield = True
339        elif 'enum' in self.attr:
340            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
341        else:
342            self.is_bitfield = False
343
344        if not self.is_bitfield and 'enum' in self.attr:
345            self.type_name = self.family.consts[self.attr['enum']].user_type
346        elif self.is_auto_scalar:
347            self.type_name = '__' + self.type[0] + '64'
348        else:
349            self.type_name = '__' + self.type
350
351    def _attr_policy(self, policy):
352        if 'flags-mask' in self.checks or self.is_bitfield:
353            if self.is_bitfield:
354                enum = self.family.consts[self.attr['enum']]
355                mask = enum.get_mask(as_flags=True)
356            else:
357                flags = self.family.consts[self.checks['flags-mask']]
358                flag_cnt = len(flags['entries'])
359                mask = (1 << flag_cnt) - 1
360            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
361        elif 'full-range' in self.checks:
362            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
363        elif 'range' in self.checks:
364            return f"NLA_POLICY_RANGE({policy}, {self.get_limit('min')}, {self.get_limit('max')})"
365        elif 'min' in self.checks:
366            return f"NLA_POLICY_MIN({policy}, {self.get_limit('min')})"
367        elif 'max' in self.checks:
368            return f"NLA_POLICY_MAX({policy}, {self.get_limit('max')})"
369        return super()._attr_policy(policy)
370
371    def _attr_typol(self):
372        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
373
374    def arg_member(self, ri):
375        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
376
377    def attr_put(self, ri, var):
378        self._attr_put_simple(ri, var, self.type)
379
380    def _attr_get(self, ri, var):
381        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
382
383    def _setter_lines(self, ri, member, presence):
384        return [f"{member} = {self.c_name};"]
385
386
387class TypeFlag(Type):
388    def arg_member(self, ri):
389        return []
390
391    def _attr_typol(self):
392        return '.type = YNL_PT_FLAG, '
393
394    def attr_put(self, ri, var):
395        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
396
397    def _attr_get(self, ri, var):
398        return [], None, None
399
400    def _setter_lines(self, ri, member, presence):
401        return []
402
403
404class TypeString(Type):
405    def arg_member(self, ri):
406        return [f"const char *{self.c_name}"]
407
408    def presence_type(self):
409        return 'len'
410
411    def struct_member(self, ri):
412        ri.cw.p(f"char *{self.c_name};")
413
414    def _attr_typol(self):
415        return f'.type = YNL_PT_NUL_STR, '
416
417    def _attr_policy(self, policy):
418        if 'exact-len' in self.checks:
419            mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
420        else:
421            mem = '{ .type = ' + policy
422            if 'max-len' in self.checks:
423                mem += ', .len = ' + str(self.get_limit('max-len'))
424            mem += ', }'
425        return mem
426
427    def attr_policy(self, cw):
428        if self.checks.get('unterminated-ok', False):
429            policy = 'NLA_STRING'
430        else:
431            policy = 'NLA_NUL_STRING'
432
433        spec = self._attr_policy(policy)
434        cw.p(f"\t[{self.enum_name}] = {spec},")
435
436    def attr_put(self, ri, var):
437        self._attr_put_simple(ri, var, 'str')
438
439    def _attr_get(self, ri, var):
440        len_mem = var + '->_present.' + self.c_name + '_len'
441        return [f"{len_mem} = len;",
442                f"{var}->{self.c_name} = malloc(len + 1);",
443                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
444                f"{var}->{self.c_name}[len] = 0;"], \
445               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
446               ['unsigned int len;']
447
448    def _setter_lines(self, ri, member, presence):
449        return [f"free({member});",
450                f"{presence}_len = strlen({self.c_name});",
451                f"{member} = malloc({presence}_len + 1);",
452                f'memcpy({member}, {self.c_name}, {presence}_len);',
453                f'{member}[{presence}_len] = 0;']
454
455
456class TypeBinary(Type):
457    def arg_member(self, ri):
458        return [f"const void *{self.c_name}", 'size_t len']
459
460    def presence_type(self):
461        return 'len'
462
463    def struct_member(self, ri):
464        ri.cw.p(f"void *{self.c_name};")
465
466    def _attr_typol(self):
467        return f'.type = YNL_PT_BINARY,'
468
469    def _attr_policy(self, policy):
470        if 'exact-len' in self.checks:
471            mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
472        else:
473            mem = '{ '
474            if len(self.checks) == 1 and 'min-len' in self.checks:
475                mem += '.len = ' + str(self.get_limit('min-len'))
476            elif len(self.checks) == 0:
477                mem += '.type = NLA_BINARY'
478            else:
479                raise Exception('One or more of binary type checks not implemented, yet')
480            mem += ', }'
481        return mem
482
483    def attr_put(self, ri, var):
484        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
485                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
486
487    def _attr_get(self, ri, var):
488        len_mem = var + '->_present.' + self.c_name + '_len'
489        return [f"{len_mem} = len;",
490                f"{var}->{self.c_name} = malloc(len);",
491                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
492               ['len = ynl_attr_data_len(attr);'], \
493               ['unsigned int len;']
494
495    def _setter_lines(self, ri, member, presence):
496        return [f"free({member});",
497                f"{presence}_len = len;",
498                f"{member} = malloc({presence}_len);",
499                f'memcpy({member}, {self.c_name}, {presence}_len);']
500
501
502class TypeBitfield32(Type):
503    def _complex_member_type(self, ri):
504        return "struct nla_bitfield32"
505
506    def _attr_typol(self):
507        return f'.type = YNL_PT_BITFIELD32, '
508
509    def _attr_policy(self, policy):
510        if not 'enum' in self.attr:
511            raise Exception('Enum required for bitfield32 attr')
512        enum = self.family.consts[self.attr['enum']]
513        mask = enum.get_mask(as_flags=True)
514        return f"NLA_POLICY_BITFIELD32({mask})"
515
516    def attr_put(self, ri, var):
517        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
518        self._attr_put_line(ri, var, line)
519
520    def _attr_get(self, ri, var):
521        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
522
523    def _setter_lines(self, ri, member, presence):
524        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
525
526
527class TypeNest(Type):
528    def is_recursive(self):
529        return self.family.pure_nested_structs[self.nested_attrs].recursive
530
531    def _complex_member_type(self, ri):
532        return self.nested_struct_type
533
534    def free(self, ri, var, ref):
535        at = '&'
536        if self.is_recursive_for_op(ri):
537            at = ''
538            ri.cw.p(f'if ({var}->{ref}{self.c_name})')
539        ri.cw.p(f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});')
540
541    def _attr_typol(self):
542        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
543
544    def _attr_policy(self, policy):
545        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
546
547    def attr_put(self, ri, var):
548        at = '' if self.is_recursive_for_op(ri) else '&'
549        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
550                            f"{self.enum_name}, {at}{var}->{self.c_name})")
551
552    def _attr_get(self, ri, var):
553        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
554                     "return MNL_CB_ERROR;"]
555        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
556                      f"parg.data = &{var}->{self.c_name};"]
557        return get_lines, init_lines, None
558
559    def setter(self, ri, space, direction, deref=False, ref=None):
560        ref = (ref if ref else []) + [self.c_name]
561
562        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
563            if attr.is_recursive():
564                continue
565            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
566
567
568class TypeMultiAttr(Type):
569    def __init__(self, family, attr_set, attr, value, base_type):
570        super().__init__(family, attr_set, attr, value)
571
572        self.base_type = base_type
573
574    def is_multi_val(self):
575        return True
576
577    def presence_type(self):
578        return 'count'
579
580    def _complex_member_type(self, ri):
581        if 'type' not in self.attr or self.attr['type'] == 'nest':
582            return self.nested_struct_type
583        elif self.attr['type'] in scalars:
584            scalar_pfx = '__' if ri.ku_space == 'user' else ''
585            return scalar_pfx + self.attr['type']
586        else:
587            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
588
589    def free_needs_iter(self):
590        return 'type' not in self.attr or self.attr['type'] == 'nest'
591
592    def free(self, ri, var, ref):
593        if self.attr['type'] in scalars:
594            ri.cw.p(f"free({var}->{ref}{self.c_name});")
595        elif 'type' not in self.attr or self.attr['type'] == 'nest':
596            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
597            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
598            ri.cw.p(f"free({var}->{ref}{self.c_name});")
599        else:
600            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
601
602    def _attr_policy(self, policy):
603        return self.base_type._attr_policy(policy)
604
605    def _attr_typol(self):
606        return self.base_type._attr_typol()
607
608    def _attr_get(self, ri, var):
609        return f'n_{self.c_name}++;', None, None
610
611    def attr_put(self, ri, var):
612        if self.attr['type'] in scalars:
613            put_type = self.type
614            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
615            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
616        elif 'type' not in self.attr or self.attr['type'] == 'nest':
617            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
618            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
619                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
620        else:
621            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
622
623    def _setter_lines(self, ri, member, presence):
624        # For multi-attr we have a count, not presence, hack up the presence
625        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
626        return [f"free({member});",
627                f"{member} = {self.c_name};",
628                f"{presence} = n_{self.c_name};"]
629
630
631class TypeArrayNest(Type):
632    def is_multi_val(self):
633        return True
634
635    def presence_type(self):
636        return 'count'
637
638    def _complex_member_type(self, ri):
639        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
640            return self.nested_struct_type
641        elif self.attr['sub-type'] in scalars:
642            scalar_pfx = '__' if ri.ku_space == 'user' else ''
643            return scalar_pfx + self.attr['sub-type']
644        else:
645            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
646
647    def _attr_typol(self):
648        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
649
650    def _attr_get(self, ri, var):
651        local_vars = ['const struct nlattr *attr2;']
652        get_lines = [f'attr_{self.c_name} = attr;',
653                     'ynl_attr_for_each_nested(attr2, attr)',
654                     f'\t{var}->n_{self.c_name}++;']
655        return get_lines, None, local_vars
656
657
658class TypeNestTypeValue(Type):
659    def _complex_member_type(self, ri):
660        return self.nested_struct_type
661
662    def _attr_typol(self):
663        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
664
665    def _attr_get(self, ri, var):
666        prev = 'attr'
667        tv_args = ''
668        get_lines = []
669        local_vars = []
670        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
671                      f"parg.data = &{var}->{self.c_name};"]
672        if 'type-value' in self.attr:
673            tv_names = [c_lower(x) for x in self.attr["type-value"]]
674            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
675            local_vars += [f'__u32 {", ".join(tv_names)};']
676            for level in self.attr["type-value"]:
677                level = c_lower(level)
678                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
679                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
680                prev = 'attr_' + level
681
682            tv_args = f", {', '.join(tv_names)}"
683
684        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
685        return get_lines, init_lines, local_vars
686
687
688class Struct:
689    def __init__(self, family, space_name, type_list=None, inherited=None):
690        self.family = family
691        self.space_name = space_name
692        self.attr_set = family.attr_sets[space_name]
693        # Use list to catch comparisons with empty sets
694        self._inherited = inherited if inherited is not None else []
695        self.inherited = []
696
697        self.nested = type_list is None
698        if family.name == c_lower(space_name):
699            self.render_name = c_lower(family.name)
700        else:
701            self.render_name = c_lower(family.name + '-' + space_name)
702        self.struct_name = 'struct ' + self.render_name
703        if self.nested and space_name in family.consts:
704            self.struct_name += '_'
705        self.ptr_name = self.struct_name + ' *'
706        # All attr sets this one contains, directly or multiple levels down
707        self.child_nests = set()
708
709        self.request = False
710        self.reply = False
711        self.recursive = False
712
713        self.attr_list = []
714        self.attrs = dict()
715        if type_list is not None:
716            for t in type_list:
717                self.attr_list.append((t, self.attr_set[t]),)
718        else:
719            for t in self.attr_set:
720                self.attr_list.append((t, self.attr_set[t]),)
721
722        max_val = 0
723        self.attr_max_val = None
724        for name, attr in self.attr_list:
725            if attr.value >= max_val:
726                max_val = attr.value
727                self.attr_max_val = attr
728            self.attrs[name] = attr
729
730    def __iter__(self):
731        yield from self.attrs
732
733    def __getitem__(self, key):
734        return self.attrs[key]
735
736    def member_list(self):
737        return self.attr_list
738
739    def set_inherited(self, new_inherited):
740        if self._inherited != new_inherited:
741            raise Exception("Inheriting different members not supported")
742        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
743
744
745class EnumEntry(SpecEnumEntry):
746    def __init__(self, enum_set, yaml, prev, value_start):
747        super().__init__(enum_set, yaml, prev, value_start)
748
749        if prev:
750            self.value_change = (self.value != prev.value + 1)
751        else:
752            self.value_change = (self.value != 0)
753        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
754
755        # Added by resolve:
756        self.c_name = None
757        delattr(self, "c_name")
758
759    def resolve(self):
760        self.resolve_up(super())
761
762        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
763
764
765class EnumSet(SpecEnumSet):
766    def __init__(self, family, yaml):
767        self.render_name = c_lower(family.name + '-' + yaml['name'])
768
769        if 'enum-name' in yaml:
770            if yaml['enum-name']:
771                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
772                self.user_type = self.enum_name
773            else:
774                self.enum_name = None
775        else:
776            self.enum_name = 'enum ' + self.render_name
777
778        if self.enum_name:
779            self.user_type = self.enum_name
780        else:
781            self.user_type = 'int'
782
783        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
784
785        super().__init__(family, yaml)
786
787    def new_entry(self, entry, prev_entry, value_start):
788        return EnumEntry(self, entry, prev_entry, value_start)
789
790    def value_range(self):
791        low = min([x.value for x in self.entries.values()])
792        high = max([x.value for x in self.entries.values()])
793
794        if high - low + 1 != len(self.entries):
795            raise Exception("Can't get value range for a noncontiguous enum")
796
797        return low, high
798
799
800class AttrSet(SpecAttrSet):
801    def __init__(self, family, yaml):
802        super().__init__(family, yaml)
803
804        if self.subset_of is None:
805            if 'name-prefix' in yaml:
806                pfx = yaml['name-prefix']
807            elif self.name == family.name:
808                pfx = family.name + '-a-'
809            else:
810                pfx = f"{family.name}-a-{self.name}-"
811            self.name_prefix = c_upper(pfx)
812            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
813            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
814        else:
815            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
816            self.max_name = family.attr_sets[self.subset_of].max_name
817            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
818
819        # Added by resolve:
820        self.c_name = None
821        delattr(self, "c_name")
822
823    def resolve(self):
824        self.c_name = c_lower(self.name)
825        if self.c_name in _C_KW:
826            self.c_name += '_'
827        if self.c_name == self.family.c_name:
828            self.c_name = ''
829
830    def new_attr(self, elem, value):
831        if elem['type'] in scalars:
832            t = TypeScalar(self.family, self, elem, value)
833        elif elem['type'] == 'unused':
834            t = TypeUnused(self.family, self, elem, value)
835        elif elem['type'] == 'pad':
836            t = TypePad(self.family, self, elem, value)
837        elif elem['type'] == 'flag':
838            t = TypeFlag(self.family, self, elem, value)
839        elif elem['type'] == 'string':
840            t = TypeString(self.family, self, elem, value)
841        elif elem['type'] == 'binary':
842            t = TypeBinary(self.family, self, elem, value)
843        elif elem['type'] == 'bitfield32':
844            t = TypeBitfield32(self.family, self, elem, value)
845        elif elem['type'] == 'nest':
846            t = TypeNest(self.family, self, elem, value)
847        elif elem['type'] == 'array-nest':
848            t = TypeArrayNest(self.family, self, elem, value)
849        elif elem['type'] == 'nest-type-value':
850            t = TypeNestTypeValue(self.family, self, elem, value)
851        else:
852            raise Exception(f"No typed class for type {elem['type']}")
853
854        if 'multi-attr' in elem and elem['multi-attr']:
855            t = TypeMultiAttr(self.family, self, elem, value, t)
856
857        return t
858
859
860class Operation(SpecOperation):
861    def __init__(self, family, yaml, req_value, rsp_value):
862        super().__init__(family, yaml, req_value, rsp_value)
863
864        self.render_name = c_lower(family.name + '_' + self.name)
865
866        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
867                         ('dump' in yaml and 'request' in yaml['dump'])
868
869        self.has_ntf = False
870
871        # Added by resolve:
872        self.enum_name = None
873        delattr(self, "enum_name")
874
875    def resolve(self):
876        self.resolve_up(super())
877
878        if not self.is_async:
879            self.enum_name = self.family.op_prefix + c_upper(self.name)
880        else:
881            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
882
883    def mark_has_ntf(self):
884        self.has_ntf = True
885
886
887class Family(SpecFamily):
888    def __init__(self, file_name, exclude_ops):
889        # Added by resolve:
890        self.c_name = None
891        delattr(self, "c_name")
892        self.op_prefix = None
893        delattr(self, "op_prefix")
894        self.async_op_prefix = None
895        delattr(self, "async_op_prefix")
896        self.mcgrps = None
897        delattr(self, "mcgrps")
898        self.consts = None
899        delattr(self, "consts")
900        self.hooks = None
901        delattr(self, "hooks")
902
903        super().__init__(file_name, exclude_ops=exclude_ops)
904
905        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
906        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
907
908        if 'definitions' not in self.yaml:
909            self.yaml['definitions'] = []
910
911        if 'uapi-header' in self.yaml:
912            self.uapi_header = self.yaml['uapi-header']
913        else:
914            self.uapi_header = f"linux/{self.name}.h"
915        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
916            self.uapi_header_name = self.uapi_header[6:-2]
917        else:
918            self.uapi_header_name = self.name
919
920    def resolve(self):
921        self.resolve_up(super())
922
923        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
924            raise Exception("Codegen only supported for genetlink")
925
926        self.c_name = c_lower(self.name)
927        if 'name-prefix' in self.yaml['operations']:
928            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
929        else:
930            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
931        if 'async-prefix' in self.yaml['operations']:
932            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
933        else:
934            self.async_op_prefix = self.op_prefix
935
936        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
937
938        self.hooks = dict()
939        for when in ['pre', 'post']:
940            self.hooks[when] = dict()
941            for op_mode in ['do', 'dump']:
942                self.hooks[when][op_mode] = dict()
943                self.hooks[when][op_mode]['set'] = set()
944                self.hooks[when][op_mode]['list'] = []
945
946        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
947        self.root_sets = dict()
948        # dict space-name -> set('request', 'reply')
949        self.pure_nested_structs = dict()
950
951        self._mark_notify()
952        self._mock_up_events()
953
954        self._load_root_sets()
955        self._load_nested_sets()
956        self._load_attr_use()
957        self._load_hooks()
958
959        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
960        if self.kernel_policy == 'global':
961            self._load_global_policy()
962
963    def new_enum(self, elem):
964        return EnumSet(self, elem)
965
966    def new_attr_set(self, elem):
967        return AttrSet(self, elem)
968
969    def new_operation(self, elem, req_value, rsp_value):
970        return Operation(self, elem, req_value, rsp_value)
971
972    def _mark_notify(self):
973        for op in self.msgs.values():
974            if 'notify' in op:
975                self.ops[op['notify']].mark_has_ntf()
976
977    # Fake a 'do' equivalent of all events, so that we can render their response parsing
978    def _mock_up_events(self):
979        for op in self.yaml['operations']['list']:
980            if 'event' in op:
981                op['do'] = {
982                    'reply': {
983                        'attributes': op['event']['attributes']
984                    }
985                }
986
987    def _load_root_sets(self):
988        for op_name, op in self.msgs.items():
989            if 'attribute-set' not in op:
990                continue
991
992            req_attrs = set()
993            rsp_attrs = set()
994            for op_mode in ['do', 'dump']:
995                if op_mode in op and 'request' in op[op_mode]:
996                    req_attrs.update(set(op[op_mode]['request']['attributes']))
997                if op_mode in op and 'reply' in op[op_mode]:
998                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
999            if 'event' in op:
1000                rsp_attrs.update(set(op['event']['attributes']))
1001
1002            if op['attribute-set'] not in self.root_sets:
1003                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1004            else:
1005                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1006                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1007
1008    def _sort_pure_types(self):
1009        # Try to reorder according to dependencies
1010        pns_key_list = list(self.pure_nested_structs.keys())
1011        pns_key_seen = set()
1012        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1013        for _ in range(rounds):
1014            if len(pns_key_list) == 0:
1015                break
1016            name = pns_key_list.pop(0)
1017            finished = True
1018            for _, spec in self.attr_sets[name].items():
1019                if 'nested-attributes' in spec:
1020                    nested = spec['nested-attributes']
1021                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1022                    if self.pure_nested_structs[nested].recursive:
1023                        continue
1024                    if nested not in pns_key_seen:
1025                        # Dicts are sorted, this will make struct last
1026                        struct = self.pure_nested_structs.pop(name)
1027                        self.pure_nested_structs[name] = struct
1028                        finished = False
1029                        break
1030            if finished:
1031                pns_key_seen.add(name)
1032            else:
1033                pns_key_list.append(name)
1034
1035    def _load_nested_sets(self):
1036        attr_set_queue = list(self.root_sets.keys())
1037        attr_set_seen = set(self.root_sets.keys())
1038
1039        while len(attr_set_queue):
1040            a_set = attr_set_queue.pop(0)
1041            for attr, spec in self.attr_sets[a_set].items():
1042                if 'nested-attributes' not in spec:
1043                    continue
1044
1045                nested = spec['nested-attributes']
1046                if nested not in attr_set_seen:
1047                    attr_set_queue.append(nested)
1048                    attr_set_seen.add(nested)
1049
1050                inherit = set()
1051                if nested not in self.root_sets:
1052                    if nested not in self.pure_nested_structs:
1053                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1054                else:
1055                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1056
1057                if 'type-value' in spec:
1058                    if nested in self.root_sets:
1059                        raise Exception("Inheriting members to a space used as root not supported")
1060                    inherit.update(set(spec['type-value']))
1061                elif spec['type'] == 'array-nest':
1062                    inherit.add('idx')
1063                self.pure_nested_structs[nested].set_inherited(inherit)
1064
1065        for root_set, rs_members in self.root_sets.items():
1066            for attr, spec in self.attr_sets[root_set].items():
1067                if 'nested-attributes' in spec:
1068                    nested = spec['nested-attributes']
1069                    if attr in rs_members['request']:
1070                        self.pure_nested_structs[nested].request = True
1071                    if attr in rs_members['reply']:
1072                        self.pure_nested_structs[nested].reply = True
1073
1074        self._sort_pure_types()
1075
1076        # Propagate the request / reply / recursive
1077        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1078            for _, spec in self.attr_sets[attr_set].items():
1079                if 'nested-attributes' in spec:
1080                    child_name = spec['nested-attributes']
1081                    struct.child_nests.add(child_name)
1082                    child = self.pure_nested_structs.get(child_name)
1083                    if child:
1084                        if not child.recursive:
1085                            struct.child_nests.update(child.child_nests)
1086                        child.request |= struct.request
1087                        child.reply |= struct.reply
1088                if attr_set in struct.child_nests:
1089                    struct.recursive = True
1090
1091        self._sort_pure_types()
1092
1093    def _load_attr_use(self):
1094        for _, struct in self.pure_nested_structs.items():
1095            if struct.request:
1096                for _, arg in struct.member_list():
1097                    arg.request = True
1098            if struct.reply:
1099                for _, arg in struct.member_list():
1100                    arg.reply = True
1101
1102        for root_set, rs_members in self.root_sets.items():
1103            for attr, spec in self.attr_sets[root_set].items():
1104                if attr in rs_members['request']:
1105                    spec.request = True
1106                if attr in rs_members['reply']:
1107                    spec.reply = True
1108
1109    def _load_global_policy(self):
1110        global_set = set()
1111        attr_set_name = None
1112        for op_name, op in self.ops.items():
1113            if not op:
1114                continue
1115            if 'attribute-set' not in op:
1116                continue
1117
1118            if attr_set_name is None:
1119                attr_set_name = op['attribute-set']
1120            if attr_set_name != op['attribute-set']:
1121                raise Exception('For a global policy all ops must use the same set')
1122
1123            for op_mode in ['do', 'dump']:
1124                if op_mode in op:
1125                    req = op[op_mode].get('request')
1126                    if req:
1127                        global_set.update(req.get('attributes', []))
1128
1129        self.global_policy = []
1130        self.global_policy_set = attr_set_name
1131        for attr in self.attr_sets[attr_set_name]:
1132            if attr in global_set:
1133                self.global_policy.append(attr)
1134
1135    def _load_hooks(self):
1136        for op in self.ops.values():
1137            for op_mode in ['do', 'dump']:
1138                if op_mode not in op:
1139                    continue
1140                for when in ['pre', 'post']:
1141                    if when not in op[op_mode]:
1142                        continue
1143                    name = op[op_mode][when]
1144                    if name in self.hooks[when][op_mode]['set']:
1145                        continue
1146                    self.hooks[when][op_mode]['set'].add(name)
1147                    self.hooks[when][op_mode]['list'].append(name)
1148
1149
1150class RenderInfo:
1151    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1152        self.family = family
1153        self.nl = cw.nlib
1154        self.ku_space = ku_space
1155        self.op_mode = op_mode
1156        self.op = op
1157
1158        self.fixed_hdr = None
1159        if op and op.fixed_header:
1160            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1161
1162        # 'do' and 'dump' response parsing is identical
1163        self.type_consistent = True
1164        if op_mode != 'do' and 'dump' in op:
1165            if 'do' in op:
1166                if ('reply' in op['do']) != ('reply' in op["dump"]):
1167                    self.type_consistent = False
1168                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1169                    self.type_consistent = False
1170            else:
1171                self.type_consistent = False
1172
1173        self.attr_set = attr_set
1174        if not self.attr_set:
1175            self.attr_set = op['attribute-set']
1176
1177        self.type_name_conflict = False
1178        if op:
1179            self.type_name = c_lower(op.name)
1180        else:
1181            self.type_name = c_lower(attr_set)
1182            if attr_set in family.consts:
1183                self.type_name_conflict = True
1184
1185        self.cw = cw
1186
1187        self.struct = dict()
1188        if op_mode == 'notify':
1189            op_mode = 'do'
1190        for op_dir in ['request', 'reply']:
1191            if op:
1192                type_list = []
1193                if op_dir in op[op_mode]:
1194                    type_list = op[op_mode][op_dir]['attributes']
1195                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1196        if op_mode == 'event':
1197            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1198
1199
1200class CodeWriter:
1201    def __init__(self, nlib, out_file=None, overwrite=True):
1202        self.nlib = nlib
1203        self._overwrite = overwrite
1204
1205        self._nl = False
1206        self._block_end = False
1207        self._silent_block = False
1208        self._ind = 0
1209        self._ifdef_block = None
1210        if out_file is None:
1211            self._out = os.sys.stdout
1212        else:
1213            self._out = tempfile.NamedTemporaryFile('w+')
1214            self._out_file = out_file
1215
1216    def __del__(self):
1217        self.close_out_file()
1218
1219    def close_out_file(self):
1220        if self._out == os.sys.stdout:
1221            return
1222        # Avoid modifying the file if contents didn't change
1223        self._out.flush()
1224        if not self._overwrite and os.path.isfile(self._out_file):
1225            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1226                return
1227        with open(self._out_file, 'w+') as out_file:
1228            self._out.seek(0)
1229            shutil.copyfileobj(self._out, out_file)
1230            self._out.close()
1231        self._out = os.sys.stdout
1232
1233    @classmethod
1234    def _is_cond(cls, line):
1235        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1236
1237    def p(self, line, add_ind=0):
1238        if self._block_end:
1239            self._block_end = False
1240            if line.startswith('else'):
1241                line = '} ' + line
1242            else:
1243                self._out.write('\t' * self._ind + '}\n')
1244
1245        if self._nl:
1246            self._out.write('\n')
1247            self._nl = False
1248
1249        ind = self._ind
1250        if line[-1] == ':':
1251            ind -= 1
1252        if self._silent_block:
1253            ind += 1
1254        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1255        if line[0] == '#':
1256            ind = 0
1257        if add_ind:
1258            ind += add_ind
1259        self._out.write('\t' * ind + line + '\n')
1260
1261    def nl(self):
1262        self._nl = True
1263
1264    def block_start(self, line=''):
1265        if line:
1266            line = line + ' '
1267        self.p(line + '{')
1268        self._ind += 1
1269
1270    def block_end(self, line=''):
1271        if line and line[0] not in {';', ','}:
1272            line = ' ' + line
1273        self._ind -= 1
1274        self._nl = False
1275        if not line:
1276            # Delay printing closing bracket in case "else" comes next
1277            if self._block_end:
1278                self._out.write('\t' * (self._ind + 1) + '}\n')
1279            self._block_end = True
1280        else:
1281            self.p('}' + line)
1282
1283    def write_doc_line(self, doc, indent=True):
1284        words = doc.split()
1285        line = ' *'
1286        for word in words:
1287            if len(line) + len(word) >= 79:
1288                self.p(line)
1289                line = ' *'
1290                if indent:
1291                    line += '  '
1292            line += ' ' + word
1293        self.p(line)
1294
1295    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1296        if not args:
1297            args = ['void']
1298
1299        if doc:
1300            self.p('/*')
1301            self.p(' * ' + doc)
1302            self.p(' */')
1303
1304        oneline = qual_ret
1305        if qual_ret[-1] != '*':
1306            oneline += ' '
1307        oneline += f"{name}({', '.join(args)}){suffix}"
1308
1309        if len(oneline) < 80:
1310            self.p(oneline)
1311            return
1312
1313        v = qual_ret
1314        if len(v) > 3:
1315            self.p(v)
1316            v = ''
1317        elif qual_ret[-1] != '*':
1318            v += ' '
1319        v += name + '('
1320        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1321        delta_ind = len(v) - len(ind)
1322        v += args[0]
1323        i = 1
1324        while i < len(args):
1325            next_len = len(v) + len(args[i])
1326            if v[0] == '\t':
1327                next_len += delta_ind
1328            if next_len > 76:
1329                self.p(v + ',')
1330                v = ind
1331            else:
1332                v += ', '
1333            v += args[i]
1334            i += 1
1335        self.p(v + ')' + suffix)
1336
1337    def write_func_lvar(self, local_vars):
1338        if not local_vars:
1339            return
1340
1341        if type(local_vars) is str:
1342            local_vars = [local_vars]
1343
1344        local_vars.sort(key=len, reverse=True)
1345        for var in local_vars:
1346            self.p(var)
1347        self.nl()
1348
1349    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1350        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1351        self.write_func_lvar(local_vars=local_vars)
1352
1353        self.block_start()
1354        for line in body:
1355            self.p(line)
1356        self.block_end()
1357
1358    def writes_defines(self, defines):
1359        longest = 0
1360        for define in defines:
1361            if len(define[0]) > longest:
1362                longest = len(define[0])
1363        longest = ((longest + 8) // 8) * 8
1364        for define in defines:
1365            line = '#define ' + define[0]
1366            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1367            if type(define[1]) is int:
1368                line += str(define[1])
1369            elif type(define[1]) is str:
1370                line += '"' + define[1] + '"'
1371            self.p(line)
1372
1373    def write_struct_init(self, members):
1374        longest = max([len(x[0]) for x in members])
1375        longest += 1  # because we prepend a .
1376        longest = ((longest + 8) // 8) * 8
1377        for one in members:
1378            line = '.' + one[0]
1379            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1380            line += '= ' + str(one[1]) + ','
1381            self.p(line)
1382
1383    def ifdef_block(self, config):
1384        config_option = None
1385        if config:
1386            config_option = 'CONFIG_' + c_upper(config)
1387        if self._ifdef_block == config_option:
1388            return
1389
1390        if self._ifdef_block:
1391            self.p('#endif /* ' + self._ifdef_block + ' */')
1392        if config_option:
1393            self.p('#ifdef ' + config_option)
1394        self._ifdef_block = config_option
1395
1396
1397scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64', 'uint', 'sint'}
1398
1399direction_to_suffix = {
1400    'reply': '_rsp',
1401    'request': '_req',
1402    '': ''
1403}
1404
1405op_mode_to_wrapper = {
1406    'do': '',
1407    'dump': '_list',
1408    'notify': '_ntf',
1409    'event': '',
1410}
1411
1412_C_KW = {
1413    'auto',
1414    'bool',
1415    'break',
1416    'case',
1417    'char',
1418    'const',
1419    'continue',
1420    'default',
1421    'do',
1422    'double',
1423    'else',
1424    'enum',
1425    'extern',
1426    'float',
1427    'for',
1428    'goto',
1429    'if',
1430    'inline',
1431    'int',
1432    'long',
1433    'register',
1434    'return',
1435    'short',
1436    'signed',
1437    'sizeof',
1438    'static',
1439    'struct',
1440    'switch',
1441    'typedef',
1442    'union',
1443    'unsigned',
1444    'void',
1445    'volatile',
1446    'while'
1447}
1448
1449
1450def rdir(direction):
1451    if direction == 'reply':
1452        return 'request'
1453    if direction == 'request':
1454        return 'reply'
1455    return direction
1456
1457
1458def op_prefix(ri, direction, deref=False):
1459    suffix = f"_{ri.type_name}"
1460
1461    if not ri.op_mode or ri.op_mode == 'do':
1462        suffix += f"{direction_to_suffix[direction]}"
1463    else:
1464        if direction == 'request':
1465            suffix += '_req_dump'
1466        else:
1467            if ri.type_consistent:
1468                if deref:
1469                    suffix += f"{direction_to_suffix[direction]}"
1470                else:
1471                    suffix += op_mode_to_wrapper[ri.op_mode]
1472            else:
1473                suffix += '_rsp'
1474                suffix += '_dump' if deref else '_list'
1475
1476    return f"{ri.family.c_name}{suffix}"
1477
1478
1479def type_name(ri, direction, deref=False):
1480    return f"struct {op_prefix(ri, direction, deref=deref)}"
1481
1482
1483def print_prototype(ri, direction, terminate=True, doc=None):
1484    suffix = ';' if terminate else ''
1485
1486    fname = ri.op.render_name
1487    if ri.op_mode == 'dump':
1488        fname += '_dump'
1489
1490    args = ['struct ynl_sock *ys']
1491    if 'request' in ri.op[ri.op_mode]:
1492        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1493
1494    ret = 'int'
1495    if 'reply' in ri.op[ri.op_mode]:
1496        ret = f"{type_name(ri, rdir(direction))} *"
1497
1498    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1499
1500
1501def print_req_prototype(ri):
1502    print_prototype(ri, "request", doc=ri.op['doc'])
1503
1504
1505def print_dump_prototype(ri):
1506    print_prototype(ri, "request")
1507
1508
1509def put_typol_fwd(cw, struct):
1510    cw.p(f'extern struct ynl_policy_nest {struct.render_name}_nest;')
1511
1512
1513def put_typol(cw, struct):
1514    type_max = struct.attr_set.max_name
1515    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1516
1517    for _, arg in struct.member_list():
1518        arg.attr_typol(cw)
1519
1520    cw.block_end(line=';')
1521    cw.nl()
1522
1523    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1524    cw.p(f'.max_attr = {type_max},')
1525    cw.p(f'.table = {struct.render_name}_policy,')
1526    cw.block_end(line=';')
1527    cw.nl()
1528
1529
1530def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1531    args = [f'int {arg_name}']
1532    if enum:
1533        args = [enum.user_type + ' ' + arg_name]
1534    cw.write_func_prot('const char *', f'{render_name}_str', args)
1535    cw.block_start()
1536    if enum and enum.type == 'flags':
1537        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1538    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1539    cw.p('return NULL;')
1540    cw.p(f'return {map_name}[{arg_name}];')
1541    cw.block_end()
1542    cw.nl()
1543
1544
1545def put_op_name_fwd(family, cw):
1546    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1547
1548
1549def put_op_name(family, cw):
1550    map_name = f'{family.c_name}_op_strmap'
1551    cw.block_start(line=f"static const char * const {map_name}[] =")
1552    for op_name, op in family.msgs.items():
1553        if op.rsp_value:
1554            # Make sure we don't add duplicated entries, if multiple commands
1555            # produce the same response in legacy families.
1556            if family.rsp_by_value[op.rsp_value] != op:
1557                cw.p(f'// skip "{op_name}", duplicate reply value')
1558                continue
1559
1560            if op.req_value == op.rsp_value:
1561                cw.p(f'[{op.enum_name}] = "{op_name}",')
1562            else:
1563                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1564    cw.block_end(line=';')
1565    cw.nl()
1566
1567    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1568
1569
1570def put_enum_to_str_fwd(family, cw, enum):
1571    args = [enum.user_type + ' value']
1572    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1573
1574
1575def put_enum_to_str(family, cw, enum):
1576    map_name = f'{enum.render_name}_strmap'
1577    cw.block_start(line=f"static const char * const {map_name}[] =")
1578    for entry in enum.entries.values():
1579        cw.p(f'[{entry.value}] = "{entry.name}",')
1580    cw.block_end(line=';')
1581    cw.nl()
1582
1583    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1584
1585
1586def put_req_nested_prototype(ri, struct, suffix=';'):
1587    func_args = ['struct nlmsghdr *nlh',
1588                 'unsigned int attr_type',
1589                 f'{struct.ptr_name}obj']
1590
1591    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1592                          suffix=suffix)
1593
1594
1595def put_req_nested(ri, struct):
1596    put_req_nested_prototype(ri, struct, suffix='')
1597    ri.cw.block_start()
1598    ri.cw.write_func_lvar('struct nlattr *nest;')
1599
1600    ri.cw.p("nest = ynl_attr_nest_start(nlh, attr_type);")
1601
1602    for _, arg in struct.member_list():
1603        arg.attr_put(ri, "obj")
1604
1605    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1606
1607    ri.cw.nl()
1608    ri.cw.p('return 0;')
1609    ri.cw.block_end()
1610    ri.cw.nl()
1611
1612
1613def _multi_parse(ri, struct, init_lines, local_vars):
1614    if struct.nested:
1615        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1616    else:
1617        if ri.fixed_hdr:
1618            local_vars += ['void *hdr;']
1619        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1620
1621    array_nests = set()
1622    multi_attrs = set()
1623    needs_parg = False
1624    for arg, aspec in struct.member_list():
1625        if aspec['type'] == 'array-nest':
1626            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1627            array_nests.add(arg)
1628        if 'multi-attr' in aspec:
1629            multi_attrs.add(arg)
1630        needs_parg |= 'nested-attributes' in aspec
1631    if array_nests or multi_attrs:
1632        local_vars.append('int i;')
1633    if needs_parg:
1634        local_vars.append('struct ynl_parse_arg parg;')
1635        init_lines.append('parg.ys = yarg->ys;')
1636
1637    all_multi = array_nests | multi_attrs
1638
1639    for anest in sorted(all_multi):
1640        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1641
1642    ri.cw.block_start()
1643    ri.cw.write_func_lvar(local_vars)
1644
1645    for line in init_lines:
1646        ri.cw.p(line)
1647    ri.cw.nl()
1648
1649    for arg in struct.inherited:
1650        ri.cw.p(f'dst->{arg} = {arg};')
1651
1652    if ri.fixed_hdr:
1653        ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1654        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1655    for anest in sorted(all_multi):
1656        aspec = struct[anest]
1657        ri.cw.p(f"if (dst->{aspec.c_name})")
1658        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1659
1660    ri.cw.nl()
1661    ri.cw.block_start(line=iter_line)
1662    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1663    ri.cw.nl()
1664
1665    first = True
1666    for _, arg in struct.member_list():
1667        good = arg.attr_get(ri, 'dst', first=first)
1668        # First may be 'unused' or 'pad', ignore those
1669        first &= not good
1670
1671    ri.cw.block_end()
1672    ri.cw.nl()
1673
1674    for anest in sorted(array_nests):
1675        aspec = struct[anest]
1676
1677        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1678        ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1679        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1680        ri.cw.p('i = 0;')
1681        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1682        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1683        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1684        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1685        ri.cw.p('return MNL_CB_ERROR;')
1686        ri.cw.p('i++;')
1687        ri.cw.block_end()
1688        ri.cw.block_end()
1689    ri.cw.nl()
1690
1691    for anest in sorted(multi_attrs):
1692        aspec = struct[anest]
1693        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1694        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1695        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1696        ri.cw.p('i = 0;')
1697        if 'nested-attributes' in aspec:
1698            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1699        ri.cw.block_start(line=iter_line)
1700        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1701        if 'nested-attributes' in aspec:
1702            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1703            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1704            ri.cw.p('return MNL_CB_ERROR;')
1705        elif aspec.type in scalars:
1706            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1707        else:
1708            raise Exception('Nest parsing type not supported yet')
1709        ri.cw.p('i++;')
1710        ri.cw.block_end()
1711        ri.cw.block_end()
1712        ri.cw.block_end()
1713    ri.cw.nl()
1714
1715    if struct.nested:
1716        ri.cw.p('return 0;')
1717    else:
1718        ri.cw.p('return MNL_CB_OK;')
1719    ri.cw.block_end()
1720    ri.cw.nl()
1721
1722
1723def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1724    func_args = ['struct ynl_parse_arg *yarg',
1725                 'const struct nlattr *nested']
1726    for arg in struct.inherited:
1727        func_args.append('__u32 ' + arg)
1728
1729    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1730                          suffix=suffix)
1731
1732
1733def parse_rsp_nested(ri, struct):
1734    parse_rsp_nested_prototype(ri, struct, suffix='')
1735
1736    local_vars = ['const struct nlattr *attr;',
1737                  f'{struct.ptr_name}dst = yarg->data;']
1738    init_lines = []
1739
1740    _multi_parse(ri, struct, init_lines, local_vars)
1741
1742
1743def parse_rsp_msg(ri, deref=False):
1744    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1745        return
1746
1747    func_args = ['const struct nlmsghdr *nlh',
1748                 'void *data']
1749
1750    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1751                  'struct ynl_parse_arg *yarg = data;',
1752                  'const struct nlattr *attr;']
1753    init_lines = ['dst = yarg->data;']
1754
1755    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1756
1757    if ri.struct["reply"].member_list():
1758        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1759    else:
1760        # Empty reply
1761        ri.cw.block_start()
1762        ri.cw.p('return MNL_CB_OK;')
1763        ri.cw.block_end()
1764        ri.cw.nl()
1765
1766
1767def print_req(ri):
1768    ret_ok = '0'
1769    ret_err = '-1'
1770    direction = "request"
1771    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1772                  'struct nlmsghdr *nlh;',
1773                  'int err;']
1774
1775    if 'reply' in ri.op[ri.op_mode]:
1776        ret_ok = 'rsp'
1777        ret_err = 'NULL'
1778        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1779
1780    if ri.fixed_hdr:
1781        local_vars += ['size_t hdr_len;',
1782                       'void *hdr;']
1783
1784    print_prototype(ri, direction, terminate=False)
1785    ri.cw.block_start()
1786    ri.cw.write_func_lvar(local_vars)
1787
1788    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1789
1790    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1791    if 'reply' in ri.op[ri.op_mode]:
1792        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1793    ri.cw.nl()
1794
1795    if ri.fixed_hdr:
1796        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1797        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1798        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1799        ri.cw.nl()
1800
1801    for _, attr in ri.struct["request"].member_list():
1802        attr.attr_put(ri, "req")
1803    ri.cw.nl()
1804
1805    if 'reply' in ri.op[ri.op_mode]:
1806        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1807        ri.cw.p('yrs.yarg.data = rsp;')
1808        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1809        if ri.op.value is not None:
1810            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1811        else:
1812            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1813        ri.cw.nl()
1814    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1815    ri.cw.p('if (err < 0)')
1816    if 'reply' in ri.op[ri.op_mode]:
1817        ri.cw.p('goto err_free;')
1818    else:
1819        ri.cw.p('return -1;')
1820    ri.cw.nl()
1821
1822    ri.cw.p(f"return {ret_ok};")
1823    ri.cw.nl()
1824
1825    if 'reply' in ri.op[ri.op_mode]:
1826        ri.cw.p('err_free:')
1827        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1828        ri.cw.p(f"return {ret_err};")
1829
1830    ri.cw.block_end()
1831
1832
1833def print_dump(ri):
1834    direction = "request"
1835    print_prototype(ri, direction, terminate=False)
1836    ri.cw.block_start()
1837    local_vars = ['struct ynl_dump_state yds = {};',
1838                  'struct nlmsghdr *nlh;',
1839                  'int err;']
1840
1841    if ri.fixed_hdr:
1842        local_vars += ['size_t hdr_len;',
1843                       'void *hdr;']
1844
1845    ri.cw.write_func_lvar(local_vars)
1846
1847    ri.cw.p('yds.yarg.ys = ys;')
1848    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1849    ri.cw.p("yds.yarg.data = NULL;")
1850    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1851    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1852    if ri.op.value is not None:
1853        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1854    else:
1855        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1856    ri.cw.nl()
1857    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1858
1859    if ri.fixed_hdr:
1860        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1861        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1862        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1863        ri.cw.nl()
1864
1865    if "request" in ri.op[ri.op_mode]:
1866        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1867        ri.cw.nl()
1868        for _, attr in ri.struct["request"].member_list():
1869            attr.attr_put(ri, "req")
1870    ri.cw.nl()
1871
1872    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1873    ri.cw.p('if (err < 0)')
1874    ri.cw.p('goto free_list;')
1875    ri.cw.nl()
1876
1877    ri.cw.p('return yds.first;')
1878    ri.cw.nl()
1879    ri.cw.p('free_list:')
1880    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1881    ri.cw.p('return NULL;')
1882    ri.cw.block_end()
1883
1884
1885def call_free(ri, direction, var):
1886    return f"{op_prefix(ri, direction)}_free({var});"
1887
1888
1889def free_arg_name(direction):
1890    if direction:
1891        return direction_to_suffix[direction][1:]
1892    return 'obj'
1893
1894
1895def print_alloc_wrapper(ri, direction):
1896    name = op_prefix(ri, direction)
1897    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1898    ri.cw.block_start()
1899    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1900    ri.cw.block_end()
1901
1902
1903def print_free_prototype(ri, direction, suffix=';'):
1904    name = op_prefix(ri, direction)
1905    struct_name = name
1906    if ri.type_name_conflict:
1907        struct_name += '_'
1908    arg = free_arg_name(direction)
1909    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1910
1911
1912def _print_type(ri, direction, struct):
1913    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1914    if not direction and ri.type_name_conflict:
1915        suffix += '_'
1916
1917    if ri.op_mode == 'dump':
1918        suffix += '_dump'
1919
1920    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
1921
1922    if ri.fixed_hdr:
1923        ri.cw.p(ri.fixed_hdr + ' _hdr;')
1924        ri.cw.nl()
1925
1926    meta_started = False
1927    for _, attr in struct.member_list():
1928        for type_filter in ['len', 'bit']:
1929            line = attr.presence_member(ri.ku_space, type_filter)
1930            if line:
1931                if not meta_started:
1932                    ri.cw.block_start(line=f"struct")
1933                    meta_started = True
1934                ri.cw.p(line)
1935    if meta_started:
1936        ri.cw.block_end(line='_present;')
1937        ri.cw.nl()
1938
1939    for arg in struct.inherited:
1940        ri.cw.p(f"__u32 {arg};")
1941
1942    for _, attr in struct.member_list():
1943        attr.struct_member(ri)
1944
1945    ri.cw.block_end(line=';')
1946    ri.cw.nl()
1947
1948
1949def print_type(ri, direction):
1950    _print_type(ri, direction, ri.struct[direction])
1951
1952
1953def print_type_full(ri, struct):
1954    _print_type(ri, "", struct)
1955
1956
1957def print_type_helpers(ri, direction, deref=False):
1958    print_free_prototype(ri, direction)
1959    ri.cw.nl()
1960
1961    if ri.ku_space == 'user' and direction == 'request':
1962        for _, attr in ri.struct[direction].member_list():
1963            attr.setter(ri, ri.attr_set, direction, deref=deref)
1964    ri.cw.nl()
1965
1966
1967def print_req_type_helpers(ri):
1968    if len(ri.struct["request"].attr_list) == 0:
1969        return
1970    print_alloc_wrapper(ri, "request")
1971    print_type_helpers(ri, "request")
1972
1973
1974def print_rsp_type_helpers(ri):
1975    if 'reply' not in ri.op[ri.op_mode]:
1976        return
1977    print_type_helpers(ri, "reply")
1978
1979
1980def print_parse_prototype(ri, direction, terminate=True):
1981    suffix = "_rsp" if direction == "reply" else "_req"
1982    term = ';' if terminate else ''
1983
1984    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1985                          ['const struct nlattr **tb',
1986                           f"struct {ri.op.render_name}{suffix} *req"],
1987                          suffix=term)
1988
1989
1990def print_req_type(ri):
1991    if len(ri.struct["request"].attr_list) == 0:
1992        return
1993    print_type(ri, "request")
1994
1995
1996def print_req_free(ri):
1997    if 'request' not in ri.op[ri.op_mode]:
1998        return
1999    _free_type(ri, 'request', ri.struct['request'])
2000
2001
2002def print_rsp_type(ri):
2003    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2004        direction = 'reply'
2005    elif ri.op_mode == 'event':
2006        direction = 'reply'
2007    else:
2008        return
2009    print_type(ri, direction)
2010
2011
2012def print_wrapped_type(ri):
2013    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2014    if ri.op_mode == 'dump':
2015        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2016    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2017        ri.cw.p('__u16 family;')
2018        ri.cw.p('__u8 cmd;')
2019        ri.cw.p('struct ynl_ntf_base_type *next;')
2020        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2021    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2022    ri.cw.block_end(line=';')
2023    ri.cw.nl()
2024    print_free_prototype(ri, 'reply')
2025    ri.cw.nl()
2026
2027
2028def _free_type_members_iter(ri, struct):
2029    for _, attr in struct.member_list():
2030        if attr.free_needs_iter():
2031            ri.cw.p('unsigned int i;')
2032            ri.cw.nl()
2033            break
2034
2035
2036def _free_type_members(ri, var, struct, ref=''):
2037    for _, attr in struct.member_list():
2038        attr.free(ri, var, ref)
2039
2040
2041def _free_type(ri, direction, struct):
2042    var = free_arg_name(direction)
2043
2044    print_free_prototype(ri, direction, suffix='')
2045    ri.cw.block_start()
2046    _free_type_members_iter(ri, struct)
2047    _free_type_members(ri, var, struct)
2048    if direction:
2049        ri.cw.p(f'free({var});')
2050    ri.cw.block_end()
2051    ri.cw.nl()
2052
2053
2054def free_rsp_nested_prototype(ri):
2055        print_free_prototype(ri, "")
2056
2057
2058def free_rsp_nested(ri, struct):
2059    _free_type(ri, "", struct)
2060
2061
2062def print_rsp_free(ri):
2063    if 'reply' not in ri.op[ri.op_mode]:
2064        return
2065    _free_type(ri, 'reply', ri.struct['reply'])
2066
2067
2068def print_dump_type_free(ri):
2069    sub_type = type_name(ri, 'reply')
2070
2071    print_free_prototype(ri, 'reply', suffix='')
2072    ri.cw.block_start()
2073    ri.cw.p(f"{sub_type} *next = rsp;")
2074    ri.cw.nl()
2075    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2076    _free_type_members_iter(ri, ri.struct['reply'])
2077    ri.cw.p('rsp = next;')
2078    ri.cw.p('next = rsp->next;')
2079    ri.cw.nl()
2080
2081    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2082    ri.cw.p(f'free(rsp);')
2083    ri.cw.block_end()
2084    ri.cw.block_end()
2085    ri.cw.nl()
2086
2087
2088def print_ntf_type_free(ri):
2089    print_free_prototype(ri, 'reply', suffix='')
2090    ri.cw.block_start()
2091    _free_type_members_iter(ri, ri.struct['reply'])
2092    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2093    ri.cw.p(f'free(rsp);')
2094    ri.cw.block_end()
2095    ri.cw.nl()
2096
2097
2098def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2099    if terminate and ri and policy_should_be_static(struct.family):
2100        return
2101
2102    if terminate:
2103        prefix = 'extern '
2104    else:
2105        if ri and policy_should_be_static(struct.family):
2106            prefix = 'static '
2107        else:
2108            prefix = ''
2109
2110    suffix = ';' if terminate else ' = {'
2111
2112    max_attr = struct.attr_max_val
2113    if ri:
2114        name = ri.op.render_name
2115        if ri.op.dual_policy:
2116            name += '_' + ri.op_mode
2117    else:
2118        name = struct.render_name
2119    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2120
2121
2122def print_req_policy(cw, struct, ri=None):
2123    if ri and ri.op:
2124        cw.ifdef_block(ri.op.get('config-cond', None))
2125    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2126    for _, arg in struct.member_list():
2127        arg.attr_policy(cw)
2128    cw.p("};")
2129    cw.ifdef_block(None)
2130    cw.nl()
2131
2132
2133def kernel_can_gen_family_struct(family):
2134    return family.proto == 'genetlink'
2135
2136
2137def policy_should_be_static(family):
2138    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2139
2140
2141def print_kernel_policy_ranges(family, cw):
2142    first = True
2143    for _, attr_set in family.attr_sets.items():
2144        if attr_set.subset_of:
2145            continue
2146
2147        for _, attr in attr_set.items():
2148            if not attr.request:
2149                continue
2150            if 'full-range' not in attr.checks:
2151                continue
2152
2153            if first:
2154                cw.p('/* Integer value ranges */')
2155                first = False
2156
2157            sign = '' if attr.type[0] == 'u' else '_signed'
2158            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2159            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2160            members = []
2161            if 'min' in attr.checks:
2162                members.append(('min', str(attr.get_limit('min')) + suffix))
2163            if 'max' in attr.checks:
2164                members.append(('max', str(attr.get_limit('max')) + suffix))
2165            cw.write_struct_init(members)
2166            cw.block_end(line=';')
2167            cw.nl()
2168
2169
2170def print_kernel_op_table_fwd(family, cw, terminate):
2171    exported = not kernel_can_gen_family_struct(family)
2172
2173    if not terminate or exported:
2174        cw.p(f"/* Ops table for {family.name} */")
2175
2176        pol_to_struct = {'global': 'genl_small_ops',
2177                         'per-op': 'genl_ops',
2178                         'split': 'genl_split_ops'}
2179        struct_type = pol_to_struct[family.kernel_policy]
2180
2181        if not exported:
2182            cnt = ""
2183        elif family.kernel_policy == 'split':
2184            cnt = 0
2185            for op in family.ops.values():
2186                if 'do' in op:
2187                    cnt += 1
2188                if 'dump' in op:
2189                    cnt += 1
2190        else:
2191            cnt = len(family.ops)
2192
2193        qual = 'static const' if not exported else 'const'
2194        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2195        if terminate:
2196            cw.p(f"extern {line};")
2197        else:
2198            cw.block_start(line=line + ' =')
2199
2200    if not terminate:
2201        return
2202
2203    cw.nl()
2204    for name in family.hooks['pre']['do']['list']:
2205        cw.write_func_prot('int', c_lower(name),
2206                           ['const struct genl_split_ops *ops',
2207                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2208    for name in family.hooks['post']['do']['list']:
2209        cw.write_func_prot('void', c_lower(name),
2210                           ['const struct genl_split_ops *ops',
2211                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2212    for name in family.hooks['pre']['dump']['list']:
2213        cw.write_func_prot('int', c_lower(name),
2214                           ['struct netlink_callback *cb'], suffix=';')
2215    for name in family.hooks['post']['dump']['list']:
2216        cw.write_func_prot('int', c_lower(name),
2217                           ['struct netlink_callback *cb'], suffix=';')
2218
2219    cw.nl()
2220
2221    for op_name, op in family.ops.items():
2222        if op.is_async:
2223            continue
2224
2225        if 'do' in op:
2226            name = c_lower(f"{family.name}-nl-{op_name}-doit")
2227            cw.write_func_prot('int', name,
2228                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2229
2230        if 'dump' in op:
2231            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
2232            cw.write_func_prot('int', name,
2233                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2234    cw.nl()
2235
2236
2237def print_kernel_op_table_hdr(family, cw):
2238    print_kernel_op_table_fwd(family, cw, terminate=True)
2239
2240
2241def print_kernel_op_table(family, cw):
2242    print_kernel_op_table_fwd(family, cw, terminate=False)
2243    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2244        for op_name, op in family.ops.items():
2245            if op.is_async:
2246                continue
2247
2248            cw.ifdef_block(op.get('config-cond', None))
2249            cw.block_start()
2250            members = [('cmd', op.enum_name)]
2251            if 'dont-validate' in op:
2252                members.append(('validate',
2253                                ' | '.join([c_upper('genl-dont-validate-' + x)
2254                                            for x in op['dont-validate']])), )
2255            for op_mode in ['do', 'dump']:
2256                if op_mode in op:
2257                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2258                    members.append((op_mode + 'it', name))
2259            if family.kernel_policy == 'per-op':
2260                struct = Struct(family, op['attribute-set'],
2261                                type_list=op['do']['request']['attributes'])
2262
2263                name = c_lower(f"{family.name}-{op_name}-nl-policy")
2264                members.append(('policy', name))
2265                members.append(('maxattr', struct.attr_max_val.enum_name))
2266            if 'flags' in op:
2267                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2268            cw.write_struct_init(members)
2269            cw.block_end(line=',')
2270    elif family.kernel_policy == 'split':
2271        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2272                    'dump': {'pre': 'start', 'post': 'done'}}
2273
2274        for op_name, op in family.ops.items():
2275            for op_mode in ['do', 'dump']:
2276                if op.is_async or op_mode not in op:
2277                    continue
2278
2279                cw.ifdef_block(op.get('config-cond', None))
2280                cw.block_start()
2281                members = [('cmd', op.enum_name)]
2282                if 'dont-validate' in op:
2283                    dont_validate = []
2284                    for x in op['dont-validate']:
2285                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2286                            continue
2287                        if op_mode == "dump" and x == 'strict':
2288                            continue
2289                        dont_validate.append(x)
2290
2291                    if dont_validate:
2292                        members.append(('validate',
2293                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2294                                                    for x in dont_validate])), )
2295                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2296                if 'pre' in op[op_mode]:
2297                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2298                members.append((op_mode + 'it', name))
2299                if 'post' in op[op_mode]:
2300                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2301                if 'request' in op[op_mode]:
2302                    struct = Struct(family, op['attribute-set'],
2303                                    type_list=op[op_mode]['request']['attributes'])
2304
2305                    if op.dual_policy:
2306                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2307                    else:
2308                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
2309                    members.append(('policy', name))
2310                    members.append(('maxattr', struct.attr_max_val.enum_name))
2311                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2312                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2313                cw.write_struct_init(members)
2314                cw.block_end(line=',')
2315    cw.ifdef_block(None)
2316
2317    cw.block_end(line=';')
2318    cw.nl()
2319
2320
2321def print_kernel_mcgrp_hdr(family, cw):
2322    if not family.mcgrps['list']:
2323        return
2324
2325    cw.block_start('enum')
2326    for grp in family.mcgrps['list']:
2327        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2328        cw.p(grp_id)
2329    cw.block_end(';')
2330    cw.nl()
2331
2332
2333def print_kernel_mcgrp_src(family, cw):
2334    if not family.mcgrps['list']:
2335        return
2336
2337    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2338    for grp in family.mcgrps['list']:
2339        name = grp['name']
2340        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2341        cw.p('[' + grp_id + '] = { "' + name + '", },')
2342    cw.block_end(';')
2343    cw.nl()
2344
2345
2346def print_kernel_family_struct_hdr(family, cw):
2347    if not kernel_can_gen_family_struct(family):
2348        return
2349
2350    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2351    cw.nl()
2352
2353
2354def print_kernel_family_struct_src(family, cw):
2355    if not kernel_can_gen_family_struct(family):
2356        return
2357
2358    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2359    cw.p('.name\t\t= ' + family.fam_key + ',')
2360    cw.p('.version\t= ' + family.ver_key + ',')
2361    cw.p('.netnsok\t= true,')
2362    cw.p('.parallel_ops\t= true,')
2363    cw.p('.module\t\t= THIS_MODULE,')
2364    if family.kernel_policy == 'per-op':
2365        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2366        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2367    elif family.kernel_policy == 'split':
2368        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2369        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2370    if family.mcgrps['list']:
2371        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2372        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2373    cw.block_end(';')
2374
2375
2376def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2377    start_line = 'enum'
2378    if enum_name in obj:
2379        if obj[enum_name]:
2380            start_line = 'enum ' + c_lower(obj[enum_name])
2381    elif ckey and ckey in obj:
2382        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2383    cw.block_start(line=start_line)
2384
2385
2386def render_uapi(family, cw):
2387    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2388    cw.p('#ifndef ' + hdr_prot)
2389    cw.p('#define ' + hdr_prot)
2390    cw.nl()
2391
2392    defines = [(family.fam_key, family["name"]),
2393               (family.ver_key, family.get('version', 1))]
2394    cw.writes_defines(defines)
2395    cw.nl()
2396
2397    defines = []
2398    for const in family['definitions']:
2399        if const['type'] != 'const':
2400            cw.writes_defines(defines)
2401            defines = []
2402            cw.nl()
2403
2404        # Write kdoc for enum and flags (one day maybe also structs)
2405        if const['type'] == 'enum' or const['type'] == 'flags':
2406            enum = family.consts[const['name']]
2407
2408            if enum.has_doc():
2409                cw.p('/**')
2410                doc = ''
2411                if 'doc' in enum:
2412                    doc = ' - ' + enum['doc']
2413                cw.write_doc_line(enum.enum_name + doc)
2414                for entry in enum.entries.values():
2415                    if entry.has_doc():
2416                        doc = '@' + entry.c_name + ': ' + entry['doc']
2417                        cw.write_doc_line(doc)
2418                cw.p(' */')
2419
2420            uapi_enum_start(family, cw, const, 'name')
2421            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2422            for entry in enum.entries.values():
2423                suffix = ','
2424                if entry.value_change:
2425                    suffix = f" = {entry.user_value()}" + suffix
2426                cw.p(entry.c_name + suffix)
2427
2428            if const.get('render-max', False):
2429                cw.nl()
2430                cw.p('/* private: */')
2431                if const['type'] == 'flags':
2432                    max_name = c_upper(name_pfx + 'mask')
2433                    max_val = f' = {enum.get_mask()},'
2434                    cw.p(max_name + max_val)
2435                else:
2436                    max_name = c_upper(name_pfx + 'max')
2437                    cw.p('__' + max_name + ',')
2438                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2439            cw.block_end(line=';')
2440            cw.nl()
2441        elif const['type'] == 'const':
2442            defines.append([c_upper(family.get('c-define-name',
2443                                               f"{family.name}-{const['name']}")),
2444                            const['value']])
2445
2446    if defines:
2447        cw.writes_defines(defines)
2448        cw.nl()
2449
2450    max_by_define = family.get('max-by-define', False)
2451
2452    for _, attr_set in family.attr_sets.items():
2453        if attr_set.subset_of:
2454            continue
2455
2456        max_value = f"({attr_set.cnt_name} - 1)"
2457
2458        val = 0
2459        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2460        for _, attr in attr_set.items():
2461            suffix = ','
2462            if attr.value != val:
2463                suffix = f" = {attr.value},"
2464                val = attr.value
2465            val += 1
2466            cw.p(attr.enum_name + suffix)
2467        cw.nl()
2468        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2469        if not max_by_define:
2470            cw.p(f"{attr_set.max_name} = {max_value}")
2471        cw.block_end(line=';')
2472        if max_by_define:
2473            cw.p(f"#define {attr_set.max_name} {max_value}")
2474        cw.nl()
2475
2476    # Commands
2477    separate_ntf = 'async-prefix' in family['operations']
2478
2479    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2480    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2481    max_value = f"({cnt_name} - 1)"
2482
2483    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2484    val = 0
2485    for op in family.msgs.values():
2486        if separate_ntf and ('notify' in op or 'event' in op):
2487            continue
2488
2489        suffix = ','
2490        if op.value != val:
2491            suffix = f" = {op.value},"
2492            val = op.value
2493        cw.p(op.enum_name + suffix)
2494        val += 1
2495    cw.nl()
2496    cw.p(cnt_name + ('' if max_by_define else ','))
2497    if not max_by_define:
2498        cw.p(f"{max_name} = {max_value}")
2499    cw.block_end(line=';')
2500    if max_by_define:
2501        cw.p(f"#define {max_name} {max_value}")
2502    cw.nl()
2503
2504    if separate_ntf:
2505        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2506        for op in family.msgs.values():
2507            if separate_ntf and not ('notify' in op or 'event' in op):
2508                continue
2509
2510            suffix = ','
2511            if 'value' in op:
2512                suffix = f" = {op['value']},"
2513            cw.p(op.enum_name + suffix)
2514        cw.block_end(line=';')
2515        cw.nl()
2516
2517    # Multicast
2518    defines = []
2519    for grp in family.mcgrps['list']:
2520        name = grp['name']
2521        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2522                        f'{name}'])
2523    cw.nl()
2524    if defines:
2525        cw.writes_defines(defines)
2526        cw.nl()
2527
2528    cw.p(f'#endif /* {hdr_prot} */')
2529
2530
2531def _render_user_ntf_entry(ri, op):
2532    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2533    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2534    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2535    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2536    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2537    ri.cw.block_end(line=',')
2538
2539
2540def render_user_family(family, cw, prototype):
2541    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2542    if prototype:
2543        cw.p(f'extern {symbol};')
2544        return
2545
2546    if family.ntfs:
2547        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2548        for ntf_op_name, ntf_op in family.ntfs.items():
2549            if 'notify' in ntf_op:
2550                op = family.ops[ntf_op['notify']]
2551                ri = RenderInfo(cw, family, "user", op, "notify")
2552            elif 'event' in ntf_op:
2553                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2554            else:
2555                raise Exception('Invalid notification ' + ntf_op_name)
2556            _render_user_ntf_entry(ri, ntf_op)
2557        for op_name, op in family.ops.items():
2558            if 'event' not in op:
2559                continue
2560            ri = RenderInfo(cw, family, "user", op, "event")
2561            _render_user_ntf_entry(ri, op)
2562        cw.block_end(line=";")
2563        cw.nl()
2564
2565    cw.block_start(f'{symbol} = ')
2566    cw.p(f'.name\t\t= "{family.c_name}",')
2567    if family.fixed_header:
2568        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2569    else:
2570        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2571    if family.ntfs:
2572        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2573        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family['name']}_ntf_info),")
2574    cw.block_end(line=';')
2575
2576
2577def family_contains_bitfield32(family):
2578    for _, attr_set in family.attr_sets.items():
2579        if attr_set.subset_of:
2580            continue
2581        for _, attr in attr_set.items():
2582            if attr.type == "bitfield32":
2583                return True
2584    return False
2585
2586
2587def find_kernel_root(full_path):
2588    sub_path = ''
2589    while True:
2590        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2591        full_path = os.path.dirname(full_path)
2592        maintainers = os.path.join(full_path, "MAINTAINERS")
2593        if os.path.exists(maintainers):
2594            return full_path, sub_path[:-1]
2595
2596
2597def main():
2598    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2599    parser.add_argument('--mode', dest='mode', type=str, required=True)
2600    parser.add_argument('--spec', dest='spec', type=str, required=True)
2601    parser.add_argument('--header', dest='header', action='store_true', default=None)
2602    parser.add_argument('--source', dest='header', action='store_false')
2603    parser.add_argument('--user-header', nargs='+', default=[])
2604    parser.add_argument('--cmp-out', action='store_true', default=None,
2605                        help='Do not overwrite the output file if the new output is identical to the old')
2606    parser.add_argument('--exclude-op', action='append', default=[])
2607    parser.add_argument('-o', dest='out_file', type=str, default=None)
2608    args = parser.parse_args()
2609
2610    if args.header is None:
2611        parser.error("--header or --source is required")
2612
2613    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2614
2615    try:
2616        parsed = Family(args.spec, exclude_ops)
2617        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2618            print('Spec license:', parsed.license)
2619            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2620            os.sys.exit(1)
2621    except yaml.YAMLError as exc:
2622        print(exc)
2623        os.sys.exit(1)
2624        return
2625
2626    supported_models = ['unified']
2627    if args.mode in ['user', 'kernel']:
2628        supported_models += ['directional']
2629    if parsed.msg_id_model not in supported_models:
2630        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2631        os.sys.exit(1)
2632
2633    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2634
2635    _, spec_kernel = find_kernel_root(args.spec)
2636    if args.mode == 'uapi' or args.header:
2637        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2638    else:
2639        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2640    cw.p("/* Do not edit directly, auto-generated from: */")
2641    cw.p(f"/*\t{spec_kernel} */")
2642    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2643    if args.exclude_op or args.user_header:
2644        line = ''
2645        line += ' --user-header '.join([''] + args.user_header)
2646        line += ' --exclude-op '.join([''] + args.exclude_op)
2647        cw.p(f'/* YNL-ARG{line} */')
2648    cw.nl()
2649
2650    if args.mode == 'uapi':
2651        render_uapi(parsed, cw)
2652        return
2653
2654    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2655    if args.header:
2656        cw.p('#ifndef ' + hdr_prot)
2657        cw.p('#define ' + hdr_prot)
2658        cw.nl()
2659
2660    if args.mode == 'kernel':
2661        cw.p('#include <net/netlink.h>')
2662        cw.p('#include <net/genetlink.h>')
2663        cw.nl()
2664        if not args.header:
2665            if args.out_file:
2666                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2667            cw.nl()
2668        headers = ['uapi/' + parsed.uapi_header]
2669    else:
2670        cw.p('#include <stdlib.h>')
2671        cw.p('#include <string.h>')
2672        if args.header:
2673            cw.p('#include <linux/types.h>')
2674            if family_contains_bitfield32(parsed):
2675                cw.p('#include <linux/netlink.h>')
2676        else:
2677            cw.p(f'#include "{parsed.name}-user.h"')
2678            cw.p('#include "ynl.h"')
2679        headers = [parsed.uapi_header]
2680    for definition in parsed['definitions']:
2681        if 'header' in definition:
2682            headers.append(definition['header'])
2683    for one in headers:
2684        cw.p(f"#include <{one}>")
2685    cw.nl()
2686
2687    if args.mode == "user":
2688        if not args.header:
2689            cw.p("#include <libmnl/libmnl.h>")
2690            cw.p("#include <linux/genetlink.h>")
2691            cw.nl()
2692            for one in args.user_header:
2693                cw.p(f'#include "{one}"')
2694        else:
2695            cw.p('struct ynl_sock;')
2696            cw.nl()
2697            render_user_family(parsed, cw, True)
2698        cw.nl()
2699
2700    if args.mode == "kernel":
2701        if args.header:
2702            for _, struct in sorted(parsed.pure_nested_structs.items()):
2703                if struct.request:
2704                    cw.p('/* Common nested types */')
2705                    break
2706            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2707                if struct.request:
2708                    print_req_policy_fwd(cw, struct)
2709            cw.nl()
2710
2711            if parsed.kernel_policy == 'global':
2712                cw.p(f"/* Global operation policy for {parsed.name} */")
2713
2714                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2715                print_req_policy_fwd(cw, struct)
2716                cw.nl()
2717
2718            if parsed.kernel_policy in {'per-op', 'split'}:
2719                for op_name, op in parsed.ops.items():
2720                    if 'do' in op and 'event' not in op:
2721                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2722                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2723                        cw.nl()
2724
2725            print_kernel_op_table_hdr(parsed, cw)
2726            print_kernel_mcgrp_hdr(parsed, cw)
2727            print_kernel_family_struct_hdr(parsed, cw)
2728        else:
2729            print_kernel_policy_ranges(parsed, cw)
2730
2731            for _, struct in sorted(parsed.pure_nested_structs.items()):
2732                if struct.request:
2733                    cw.p('/* Common nested types */')
2734                    break
2735            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2736                if struct.request:
2737                    print_req_policy(cw, struct)
2738            cw.nl()
2739
2740            if parsed.kernel_policy == 'global':
2741                cw.p(f"/* Global operation policy for {parsed.name} */")
2742
2743                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2744                print_req_policy(cw, struct)
2745                cw.nl()
2746
2747            for op_name, op in parsed.ops.items():
2748                if parsed.kernel_policy in {'per-op', 'split'}:
2749                    for op_mode in ['do', 'dump']:
2750                        if op_mode in op and 'request' in op[op_mode]:
2751                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2752                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2753                            print_req_policy(cw, ri.struct['request'], ri=ri)
2754                            cw.nl()
2755
2756            print_kernel_op_table(parsed, cw)
2757            print_kernel_mcgrp_src(parsed, cw)
2758            print_kernel_family_struct_src(parsed, cw)
2759
2760    if args.mode == "user":
2761        if args.header:
2762            cw.p('/* Enums */')
2763            put_op_name_fwd(parsed, cw)
2764
2765            for name, const in parsed.consts.items():
2766                if isinstance(const, EnumSet):
2767                    put_enum_to_str_fwd(parsed, cw, const)
2768            cw.nl()
2769
2770            cw.p('/* Common nested types */')
2771            for attr_set, struct in parsed.pure_nested_structs.items():
2772                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2773                print_type_full(ri, struct)
2774
2775            for op_name, op in parsed.ops.items():
2776                cw.p(f"/* ============== {op.enum_name} ============== */")
2777
2778                if 'do' in op and 'event' not in op:
2779                    cw.p(f"/* {op.enum_name} - do */")
2780                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2781                    print_req_type(ri)
2782                    print_req_type_helpers(ri)
2783                    cw.nl()
2784                    print_rsp_type(ri)
2785                    print_rsp_type_helpers(ri)
2786                    cw.nl()
2787                    print_req_prototype(ri)
2788                    cw.nl()
2789
2790                if 'dump' in op:
2791                    cw.p(f"/* {op.enum_name} - dump */")
2792                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2793                    print_req_type(ri)
2794                    print_req_type_helpers(ri)
2795                    if not ri.type_consistent:
2796                        print_rsp_type(ri)
2797                    print_wrapped_type(ri)
2798                    print_dump_prototype(ri)
2799                    cw.nl()
2800
2801                if op.has_ntf:
2802                    cw.p(f"/* {op.enum_name} - notify */")
2803                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2804                    if not ri.type_consistent:
2805                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2806                    print_wrapped_type(ri)
2807
2808            for op_name, op in parsed.ntfs.items():
2809                if 'event' in op:
2810                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2811                    cw.p(f"/* {op.enum_name} - event */")
2812                    print_rsp_type(ri)
2813                    cw.nl()
2814                    print_wrapped_type(ri)
2815            cw.nl()
2816        else:
2817            cw.p('/* Enums */')
2818            put_op_name(parsed, cw)
2819
2820            for name, const in parsed.consts.items():
2821                if isinstance(const, EnumSet):
2822                    put_enum_to_str(parsed, cw, const)
2823            cw.nl()
2824
2825            has_recursive_nests = False
2826            cw.p('/* Policies */')
2827            for struct in parsed.pure_nested_structs.values():
2828                if struct.recursive:
2829                    put_typol_fwd(cw, struct)
2830                    has_recursive_nests = True
2831            if has_recursive_nests:
2832                cw.nl()
2833            for name in parsed.pure_nested_structs:
2834                struct = Struct(parsed, name)
2835                put_typol(cw, struct)
2836            for name in parsed.root_sets:
2837                struct = Struct(parsed, name)
2838                put_typol(cw, struct)
2839
2840            cw.p('/* Common nested types */')
2841            if has_recursive_nests:
2842                for attr_set, struct in parsed.pure_nested_structs.items():
2843                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2844                    free_rsp_nested_prototype(ri)
2845                    if struct.request:
2846                        put_req_nested_prototype(ri, struct)
2847                    if struct.reply:
2848                        parse_rsp_nested_prototype(ri, struct)
2849                cw.nl()
2850            for attr_set, struct in parsed.pure_nested_structs.items():
2851                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2852
2853                free_rsp_nested(ri, struct)
2854                if struct.request:
2855                    put_req_nested(ri, struct)
2856                if struct.reply:
2857                    parse_rsp_nested(ri, struct)
2858
2859            for op_name, op in parsed.ops.items():
2860                cw.p(f"/* ============== {op.enum_name} ============== */")
2861                if 'do' in op and 'event' not in op:
2862                    cw.p(f"/* {op.enum_name} - do */")
2863                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2864                    print_req_free(ri)
2865                    print_rsp_free(ri)
2866                    parse_rsp_msg(ri)
2867                    print_req(ri)
2868                    cw.nl()
2869
2870                if 'dump' in op:
2871                    cw.p(f"/* {op.enum_name} - dump */")
2872                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2873                    if not ri.type_consistent:
2874                        parse_rsp_msg(ri, deref=True)
2875                    print_req_free(ri)
2876                    print_dump_type_free(ri)
2877                    print_dump(ri)
2878                    cw.nl()
2879
2880                if op.has_ntf:
2881                    cw.p(f"/* {op.enum_name} - notify */")
2882                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2883                    if not ri.type_consistent:
2884                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2885                    print_ntf_type_free(ri)
2886
2887            for op_name, op in parsed.ntfs.items():
2888                if 'event' in op:
2889                    cw.p(f"/* {op.enum_name} - event */")
2890
2891                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2892                    parse_rsp_msg(ri)
2893
2894                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2895                    print_ntf_type_free(ri)
2896            cw.nl()
2897            render_user_family(parsed, cw, False)
2898
2899    if args.header:
2900        cw.p(f'#endif /* {hdr_prot} */')
2901
2902
2903if __name__ == "__main__":
2904    main()
2905