1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Contains well known classes.
32
33This files defines well known classes which need extra maintenance including:
34  - Any
35  - Duration
36  - FieldMask
37  - Struct
38  - Timestamp
39"""
40
41__author__ = 'jieluo@google.com (Jie Luo)'
42
43import calendar
44from datetime import datetime
45from datetime import timedelta
46import six
47
48try:
49  # Since python 3
50  import collections.abc as collections_abc
51except ImportError:
52  # Won't work after python 3.8
53  import collections as collections_abc
54
55from google.protobuf.descriptor import FieldDescriptor
56
57_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
58_NANOS_PER_SECOND = 1000000000
59_NANOS_PER_MILLISECOND = 1000000
60_NANOS_PER_MICROSECOND = 1000
61_MILLIS_PER_SECOND = 1000
62_MICROS_PER_SECOND = 1000000
63_SECONDS_PER_DAY = 24 * 3600
64_DURATION_SECONDS_MAX = 315576000000
65
66
67class Any(object):
68  """Class for Any Message type."""
69
70  __slots__ = ()
71
72  def Pack(self, msg, type_url_prefix='type.googleapis.com/',
73           deterministic=None):
74    """Packs the specified message into current Any message."""
75    if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/':
76      self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
77    else:
78      self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
79    self.value = msg.SerializeToString(deterministic=deterministic)
80
81  def Unpack(self, msg):
82    """Unpacks the current Any message into specified message."""
83    descriptor = msg.DESCRIPTOR
84    if not self.Is(descriptor):
85      return False
86    msg.ParseFromString(self.value)
87    return True
88
89  def TypeName(self):
90    """Returns the protobuf type name of the inner message."""
91    # Only last part is to be used: b/25630112
92    return self.type_url.split('/')[-1]
93
94  def Is(self, descriptor):
95    """Checks if this Any represents the given protobuf type."""
96    return '/' in self.type_url and self.TypeName() == descriptor.full_name
97
98
99_EPOCH_DATETIME = datetime.utcfromtimestamp(0)
100
101
102class Timestamp(object):
103  """Class for Timestamp message type."""
104
105  __slots__ = ()
106
107  def ToJsonString(self):
108    """Converts Timestamp to RFC 3339 date string format.
109
110    Returns:
111      A string converted from timestamp. The string is always Z-normalized
112      and uses 3, 6 or 9 fractional digits as required to represent the
113      exact time. Example of the return format: '1972-01-01T10:00:20.021Z'
114    """
115    nanos = self.nanos % _NANOS_PER_SECOND
116    total_sec = self.seconds + (self.nanos - nanos) // _NANOS_PER_SECOND
117    seconds = total_sec % _SECONDS_PER_DAY
118    days = (total_sec - seconds) // _SECONDS_PER_DAY
119    dt = datetime(1970, 1, 1) + timedelta(days, seconds)
120
121    result = dt.isoformat()
122    if (nanos % 1e9) == 0:
123      # If there are 0 fractional digits, the fractional
124      # point '.' should be omitted when serializing.
125      return result + 'Z'
126    if (nanos % 1e6) == 0:
127      # Serialize 3 fractional digits.
128      return result + '.%03dZ' % (nanos / 1e6)
129    if (nanos % 1e3) == 0:
130      # Serialize 6 fractional digits.
131      return result + '.%06dZ' % (nanos / 1e3)
132    # Serialize 9 fractional digits.
133    return result + '.%09dZ' % nanos
134
135  def FromJsonString(self, value):
136    """Parse a RFC 3339 date string format to Timestamp.
137
138    Args:
139      value: A date string. Any fractional digits (or none) and any offset are
140          accepted as long as they fit into nano-seconds precision.
141          Example of accepted format: '1972-01-01T10:00:20.021-05:00'
142
143    Raises:
144      ValueError: On parsing problems.
145    """
146    timezone_offset = value.find('Z')
147    if timezone_offset == -1:
148      timezone_offset = value.find('+')
149    if timezone_offset == -1:
150      timezone_offset = value.rfind('-')
151    if timezone_offset == -1:
152      raise ValueError(
153          'Failed to parse timestamp: missing valid timezone offset.')
154    time_value = value[0:timezone_offset]
155    # Parse datetime and nanos.
156    point_position = time_value.find('.')
157    if point_position == -1:
158      second_value = time_value
159      nano_value = ''
160    else:
161      second_value = time_value[:point_position]
162      nano_value = time_value[point_position + 1:]
163    if 't' in second_value:
164      raise ValueError(
165          'time data \'{0}\' does not match format \'%Y-%m-%dT%H:%M:%S\', '
166          'lowercase \'t\' is not accepted'.format(second_value))
167    date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT)
168    td = date_object - datetime(1970, 1, 1)
169    seconds = td.seconds + td.days * _SECONDS_PER_DAY
170    if len(nano_value) > 9:
171      raise ValueError(
172          'Failed to parse Timestamp: nanos {0} more than '
173          '9 fractional digits.'.format(nano_value))
174    if nano_value:
175      nanos = round(float('0.' + nano_value) * 1e9)
176    else:
177      nanos = 0
178    # Parse timezone offsets.
179    if value[timezone_offset] == 'Z':
180      if len(value) != timezone_offset + 1:
181        raise ValueError('Failed to parse timestamp: invalid trailing'
182                         ' data {0}.'.format(value))
183    else:
184      timezone = value[timezone_offset:]
185      pos = timezone.find(':')
186      if pos == -1:
187        raise ValueError(
188            'Invalid timezone offset value: {0}.'.format(timezone))
189      if timezone[0] == '+':
190        seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
191      else:
192        seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
193    # Set seconds and nanos
194    self.seconds = int(seconds)
195    self.nanos = int(nanos)
196
197  def GetCurrentTime(self):
198    """Get the current UTC into Timestamp."""
199    self.FromDatetime(datetime.utcnow())
200
201  def ToNanoseconds(self):
202    """Converts Timestamp to nanoseconds since epoch."""
203    return self.seconds * _NANOS_PER_SECOND + self.nanos
204
205  def ToMicroseconds(self):
206    """Converts Timestamp to microseconds since epoch."""
207    return (self.seconds * _MICROS_PER_SECOND +
208            self.nanos // _NANOS_PER_MICROSECOND)
209
210  def ToMilliseconds(self):
211    """Converts Timestamp to milliseconds since epoch."""
212    return (self.seconds * _MILLIS_PER_SECOND +
213            self.nanos // _NANOS_PER_MILLISECOND)
214
215  def ToSeconds(self):
216    """Converts Timestamp to seconds since epoch."""
217    return self.seconds
218
219  def FromNanoseconds(self, nanos):
220    """Converts nanoseconds since epoch to Timestamp."""
221    self.seconds = nanos // _NANOS_PER_SECOND
222    self.nanos = nanos % _NANOS_PER_SECOND
223
224  def FromMicroseconds(self, micros):
225    """Converts microseconds since epoch to Timestamp."""
226    self.seconds = micros // _MICROS_PER_SECOND
227    self.nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND
228
229  def FromMilliseconds(self, millis):
230    """Converts milliseconds since epoch to Timestamp."""
231    self.seconds = millis // _MILLIS_PER_SECOND
232    self.nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND
233
234  def FromSeconds(self, seconds):
235    """Converts seconds since epoch to Timestamp."""
236    self.seconds = seconds
237    self.nanos = 0
238
239  def ToDatetime(self):
240    """Converts Timestamp to datetime."""
241    return _EPOCH_DATETIME + timedelta(
242        seconds=self.seconds, microseconds=_RoundTowardZero(
243            self.nanos, _NANOS_PER_MICROSECOND))
244
245  def FromDatetime(self, dt):
246    """Converts datetime to Timestamp."""
247    # Using this guide: http://wiki.python.org/moin/WorkingWithTime
248    # And this conversion guide: http://docs.python.org/library/time.html
249
250    # Turn the date parameter into a tuple (struct_time) that can then be
251    # manipulated into a long value of seconds.  During the conversion from
252    # struct_time to long, the source date in UTC, and so it follows that the
253    # correct transformation is calendar.timegm()
254    self.seconds = calendar.timegm(dt.utctimetuple())
255    self.nanos = dt.microsecond * _NANOS_PER_MICROSECOND
256
257
258class Duration(object):
259  """Class for Duration message type."""
260
261  __slots__ = ()
262
263  def ToJsonString(self):
264    """Converts Duration to string format.
265
266    Returns:
267      A string converted from self. The string format will contains
268      3, 6, or 9 fractional digits depending on the precision required to
269      represent the exact Duration value. For example: "1s", "1.010s",
270      "1.000000100s", "-3.100s"
271    """
272    _CheckDurationValid(self.seconds, self.nanos)
273    if self.seconds < 0 or self.nanos < 0:
274      result = '-'
275      seconds = - self.seconds + int((0 - self.nanos) // 1e9)
276      nanos = (0 - self.nanos) % 1e9
277    else:
278      result = ''
279      seconds = self.seconds + int(self.nanos // 1e9)
280      nanos = self.nanos % 1e9
281    result += '%d' % seconds
282    if (nanos % 1e9) == 0:
283      # If there are 0 fractional digits, the fractional
284      # point '.' should be omitted when serializing.
285      return result + 's'
286    if (nanos % 1e6) == 0:
287      # Serialize 3 fractional digits.
288      return result + '.%03ds' % (nanos / 1e6)
289    if (nanos % 1e3) == 0:
290      # Serialize 6 fractional digits.
291      return result + '.%06ds' % (nanos / 1e3)
292    # Serialize 9 fractional digits.
293    return result + '.%09ds' % nanos
294
295  def FromJsonString(self, value):
296    """Converts a string to Duration.
297
298    Args:
299      value: A string to be converted. The string must end with 's'. Any
300          fractional digits (or none) are accepted as long as they fit into
301          precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s
302
303    Raises:
304      ValueError: On parsing problems.
305    """
306    if len(value) < 1 or value[-1] != 's':
307      raise ValueError(
308          'Duration must end with letter "s": {0}.'.format(value))
309    try:
310      pos = value.find('.')
311      if pos == -1:
312        seconds = int(value[:-1])
313        nanos = 0
314      else:
315        seconds = int(value[:pos])
316        if value[0] == '-':
317          nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
318        else:
319          nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
320      _CheckDurationValid(seconds, nanos)
321      self.seconds = seconds
322      self.nanos = nanos
323    except ValueError as e:
324      raise ValueError(
325          'Couldn\'t parse duration: {0} : {1}.'.format(value, e))
326
327  def ToNanoseconds(self):
328    """Converts a Duration to nanoseconds."""
329    return self.seconds * _NANOS_PER_SECOND + self.nanos
330
331  def ToMicroseconds(self):
332    """Converts a Duration to microseconds."""
333    micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND)
334    return self.seconds * _MICROS_PER_SECOND + micros
335
336  def ToMilliseconds(self):
337    """Converts a Duration to milliseconds."""
338    millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND)
339    return self.seconds * _MILLIS_PER_SECOND + millis
340
341  def ToSeconds(self):
342    """Converts a Duration to seconds."""
343    return self.seconds
344
345  def FromNanoseconds(self, nanos):
346    """Converts nanoseconds to Duration."""
347    self._NormalizeDuration(nanos // _NANOS_PER_SECOND,
348                            nanos % _NANOS_PER_SECOND)
349
350  def FromMicroseconds(self, micros):
351    """Converts microseconds to Duration."""
352    self._NormalizeDuration(
353        micros // _MICROS_PER_SECOND,
354        (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND)
355
356  def FromMilliseconds(self, millis):
357    """Converts milliseconds to Duration."""
358    self._NormalizeDuration(
359        millis // _MILLIS_PER_SECOND,
360        (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND)
361
362  def FromSeconds(self, seconds):
363    """Converts seconds to Duration."""
364    self.seconds = seconds
365    self.nanos = 0
366
367  def ToTimedelta(self):
368    """Converts Duration to timedelta."""
369    return timedelta(
370        seconds=self.seconds, microseconds=_RoundTowardZero(
371            self.nanos, _NANOS_PER_MICROSECOND))
372
373  def FromTimedelta(self, td):
374    """Converts timedelta to Duration."""
375    self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY,
376                            td.microseconds * _NANOS_PER_MICROSECOND)
377
378  def _NormalizeDuration(self, seconds, nanos):
379    """Set Duration by seconds and nanos."""
380    # Force nanos to be negative if the duration is negative.
381    if seconds < 0 and nanos > 0:
382      seconds += 1
383      nanos -= _NANOS_PER_SECOND
384    self.seconds = seconds
385    self.nanos = nanos
386
387
388def _CheckDurationValid(seconds, nanos):
389  if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX:
390    raise ValueError(
391        'Duration is not valid: Seconds {0} must be in range '
392        '[-315576000000, 315576000000].'.format(seconds))
393  if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND:
394    raise ValueError(
395        'Duration is not valid: Nanos {0} must be in range '
396        '[-999999999, 999999999].'.format(nanos))
397  if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0):
398    raise ValueError(
399        'Duration is not valid: Sign mismatch.')
400
401
402def _RoundTowardZero(value, divider):
403  """Truncates the remainder part after division."""
404  # For some languages, the sign of the remainder is implementation
405  # dependent if any of the operands is negative. Here we enforce
406  # "rounded toward zero" semantics. For example, for (-5) / 2 an
407  # implementation may give -3 as the result with the remainder being
408  # 1. This function ensures we always return -2 (closer to zero).
409  result = value // divider
410  remainder = value % divider
411  if result < 0 and remainder > 0:
412    return result + 1
413  else:
414    return result
415
416
417class FieldMask(object):
418  """Class for FieldMask message type."""
419
420  __slots__ = ()
421
422  def ToJsonString(self):
423    """Converts FieldMask to string according to proto3 JSON spec."""
424    camelcase_paths = []
425    for path in self.paths:
426      camelcase_paths.append(_SnakeCaseToCamelCase(path))
427    return ','.join(camelcase_paths)
428
429  def FromJsonString(self, value):
430    """Converts string to FieldMask according to proto3 JSON spec."""
431    self.Clear()
432    if value:
433      for path in value.split(','):
434        self.paths.append(_CamelCaseToSnakeCase(path))
435
436  def IsValidForDescriptor(self, message_descriptor):
437    """Checks whether the FieldMask is valid for Message Descriptor."""
438    for path in self.paths:
439      if not _IsValidPath(message_descriptor, path):
440        return False
441    return True
442
443  def AllFieldsFromDescriptor(self, message_descriptor):
444    """Gets all direct fields of Message Descriptor to FieldMask."""
445    self.Clear()
446    for field in message_descriptor.fields:
447      self.paths.append(field.name)
448
449  def CanonicalFormFromMask(self, mask):
450    """Converts a FieldMask to the canonical form.
451
452    Removes paths that are covered by another path. For example,
453    "foo.bar" is covered by "foo" and will be removed if "foo"
454    is also in the FieldMask. Then sorts all paths in alphabetical order.
455
456    Args:
457      mask: The original FieldMask to be converted.
458    """
459    tree = _FieldMaskTree(mask)
460    tree.ToFieldMask(self)
461
462  def Union(self, mask1, mask2):
463    """Merges mask1 and mask2 into this FieldMask."""
464    _CheckFieldMaskMessage(mask1)
465    _CheckFieldMaskMessage(mask2)
466    tree = _FieldMaskTree(mask1)
467    tree.MergeFromFieldMask(mask2)
468    tree.ToFieldMask(self)
469
470  def Intersect(self, mask1, mask2):
471    """Intersects mask1 and mask2 into this FieldMask."""
472    _CheckFieldMaskMessage(mask1)
473    _CheckFieldMaskMessage(mask2)
474    tree = _FieldMaskTree(mask1)
475    intersection = _FieldMaskTree()
476    for path in mask2.paths:
477      tree.IntersectPath(path, intersection)
478    intersection.ToFieldMask(self)
479
480  def MergeMessage(
481      self, source, destination,
482      replace_message_field=False, replace_repeated_field=False):
483    """Merges fields specified in FieldMask from source to destination.
484
485    Args:
486      source: Source message.
487      destination: The destination message to be merged into.
488      replace_message_field: Replace message field if True. Merge message
489          field if False.
490      replace_repeated_field: Replace repeated field if True. Append
491          elements of repeated field if False.
492    """
493    tree = _FieldMaskTree(self)
494    tree.MergeMessage(
495        source, destination, replace_message_field, replace_repeated_field)
496
497
498def _IsValidPath(message_descriptor, path):
499  """Checks whether the path is valid for Message Descriptor."""
500  parts = path.split('.')
501  last = parts.pop()
502  for name in parts:
503    field = message_descriptor.fields_by_name.get(name)
504    if (field is None or
505        field.label == FieldDescriptor.LABEL_REPEATED or
506        field.type != FieldDescriptor.TYPE_MESSAGE):
507      return False
508    message_descriptor = field.message_type
509  return last in message_descriptor.fields_by_name
510
511
512def _CheckFieldMaskMessage(message):
513  """Raises ValueError if message is not a FieldMask."""
514  message_descriptor = message.DESCRIPTOR
515  if (message_descriptor.name != 'FieldMask' or
516      message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
517    raise ValueError('Message {0} is not a FieldMask.'.format(
518        message_descriptor.full_name))
519
520
521def _SnakeCaseToCamelCase(path_name):
522  """Converts a path name from snake_case to camelCase."""
523  result = []
524  after_underscore = False
525  for c in path_name:
526    if c.isupper():
527      raise ValueError(
528          'Fail to print FieldMask to Json string: Path name '
529          '{0} must not contain uppercase letters.'.format(path_name))
530    if after_underscore:
531      if c.islower():
532        result.append(c.upper())
533        after_underscore = False
534      else:
535        raise ValueError(
536            'Fail to print FieldMask to Json string: The '
537            'character after a "_" must be a lowercase letter '
538            'in path name {0}.'.format(path_name))
539    elif c == '_':
540      after_underscore = True
541    else:
542      result += c
543
544  if after_underscore:
545    raise ValueError('Fail to print FieldMask to Json string: Trailing "_" '
546                     'in path name {0}.'.format(path_name))
547  return ''.join(result)
548
549
550def _CamelCaseToSnakeCase(path_name):
551  """Converts a field name from camelCase to snake_case."""
552  result = []
553  for c in path_name:
554    if c == '_':
555      raise ValueError('Fail to parse FieldMask: Path name '
556                       '{0} must not contain "_"s.'.format(path_name))
557    if c.isupper():
558      result += '_'
559      result += c.lower()
560    else:
561      result += c
562  return ''.join(result)
563
564
565class _FieldMaskTree(object):
566  """Represents a FieldMask in a tree structure.
567
568  For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
569  the FieldMaskTree will be:
570      [_root] -+- foo -+- bar
571            |       |
572            |       +- baz
573            |
574            +- bar --- baz
575  In the tree, each leaf node represents a field path.
576  """
577
578  __slots__ = ('_root',)
579
580  def __init__(self, field_mask=None):
581    """Initializes the tree by FieldMask."""
582    self._root = {}
583    if field_mask:
584      self.MergeFromFieldMask(field_mask)
585
586  def MergeFromFieldMask(self, field_mask):
587    """Merges a FieldMask to the tree."""
588    for path in field_mask.paths:
589      self.AddPath(path)
590
591  def AddPath(self, path):
592    """Adds a field path into the tree.
593
594    If the field path to add is a sub-path of an existing field path
595    in the tree (i.e., a leaf node), it means the tree already matches
596    the given path so nothing will be added to the tree. If the path
597    matches an existing non-leaf node in the tree, that non-leaf node
598    will be turned into a leaf node with all its children removed because
599    the path matches all the node's children. Otherwise, a new path will
600    be added.
601
602    Args:
603      path: The field path to add.
604    """
605    node = self._root
606    for name in path.split('.'):
607      if name not in node:
608        node[name] = {}
609      elif not node[name]:
610        # Pre-existing empty node implies we already have this entire tree.
611        return
612      node = node[name]
613    # Remove any sub-trees we might have had.
614    node.clear()
615
616  def ToFieldMask(self, field_mask):
617    """Converts the tree to a FieldMask."""
618    field_mask.Clear()
619    _AddFieldPaths(self._root, '', field_mask)
620
621  def IntersectPath(self, path, intersection):
622    """Calculates the intersection part of a field path with this tree.
623
624    Args:
625      path: The field path to calculates.
626      intersection: The out tree to record the intersection part.
627    """
628    node = self._root
629    for name in path.split('.'):
630      if name not in node:
631        return
632      elif not node[name]:
633        intersection.AddPath(path)
634        return
635      node = node[name]
636    intersection.AddLeafNodes(path, node)
637
638  def AddLeafNodes(self, prefix, node):
639    """Adds leaf nodes begin with prefix to this tree."""
640    if not node:
641      self.AddPath(prefix)
642    for name in node:
643      child_path = prefix + '.' + name
644      self.AddLeafNodes(child_path, node[name])
645
646  def MergeMessage(
647      self, source, destination,
648      replace_message, replace_repeated):
649    """Merge all fields specified by this tree from source to destination."""
650    _MergeMessage(
651        self._root, source, destination, replace_message, replace_repeated)
652
653
654def _StrConvert(value):
655  """Converts value to str if it is not."""
656  # This file is imported by c extension and some methods like ClearField
657  # requires string for the field name. py2/py3 has different text
658  # type and may use unicode.
659  if not isinstance(value, str):
660    return value.encode('utf-8')
661  return value
662
663
664def _MergeMessage(
665    node, source, destination, replace_message, replace_repeated):
666  """Merge all fields specified by a sub-tree from source to destination."""
667  source_descriptor = source.DESCRIPTOR
668  for name in node:
669    child = node[name]
670    field = source_descriptor.fields_by_name[name]
671    if field is None:
672      raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
673          name, source_descriptor.full_name))
674    if child:
675      # Sub-paths are only allowed for singular message fields.
676      if (field.label == FieldDescriptor.LABEL_REPEATED or
677          field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
678        raise ValueError('Error: Field {0} in message {1} is not a singular '
679                         'message field and cannot have sub-fields.'.format(
680                             name, source_descriptor.full_name))
681      if source.HasField(name):
682        _MergeMessage(
683            child, getattr(source, name), getattr(destination, name),
684            replace_message, replace_repeated)
685      continue
686    if field.label == FieldDescriptor.LABEL_REPEATED:
687      if replace_repeated:
688        destination.ClearField(_StrConvert(name))
689      repeated_source = getattr(source, name)
690      repeated_destination = getattr(destination, name)
691      repeated_destination.MergeFrom(repeated_source)
692    else:
693      if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
694        if replace_message:
695          destination.ClearField(_StrConvert(name))
696        if source.HasField(name):
697          getattr(destination, name).MergeFrom(getattr(source, name))
698      else:
699        setattr(destination, name, getattr(source, name))
700
701
702def _AddFieldPaths(node, prefix, field_mask):
703  """Adds the field paths descended from node to field_mask."""
704  if not node and prefix:
705    field_mask.paths.append(prefix)
706    return
707  for name in sorted(node):
708    if prefix:
709      child_path = prefix + '.' + name
710    else:
711      child_path = name
712    _AddFieldPaths(node[name], child_path, field_mask)
713
714
715_INT_OR_FLOAT = six.integer_types + (float,)
716
717
718def _SetStructValue(struct_value, value):
719  if value is None:
720    struct_value.null_value = 0
721  elif isinstance(value, bool):
722    # Note: this check must come before the number check because in Python
723    # True and False are also considered numbers.
724    struct_value.bool_value = value
725  elif isinstance(value, six.string_types):
726    struct_value.string_value = value
727  elif isinstance(value, _INT_OR_FLOAT):
728    struct_value.number_value = value
729  elif isinstance(value, (dict, Struct)):
730    struct_value.struct_value.Clear()
731    struct_value.struct_value.update(value)
732  elif isinstance(value, (list, ListValue)):
733    struct_value.list_value.Clear()
734    struct_value.list_value.extend(value)
735  else:
736    raise ValueError('Unexpected type')
737
738
739def _GetStructValue(struct_value):
740  which = struct_value.WhichOneof('kind')
741  if which == 'struct_value':
742    return struct_value.struct_value
743  elif which == 'null_value':
744    return None
745  elif which == 'number_value':
746    return struct_value.number_value
747  elif which == 'string_value':
748    return struct_value.string_value
749  elif which == 'bool_value':
750    return struct_value.bool_value
751  elif which == 'list_value':
752    return struct_value.list_value
753  elif which is None:
754    raise ValueError('Value not set')
755
756
757class Struct(object):
758  """Class for Struct message type."""
759
760  __slots__ = ()
761
762  def __getitem__(self, key):
763    return _GetStructValue(self.fields[key])
764
765  def __contains__(self, item):
766    return item in self.fields
767
768  def __setitem__(self, key, value):
769    _SetStructValue(self.fields[key], value)
770
771  def __delitem__(self, key):
772    del self.fields[key]
773
774  def __len__(self):
775    return len(self.fields)
776
777  def __iter__(self):
778    return iter(self.fields)
779
780  def keys(self):  # pylint: disable=invalid-name
781    return self.fields.keys()
782
783  def values(self):  # pylint: disable=invalid-name
784    return [self[key] for key in self]
785
786  def items(self):  # pylint: disable=invalid-name
787    return [(key, self[key]) for key in self]
788
789  def get_or_create_list(self, key):
790    """Returns a list for this key, creating if it didn't exist already."""
791    if not self.fields[key].HasField('list_value'):
792      # Clear will mark list_value modified which will indeed create a list.
793      self.fields[key].list_value.Clear()
794    return self.fields[key].list_value
795
796  def get_or_create_struct(self, key):
797    """Returns a struct for this key, creating if it didn't exist already."""
798    if not self.fields[key].HasField('struct_value'):
799      # Clear will mark struct_value modified which will indeed create a struct.
800      self.fields[key].struct_value.Clear()
801    return self.fields[key].struct_value
802
803  def update(self, dictionary):  # pylint: disable=invalid-name
804    for key, value in dictionary.items():
805      _SetStructValue(self.fields[key], value)
806
807collections_abc.MutableMapping.register(Struct)
808
809
810class ListValue(object):
811  """Class for ListValue message type."""
812
813  __slots__ = ()
814
815  def __len__(self):
816    return len(self.values)
817
818  def append(self, value):
819    _SetStructValue(self.values.add(), value)
820
821  def extend(self, elem_seq):
822    for value in elem_seq:
823      self.append(value)
824
825  def __getitem__(self, index):
826    """Retrieves item by the specified index."""
827    return _GetStructValue(self.values.__getitem__(index))
828
829  def __setitem__(self, index, value):
830    _SetStructValue(self.values.__getitem__(index), value)
831
832  def __delitem__(self, key):
833    del self.values[key]
834
835  def items(self):
836    for i in range(len(self)):
837      yield self[i]
838
839  def add_struct(self):
840    """Appends and returns a struct value as the next value in the list."""
841    struct_value = self.values.add().struct_value
842    # Clear will mark struct_value modified which will indeed create a struct.
843    struct_value.Clear()
844    return struct_value
845
846  def add_list(self):
847    """Appends and returns a list value as the next value in the list."""
848    list_value = self.values.add().list_value
849    # Clear will mark list_value modified which will indeed create a list.
850    list_value.Clear()
851    return list_value
852
853collections_abc.MutableSequence.register(ListValue)
854
855
856WKTBASES = {
857    'google.protobuf.Any': Any,
858    'google.protobuf.Duration': Duration,
859    'google.protobuf.FieldMask': FieldMask,
860    'google.protobuf.ListValue': ListValue,
861    'google.protobuf.Struct': Struct,
862    'google.protobuf.Timestamp': Timestamp,
863}
864