1#!/usr/bin/env python3
2# kate: replace-tabs on; indent-width 4;
3
4from __future__ import unicode_literals
5
6'''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
7nanopb_version = "nanopb-0.4.4"
8
9import sys
10import re
11import codecs
12import copy
13import itertools
14import tempfile
15import shutil
16import os
17from functools import reduce
18
19try:
20    # Add some dummy imports to keep packaging tools happy.
21    import google, distutils.util # bbfreeze seems to need these
22    import pkg_resources # pyinstaller / protobuf 2.5 seem to need these
23    import proto.nanopb_pb2 as nanopb_pb2 # pyinstaller seems to need this
24    import pkg_resources.py2_warn
25except:
26    # Don't care, we will error out later if it is actually important.
27    pass
28
29try:
30    # Make sure grpc_tools gets included in binary package if it is available
31    import grpc_tools.protoc
32except:
33    pass
34
35try:
36    import google.protobuf.text_format as text_format
37    import google.protobuf.descriptor_pb2 as descriptor
38    import google.protobuf.compiler.plugin_pb2 as plugin_pb2
39    import google.protobuf.reflection as reflection
40    import google.protobuf.descriptor
41except:
42    sys.stderr.write('''
43         *************************************************************
44         *** Could not import the Google protobuf Python libraries ***
45         *** Try installing package 'python3-protobuf' or similar.  ***
46         *************************************************************
47    ''' + '\n')
48    raise
49
50try:
51    from .proto import nanopb_pb2
52    from .proto._utils import invoke_protoc
53except TypeError:
54    sys.stderr.write('''
55         ****************************************************************************
56         *** Got TypeError when importing the protocol definitions for generator. ***
57         *** This usually means that the protoc in your path doesn't match the    ***
58         *** Python protobuf library version.                                     ***
59         ***                                                                      ***
60         *** Please check the output of the following commands:                   ***
61         *** which protoc                                                         ***
62         *** protoc --version                                                     ***
63         *** python3 -c 'import google.protobuf; print(google.protobuf.__file__)'  ***
64         *** If you are not able to find the python protobuf version using the    ***
65         *** above command, use this command.                                     ***
66         *** pip freeze | grep -i protobuf                                        ***
67         ****************************************************************************
68    ''' + '\n')
69    raise
70except (ValueError, SystemError, ImportError):
71    # Probably invoked directly instead of via installed scripts.
72    import proto.nanopb_pb2 as nanopb_pb2
73    from proto._utils import invoke_protoc
74except:
75    sys.stderr.write('''
76         ********************************************************************
77         *** Failed to import the protocol definitions for generator.     ***
78         *** You have to run 'make' in the nanopb/generator/proto folder. ***
79         ********************************************************************
80    ''' + '\n')
81    raise
82
83try:
84    from tempfile import TemporaryDirectory
85except ImportError:
86    class TemporaryDirectory:
87        '''TemporaryDirectory fallback for Python 2'''
88        def __enter__(self):
89            self.dir = tempfile.mkdtemp()
90            return self.dir
91
92        def __exit__(self, *args):
93            shutil.rmtree(self.dir)
94
95# ---------------------------------------------------------------------------
96#                     Generation of single fields
97# ---------------------------------------------------------------------------
98
99import time
100import os.path
101
102# Values are tuple (c type, pb type, encoded size, data_size)
103FieldD = descriptor.FieldDescriptorProto
104datatypes = {
105    FieldD.TYPE_BOOL:       ('bool',     'BOOL',        1,  4),
106    FieldD.TYPE_DOUBLE:     ('double',   'DOUBLE',      8,  8),
107    FieldD.TYPE_FIXED32:    ('uint32_t', 'FIXED32',     4,  4),
108    FieldD.TYPE_FIXED64:    ('uint64_t', 'FIXED64',     8,  8),
109    FieldD.TYPE_FLOAT:      ('float',    'FLOAT',       4,  4),
110    FieldD.TYPE_INT32:      ('int32_t',  'INT32',      10,  4),
111    FieldD.TYPE_INT64:      ('int64_t',  'INT64',      10,  8),
112    FieldD.TYPE_SFIXED32:   ('int32_t',  'SFIXED32',    4,  4),
113    FieldD.TYPE_SFIXED64:   ('int64_t',  'SFIXED64',    8,  8),
114    FieldD.TYPE_SINT32:     ('int32_t',  'SINT32',      5,  4),
115    FieldD.TYPE_SINT64:     ('int64_t',  'SINT64',     10,  8),
116    FieldD.TYPE_UINT32:     ('uint32_t', 'UINT32',      5,  4),
117    FieldD.TYPE_UINT64:     ('uint64_t', 'UINT64',     10,  8),
118
119    # Integer size override options
120    (FieldD.TYPE_INT32,   nanopb_pb2.IS_8):   ('int8_t',   'INT32', 10,  1),
121    (FieldD.TYPE_INT32,  nanopb_pb2.IS_16):   ('int16_t',  'INT32', 10,  2),
122    (FieldD.TYPE_INT32,  nanopb_pb2.IS_32):   ('int32_t',  'INT32', 10,  4),
123    (FieldD.TYPE_INT32,  nanopb_pb2.IS_64):   ('int64_t',  'INT32', 10,  8),
124    (FieldD.TYPE_SINT32,  nanopb_pb2.IS_8):   ('int8_t',  'SINT32',  2,  1),
125    (FieldD.TYPE_SINT32, nanopb_pb2.IS_16):   ('int16_t', 'SINT32',  3,  2),
126    (FieldD.TYPE_SINT32, nanopb_pb2.IS_32):   ('int32_t', 'SINT32',  5,  4),
127    (FieldD.TYPE_SINT32, nanopb_pb2.IS_64):   ('int64_t', 'SINT32', 10,  8),
128    (FieldD.TYPE_UINT32,  nanopb_pb2.IS_8):   ('uint8_t', 'UINT32',  2,  1),
129    (FieldD.TYPE_UINT32, nanopb_pb2.IS_16):   ('uint16_t','UINT32',  3,  2),
130    (FieldD.TYPE_UINT32, nanopb_pb2.IS_32):   ('uint32_t','UINT32',  5,  4),
131    (FieldD.TYPE_UINT32, nanopb_pb2.IS_64):   ('uint64_t','UINT32', 10,  8),
132    (FieldD.TYPE_INT64,   nanopb_pb2.IS_8):   ('int8_t',   'INT64', 10,  1),
133    (FieldD.TYPE_INT64,  nanopb_pb2.IS_16):   ('int16_t',  'INT64', 10,  2),
134    (FieldD.TYPE_INT64,  nanopb_pb2.IS_32):   ('int32_t',  'INT64', 10,  4),
135    (FieldD.TYPE_INT64,  nanopb_pb2.IS_64):   ('int64_t',  'INT64', 10,  8),
136    (FieldD.TYPE_SINT64,  nanopb_pb2.IS_8):   ('int8_t',  'SINT64',  2,  1),
137    (FieldD.TYPE_SINT64, nanopb_pb2.IS_16):   ('int16_t', 'SINT64',  3,  2),
138    (FieldD.TYPE_SINT64, nanopb_pb2.IS_32):   ('int32_t', 'SINT64',  5,  4),
139    (FieldD.TYPE_SINT64, nanopb_pb2.IS_64):   ('int64_t', 'SINT64', 10,  8),
140    (FieldD.TYPE_UINT64,  nanopb_pb2.IS_8):   ('uint8_t', 'UINT64',  2,  1),
141    (FieldD.TYPE_UINT64, nanopb_pb2.IS_16):   ('uint16_t','UINT64',  3,  2),
142    (FieldD.TYPE_UINT64, nanopb_pb2.IS_32):   ('uint32_t','UINT64',  5,  4),
143    (FieldD.TYPE_UINT64, nanopb_pb2.IS_64):   ('uint64_t','UINT64', 10,  8),
144}
145
146class Globals:
147    '''Ugly global variables, should find a good way to pass these.'''
148    verbose_options = False
149    separate_options = []
150    matched_namemasks = set()
151    protoc_insertion_points = False
152
153# String types (for python 2 / python 3 compatibility)
154try:
155    strtypes = (unicode, str)
156    openmode_unicode = 'rU'
157except NameError:
158    strtypes = (str, )
159    openmode_unicode = 'r'
160
161
162class Names:
163    '''Keeps a set of nested names and formats them to C identifier.'''
164    def __init__(self, parts = ()):
165        if isinstance(parts, Names):
166            parts = parts.parts
167        elif isinstance(parts, strtypes):
168            parts = (parts,)
169        self.parts = tuple(parts)
170
171    def __str__(self):
172        return '_'.join(self.parts)
173
174    def __add__(self, other):
175        if isinstance(other, strtypes):
176            return Names(self.parts + (other,))
177        elif isinstance(other, Names):
178            return Names(self.parts + other.parts)
179        elif isinstance(other, tuple):
180            return Names(self.parts + other)
181        else:
182            raise ValueError("Name parts should be of type str")
183
184    def __eq__(self, other):
185        return isinstance(other, Names) and self.parts == other.parts
186
187    def __lt__(self, other):
188        if not isinstance(other, Names):
189            return NotImplemented
190        return str(self) < str(other)
191
192def names_from_type_name(type_name):
193    '''Parse Names() from FieldDescriptorProto type_name'''
194    if type_name[0] != '.':
195        raise NotImplementedError("Lookup of non-absolute type names is not supported")
196    return Names(type_name[1:].split('.'))
197
198def varint_max_size(max_value):
199    '''Returns the maximum number of bytes a varint can take when encoded.'''
200    if max_value < 0:
201        max_value = 2**64 - max_value
202    for i in range(1, 11):
203        if (max_value >> (i * 7)) == 0:
204            return i
205    raise ValueError("Value too large for varint: " + str(max_value))
206
207assert varint_max_size(-1) == 10
208assert varint_max_size(0) == 1
209assert varint_max_size(127) == 1
210assert varint_max_size(128) == 2
211
212class EncodedSize:
213    '''Class used to represent the encoded size of a field or a message.
214    Consists of a combination of symbolic sizes and integer sizes.'''
215    def __init__(self, value = 0, symbols = [], declarations = [], required_defines = []):
216        if isinstance(value, EncodedSize):
217            self.value = value.value
218            self.symbols = value.symbols
219            self.declarations = value.declarations
220            self.required_defines = value.required_defines
221        elif isinstance(value, strtypes + (Names,)):
222            self.symbols = [str(value)]
223            self.value = 0
224            self.declarations = []
225            self.required_defines = [str(value)]
226        else:
227            self.value = value
228            self.symbols = symbols
229            self.declarations = declarations
230            self.required_defines = required_defines
231
232    def __add__(self, other):
233        if isinstance(other, int):
234            return EncodedSize(self.value + other, self.symbols, self.declarations, self.required_defines)
235        elif isinstance(other, strtypes + (Names,)):
236            return EncodedSize(self.value, self.symbols + [str(other)], self.declarations, self.required_defines + [str(other)])
237        elif isinstance(other, EncodedSize):
238            return EncodedSize(self.value + other.value, self.symbols + other.symbols,
239                               self.declarations + other.declarations, self.required_defines + other.required_defines)
240        else:
241            raise ValueError("Cannot add size: " + repr(other))
242
243    def __mul__(self, other):
244        if isinstance(other, int):
245            return EncodedSize(self.value * other, [str(other) + '*' + s for s in self.symbols],
246                               self.declarations, self.required_defines)
247        else:
248            raise ValueError("Cannot multiply size: " + repr(other))
249
250    def __str__(self):
251        if not self.symbols:
252            return str(self.value)
253        else:
254            return '(' + str(self.value) + ' + ' + ' + '.join(self.symbols) + ')'
255
256    def get_declarations(self):
257        '''Get any declarations that must appear alongside this encoded size definition,
258        such as helper union {} types.'''
259        return '\n'.join(self.declarations)
260
261    def get_cpp_guard(self, local_defines):
262        '''Get an #if preprocessor statement listing all defines that are required for this definition.'''
263        needed = [x for x in self.required_defines if x not in local_defines]
264        if needed:
265            return '#if ' + ' && '.join(['defined(%s)' % x for x in needed]) + "\n"
266        else:
267            return ''
268
269    def upperlimit(self):
270        if not self.symbols:
271            return self.value
272        else:
273            return 2**32 - 1
274
275class Enum:
276    def __init__(self, names, desc, enum_options):
277        '''desc is EnumDescriptorProto'''
278
279        self.options = enum_options
280        self.names = names
281
282        # by definition, `names` include this enum's name
283        base_name = Names(names.parts[:-1])
284
285        if enum_options.long_names:
286            self.values = [(names + x.name, x.number) for x in desc.value]
287        else:
288            self.values = [(base_name + x.name, x.number) for x in desc.value]
289
290        self.value_longnames = [self.names + x.name for x in desc.value]
291        self.packed = enum_options.packed_enum
292
293    def has_negative(self):
294        for n, v in self.values:
295            if v < 0:
296                return True
297        return False
298
299    def encoded_size(self):
300        return max([varint_max_size(v) for n,v in self.values])
301
302    def __str__(self):
303        result = 'typedef enum _%s {\n' % self.names
304        result += ',\n'.join(["    %s = %d" % x for x in self.values])
305        result += '\n}'
306
307        if self.packed:
308            result += ' pb_packed'
309
310        result += ' %s;' % self.names
311        return result
312
313    def auxiliary_defines(self):
314        # sort the enum by value
315        sorted_values = sorted(self.values, key = lambda x: (x[1], x[0]))
316        result  = '#define _%s_MIN %s\n' % (self.names, sorted_values[0][0])
317        result += '#define _%s_MAX %s\n' % (self.names, sorted_values[-1][0])
318        result += '#define _%s_ARRAYSIZE ((%s)(%s+1))\n' % (self.names, self.names, sorted_values[-1][0])
319
320        if not self.options.long_names:
321            # Define the long names always so that enum value references
322            # from other files work properly.
323            for i, x in enumerate(self.values):
324                result += '#define %s %s\n' % (self.value_longnames[i], x[0])
325
326        if self.options.enum_to_string:
327            result += 'const char *%s_name(%s v);\n' % (self.names, self.names)
328
329        return result
330
331    def enum_to_string_definition(self):
332        if not self.options.enum_to_string:
333            return ""
334
335        result = 'const char *%s_name(%s v) {\n' % (self.names, self.names)
336        result += '    switch (v) {\n'
337
338        for ((enumname, _), strname) in zip(self.values, self.value_longnames):
339            # Strip off the leading type name from the string value.
340            strval = str(strname)[len(str(self.names)) + 1:]
341            result += '        case %s: return "%s";\n' % (enumname, strval)
342
343        result += '    }\n'
344        result += '    return "unknown";\n'
345        result += '}\n'
346
347        return result
348
349class FieldMaxSize:
350    def __init__(self, worst = 0, checks = [], field_name = 'undefined'):
351        if isinstance(worst, list):
352            self.worst = max(i for i in worst if i is not None)
353        else:
354            self.worst = worst
355
356        self.worst_field = field_name
357        self.checks = list(checks)
358
359    def extend(self, extend, field_name = None):
360        self.worst = max(self.worst, extend.worst)
361
362        if self.worst == extend.worst:
363            self.worst_field = extend.worst_field
364
365        self.checks.extend(extend.checks)
366
367class Field:
368    macro_x_param = 'X'
369    macro_a_param = 'a'
370
371    def __init__(self, struct_name, desc, field_options):
372        '''desc is FieldDescriptorProto'''
373        self.tag = desc.number
374        self.struct_name = struct_name
375        self.union_name = None
376        self.name = desc.name
377        self.default = None
378        self.max_size = None
379        self.max_count = None
380        self.array_decl = ""
381        self.enc_size = None
382        self.data_item_size = None
383        self.ctype = None
384        self.fixed_count = False
385        self.callback_datatype = field_options.callback_datatype
386        self.math_include_required = False
387        self.sort_by_tag = field_options.sort_by_tag
388
389        if field_options.type == nanopb_pb2.FT_INLINE:
390            # Before nanopb-0.3.8, fixed length bytes arrays were specified
391            # by setting type to FT_INLINE. But to handle pointer typed fields,
392            # it makes sense to have it as a separate option.
393            field_options.type = nanopb_pb2.FT_STATIC
394            field_options.fixed_length = True
395
396        # Parse field options
397        if field_options.HasField("max_size"):
398            self.max_size = field_options.max_size
399
400        self.default_has = field_options.default_has
401
402        if desc.type == FieldD.TYPE_STRING and field_options.HasField("max_length"):
403            # max_length overrides max_size for strings
404            self.max_size = field_options.max_length + 1
405
406        if field_options.HasField("max_count"):
407            self.max_count = field_options.max_count
408
409        if desc.HasField('default_value'):
410            self.default = desc.default_value
411
412        # Check field rules, i.e. required/optional/repeated.
413        can_be_static = True
414        if desc.label == FieldD.LABEL_REPEATED:
415            self.rules = 'REPEATED'
416            if self.max_count is None:
417                can_be_static = False
418            else:
419                self.array_decl = '[%d]' % self.max_count
420                if field_options.fixed_count:
421                  self.rules = 'FIXARRAY'
422
423        elif field_options.proto3:
424            if desc.type == FieldD.TYPE_MESSAGE and not field_options.proto3_singular_msgs:
425                # In most other protobuf libraries proto3 submessages have
426                # "null" status. For nanopb, that is implemented as has_ field.
427                self.rules = 'OPTIONAL'
428            elif hasattr(desc, "proto3_optional") and desc.proto3_optional:
429                # Protobuf 3.12 introduced optional fields for proto3 syntax
430                self.rules = 'OPTIONAL'
431            else:
432                # Proto3 singular fields (without has_field)
433                self.rules = 'SINGULAR'
434        elif desc.label == FieldD.LABEL_REQUIRED:
435            self.rules = 'REQUIRED'
436        elif desc.label == FieldD.LABEL_OPTIONAL:
437            self.rules = 'OPTIONAL'
438        else:
439            raise NotImplementedError(desc.label)
440
441        # Check if the field can be implemented with static allocation
442        # i.e. whether the data size is known.
443        if desc.type == FieldD.TYPE_STRING and self.max_size is None:
444            can_be_static = False
445
446        if desc.type == FieldD.TYPE_BYTES and self.max_size is None:
447            can_be_static = False
448
449        # Decide how the field data will be allocated
450        if field_options.type == nanopb_pb2.FT_DEFAULT:
451            if can_be_static:
452                field_options.type = nanopb_pb2.FT_STATIC
453            else:
454                field_options.type = nanopb_pb2.FT_CALLBACK
455
456        if field_options.type == nanopb_pb2.FT_STATIC and not can_be_static:
457            raise Exception("Field '%s' is defined as static, but max_size or "
458                            "max_count is not given." % self.name)
459
460        if field_options.fixed_count and self.max_count is None:
461            raise Exception("Field '%s' is defined as fixed count, "
462                            "but max_count is not given." % self.name)
463
464        if field_options.type == nanopb_pb2.FT_STATIC:
465            self.allocation = 'STATIC'
466        elif field_options.type == nanopb_pb2.FT_POINTER:
467            self.allocation = 'POINTER'
468        elif field_options.type == nanopb_pb2.FT_CALLBACK:
469            self.allocation = 'CALLBACK'
470        else:
471            raise NotImplementedError(field_options.type)
472
473        if field_options.HasField("type_override"):
474            desc.type = field_options.type_override
475
476        # Decide the C data type to use in the struct.
477        if desc.type in datatypes:
478            self.ctype, self.pbtype, self.enc_size, self.data_item_size = datatypes[desc.type]
479
480            # Override the field size if user wants to use smaller integers
481            if (desc.type, field_options.int_size) in datatypes:
482                self.ctype, self.pbtype, self.enc_size, self.data_item_size = datatypes[(desc.type, field_options.int_size)]
483        elif desc.type == FieldD.TYPE_ENUM:
484            self.pbtype = 'ENUM'
485            self.data_item_size = 4
486            self.ctype = names_from_type_name(desc.type_name)
487            if self.default is not None:
488                self.default = self.ctype + self.default
489            self.enc_size = None # Needs to be filled in when enum values are known
490        elif desc.type == FieldD.TYPE_STRING:
491            self.pbtype = 'STRING'
492            self.ctype = 'char'
493            if self.allocation == 'STATIC':
494                self.ctype = 'char'
495                self.array_decl += '[%d]' % self.max_size
496                # -1 because of null terminator. Both pb_encode and pb_decode
497                # check the presence of it.
498                self.enc_size = varint_max_size(self.max_size) + self.max_size - 1
499        elif desc.type == FieldD.TYPE_BYTES:
500            if field_options.fixed_length:
501                self.pbtype = 'FIXED_LENGTH_BYTES'
502
503                if self.max_size is None:
504                    raise Exception("Field '%s' is defined as fixed length, "
505                                    "but max_size is not given." % self.name)
506
507                self.enc_size = varint_max_size(self.max_size) + self.max_size
508                self.ctype = 'pb_byte_t'
509                self.array_decl += '[%d]' % self.max_size
510            else:
511                self.pbtype = 'BYTES'
512                self.ctype = 'pb_bytes_array_t'
513                if self.allocation == 'STATIC':
514                    self.ctype = self.struct_name + self.name + 't'
515                    self.enc_size = varint_max_size(self.max_size) + self.max_size
516        elif desc.type == FieldD.TYPE_MESSAGE:
517            self.pbtype = 'MESSAGE'
518            self.ctype = self.submsgname = names_from_type_name(desc.type_name)
519            self.enc_size = None # Needs to be filled in after the message type is available
520            if field_options.submsg_callback:
521                self.pbtype = 'MSG_W_CB'
522        else:
523            raise NotImplementedError(desc.type)
524
525        if self.default and self.pbtype in ['FLOAT', 'DOUBLE']:
526            if 'inf' in self.default or 'nan' in self.default:
527                self.math_include_required = True
528
529    def __lt__(self, other):
530        return self.tag < other.tag
531
532    def __str__(self):
533        result = ''
534        if self.allocation == 'POINTER':
535            if self.rules == 'REPEATED':
536                if self.pbtype == 'MSG_W_CB':
537                    result += '    pb_callback_t cb_' + self.name + ';\n'
538                result += '    pb_size_t ' + self.name + '_count;\n'
539
540            if self.pbtype in ['MESSAGE', 'MSG_W_CB']:
541                # Use struct definition, so recursive submessages are possible
542                result += '    struct _%s *%s;' % (self.ctype, self.name)
543            elif self.pbtype == 'FIXED_LENGTH_BYTES' or self.rules == 'FIXARRAY':
544                # Pointer to fixed size array
545                result += '    %s (*%s)%s;' % (self.ctype, self.name, self.array_decl)
546            elif self.rules in ['REPEATED', 'FIXARRAY'] and self.pbtype in ['STRING', 'BYTES']:
547                # String/bytes arrays need to be defined as pointers to pointers
548                result += '    %s **%s;' % (self.ctype, self.name)
549            else:
550                result += '    %s *%s;' % (self.ctype, self.name)
551        elif self.allocation == 'CALLBACK':
552            result += '    %s %s;' % (self.callback_datatype, self.name)
553        else:
554            if self.pbtype == 'MSG_W_CB' and self.rules in ['OPTIONAL', 'REPEATED']:
555                result += '    pb_callback_t cb_' + self.name + ';\n'
556
557            if self.rules == 'OPTIONAL':
558                result += '    bool has_' + self.name + ';\n'
559            elif self.rules == 'REPEATED':
560                result += '    pb_size_t ' + self.name + '_count;\n'
561            result += '    %s %s%s;' % (self.ctype, self.name, self.array_decl)
562        return result
563
564    def types(self):
565        '''Return definitions for any special types this field might need.'''
566        if self.pbtype == 'BYTES' and self.allocation == 'STATIC':
567            result = 'typedef PB_BYTES_ARRAY_T(%d) %s;\n' % (self.max_size, self.ctype)
568        else:
569            result = ''
570        return result
571
572    def get_dependencies(self):
573        '''Get list of type names used by this field.'''
574        if self.allocation == 'STATIC':
575            return [str(self.ctype)]
576        else:
577            return []
578
579    def get_initializer(self, null_init, inner_init_only = False):
580        '''Return literal expression for this field's default value.
581        null_init: If True, initialize to a 0 value instead of default from .proto
582        inner_init_only: If True, exclude initialization for any count/has fields
583        '''
584
585        inner_init = None
586        if self.pbtype in ['MESSAGE', 'MSG_W_CB']:
587            if null_init:
588                inner_init = '%s_init_zero' % self.ctype
589            else:
590                inner_init = '%s_init_default' % self.ctype
591        elif self.default is None or null_init:
592            if self.pbtype == 'STRING':
593                inner_init = '""'
594            elif self.pbtype == 'BYTES':
595                inner_init = '{0, {0}}'
596            elif self.pbtype == 'FIXED_LENGTH_BYTES':
597                inner_init = '{0}'
598            elif self.pbtype in ('ENUM', 'UENUM'):
599                inner_init = '_%s_MIN' % self.ctype
600            else:
601                inner_init = '0'
602        else:
603            if self.pbtype == 'STRING':
604                data = codecs.escape_encode(self.default.encode('utf-8'))[0]
605                inner_init = '"' + data.decode('ascii') + '"'
606            elif self.pbtype == 'BYTES':
607                data = codecs.escape_decode(self.default)[0]
608                data = ["0x%02x" % c for c in bytearray(data)]
609                if len(data) == 0:
610                    inner_init = '{0, {0}}'
611                else:
612                    inner_init = '{%d, {%s}}' % (len(data), ','.join(data))
613            elif self.pbtype == 'FIXED_LENGTH_BYTES':
614                data = codecs.escape_decode(self.default)[0]
615                data = ["0x%02x" % c for c in bytearray(data)]
616                if len(data) == 0:
617                    inner_init = '{0}'
618                else:
619                    inner_init = '{%s}' % ','.join(data)
620            elif self.pbtype in ['FIXED32', 'UINT32']:
621                inner_init = str(self.default) + 'u'
622            elif self.pbtype in ['FIXED64', 'UINT64']:
623                inner_init = str(self.default) + 'ull'
624            elif self.pbtype in ['SFIXED64', 'INT64']:
625                inner_init = str(self.default) + 'll'
626            elif self.pbtype in ['FLOAT', 'DOUBLE']:
627                inner_init = str(self.default)
628                if 'inf' in inner_init:
629                    inner_init = inner_init.replace('inf', 'INFINITY')
630                elif 'nan' in inner_init:
631                    inner_init = inner_init.replace('nan', 'NAN')
632                elif (not '.' in inner_init) and self.pbtype == 'FLOAT':
633                    inner_init += '.0f'
634                elif self.pbtype == 'FLOAT':
635                    inner_init += 'f'
636            else:
637                inner_init = str(self.default)
638
639        if inner_init_only:
640            return inner_init
641
642        outer_init = None
643        if self.allocation == 'STATIC':
644            if self.rules == 'REPEATED':
645                outer_init = '0, {' + ', '.join([inner_init] * self.max_count) + '}'
646            elif self.rules == 'FIXARRAY':
647                outer_init = '{' + ', '.join([inner_init] * self.max_count) + '}'
648            elif self.rules == 'OPTIONAL':
649                if null_init or not self.default_has:
650                    outer_init = 'false, ' + inner_init
651                else:
652                    outer_init = 'true, ' + inner_init
653            else:
654                outer_init = inner_init
655        elif self.allocation == 'POINTER':
656            if self.rules == 'REPEATED':
657                outer_init = '0, NULL'
658            else:
659                outer_init = 'NULL'
660        elif self.allocation == 'CALLBACK':
661            if self.pbtype == 'EXTENSION':
662                outer_init = 'NULL'
663            else:
664                outer_init = '{{NULL}, NULL}'
665
666        if self.pbtype == 'MSG_W_CB' and self.rules in ['REPEATED', 'OPTIONAL']:
667            outer_init = '{{NULL}, NULL}, ' + outer_init
668
669        return outer_init
670
671    def tags(self):
672        '''Return the #define for the tag number of this field.'''
673        identifier = '%s_%s_tag' % (self.struct_name, self.name)
674        return '#define %-40s %d\n' % (identifier, self.tag)
675
676    def fieldlist(self):
677        '''Return the FIELDLIST macro entry for this field.
678        Format is: X(a, ATYPE, HTYPE, LTYPE, field_name, tag)
679        '''
680        name = self.name
681
682        if self.rules == "ONEOF":
683          # For oneofs, make a tuple of the union name, union member name,
684          # and the name inside the parent struct.
685          if not self.anonymous:
686            name = '(%s,%s,%s)' % (self.union_name, self.name, self.union_name + '.' + self.name)
687          else:
688            name = '(%s,%s,%s)' % (self.union_name, self.name, self.name)
689
690        return '%s(%s, %-9s %-9s %-9s %-16s %3d)' % (self.macro_x_param,
691                                                     self.macro_a_param,
692                                                     self.allocation + ',',
693                                                     self.rules + ',',
694                                                     self.pbtype + ',',
695                                                     name + ',',
696                                                     self.tag)
697
698    def data_size(self, dependencies):
699        '''Return estimated size of this field in the C struct.
700        This is used to try to automatically pick right descriptor size.
701        If the estimate is wrong, it will result in compile time error and
702        user having to specify descriptor_width option.
703        '''
704        if self.allocation == 'POINTER' or self.pbtype == 'EXTENSION':
705            size = 8
706        elif self.allocation == 'CALLBACK':
707            size = 16
708        elif self.pbtype in ['MESSAGE', 'MSG_W_CB']:
709            if str(self.submsgname) in dependencies:
710                other_dependencies = dict(x for x in dependencies.items() if x[0] != str(self.struct_name))
711                size = dependencies[str(self.submsgname)].data_size(other_dependencies)
712            else:
713                size = 256 # Message is in other file, this is reasonable guess for most cases
714
715            if self.pbtype == 'MSG_W_CB':
716                size += 16
717        elif self.pbtype in ['STRING', 'FIXED_LENGTH_BYTES']:
718            size = self.max_size
719        elif self.pbtype == 'BYTES':
720            size = self.max_size + 4
721        elif self.data_item_size is not None:
722            size = self.data_item_size
723        else:
724            raise Exception("Unhandled field type: %s" % self.pbtype)
725
726        if self.rules in ['REPEATED', 'FIXARRAY'] and self.allocation == 'STATIC':
727            size *= self.max_count
728
729        if self.rules not in ('REQUIRED', 'SINGULAR'):
730            size += 4
731
732        if size % 4 != 0:
733            # Estimate how much alignment requirements will increase the size.
734            size += 4 - (size % 4)
735
736        return size
737
738    def encoded_size(self, dependencies):
739        '''Return the maximum size that this field can take when encoded,
740        including the field tag. If the size cannot be determined, returns
741        None.'''
742
743        if self.allocation != 'STATIC':
744            return None
745
746        if self.pbtype in ['MESSAGE', 'MSG_W_CB']:
747            encsize = None
748            if str(self.submsgname) in dependencies:
749                submsg = dependencies[str(self.submsgname)]
750                other_dependencies = dict(x for x in dependencies.items() if x[0] != str(self.struct_name))
751                encsize = submsg.encoded_size(other_dependencies)
752                if encsize is not None:
753                    # Include submessage length prefix
754                    encsize += varint_max_size(encsize.upperlimit())
755                else:
756                    my_msg = dependencies.get(str(self.struct_name))
757                    if my_msg and submsg.protofile == my_msg.protofile:
758                        # The dependency is from the same file and size cannot be
759                        # determined for it, thus we know it will not be possible
760                        # in runtime either.
761                        return None
762
763            if encsize is None:
764                # Submessage or its size cannot be found.
765                # This can occur if submessage is defined in different
766                # file, and it or its .options could not be found.
767                # Instead of direct numeric value, reference the size that
768                # has been #defined in the other file.
769                encsize = EncodedSize(self.submsgname + 'size')
770
771                # We will have to make a conservative assumption on the length
772                # prefix size, though.
773                encsize += 5
774
775        elif self.pbtype in ['ENUM', 'UENUM']:
776            if str(self.ctype) in dependencies:
777                enumtype = dependencies[str(self.ctype)]
778                encsize = enumtype.encoded_size()
779            else:
780                # Conservative assumption
781                encsize = 10
782
783        elif self.enc_size is None:
784            raise RuntimeError("Could not determine encoded size for %s.%s"
785                               % (self.struct_name, self.name))
786        else:
787            encsize = EncodedSize(self.enc_size)
788
789        encsize += varint_max_size(self.tag << 3) # Tag + wire type
790
791        if self.rules in ['REPEATED', 'FIXARRAY']:
792            # Decoders must be always able to handle unpacked arrays.
793            # Therefore we have to reserve space for it, even though
794            # we emit packed arrays ourselves. For length of 1, packed
795            # arrays are larger however so we need to add allowance
796            # for the length byte.
797            encsize *= self.max_count
798
799            if self.max_count == 1:
800                encsize += 1
801
802        return encsize
803
804    def has_callbacks(self):
805        return self.allocation == 'CALLBACK'
806
807    def requires_custom_field_callback(self):
808        return self.allocation == 'CALLBACK' and self.callback_datatype != 'pb_callback_t'
809
810class ExtensionRange(Field):
811    def __init__(self, struct_name, range_start, field_options):
812        '''Implements a special pb_extension_t* field in an extensible message
813        structure. The range_start signifies the index at which the extensions
814        start. Not necessarily all tags above this are extensions, it is merely
815        a speed optimization.
816        '''
817        self.tag = range_start
818        self.struct_name = struct_name
819        self.name = 'extensions'
820        self.pbtype = 'EXTENSION'
821        self.rules = 'OPTIONAL'
822        self.allocation = 'CALLBACK'
823        self.ctype = 'pb_extension_t'
824        self.array_decl = ''
825        self.default = None
826        self.max_size = 0
827        self.max_count = 0
828        self.data_item_size = 0
829        self.fixed_count = False
830        self.callback_datatype = 'pb_extension_t*'
831
832    def requires_custom_field_callback(self):
833        return False
834
835    def __str__(self):
836        return '    pb_extension_t *extensions;'
837
838    def types(self):
839        return ''
840
841    def tags(self):
842        return ''
843
844    def encoded_size(self, dependencies):
845        # We exclude extensions from the count, because they cannot be known
846        # until runtime. Other option would be to return None here, but this
847        # way the value remains useful if extensions are not used.
848        return EncodedSize(0)
849
850class ExtensionField(Field):
851    def __init__(self, fullname, desc, field_options):
852        self.fullname = fullname
853        self.extendee_name = names_from_type_name(desc.extendee)
854        Field.__init__(self, self.fullname + "extmsg", desc, field_options)
855
856        if self.rules != 'OPTIONAL':
857            self.skip = True
858        else:
859            self.skip = False
860            self.rules = 'REQUIRED' # We don't really want the has_field for extensions
861            self.msg = Message(self.fullname + "extmsg", None, field_options)
862            self.msg.fields.append(self)
863
864    def tags(self):
865        '''Return the #define for the tag number of this field.'''
866        identifier = '%s_tag' % self.fullname
867        return '#define %-40s %d\n' % (identifier, self.tag)
868
869    def extension_decl(self):
870        '''Declaration of the extension type in the .pb.h file'''
871        if self.skip:
872            msg = '/* Extension field %s was skipped because only "optional"\n' % self.fullname
873            msg +='   type of extension fields is currently supported. */\n'
874            return msg
875
876        return ('extern const pb_extension_type_t %s; /* field type: %s */\n' %
877            (self.fullname, str(self).strip()))
878
879    def extension_def(self, dependencies):
880        '''Definition of the extension type in the .pb.c file'''
881
882        if self.skip:
883            return ''
884
885        result = "/* Definition for extension field %s */\n" % self.fullname
886        result += str(self.msg)
887        result += self.msg.fields_declaration(dependencies)
888        result += 'pb_byte_t %s_default[] = {0x00};\n' % self.msg.name
889        result += self.msg.fields_definition(dependencies)
890        result += 'const pb_extension_type_t %s = {\n' % self.fullname
891        result += '    NULL,\n'
892        result += '    NULL,\n'
893        result += '    &%s_msg\n' % self.msg.name
894        result += '};\n'
895        return result
896
897
898# ---------------------------------------------------------------------------
899#                   Generation of oneofs (unions)
900# ---------------------------------------------------------------------------
901
902class OneOf(Field):
903    def __init__(self, struct_name, oneof_desc, oneof_options):
904        self.struct_name = struct_name
905        self.name = oneof_desc.name
906        self.ctype = 'union'
907        self.pbtype = 'oneof'
908        self.fields = []
909        self.allocation = 'ONEOF'
910        self.default = None
911        self.rules = 'ONEOF'
912        self.anonymous = oneof_options.anonymous_oneof
913        self.sort_by_tag = oneof_options.sort_by_tag
914        self.has_msg_cb = False
915
916    def add_field(self, field):
917        field.union_name = self.name
918        field.rules = 'ONEOF'
919        field.anonymous = self.anonymous
920        self.fields.append(field)
921
922        if self.sort_by_tag:
923            self.fields.sort()
924
925        if field.pbtype == 'MSG_W_CB':
926            self.has_msg_cb = True
927
928        # Sort by the lowest tag number inside union
929        self.tag = min([f.tag for f in self.fields])
930
931    def __str__(self):
932        result = ''
933        if self.fields:
934            if self.has_msg_cb:
935                result += '    pb_callback_t cb_' + self.name + ';\n'
936
937            result += '    pb_size_t which_' + self.name + ";\n"
938            result += '    union {\n'
939            for f in self.fields:
940                result += '    ' + str(f).replace('\n', '\n    ') + '\n'
941            if self.anonymous:
942                result += '    };'
943            else:
944                result += '    } ' + self.name + ';'
945        return result
946
947    def types(self):
948        return ''.join([f.types() for f in self.fields])
949
950    def get_dependencies(self):
951        deps = []
952        for f in self.fields:
953            deps += f.get_dependencies()
954        return deps
955
956    def get_initializer(self, null_init):
957        if self.has_msg_cb:
958            return '{{NULL}, NULL}, 0, {' + self.fields[0].get_initializer(null_init) + '}'
959        else:
960            return '0, {' + self.fields[0].get_initializer(null_init) + '}'
961
962    def tags(self):
963        return ''.join([f.tags() for f in self.fields])
964
965    def fieldlist(self):
966        return ' \\\n'.join(field.fieldlist() for field in self.fields)
967
968    def data_size(self, dependencies):
969        return max(f.data_size(dependencies) for f in self.fields)
970
971    def encoded_size(self, dependencies):
972        '''Returns the size of the largest oneof field.'''
973        largest = 0
974        dynamic_sizes = {}
975        for f in self.fields:
976            size = EncodedSize(f.encoded_size(dependencies))
977            if size is None or size.value is None:
978                return None
979            elif size.symbols:
980                dynamic_sizes[f.tag] = size
981            elif size.value > largest:
982                largest = size.value
983
984        if not dynamic_sizes:
985            # Simple case, all sizes were known at generator time
986            return EncodedSize(largest)
987
988        if largest > 0:
989            # Some sizes were known, some were not
990            dynamic_sizes[0] = EncodedSize(largest)
991
992        # Couldn't find size for submessage at generation time,
993        # have to rely on macro resolution at compile time.
994        if len(dynamic_sizes) == 1:
995            # Only one symbol was needed
996            return list(dynamic_sizes.values())[0]
997        else:
998            # Use sizeof(union{}) construct to find the maximum size of
999            # submessages.
1000            union_name = "%s_%s_size_union" % (self.struct_name, self.name)
1001            union_def = 'union %s {%s};\n' % (union_name, ' '.join('char f%d[%s];' % (k, s) for k,s in dynamic_sizes.items()))
1002            required_defs = list(itertools.chain.from_iterable(s.required_defines for k,s in dynamic_sizes.items()))
1003            return EncodedSize(0, ['sizeof(%s)' % union_name], [union_def], required_defs)
1004
1005    def has_callbacks(self):
1006        return bool([f for f in self.fields if f.has_callbacks()])
1007
1008    def requires_custom_field_callback(self):
1009        return bool([f for f in self.fields if f.requires_custom_field_callback()])
1010
1011# ---------------------------------------------------------------------------
1012#                   Generation of messages (structures)
1013# ---------------------------------------------------------------------------
1014
1015
1016class Message:
1017    def __init__(self, names, desc, message_options):
1018        self.name = names
1019        self.fields = []
1020        self.oneofs = {}
1021        self.desc = desc
1022        self.math_include_required = False
1023        self.packed = message_options.packed_struct
1024        self.descriptorsize = message_options.descriptorsize
1025
1026        if message_options.msgid:
1027            self.msgid = message_options.msgid
1028
1029        if desc is not None:
1030            self.load_fields(desc, message_options)
1031
1032        self.callback_function = message_options.callback_function
1033        if not message_options.HasField('callback_function'):
1034            # Automatically assign a per-message callback if any field has
1035            # a special callback_datatype.
1036            for field in self.fields:
1037                if field.requires_custom_field_callback():
1038                    self.callback_function = "%s_callback" % self.name
1039                    break
1040
1041    def load_fields(self, desc, message_options):
1042        '''Load field list from DescriptorProto'''
1043
1044        no_unions = []
1045
1046        if hasattr(desc, 'oneof_decl'):
1047            for i, f in enumerate(desc.oneof_decl):
1048                oneof_options = get_nanopb_suboptions(desc, message_options, self.name + f.name)
1049                if oneof_options.no_unions:
1050                    no_unions.append(i) # No union, but add fields normally
1051                elif oneof_options.type == nanopb_pb2.FT_IGNORE:
1052                    pass # No union and skip fields also
1053                else:
1054                    oneof = OneOf(self.name, f, oneof_options)
1055                    self.oneofs[i] = oneof
1056        else:
1057            sys.stderr.write('Note: This Python protobuf library has no OneOf support\n')
1058
1059        for f in desc.field:
1060            field_options = get_nanopb_suboptions(f, message_options, self.name + f.name)
1061            if field_options.type == nanopb_pb2.FT_IGNORE:
1062                continue
1063
1064            if field_options.descriptorsize > self.descriptorsize:
1065                self.descriptorsize = field_options.descriptorsize
1066
1067            field = Field(self.name, f, field_options)
1068            if hasattr(f, 'oneof_index') and f.HasField('oneof_index'):
1069                if hasattr(f, 'proto3_optional') and f.proto3_optional:
1070                    no_unions.append(f.oneof_index)
1071
1072                if f.oneof_index in no_unions:
1073                    self.fields.append(field)
1074                elif f.oneof_index in self.oneofs:
1075                    self.oneofs[f.oneof_index].add_field(field)
1076
1077                    if self.oneofs[f.oneof_index] not in self.fields:
1078                        self.fields.append(self.oneofs[f.oneof_index])
1079            else:
1080                self.fields.append(field)
1081
1082            if field.math_include_required:
1083                self.math_include_required = True
1084
1085        if len(desc.extension_range) > 0:
1086            field_options = get_nanopb_suboptions(desc, message_options, self.name + 'extensions')
1087            range_start = min([r.start for r in desc.extension_range])
1088            if field_options.type != nanopb_pb2.FT_IGNORE:
1089                self.fields.append(ExtensionRange(self.name, range_start, field_options))
1090
1091        if message_options.sort_by_tag:
1092            self.fields.sort()
1093
1094    def get_dependencies(self):
1095        '''Get list of type names that this structure refers to.'''
1096        deps = []
1097        for f in self.fields:
1098            deps += f.get_dependencies()
1099        return deps
1100
1101    def __str__(self):
1102        result = 'typedef struct _%s {\n' % self.name
1103
1104        if not self.fields:
1105            # Empty structs are not allowed in C standard.
1106            # Therefore add a dummy field if an empty message occurs.
1107            result += '    char dummy_field;'
1108
1109        result += '\n'.join([str(f) for f in self.fields])
1110
1111        if Globals.protoc_insertion_points:
1112            result += '\n/* @@protoc_insertion_point(struct:%s) */' % self.name
1113
1114        result += '\n}'
1115
1116        if self.packed:
1117            result += ' pb_packed'
1118
1119        result += ' %s;' % self.name
1120
1121        if self.packed:
1122            result = 'PB_PACKED_STRUCT_START\n' + result
1123            result += '\nPB_PACKED_STRUCT_END'
1124
1125        return result + '\n'
1126
1127    def types(self):
1128        return ''.join([f.types() for f in self.fields])
1129
1130    def get_initializer(self, null_init):
1131        if not self.fields:
1132            return '{0}'
1133
1134        parts = []
1135        for field in self.fields:
1136            parts.append(field.get_initializer(null_init))
1137        return '{' + ', '.join(parts) + '}'
1138
1139    def count_required_fields(self):
1140        '''Returns number of required fields inside this message'''
1141        count = 0
1142        for f in self.fields:
1143            if not isinstance(f, OneOf):
1144                if f.rules == 'REQUIRED':
1145                    count += 1
1146        return count
1147
1148    def all_fields(self):
1149        '''Iterate over all fields in this message, including nested OneOfs.'''
1150        for f in self.fields:
1151            if isinstance(f, OneOf):
1152                for f2 in f.fields:
1153                    yield f2
1154            else:
1155                yield f
1156
1157
1158    def field_for_tag(self, tag):
1159        '''Given a tag number, return the Field instance.'''
1160        for field in self.all_fields():
1161            if field.tag == tag:
1162                return field
1163        return None
1164
1165    def count_all_fields(self):
1166        '''Count the total number of fields in this message.'''
1167        count = 0
1168        for f in self.fields:
1169            if isinstance(f, OneOf):
1170                count += len(f.fields)
1171            else:
1172                count += 1
1173        return count
1174
1175    def fields_declaration(self, dependencies):
1176        '''Return X-macro declaration of all fields in this message.'''
1177        Field.macro_x_param = 'X'
1178        Field.macro_a_param = 'a'
1179        while any(field.name == Field.macro_x_param for field in self.all_fields()):
1180            Field.macro_x_param += '_'
1181        while any(field.name == Field.macro_a_param for field in self.all_fields()):
1182            Field.macro_a_param += '_'
1183
1184        result = '#define %s_FIELDLIST(%s, %s) \\\n' % (self.name,
1185                                                        Field.macro_x_param,
1186                                                        Field.macro_a_param)
1187        result += ' \\\n'.join(field.fieldlist() for field in sorted(self.fields))
1188        result += '\n'
1189
1190        has_callbacks = bool([f for f in self.fields if f.has_callbacks()])
1191        if has_callbacks:
1192            if self.callback_function != 'pb_default_field_callback':
1193                result += "extern bool %s(pb_istream_t *istream, pb_ostream_t *ostream, const pb_field_t *field);\n" % self.callback_function
1194            result += "#define %s_CALLBACK %s\n" % (self.name, self.callback_function)
1195        else:
1196            result += "#define %s_CALLBACK NULL\n" % self.name
1197
1198        defval = self.default_value(dependencies)
1199        if defval:
1200            hexcoded = ''.join("\\x%02x" % ord(defval[i:i+1]) for i in range(len(defval)))
1201            result += '#define %s_DEFAULT (const pb_byte_t*)"%s\\x00"\n' % (self.name, hexcoded)
1202        else:
1203            result += '#define %s_DEFAULT NULL\n' % self.name
1204
1205        for field in sorted(self.fields):
1206            if field.pbtype in ['MESSAGE', 'MSG_W_CB']:
1207                result += "#define %s_%s_MSGTYPE %s\n" % (self.name, field.name, field.ctype)
1208            elif field.rules == 'ONEOF':
1209                for member in field.fields:
1210                    if member.pbtype in ['MESSAGE', 'MSG_W_CB']:
1211                        result += "#define %s_%s_%s_MSGTYPE %s\n" % (self.name, member.union_name, member.name, member.ctype)
1212
1213        return result
1214
1215    def fields_declaration_cpp_lookup(self):
1216        result = 'template <>\n'
1217        result += 'struct MessageDescriptor<%s> {\n' % (self.name)
1218        result += '    static PB_INLINE_CONSTEXPR const pb_size_t fields_array_length = %d;\n' % (self.count_all_fields())
1219        result += '    static inline const pb_msgdesc_t* fields() {\n'
1220        result += '        return &%s_msg;\n' % (self.name)
1221        result += '    }\n'
1222        result += '};'
1223        return result
1224
1225    def fields_definition(self, dependencies):
1226        '''Return the field descriptor definition that goes in .pb.c file.'''
1227        width = self.required_descriptor_width(dependencies)
1228        if width == 1:
1229          width = 'AUTO'
1230
1231        result = 'PB_BIND(%s, %s, %s)\n' % (self.name, self.name, width)
1232        return result
1233
1234    def required_descriptor_width(self, dependencies):
1235        '''Estimate how many words are necessary for each field descriptor.'''
1236        if self.descriptorsize != nanopb_pb2.DS_AUTO:
1237            return int(self.descriptorsize)
1238
1239        if not self.fields:
1240          return 1
1241
1242        max_tag = max(field.tag for field in self.all_fields())
1243        max_offset = self.data_size(dependencies)
1244        max_arraysize = max((field.max_count or 0) for field in self.all_fields())
1245        max_datasize = max(field.data_size(dependencies) for field in self.all_fields())
1246
1247        if max_arraysize > 0xFFFF:
1248            return 8
1249        elif (max_tag > 0x3FF or max_offset > 0xFFFF or
1250              max_arraysize > 0x0FFF or max_datasize > 0x0FFF):
1251            return 4
1252        elif max_tag > 0x3F or max_offset > 0xFF:
1253            return 2
1254        else:
1255            # NOTE: Macro logic in pb.h ensures that width 1 will
1256            # be raised to 2 automatically for string/submsg fields
1257            # and repeated fields. Thus only tag and offset need to
1258            # be checked.
1259            return 1
1260
1261    def data_size(self, dependencies):
1262        '''Return approximate sizeof(struct) in the compiled code.'''
1263        return sum(f.data_size(dependencies) for f in self.fields)
1264
1265    def encoded_size(self, dependencies):
1266        '''Return the maximum size that this message can take when encoded.
1267        If the size cannot be determined, returns None.
1268        '''
1269        size = EncodedSize(0)
1270        for field in self.fields:
1271            fsize = field.encoded_size(dependencies)
1272            if fsize is None:
1273                return None
1274            size += fsize
1275
1276        return size
1277
1278    def default_value(self, dependencies):
1279        '''Generate serialized protobuf message that contains the
1280        default values for optional fields.'''
1281
1282        if not self.desc:
1283            return b''
1284
1285        if self.desc.options.map_entry:
1286            return b''
1287
1288        optional_only = copy.deepcopy(self.desc)
1289
1290        # Remove fields without default values
1291        # The iteration is done in reverse order to avoid remove() messing up iteration.
1292        for field in reversed(list(optional_only.field)):
1293            field.ClearField(str('extendee'))
1294            parsed_field = self.field_for_tag(field.number)
1295            if parsed_field is None or parsed_field.allocation != 'STATIC':
1296                optional_only.field.remove(field)
1297            elif (field.label == FieldD.LABEL_REPEATED or
1298                  field.type == FieldD.TYPE_MESSAGE):
1299                optional_only.field.remove(field)
1300            elif hasattr(field, 'oneof_index') and field.HasField('oneof_index'):
1301                optional_only.field.remove(field)
1302            elif field.type == FieldD.TYPE_ENUM:
1303                # The partial descriptor doesn't include the enum type
1304                # so we fake it with int64.
1305                enumname = names_from_type_name(field.type_name)
1306                try:
1307                    enumtype = dependencies[str(enumname)]
1308                except KeyError:
1309                    raise Exception("Could not find enum type %s while generating default values for %s.\n" % (enumname, self.name)
1310                                    + "Try passing all source files to generator at once, or use -I option.")
1311
1312                if field.HasField('default_value'):
1313                    defvals = [v for n,v in enumtype.values if n.parts[-1] == field.default_value]
1314                else:
1315                    # If no default is specified, the default is the first value.
1316                    defvals = [v for n,v in enumtype.values]
1317                if defvals and defvals[0] != 0:
1318                    field.type = FieldD.TYPE_INT64
1319                    field.default_value = str(defvals[0])
1320                    field.ClearField(str('type_name'))
1321                else:
1322                    optional_only.field.remove(field)
1323            elif not field.HasField('default_value'):
1324                optional_only.field.remove(field)
1325
1326        if len(optional_only.field) == 0:
1327            return b''
1328
1329        optional_only.ClearField(str('oneof_decl'))
1330        optional_only.ClearField(str('nested_type'))
1331        optional_only.ClearField(str('extension'))
1332        optional_only.ClearField(str('enum_type'))
1333        desc = google.protobuf.descriptor.MakeDescriptor(optional_only)
1334        msg = reflection.MakeClass(desc)()
1335
1336        for field in optional_only.field:
1337            if field.type == FieldD.TYPE_STRING:
1338                setattr(msg, field.name, field.default_value)
1339            elif field.type == FieldD.TYPE_BYTES:
1340                setattr(msg, field.name, codecs.escape_decode(field.default_value)[0])
1341            elif field.type in [FieldD.TYPE_FLOAT, FieldD.TYPE_DOUBLE]:
1342                setattr(msg, field.name, float(field.default_value))
1343            elif field.type == FieldD.TYPE_BOOL:
1344                setattr(msg, field.name, field.default_value == 'true')
1345            else:
1346                setattr(msg, field.name, int(field.default_value))
1347
1348        return msg.SerializeToString()
1349
1350
1351# ---------------------------------------------------------------------------
1352#                    Processing of entire .proto files
1353# ---------------------------------------------------------------------------
1354
1355def iterate_messages(desc, flatten = False, names = Names()):
1356    '''Recursively find all messages. For each, yield name, DescriptorProto.'''
1357    if hasattr(desc, 'message_type'):
1358        submsgs = desc.message_type
1359    else:
1360        submsgs = desc.nested_type
1361
1362    for submsg in submsgs:
1363        sub_names = names + submsg.name
1364        if flatten:
1365            yield Names(submsg.name), submsg
1366        else:
1367            yield sub_names, submsg
1368
1369        for x in iterate_messages(submsg, flatten, sub_names):
1370            yield x
1371
1372def iterate_extensions(desc, flatten = False, names = Names()):
1373    '''Recursively find all extensions.
1374    For each, yield name, FieldDescriptorProto.
1375    '''
1376    for extension in desc.extension:
1377        yield names, extension
1378
1379    for subname, subdesc in iterate_messages(desc, flatten, names):
1380        for extension in subdesc.extension:
1381            yield subname, extension
1382
1383def toposort2(data):
1384    '''Topological sort.
1385    From http://code.activestate.com/recipes/577413-topological-sort/
1386    This function is under the MIT license.
1387    '''
1388    for k, v in list(data.items()):
1389        v.discard(k) # Ignore self dependencies
1390    extra_items_in_deps = reduce(set.union, list(data.values()), set()) - set(data.keys())
1391    data.update(dict([(item, set()) for item in extra_items_in_deps]))
1392    while True:
1393        ordered = set(item for item,dep in list(data.items()) if not dep)
1394        if not ordered:
1395            break
1396        for item in sorted(ordered):
1397            yield item
1398        data = dict([(item, (dep - ordered)) for item,dep in list(data.items())
1399                if item not in ordered])
1400    assert not data, "A cyclic dependency exists amongst %r" % data
1401
1402def sort_dependencies(messages):
1403    '''Sort a list of Messages based on dependencies.'''
1404    dependencies = {}
1405    message_by_name = {}
1406    for message in messages:
1407        dependencies[str(message.name)] = set(message.get_dependencies())
1408        message_by_name[str(message.name)] = message
1409
1410    for msgname in toposort2(dependencies):
1411        if msgname in message_by_name:
1412            yield message_by_name[msgname]
1413
1414def make_identifier(headername):
1415    '''Make #ifndef identifier that contains uppercase A-Z and digits 0-9'''
1416    result = ""
1417    for c in headername.upper():
1418        if c.isalnum():
1419            result += c
1420        else:
1421            result += '_'
1422    return result
1423
1424class ProtoFile:
1425    def __init__(self, fdesc, file_options):
1426        '''Takes a FileDescriptorProto and parses it.'''
1427        self.fdesc = fdesc
1428        self.file_options = file_options
1429        self.dependencies = {}
1430        self.math_include_required = False
1431        self.parse()
1432        for message in self.messages:
1433            if message.math_include_required:
1434                self.math_include_required = True
1435                break
1436
1437        # Some of types used in this file probably come from the file itself.
1438        # Thus it has implicit dependency on itself.
1439        self.add_dependency(self)
1440
1441    def parse(self):
1442        self.enums = []
1443        self.messages = []
1444        self.extensions = []
1445
1446        mangle_names = self.file_options.mangle_names
1447        flatten = mangle_names == nanopb_pb2.M_FLATTEN
1448        strip_prefix = None
1449        replacement_prefix = None
1450        if mangle_names == nanopb_pb2.M_STRIP_PACKAGE:
1451            strip_prefix = "." + self.fdesc.package
1452        elif mangle_names == nanopb_pb2.M_PACKAGE_INITIALS:
1453            strip_prefix = "." + self.fdesc.package
1454            replacement_prefix = ""
1455            for part in self.fdesc.package.split("."):
1456                replacement_prefix += part[0]
1457        elif self.file_options.package:
1458            strip_prefix = "." + self.fdesc.package
1459            replacement_prefix = self.file_options.package
1460
1461
1462        def create_name(names):
1463            if mangle_names in (nanopb_pb2.M_NONE, nanopb_pb2.M_PACKAGE_INITIALS):
1464                return base_name + names
1465            if mangle_names == nanopb_pb2.M_STRIP_PACKAGE:
1466                return Names(names)
1467            single_name = names
1468            if isinstance(names, Names):
1469                single_name = names.parts[-1]
1470            return Names(single_name)
1471
1472        def mangle_field_typename(typename):
1473            if mangle_names == nanopb_pb2.M_FLATTEN:
1474                return "." + typename.split(".")[-1]
1475            if strip_prefix is not None and typename.startswith(strip_prefix):
1476                if replacement_prefix is not None:
1477                    return "." + replacement_prefix + typename[len(strip_prefix):]
1478                else:
1479                    return typename[len(strip_prefix):]
1480            if self.file_options.package:
1481                return "." + replacement_prefix + typename
1482            return typename
1483
1484        if replacement_prefix is not None:
1485            base_name = Names(replacement_prefix.split('.'))
1486        elif self.fdesc.package:
1487            base_name = Names(self.fdesc.package.split('.'))
1488        else:
1489            base_name = Names()
1490
1491        for enum in self.fdesc.enum_type:
1492            name = create_name(enum.name)
1493            enum_options = get_nanopb_suboptions(enum, self.file_options, name)
1494            self.enums.append(Enum(name, enum, enum_options))
1495
1496        for names, message in iterate_messages(self.fdesc, flatten):
1497            name = create_name(names)
1498            message_options = get_nanopb_suboptions(message, self.file_options, name)
1499
1500            if message_options.skip_message:
1501                continue
1502
1503            message = copy.deepcopy(message)
1504            for field in message.field:
1505                if field.type in (FieldD.TYPE_MESSAGE, FieldD.TYPE_ENUM):
1506                    field.type_name = mangle_field_typename(field.type_name)
1507
1508            self.messages.append(Message(name, message, message_options))
1509            for enum in message.enum_type:
1510                name = create_name(names + enum.name)
1511                enum_options = get_nanopb_suboptions(enum, message_options, name)
1512                self.enums.append(Enum(name, enum, enum_options))
1513
1514        for names, extension in iterate_extensions(self.fdesc, flatten):
1515            name = create_name(names + extension.name)
1516            field_options = get_nanopb_suboptions(extension, self.file_options, name)
1517
1518            extension = copy.deepcopy(extension)
1519            if extension.type in (FieldD.TYPE_MESSAGE, FieldD.TYPE_ENUM):
1520                extension.type_name = mangle_field_typename(extension.type_name)
1521
1522            if field_options.type != nanopb_pb2.FT_IGNORE:
1523                self.extensions.append(ExtensionField(name, extension, field_options))
1524
1525    def add_dependency(self, other):
1526        for enum in other.enums:
1527            self.dependencies[str(enum.names)] = enum
1528            enum.protofile = other
1529
1530        for msg in other.messages:
1531            self.dependencies[str(msg.name)] = msg
1532            msg.protofile = other
1533
1534        # Fix field default values where enum short names are used.
1535        for enum in other.enums:
1536            if not enum.options.long_names:
1537                for message in self.messages:
1538                    for field in message.all_fields():
1539                        if field.default in enum.value_longnames:
1540                            idx = enum.value_longnames.index(field.default)
1541                            field.default = enum.values[idx][0]
1542
1543        # Fix field data types where enums have negative values.
1544        for enum in other.enums:
1545            if not enum.has_negative():
1546                for message in self.messages:
1547                    for field in message.all_fields():
1548                        if field.pbtype == 'ENUM' and field.ctype == enum.names:
1549                            field.pbtype = 'UENUM'
1550
1551    def generate_header(self, includes, headername, options):
1552        '''Generate content for a header file.
1553        Generates strings, which should be concatenated and stored to file.
1554        '''
1555
1556        yield '/* Automatically generated nanopb header */\n'
1557        if options.notimestamp:
1558            yield '/* Generated by %s */\n\n' % (nanopb_version)
1559        else:
1560            yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
1561
1562        if self.fdesc.package:
1563            symbol = make_identifier(self.fdesc.package + '_' + headername)
1564        else:
1565            symbol = make_identifier(headername)
1566        yield '#ifndef PB_%s_INCLUDED\n' % symbol
1567        yield '#define PB_%s_INCLUDED\n' % symbol
1568        if self.math_include_required:
1569            yield '#include <math.h>\n'
1570        try:
1571            yield options.libformat % ('pb.h')
1572        except TypeError:
1573            # no %s specified - use whatever was passed in as options.libformat
1574            yield options.libformat
1575        yield '\n'
1576
1577        for incfile in self.file_options.include:
1578            # allow including system headers
1579            if (incfile.startswith('<')):
1580                yield '#include %s\n' % incfile
1581            else:
1582                yield options.genformat % incfile
1583                yield '\n'
1584
1585        for incfile in includes:
1586            noext = os.path.splitext(incfile)[0]
1587            yield options.genformat % (noext + options.extension + options.header_extension)
1588            yield '\n'
1589
1590        if Globals.protoc_insertion_points:
1591            yield '/* @@protoc_insertion_point(includes) */\n'
1592
1593        yield '\n'
1594
1595        yield '#if PB_PROTO_HEADER_VERSION != 40\n'
1596        yield '#error Regenerate this file with the current version of nanopb generator.\n'
1597        yield '#endif\n'
1598        yield '\n'
1599
1600        if self.enums:
1601            yield '/* Enum definitions */\n'
1602            for enum in self.enums:
1603                yield str(enum) + '\n\n'
1604
1605        if self.messages:
1606            yield '/* Struct definitions */\n'
1607            for msg in sort_dependencies(self.messages):
1608                yield msg.types()
1609                yield str(msg) + '\n'
1610            yield '\n'
1611
1612        if self.extensions:
1613            yield '/* Extensions */\n'
1614            for extension in self.extensions:
1615                yield extension.extension_decl()
1616            yield '\n'
1617
1618        if self.enums:
1619                yield '/* Helper constants for enums */\n'
1620                for enum in self.enums:
1621                    yield enum.auxiliary_defines() + '\n'
1622                yield '\n'
1623
1624        yield '#ifdef __cplusplus\n'
1625        yield 'extern "C" {\n'
1626        yield '#endif\n\n'
1627
1628        if self.messages:
1629            yield '/* Initializer values for message structs */\n'
1630            for msg in self.messages:
1631                identifier = '%s_init_default' % msg.name
1632                yield '#define %-40s %s\n' % (identifier, msg.get_initializer(False))
1633            for msg in self.messages:
1634                identifier = '%s_init_zero' % msg.name
1635                yield '#define %-40s %s\n' % (identifier, msg.get_initializer(True))
1636            yield '\n'
1637
1638            yield '/* Field tags (for use in manual encoding/decoding) */\n'
1639            for msg in sort_dependencies(self.messages):
1640                for field in msg.fields:
1641                    yield field.tags()
1642            for extension in self.extensions:
1643                yield extension.tags()
1644            yield '\n'
1645
1646            yield '/* Struct field encoding specification for nanopb */\n'
1647            for msg in self.messages:
1648                yield msg.fields_declaration(self.dependencies) + '\n'
1649            for msg in self.messages:
1650                yield 'extern const pb_msgdesc_t %s_msg;\n' % msg.name
1651            yield '\n'
1652
1653            yield '/* Defines for backwards compatibility with code written before nanopb-0.4.0 */\n'
1654            for msg in self.messages:
1655              yield '#define %s_fields &%s_msg\n' % (msg.name, msg.name)
1656            yield '\n'
1657
1658            yield '/* Maximum encoded size of messages (where known) */\n'
1659            messagesizes = []
1660            for msg in self.messages:
1661                identifier = '%s_size' % msg.name
1662                messagesizes.append((identifier, msg.encoded_size(self.dependencies)))
1663
1664            # If we require a symbol from another file, put a preprocessor if statement
1665            # around it to prevent compilation errors if the symbol is not actually available.
1666            local_defines = [identifier for identifier, msize in messagesizes if msize is not None]
1667            for identifier, msize in messagesizes:
1668                if msize is not None:
1669                    cpp_guard = msize.get_cpp_guard(local_defines)
1670                    yield cpp_guard
1671                    yield msize.get_declarations()
1672                    yield '#define %-40s %s\n' % (identifier, msize)
1673                    if cpp_guard: yield "#endif\n"
1674                else:
1675                    yield '/* %s depends on runtime parameters */\n' % identifier
1676            yield '\n'
1677
1678            if [msg for msg in self.messages if hasattr(msg,'msgid')]:
1679              yield '/* Message IDs (where set with "msgid" option) */\n'
1680              for msg in self.messages:
1681                  if hasattr(msg,'msgid'):
1682                      yield '#define PB_MSG_%d %s\n' % (msg.msgid, msg.name)
1683              yield '\n'
1684
1685              symbol = make_identifier(headername.split('.')[0])
1686              yield '#define %s_MESSAGES \\\n' % symbol
1687
1688              for msg in self.messages:
1689                  m = "-1"
1690                  msize = msg.encoded_size(self.dependencies)
1691                  if msize is not None:
1692                      m = msize
1693                  if hasattr(msg,'msgid'):
1694                      yield '\tPB_MSG(%d,%s,%s) \\\n' % (msg.msgid, m, msg.name)
1695              yield '\n'
1696
1697              for msg in self.messages:
1698                  if hasattr(msg,'msgid'):
1699                      yield '#define %s_msgid %d\n' % (msg.name, msg.msgid)
1700              yield '\n'
1701
1702        yield '#ifdef __cplusplus\n'
1703        yield '} /* extern "C" */\n'
1704        yield '#endif\n'
1705
1706        if options.cpp_descriptors:
1707            yield '\n'
1708            yield '#ifdef __cplusplus\n'
1709            yield '/* Message descriptors for nanopb */\n'
1710            yield 'namespace nanopb {\n'
1711            for msg in self.messages:
1712                yield msg.fields_declaration_cpp_lookup() + '\n'
1713            yield '}  // namespace nanopb\n'
1714            yield '\n'
1715            yield '#endif  /* __cplusplus */\n'
1716            yield '\n'
1717
1718        if Globals.protoc_insertion_points:
1719            yield '/* @@protoc_insertion_point(eof) */\n'
1720
1721        # End of header
1722        yield '\n#endif\n'
1723
1724    def generate_source(self, headername, options):
1725        '''Generate content for a source file.'''
1726
1727        yield '/* Automatically generated nanopb constant definitions */\n'
1728        if options.notimestamp:
1729            yield '/* Generated by %s */\n\n' % (nanopb_version)
1730        else:
1731            yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
1732        yield options.genformat % (headername)
1733        yield '\n'
1734
1735        if Globals.protoc_insertion_points:
1736            yield '/* @@protoc_insertion_point(includes) */\n'
1737
1738        yield '#if PB_PROTO_HEADER_VERSION != 40\n'
1739        yield '#error Regenerate this file with the current version of nanopb generator.\n'
1740        yield '#endif\n'
1741        yield '\n'
1742
1743        for msg in self.messages:
1744            yield msg.fields_definition(self.dependencies) + '\n\n'
1745
1746        for ext in self.extensions:
1747            yield ext.extension_def(self.dependencies) + '\n'
1748
1749        for enum in self.enums:
1750            yield enum.enum_to_string_definition() + '\n'
1751
1752        # Add checks for numeric limits
1753        if self.messages:
1754            largest_msg = max(self.messages, key = lambda m: m.count_required_fields())
1755            largest_count = largest_msg.count_required_fields()
1756            if largest_count > 64:
1757                yield '\n/* Check that missing required fields will be properly detected */\n'
1758                yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count
1759                yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name
1760                yield '       setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count
1761                yield '#endif\n'
1762
1763        # Add check for sizeof(double)
1764        has_double = False
1765        for msg in self.messages:
1766            for field in msg.all_fields():
1767                if field.ctype == 'double':
1768                    has_double = True
1769
1770        if has_double:
1771            yield '\n'
1772            yield '#ifndef PB_CONVERT_DOUBLE_FLOAT\n'
1773            yield '/* On some platforms (such as AVR), double is really float.\n'
1774            yield ' * To be able to encode/decode double on these platforms, you need.\n'
1775            yield ' * to define PB_CONVERT_DOUBLE_FLOAT in pb.h or compiler command line.\n'
1776            yield ' */\n'
1777            yield 'PB_STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n'
1778            yield '#endif\n'
1779
1780        yield '\n'
1781
1782        if Globals.protoc_insertion_points:
1783            yield '/* @@protoc_insertion_point(eof) */\n'
1784
1785# ---------------------------------------------------------------------------
1786#                    Options parsing for the .proto files
1787# ---------------------------------------------------------------------------
1788
1789from fnmatch import fnmatchcase
1790
1791def read_options_file(infile):
1792    '''Parse a separate options file to list:
1793        [(namemask, options), ...]
1794    '''
1795    results = []
1796    data = infile.read()
1797    data = re.sub(r'/\*.*?\*/', '', data, flags = re.MULTILINE)
1798    data = re.sub(r'//.*?$', '', data, flags = re.MULTILINE)
1799    data = re.sub(r'#.*?$', '', data, flags = re.MULTILINE)
1800    for i, line in enumerate(data.split('\n')):
1801        line = line.strip()
1802        if not line:
1803            continue
1804
1805        parts = line.split(None, 1)
1806
1807        if len(parts) < 2:
1808            sys.stderr.write("%s:%d: " % (infile.name, i + 1) +
1809                             "Option lines should have space between field name and options. " +
1810                             "Skipping line: '%s'\n" % line)
1811            sys.exit(1)
1812
1813        opts = nanopb_pb2.NanoPBOptions()
1814
1815        try:
1816            text_format.Merge(parts[1], opts)
1817        except Exception as e:
1818            sys.stderr.write("%s:%d: " % (infile.name, i + 1) +
1819                             "Unparseable option line: '%s'. " % line +
1820                             "Error: %s\n" % str(e))
1821            sys.exit(1)
1822        results.append((parts[0], opts))
1823
1824    return results
1825
1826def get_nanopb_suboptions(subdesc, options, name):
1827    '''Get copy of options, and merge information from subdesc.'''
1828    new_options = nanopb_pb2.NanoPBOptions()
1829    new_options.CopyFrom(options)
1830
1831    if hasattr(subdesc, 'syntax') and subdesc.syntax == "proto3":
1832        new_options.proto3 = True
1833
1834    # Handle options defined in a separate file
1835    dotname = '.'.join(name.parts)
1836    for namemask, options in Globals.separate_options:
1837        if fnmatchcase(dotname, namemask):
1838            Globals.matched_namemasks.add(namemask)
1839            new_options.MergeFrom(options)
1840
1841    # Handle options defined in .proto
1842    if isinstance(subdesc.options, descriptor.FieldOptions):
1843        ext_type = nanopb_pb2.nanopb
1844    elif isinstance(subdesc.options, descriptor.FileOptions):
1845        ext_type = nanopb_pb2.nanopb_fileopt
1846    elif isinstance(subdesc.options, descriptor.MessageOptions):
1847        ext_type = nanopb_pb2.nanopb_msgopt
1848    elif isinstance(subdesc.options, descriptor.EnumOptions):
1849        ext_type = nanopb_pb2.nanopb_enumopt
1850    else:
1851        raise Exception("Unknown options type")
1852
1853    if subdesc.options.HasExtension(ext_type):
1854        ext = subdesc.options.Extensions[ext_type]
1855        new_options.MergeFrom(ext)
1856
1857    if Globals.verbose_options:
1858        sys.stderr.write("Options for " + dotname + ": ")
1859        sys.stderr.write(text_format.MessageToString(new_options) + "\n")
1860
1861    return new_options
1862
1863
1864# ---------------------------------------------------------------------------
1865#                         Command line interface
1866# ---------------------------------------------------------------------------
1867
1868import sys
1869import os.path
1870from optparse import OptionParser
1871
1872optparser = OptionParser(
1873    usage = "Usage: nanopb_generator.py [options] file.pb ...",
1874    epilog = "Compile file.pb from file.proto by: 'protoc -ofile.pb file.proto'. " +
1875             "Output will be written to file.pb.h and file.pb.c.")
1876optparser.add_option("--version", dest="version", action="store_true",
1877    help="Show version info and exit")
1878optparser.add_option("-x", dest="exclude", metavar="FILE", action="append", default=[],
1879    help="Exclude file from generated #include list.")
1880optparser.add_option("-e", "--extension", dest="extension", metavar="EXTENSION", default=".pb",
1881    help="Set extension to use instead of '.pb' for generated files. [default: %default]")
1882optparser.add_option("-H", "--header-extension", dest="header_extension", metavar="EXTENSION", default=".h",
1883    help="Set extension to use for generated header files. [default: %default]")
1884optparser.add_option("-S", "--source-extension", dest="source_extension", metavar="EXTENSION", default=".c",
1885    help="Set extension to use for generated source files. [default: %default]")
1886optparser.add_option("-f", "--options-file", dest="options_file", metavar="FILE", default="%s.options",
1887    help="Set name of a separate generator options file.")
1888optparser.add_option("-I", "--options-path", dest="options_path", metavar="DIR",
1889    action="append", default = [],
1890    help="Search for .options files additionally in this path")
1891optparser.add_option("--error-on-unmatched", dest="error_on_unmatched", action="store_true", default=False,
1892                     help ="Stop generation if there are unmatched fields in options file")
1893optparser.add_option("--no-error-on-unmatched", dest="error_on_unmatched", action="store_false", default=False,
1894                     help ="Continue generation if there are unmatched fields in options file (default)")
1895optparser.add_option("-D", "--output-dir", dest="output_dir",
1896                     metavar="OUTPUTDIR", default=None,
1897                     help="Output directory of .pb.h and .pb.c files")
1898optparser.add_option("-Q", "--generated-include-format", dest="genformat",
1899    metavar="FORMAT", default='#include "%s"',
1900    help="Set format string to use for including other .pb.h files. [default: %default]")
1901optparser.add_option("-L", "--library-include-format", dest="libformat",
1902    metavar="FORMAT", default='#include <%s>',
1903    help="Set format string to use for including the nanopb pb.h header. [default: %default]")
1904optparser.add_option("--strip-path", dest="strip_path", action="store_true", default=False,
1905    help="Strip directory path from #included .pb.h file name")
1906optparser.add_option("--no-strip-path", dest="strip_path", action="store_false",
1907    help="Opposite of --strip-path (default since 0.4.0)")
1908optparser.add_option("--cpp-descriptors", action="store_true",
1909    help="Generate C++ descriptors to lookup by type (e.g. pb_field_t for a message)")
1910optparser.add_option("-T", "--no-timestamp", dest="notimestamp", action="store_true", default=True,
1911    help="Don't add timestamp to .pb.h and .pb.c preambles (default since 0.4.0)")
1912optparser.add_option("-t", "--timestamp", dest="notimestamp", action="store_false", default=True,
1913    help="Add timestamp to .pb.h and .pb.c preambles")
1914optparser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False,
1915    help="Don't print anything except errors.")
1916optparser.add_option("-v", "--verbose", dest="verbose", action="store_true", default=False,
1917    help="Print more information.")
1918optparser.add_option("-s", dest="settings", metavar="OPTION:VALUE", action="append", default=[],
1919    help="Set generator option (max_size, max_count etc.).")
1920optparser.add_option("--protoc-insertion-points", dest="protoc_insertion_points", action="store_true", default=False,
1921                     help="Include insertion point comments in output for use by custom protoc plugins")
1922
1923def parse_file(filename, fdesc, options):
1924    '''Parse a single file. Returns a ProtoFile instance.'''
1925    toplevel_options = nanopb_pb2.NanoPBOptions()
1926    for s in options.settings:
1927        text_format.Merge(s, toplevel_options)
1928
1929    if not fdesc:
1930        data = open(filename, 'rb').read()
1931        fdesc = descriptor.FileDescriptorSet.FromString(data).file[0]
1932
1933    # Check if there is a separate .options file
1934    had_abspath = False
1935    try:
1936        optfilename = options.options_file % os.path.splitext(filename)[0]
1937    except TypeError:
1938        # No %s specified, use the filename as-is
1939        optfilename = options.options_file
1940        had_abspath = True
1941
1942    paths = ['.'] + options.options_path
1943    for p in paths:
1944        if os.path.isfile(os.path.join(p, optfilename)):
1945            optfilename = os.path.join(p, optfilename)
1946            if options.verbose:
1947                sys.stderr.write('Reading options from ' + optfilename + '\n')
1948            Globals.separate_options = read_options_file(open(optfilename, openmode_unicode))
1949            break
1950    else:
1951        # If we are given a full filename and it does not exist, give an error.
1952        # However, don't give error when we automatically look for .options file
1953        # with the same name as .proto.
1954        if options.verbose or had_abspath:
1955            sys.stderr.write('Options file not found: ' + optfilename + '\n')
1956        Globals.separate_options = []
1957
1958    Globals.matched_namemasks = set()
1959    Globals.protoc_insertion_points = options.protoc_insertion_points
1960
1961    # Parse the file
1962    file_options = get_nanopb_suboptions(fdesc, toplevel_options, Names([filename]))
1963    f = ProtoFile(fdesc, file_options)
1964    f.optfilename = optfilename
1965
1966    return f
1967
1968def process_file(filename, fdesc, options, other_files = {}):
1969    '''Process a single file.
1970    filename: The full path to the .proto or .pb source file, as string.
1971    fdesc: The loaded FileDescriptorSet, or None to read from the input file.
1972    options: Command line options as they come from OptionsParser.
1973
1974    Returns a dict:
1975        {'headername': Name of header file,
1976         'headerdata': Data for the .h header file,
1977         'sourcename': Name of the source code file,
1978         'sourcedata': Data for the .c source code file
1979        }
1980    '''
1981    f = parse_file(filename, fdesc, options)
1982
1983    # Provide dependencies if available
1984    for dep in f.fdesc.dependency:
1985        if dep in other_files:
1986            f.add_dependency(other_files[dep])
1987
1988    # Decide the file names
1989    noext = os.path.splitext(filename)[0]
1990    headername = noext + options.extension + options.header_extension
1991    sourcename = noext + options.extension + options.source_extension
1992
1993    if options.strip_path:
1994        headerbasename = os.path.basename(headername)
1995    else:
1996        headerbasename = headername
1997
1998    # List of .proto files that should not be included in the C header file
1999    # even if they are mentioned in the source .proto.
2000    excludes = ['nanopb.proto', 'google/protobuf/descriptor.proto'] + options.exclude + list(f.file_options.exclude)
2001    includes = [d for d in f.fdesc.dependency if d not in excludes]
2002
2003    headerdata = ''.join(f.generate_header(includes, headerbasename, options))
2004    sourcedata = ''.join(f.generate_source(headerbasename, options))
2005
2006    # Check if there were any lines in .options that did not match a member
2007    unmatched = [n for n,o in Globals.separate_options if n not in Globals.matched_namemasks]
2008    if unmatched:
2009        if options.error_on_unmatched:
2010            raise Exception("Following patterns in " + f.optfilename + " did not match any fields: "
2011                            + ', '.join(unmatched));
2012        elif not options.quiet:
2013            sys.stderr.write("Following patterns in " + f.optfilename + " did not match any fields: "
2014                            + ', '.join(unmatched) + "\n")
2015
2016        if not Globals.verbose_options:
2017            sys.stderr.write("Use  protoc --nanopb-out=-v:.   to see a list of the field names.\n")
2018
2019    return {'headername': headername, 'headerdata': headerdata,
2020            'sourcename': sourcename, 'sourcedata': sourcedata}
2021
2022def main_cli():
2023    '''Main function when invoked directly from the command line.'''
2024
2025    options, filenames = optparser.parse_args()
2026
2027    if options.version:
2028        print(nanopb_version)
2029        sys.exit(0)
2030
2031    if not filenames:
2032        optparser.print_help()
2033        sys.exit(1)
2034
2035    if options.quiet:
2036        options.verbose = False
2037
2038    if options.output_dir and not os.path.exists(options.output_dir):
2039        optparser.print_help()
2040        sys.stderr.write("\noutput_dir does not exist: %s\n" % options.output_dir)
2041        sys.exit(1)
2042
2043    if options.verbose:
2044        sys.stderr.write("Nanopb version %s\n" % nanopb_version)
2045        sys.stderr.write('Google Python protobuf library imported from %s, version %s\n'
2046                         % (google.protobuf.__file__, google.protobuf.__version__))
2047
2048    # Load .pb files into memory and compile any .proto files.
2049    fdescs = {}
2050    include_path = ['-I%s' % p for p in options.options_path]
2051    for filename in filenames:
2052        if filename.endswith(".proto"):
2053            with TemporaryDirectory() as tmpdir:
2054                tmpname = os.path.join(tmpdir, os.path.basename(filename) + ".pb")
2055                status = invoke_protoc(["protoc"] + include_path + ['--include_imports', '-o' + tmpname, filename])
2056                if status != 0: sys.exit(status)
2057                data = open(tmpname, 'rb').read()
2058        else:
2059            data = open(filename, 'rb').read()
2060
2061        fdesc = descriptor.FileDescriptorSet.FromString(data).file[-1]
2062        fdescs[fdesc.name] = fdesc
2063
2064    # Process any include files first, in order to have them
2065    # available as dependencies
2066    other_files = {}
2067    for fdesc in fdescs.values():
2068        other_files[fdesc.name] = parse_file(fdesc.name, fdesc, options)
2069
2070    # Then generate the headers / sources
2071    Globals.verbose_options = options.verbose
2072    for fdesc in fdescs.values():
2073        results = process_file(fdesc.name, fdesc, options, other_files)
2074
2075        base_dir = options.output_dir or ''
2076        to_write = [
2077            (os.path.join(base_dir, results['headername']), results['headerdata']),
2078            (os.path.join(base_dir, results['sourcename']), results['sourcedata']),
2079        ]
2080
2081        if not options.quiet:
2082            paths = " and ".join([x[0] for x in to_write])
2083            sys.stderr.write("Writing to %s\n" % paths)
2084
2085        for path, data in to_write:
2086            dirname = os.path.dirname(path)
2087            if dirname and not os.path.exists(dirname):
2088                os.makedirs(dirname)
2089
2090            with open(path, 'w') as f:
2091                f.write(data)
2092
2093def main_plugin():
2094    '''Main function when invoked as a protoc plugin.'''
2095
2096    import io, sys
2097    if sys.platform == "win32":
2098        import os, msvcrt
2099        # Set stdin and stdout to binary mode
2100        msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
2101        msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
2102
2103    data = io.open(sys.stdin.fileno(), "rb").read()
2104
2105    request = plugin_pb2.CodeGeneratorRequest.FromString(data)
2106
2107    try:
2108        # Versions of Python prior to 2.7.3 do not support unicode
2109        # input to shlex.split(). Try to convert to str if possible.
2110        params = str(request.parameter)
2111    except UnicodeEncodeError:
2112        params = request.parameter
2113
2114    import shlex
2115    args = shlex.split(params)
2116
2117    if len(args) == 1 and ',' in args[0]:
2118        # For compatibility with other protoc plugins, support options
2119        # separated by comma.
2120        lex = shlex.shlex(params)
2121        lex.whitespace_split = True
2122        lex.whitespace = ','
2123        lex.commenters = ''
2124        args = list(lex)
2125
2126    optparser.usage = "Usage: protoc --nanopb_out=[options][,more_options]:outdir file.proto"
2127    optparser.epilog = "Output will be written to file.pb.h and file.pb.c."
2128
2129    if '-h' in args or '--help' in args:
2130        # By default optparser prints help to stdout, which doesn't work for
2131        # protoc plugins.
2132        optparser.print_help(sys.stderr)
2133        sys.exit(1)
2134
2135    options, dummy = optparser.parse_args(args)
2136
2137    if options.version:
2138        sys.stderr.write('%s\n' % (nanopb_version))
2139        sys.exit(0)
2140
2141    Globals.verbose_options = options.verbose
2142
2143    if options.verbose:
2144        sys.stderr.write("Nanopb version %s\n" % nanopb_version)
2145        sys.stderr.write('Google Python protobuf library imported from %s, version %s\n'
2146                         % (google.protobuf.__file__, google.protobuf.__version__))
2147
2148    response = plugin_pb2.CodeGeneratorResponse()
2149
2150    # Google's protoc does not currently indicate the full path of proto files.
2151    # Instead always add the main file path to the search dirs, that works for
2152    # the common case.
2153    import os.path
2154    options.options_path.append(os.path.dirname(request.file_to_generate[0]))
2155
2156    # Process any include files first, in order to have them
2157    # available as dependencies
2158    other_files = {}
2159    for fdesc in request.proto_file:
2160        other_files[fdesc.name] = parse_file(fdesc.name, fdesc, options)
2161
2162    for filename in request.file_to_generate:
2163        for fdesc in request.proto_file:
2164            if fdesc.name == filename:
2165                results = process_file(filename, fdesc, options, other_files)
2166
2167                f = response.file.add()
2168                f.name = results['headername']
2169                f.content = results['headerdata']
2170
2171                f = response.file.add()
2172                f.name = results['sourcename']
2173                f.content = results['sourcedata']
2174
2175    if hasattr(plugin_pb2.CodeGeneratorResponse, "FEATURE_PROTO3_OPTIONAL"):
2176        response.supported_features = plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL
2177
2178    io.open(sys.stdout.fileno(), "wb").write(response.SerializeToString())
2179
2180if __name__ == '__main__':
2181    # Check if we are running as a plugin under protoc
2182    if 'protoc-gen-' in sys.argv[0] or '--protoc-plugin' in sys.argv:
2183        main_plugin()
2184    else:
2185        main_cli()
2186