1#! /usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4# Protocol Buffers - Google's data interchange format
5# Copyright 2008 Google Inc.  All rights reserved.
6# https://developers.google.com/protocol-buffers/
7#
8# Redistribution and use in source and binary forms, with or without
9# modification, are permitted provided that the following conditions are
10# met:
11#
12#     * Redistributions of source code must retain the above copyright
13# notice, this list of conditions and the following disclaimer.
14#     * Redistributions in binary form must reproduce the above
15# copyright notice, this list of conditions and the following disclaimer
16# in the documentation and/or other materials provided with the
17# distribution.
18#     * Neither the name of Google Inc. nor the names of its
19# contributors may be used to endorse or promote products derived from
20# this software without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34"""Unittest for reflection.py, which also indirectly tests the output of the
35pure-Python protocol compiler.
36"""
37
38import copy
39import gc
40import operator
41import six
42import struct
43import warnings
44
45try:
46  import unittest2 as unittest  #PY26
47except ImportError:
48  import unittest
49
50from google.protobuf import unittest_import_pb2
51from google.protobuf import unittest_mset_pb2
52from google.protobuf import unittest_pb2
53from google.protobuf import unittest_proto3_arena_pb2
54from google.protobuf import descriptor_pb2
55from google.protobuf import descriptor
56from google.protobuf import message
57from google.protobuf import reflection
58from google.protobuf import text_format
59from google.protobuf.internal import api_implementation
60from google.protobuf.internal import more_extensions_pb2
61from google.protobuf.internal import more_messages_pb2
62from google.protobuf.internal import message_set_extensions_pb2
63from google.protobuf.internal import wire_format
64from google.protobuf.internal import test_util
65from google.protobuf.internal import testing_refleaks
66from google.protobuf.internal import decoder
67from google.protobuf.internal import _parameterized
68
69
70if six.PY3:
71  long = int  # pylint: disable=redefined-builtin,invalid-name
72
73
74warnings.simplefilter('error', DeprecationWarning)
75
76
77class _MiniDecoder(object):
78  """Decodes a stream of values from a string.
79
80  Once upon a time we actually had a class called decoder.Decoder.  Then we
81  got rid of it during a redesign that made decoding much, much faster overall.
82  But a couple tests in this file used it to check that the serialized form of
83  a message was correct.  So, this class implements just the methods that were
84  used by said tests, so that we don't have to rewrite the tests.
85  """
86
87  def __init__(self, bytes):
88    self._bytes = bytes
89    self._pos = 0
90
91  def ReadVarint(self):
92    result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
93    return result
94
95  ReadInt32 = ReadVarint
96  ReadInt64 = ReadVarint
97  ReadUInt32 = ReadVarint
98  ReadUInt64 = ReadVarint
99
100  def ReadSInt64(self):
101    return wire_format.ZigZagDecode(self.ReadVarint())
102
103  ReadSInt32 = ReadSInt64
104
105  def ReadFieldNumberAndWireType(self):
106    return wire_format.UnpackTag(self.ReadVarint())
107
108  def ReadFloat(self):
109    result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0]
110    self._pos += 4
111    return result
112
113  def ReadDouble(self):
114    result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0]
115    self._pos += 8
116    return result
117
118  def EndOfStream(self):
119    return self._pos == len(self._bytes)
120
121
122@_parameterized.named_parameters(
123    ('_proto2', unittest_pb2),
124    ('_proto3', unittest_proto3_arena_pb2))
125@testing_refleaks.TestCase
126class ReflectionTest(unittest.TestCase):
127
128  def assertListsEqual(self, values, others):
129    self.assertEqual(len(values), len(others))
130    for i in range(len(values)):
131      self.assertEqual(values[i], others[i])
132
133  def testScalarConstructor(self, message_module):
134    # Constructor with only scalar types should succeed.
135    proto = message_module.TestAllTypes(
136        optional_int32=24,
137        optional_double=54.321,
138        optional_string='optional_string',
139        optional_float=None)
140
141    self.assertEqual(24, proto.optional_int32)
142    self.assertEqual(54.321, proto.optional_double)
143    self.assertEqual('optional_string', proto.optional_string)
144    if message_module is unittest_pb2:
145      self.assertFalse(proto.HasField("optional_float"))
146
147  def testRepeatedScalarConstructor(self, message_module):
148    # Constructor with only repeated scalar types should succeed.
149    proto = message_module.TestAllTypes(
150        repeated_int32=[1, 2, 3, 4],
151        repeated_double=[1.23, 54.321],
152        repeated_bool=[True, False, False],
153        repeated_string=["optional_string"],
154        repeated_float=None)
155
156    self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32))
157    self.assertEqual([1.23, 54.321], list(proto.repeated_double))
158    self.assertEqual([True, False, False], list(proto.repeated_bool))
159    self.assertEqual(["optional_string"], list(proto.repeated_string))
160    self.assertEqual([], list(proto.repeated_float))
161
162  def testMixedConstructor(self, message_module):
163    # Constructor with only mixed types should succeed.
164    proto = message_module.TestAllTypes(
165        optional_int32=24,
166        optional_string='optional_string',
167        repeated_double=[1.23, 54.321],
168        repeated_bool=[True, False, False],
169        repeated_nested_message=[
170            message_module.TestAllTypes.NestedMessage(
171                bb=message_module.TestAllTypes.FOO),
172            message_module.TestAllTypes.NestedMessage(
173                bb=message_module.TestAllTypes.BAR)],
174        repeated_foreign_message=[
175            message_module.ForeignMessage(c=-43),
176            message_module.ForeignMessage(c=45324),
177            message_module.ForeignMessage(c=12)],
178        optional_nested_message=None)
179
180    self.assertEqual(24, proto.optional_int32)
181    self.assertEqual('optional_string', proto.optional_string)
182    self.assertEqual([1.23, 54.321], list(proto.repeated_double))
183    self.assertEqual([True, False, False], list(proto.repeated_bool))
184    self.assertEqual(
185        [message_module.TestAllTypes.NestedMessage(
186            bb=message_module.TestAllTypes.FOO),
187         message_module.TestAllTypes.NestedMessage(
188             bb=message_module.TestAllTypes.BAR)],
189        list(proto.repeated_nested_message))
190    self.assertEqual(
191        [message_module.ForeignMessage(c=-43),
192         message_module.ForeignMessage(c=45324),
193         message_module.ForeignMessage(c=12)],
194        list(proto.repeated_foreign_message))
195    self.assertFalse(proto.HasField("optional_nested_message"))
196
197  def testConstructorTypeError(self, message_module):
198    self.assertRaises(
199        TypeError, message_module.TestAllTypes, optional_int32='foo')
200    self.assertRaises(
201        TypeError, message_module.TestAllTypes, optional_string=1234)
202    self.assertRaises(
203        TypeError, message_module.TestAllTypes, optional_nested_message=1234)
204    self.assertRaises(
205        TypeError, message_module.TestAllTypes, repeated_int32=1234)
206    self.assertRaises(
207        TypeError, message_module.TestAllTypes, repeated_int32=['foo'])
208    self.assertRaises(
209        TypeError, message_module.TestAllTypes, repeated_string=1234)
210    self.assertRaises(
211        TypeError, message_module.TestAllTypes, repeated_string=[1234])
212    self.assertRaises(
213        TypeError, message_module.TestAllTypes, repeated_nested_message=1234)
214    self.assertRaises(
215        TypeError, message_module.TestAllTypes, repeated_nested_message=[1234])
216
217  def testConstructorInvalidatesCachedByteSize(self, message_module):
218    message = message_module.TestAllTypes(optional_int32=12)
219    self.assertEqual(2, message.ByteSize())
220
221    message = message_module.TestAllTypes(
222        optional_nested_message=message_module.TestAllTypes.NestedMessage())
223    self.assertEqual(3, message.ByteSize())
224
225    message = message_module.TestAllTypes(repeated_int32=[12])
226    # TODO(user): Add this test back for proto3
227    if message_module is unittest_pb2:
228      self.assertEqual(3, message.ByteSize())
229
230    message = message_module.TestAllTypes(
231        repeated_nested_message=[message_module.TestAllTypes.NestedMessage()])
232    self.assertEqual(3, message.ByteSize())
233
234  def testReferencesToNestedMessage(self, message_module):
235    proto = message_module.TestAllTypes()
236    nested = proto.optional_nested_message
237    del proto
238    # A previous version had a bug where this would raise an exception when
239    # hitting a now-dead weak reference.
240    nested.bb = 23
241
242  def testOneOf(self, message_module):
243    proto = message_module.TestAllTypes()
244    proto.oneof_uint32 = 10
245    proto.oneof_nested_message.bb = 11
246    self.assertEqual(11, proto.oneof_nested_message.bb)
247    self.assertFalse(proto.HasField('oneof_uint32'))
248    nested = proto.oneof_nested_message
249    proto.oneof_string = 'abc'
250    self.assertEqual('abc', proto.oneof_string)
251    self.assertEqual(11, nested.bb)
252    self.assertFalse(proto.HasField('oneof_nested_message'))
253
254  def testGetDefaultMessageAfterDisconnectingDefaultMessage(
255      self, message_module):
256    proto = message_module.TestAllTypes()
257    nested = proto.optional_nested_message
258    proto.ClearField('optional_nested_message')
259    del proto
260    del nested
261    # Force a garbage collect so that the underlying CMessages are freed along
262    # with the Messages they point to. This is to make sure we're not deleting
263    # default message instances.
264    gc.collect()
265    proto = message_module.TestAllTypes()
266    nested = proto.optional_nested_message
267
268  def testDisconnectingNestedMessageAfterSettingField(self, message_module):
269    proto = message_module.TestAllTypes()
270    nested = proto.optional_nested_message
271    nested.bb = 5
272    self.assertTrue(proto.HasField('optional_nested_message'))
273    proto.ClearField('optional_nested_message')  # Should disconnect from parent
274    self.assertEqual(5, nested.bb)
275    self.assertEqual(0, proto.optional_nested_message.bb)
276    self.assertIsNot(nested, proto.optional_nested_message)
277    nested.bb = 23
278    self.assertFalse(proto.HasField('optional_nested_message'))
279    self.assertEqual(0, proto.optional_nested_message.bb)
280
281  def testDisconnectingNestedMessageBeforeGettingField(self, message_module):
282    proto = message_module.TestAllTypes()
283    self.assertFalse(proto.HasField('optional_nested_message'))
284    proto.ClearField('optional_nested_message')
285    self.assertFalse(proto.HasField('optional_nested_message'))
286
287  def testDisconnectingNestedMessageAfterMerge(self, message_module):
288    # This test exercises the code path that does not use ReleaseMessage().
289    # The underlying fear is that if we use ReleaseMessage() incorrectly,
290    # we will have memory leaks.  It's hard to check that that doesn't happen,
291    # but at least we can exercise that code path to make sure it works.
292    proto1 = message_module.TestAllTypes()
293    proto2 = message_module.TestAllTypes()
294    proto2.optional_nested_message.bb = 5
295    proto1.MergeFrom(proto2)
296    self.assertTrue(proto1.HasField('optional_nested_message'))
297    proto1.ClearField('optional_nested_message')
298    self.assertFalse(proto1.HasField('optional_nested_message'))
299
300  def testDisconnectingLazyNestedMessage(self, message_module):
301    # This test exercises releasing a nested message that is lazy. This test
302    # only exercises real code in the C++ implementation as Python does not
303    # support lazy parsing, but the current C++ implementation results in
304    # memory corruption and a crash.
305    if api_implementation.Type() != 'python':
306      return
307    proto = message_module.TestAllTypes()
308    proto.optional_lazy_message.bb = 5
309    proto.ClearField('optional_lazy_message')
310    del proto
311    gc.collect()
312
313  def testSingularListFields(self, message_module):
314    proto = message_module.TestAllTypes()
315    proto.optional_fixed32 = 1
316    proto.optional_int32 = 5
317    proto.optional_string = 'foo'
318    # Access sub-message but don't set it yet.
319    nested_message = proto.optional_nested_message
320    self.assertEqual(
321      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
322        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
323        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
324      proto.ListFields())
325
326    proto.optional_nested_message.bb = 123
327    self.assertEqual(
328      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
329        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
330        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
331        (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
332             nested_message) ],
333      proto.ListFields())
334
335  def testRepeatedListFields(self, message_module):
336    proto = message_module.TestAllTypes()
337    proto.repeated_fixed32.append(1)
338    proto.repeated_int32.append(5)
339    proto.repeated_int32.append(11)
340    proto.repeated_string.extend(['foo', 'bar'])
341    proto.repeated_string.extend([])
342    proto.repeated_string.append('baz')
343    proto.repeated_string.extend(str(x) for x in range(2))
344    proto.optional_int32 = 21
345    proto.repeated_bool  # Access but don't set anything; should not be listed.
346    self.assertEqual(
347      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 21),
348        (proto.DESCRIPTOR.fields_by_name['repeated_int32'  ], [5, 11]),
349        (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
350        (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
351          ['foo', 'bar', 'baz', '0', '1']) ],
352      proto.ListFields())
353
354  def testClearFieldWithUnknownFieldName(self, message_module):
355    proto = message_module.TestAllTypes()
356    self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
357    self.assertRaises(ValueError, proto.ClearField, b'nonexistent_field')
358
359  def testDisallowedAssignments(self, message_module):
360    # It's illegal to assign values directly to repeated fields
361    # or to nonrepeated composite fields.  Ensure that this fails.
362    proto = message_module.TestAllTypes()
363    # Repeated fields.
364    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
365    # Lists shouldn't work, either.
366    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
367    # Composite fields.
368    self.assertRaises(AttributeError, setattr, proto,
369                      'optional_nested_message', 23)
370    # Assignment to a repeated nested message field without specifying
371    # the index in the array of nested messages.
372    self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
373                      'bb', 34)
374    # Assignment to an attribute of a repeated field.
375    self.assertRaises(AttributeError, setattr, proto.repeated_float,
376                      'some_attribute', 34)
377    # proto.nonexistent_field = 23 should fail as well.
378    self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
379
380  def testSingleScalarTypeSafety(self, message_module):
381    proto = message_module.TestAllTypes()
382    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
383    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
384    self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
385    self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
386    self.assertRaises(TypeError, setattr, proto, 'optional_bool', 'foo')
387    self.assertRaises(TypeError, setattr, proto, 'optional_float', 'foo')
388    self.assertRaises(TypeError, setattr, proto, 'optional_double', 'foo')
389    # TODO(user): Fix type checking difference for python and c extension
390    if api_implementation.Type() == 'python':
391      self.assertRaises(TypeError, setattr, proto, 'optional_bool', 1.1)
392    else:
393      proto.optional_bool = 1.1
394
395  def assertIntegerTypes(self, integer_fn, message_module):
396    """Verifies setting of scalar integers.
397
398    Args:
399      integer_fn: A function to wrap the integers that will be assigned.
400      message_module: unittest_pb2 or unittest_proto3_arena_pb2
401    """
402    def TestGetAndDeserialize(field_name, value, expected_type):
403      proto = message_module.TestAllTypes()
404      value = integer_fn(value)
405      setattr(proto, field_name, value)
406      self.assertIsInstance(getattr(proto, field_name), expected_type)
407      proto2 = message_module.TestAllTypes()
408      proto2.ParseFromString(proto.SerializeToString())
409      self.assertIsInstance(getattr(proto2, field_name), expected_type)
410
411    TestGetAndDeserialize('optional_int32', 1, int)
412    TestGetAndDeserialize('optional_int32', 1 << 30, int)
413    TestGetAndDeserialize('optional_uint32', 1 << 30, int)
414    integer_64 = long
415    if struct.calcsize('L') == 4:
416      # Python only has signed ints, so 32-bit python can't fit an uint32
417      # in an int.
418      TestGetAndDeserialize('optional_uint32', 1 << 31, integer_64)
419    else:
420      # 64-bit python can fit uint32 inside an int
421      TestGetAndDeserialize('optional_uint32', 1 << 31, int)
422    TestGetAndDeserialize('optional_int64', 1 << 30, integer_64)
423    TestGetAndDeserialize('optional_int64', 1 << 60, integer_64)
424    TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
425    TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
426
427  def testIntegerTypes(self, message_module):
428    self.assertIntegerTypes(lambda x: x, message_module)
429
430  def testNonStandardIntegerTypes(self, message_module):
431    self.assertIntegerTypes(test_util.NonStandardInteger, message_module)
432
433  def testIllegalValuesForIntegers(self, message_module):
434    pb = message_module.TestAllTypes()
435
436    # Strings are illegal, even when the represent an integer.
437    with self.assertRaises(TypeError):
438      pb.optional_uint64 = '2'
439
440    # The exact error should propagate with a poorly written custom integer.
441    with self.assertRaisesRegexp(RuntimeError, 'my_error'):
442      pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error')
443
444  def assetIntegerBoundsChecking(self, integer_fn, message_module):
445    """Verifies bounds checking for scalar integer fields.
446
447    Args:
448      integer_fn: A function to wrap the integers that will be assigned.
449      message_module: unittest_pb2 or unittest_proto3_arena_pb2
450    """
451    def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
452      pb = message_module.TestAllTypes()
453      expected_min = integer_fn(expected_min)
454      expected_max = integer_fn(expected_max)
455      setattr(pb, field_name, expected_min)
456      self.assertEqual(expected_min, getattr(pb, field_name))
457      setattr(pb, field_name, expected_max)
458      self.assertEqual(expected_max, getattr(pb, field_name))
459      self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
460                        expected_min - 1)
461      self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
462                        expected_max + 1)
463
464    TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
465    TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
466    TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
467    TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
468    # A bit of white-box testing since -1 is an int and not a long in C++ and
469    # so goes down a different path.
470    pb = message_module.TestAllTypes()
471    with self.assertRaises((ValueError, TypeError)):
472      pb.optional_uint64 = integer_fn(-(1 << 63))
473
474    pb = message_module.TestAllTypes()
475    pb.optional_nested_enum = integer_fn(1)
476    self.assertEqual(1, pb.optional_nested_enum)
477
478  def testSingleScalarBoundsChecking(self, message_module):
479    self.assetIntegerBoundsChecking(lambda x: x, message_module)
480
481  def testNonStandardSingleScalarBoundsChecking(self, message_module):
482    self.assetIntegerBoundsChecking(
483        test_util.NonStandardInteger, message_module)
484
485  def testRepeatedScalarTypeSafety(self, message_module):
486    proto = message_module.TestAllTypes()
487    self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
488    self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
489    self.assertRaises(TypeError, proto.repeated_string, 10)
490    self.assertRaises(TypeError, proto.repeated_bytes, 10)
491
492    proto.repeated_int32.append(10)
493    proto.repeated_int32[0] = 23
494    self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
495    self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
496    self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, [])
497    self.assertRaises(TypeError, proto.repeated_int32.__setitem__,
498                      'index', 23)
499
500    proto.repeated_string.append('2')
501    self.assertRaises(TypeError, proto.repeated_string.__setitem__, 0, 10)
502
503    # Repeated enums tests.
504    #proto.repeated_nested_enum.append(0)
505
506  def testSingleScalarGettersAndSetters(self, message_module):
507    proto = message_module.TestAllTypes()
508    self.assertEqual(0, proto.optional_int32)
509    proto.optional_int32 = 1
510    self.assertEqual(1, proto.optional_int32)
511
512    proto.optional_uint64 = 0xffffffffffff
513    self.assertEqual(0xffffffffffff, proto.optional_uint64)
514    proto.optional_uint64 = 0xffffffffffffffff
515    self.assertEqual(0xffffffffffffffff, proto.optional_uint64)
516    # TODO(user): Test all other scalar field types.
517
518  def testEnums(self, message_module):
519    proto = message_module.TestAllTypes()
520    self.assertEqual(1, proto.FOO)
521    self.assertEqual(1, message_module.TestAllTypes.FOO)
522    self.assertEqual(2, proto.BAR)
523    self.assertEqual(2, message_module.TestAllTypes.BAR)
524    self.assertEqual(3, proto.BAZ)
525    self.assertEqual(3, message_module.TestAllTypes.BAZ)
526
527  def testEnum_Name(self, message_module):
528    self.assertEqual(
529        'FOREIGN_FOO',
530        message_module.ForeignEnum.Name(message_module.FOREIGN_FOO))
531    self.assertEqual(
532        'FOREIGN_BAR',
533        message_module.ForeignEnum.Name(message_module.FOREIGN_BAR))
534    self.assertEqual(
535        'FOREIGN_BAZ',
536        message_module.ForeignEnum.Name(message_module.FOREIGN_BAZ))
537    self.assertRaises(ValueError,
538                      message_module.ForeignEnum.Name, 11312)
539
540    proto = message_module.TestAllTypes()
541    self.assertEqual('FOO',
542                     proto.NestedEnum.Name(proto.FOO))
543    self.assertEqual('FOO',
544                     message_module.TestAllTypes.NestedEnum.Name(proto.FOO))
545    self.assertEqual('BAR',
546                     proto.NestedEnum.Name(proto.BAR))
547    self.assertEqual('BAR',
548                     message_module.TestAllTypes.NestedEnum.Name(proto.BAR))
549    self.assertEqual('BAZ',
550                     proto.NestedEnum.Name(proto.BAZ))
551    self.assertEqual('BAZ',
552                     message_module.TestAllTypes.NestedEnum.Name(proto.BAZ))
553    self.assertRaises(ValueError,
554                      proto.NestedEnum.Name, 11312)
555    self.assertRaises(ValueError,
556                      message_module.TestAllTypes.NestedEnum.Name, 11312)
557
558    # Check some coercion cases.
559    self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name,
560                      11312.0)
561    self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name,
562                      None)
563    self.assertEqual('FOO', message_module.TestAllTypes.NestedEnum.Name(True))
564
565  def testEnum_Value(self, message_module):
566    self.assertEqual(message_module.FOREIGN_FOO,
567                     message_module.ForeignEnum.Value('FOREIGN_FOO'))
568    self.assertEqual(message_module.FOREIGN_FOO,
569                     message_module.ForeignEnum.FOREIGN_FOO)
570
571    self.assertEqual(message_module.FOREIGN_BAR,
572                     message_module.ForeignEnum.Value('FOREIGN_BAR'))
573    self.assertEqual(message_module.FOREIGN_BAR,
574                     message_module.ForeignEnum.FOREIGN_BAR)
575
576    self.assertEqual(message_module.FOREIGN_BAZ,
577                     message_module.ForeignEnum.Value('FOREIGN_BAZ'))
578    self.assertEqual(message_module.FOREIGN_BAZ,
579                     message_module.ForeignEnum.FOREIGN_BAZ)
580
581    self.assertRaises(ValueError,
582                      message_module.ForeignEnum.Value, 'FO')
583    with self.assertRaises(AttributeError):
584      message_module.ForeignEnum.FO
585
586    proto = message_module.TestAllTypes()
587    self.assertEqual(proto.FOO,
588                     proto.NestedEnum.Value('FOO'))
589    self.assertEqual(proto.FOO,
590                     proto.NestedEnum.FOO)
591
592    self.assertEqual(proto.FOO,
593                     message_module.TestAllTypes.NestedEnum.Value('FOO'))
594    self.assertEqual(proto.FOO,
595                     message_module.TestAllTypes.NestedEnum.FOO)
596
597    self.assertEqual(proto.BAR,
598                     proto.NestedEnum.Value('BAR'))
599    self.assertEqual(proto.BAR,
600                     proto.NestedEnum.BAR)
601
602    self.assertEqual(proto.BAR,
603                     message_module.TestAllTypes.NestedEnum.Value('BAR'))
604    self.assertEqual(proto.BAR,
605                     message_module.TestAllTypes.NestedEnum.BAR)
606
607    self.assertEqual(proto.BAZ,
608                     proto.NestedEnum.Value('BAZ'))
609    self.assertEqual(proto.BAZ,
610                     proto.NestedEnum.BAZ)
611
612    self.assertEqual(proto.BAZ,
613                     message_module.TestAllTypes.NestedEnum.Value('BAZ'))
614    self.assertEqual(proto.BAZ,
615                     message_module.TestAllTypes.NestedEnum.BAZ)
616
617    self.assertRaises(ValueError,
618                      proto.NestedEnum.Value, 'Foo')
619    with self.assertRaises(AttributeError):
620      proto.NestedEnum.Value.Foo
621
622    self.assertRaises(ValueError,
623                      message_module.TestAllTypes.NestedEnum.Value, 'Foo')
624    with self.assertRaises(AttributeError):
625      message_module.TestAllTypes.NestedEnum.Value.Foo
626
627  def testEnum_KeysAndValues(self, message_module):
628    if message_module == unittest_pb2:
629      keys = ['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ']
630      values = [4, 5, 6]
631      items = [('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), ('FOREIGN_BAZ', 6)]
632    else:
633      keys = ['FOREIGN_ZERO', 'FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ']
634      values = [0, 4, 5, 6]
635      items = [('FOREIGN_ZERO', 0), ('FOREIGN_FOO', 4),
636               ('FOREIGN_BAR', 5), ('FOREIGN_BAZ', 6)]
637    self.assertEqual(keys,
638                     list(message_module.ForeignEnum.keys()))
639    self.assertEqual(values,
640                     list(message_module.ForeignEnum.values()))
641    self.assertEqual(items,
642                     list(message_module.ForeignEnum.items()))
643
644    proto = message_module.TestAllTypes()
645    if message_module == unittest_pb2:
646      keys = ['FOO', 'BAR', 'BAZ', 'NEG']
647      values = [1, 2, 3, -1]
648      items = [('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)]
649    else:
650      keys = ['ZERO', 'FOO', 'BAR', 'BAZ', 'NEG']
651      values = [0, 1, 2, 3, -1]
652      items = [('ZERO', 0), ('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)]
653    self.assertEqual(keys, list(proto.NestedEnum.keys()))
654    self.assertEqual(values, list(proto.NestedEnum.values()))
655    self.assertEqual(items,
656                     list(proto.NestedEnum.items()))
657
658  def testStaticParseFrom(self, message_module):
659    proto1 = message_module.TestAllTypes()
660    test_util.SetAllFields(proto1)
661
662    string1 = proto1.SerializeToString()
663    proto2 = message_module.TestAllTypes.FromString(string1)
664
665    # Messages should be equal.
666    self.assertEqual(proto2, proto1)
667
668  def testMergeFromSingularField(self, message_module):
669    # Test merge with just a singular field.
670    proto1 = message_module.TestAllTypes()
671    proto1.optional_int32 = 1
672
673    proto2 = message_module.TestAllTypes()
674    # This shouldn't get overwritten.
675    proto2.optional_string = 'value'
676
677    proto2.MergeFrom(proto1)
678    self.assertEqual(1, proto2.optional_int32)
679    self.assertEqual('value', proto2.optional_string)
680
681  def testMergeFromRepeatedField(self, message_module):
682    # Test merge with just a repeated field.
683    proto1 = message_module.TestAllTypes()
684    proto1.repeated_int32.append(1)
685    proto1.repeated_int32.append(2)
686
687    proto2 = message_module.TestAllTypes()
688    proto2.repeated_int32.append(0)
689    proto2.MergeFrom(proto1)
690
691    self.assertEqual(0, proto2.repeated_int32[0])
692    self.assertEqual(1, proto2.repeated_int32[1])
693    self.assertEqual(2, proto2.repeated_int32[2])
694
695  def testMergeFromRepeatedNestedMessage(self, message_module):
696    # Test merge with a repeated nested message.
697    proto1 = message_module.TestAllTypes()
698    m = proto1.repeated_nested_message.add()
699    m.bb = 123
700    m = proto1.repeated_nested_message.add()
701    m.bb = 321
702
703    proto2 = message_module.TestAllTypes()
704    m = proto2.repeated_nested_message.add()
705    m.bb = 999
706    proto2.MergeFrom(proto1)
707    self.assertEqual(999, proto2.repeated_nested_message[0].bb)
708    self.assertEqual(123, proto2.repeated_nested_message[1].bb)
709    self.assertEqual(321, proto2.repeated_nested_message[2].bb)
710
711    proto3 = message_module.TestAllTypes()
712    proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message)
713    self.assertEqual(999, proto3.repeated_nested_message[0].bb)
714    self.assertEqual(123, proto3.repeated_nested_message[1].bb)
715    self.assertEqual(321, proto3.repeated_nested_message[2].bb)
716
717  def testMergeFromAllFields(self, message_module):
718    # With all fields set.
719    proto1 = message_module.TestAllTypes()
720    test_util.SetAllFields(proto1)
721    proto2 = message_module.TestAllTypes()
722    proto2.MergeFrom(proto1)
723
724    # Messages should be equal.
725    self.assertEqual(proto2, proto1)
726
727    # Serialized string should be equal too.
728    string1 = proto1.SerializeToString()
729    string2 = proto2.SerializeToString()
730    self.assertEqual(string1, string2)
731
732  def testMergeFromBug(self, message_module):
733    message1 = message_module.TestAllTypes()
734    message2 = message_module.TestAllTypes()
735
736    # Cause optional_nested_message to be instantiated within message1, even
737    # though it is not considered to be "present".
738    message1.optional_nested_message
739    self.assertFalse(message1.HasField('optional_nested_message'))
740
741    # Merge into message2.  This should not instantiate the field is message2.
742    message2.MergeFrom(message1)
743    self.assertFalse(message2.HasField('optional_nested_message'))
744
745  def testCopyFromSingularField(self, message_module):
746    # Test copy with just a singular field.
747    proto1 = message_module.TestAllTypes()
748    proto1.optional_int32 = 1
749    proto1.optional_string = 'important-text'
750
751    proto2 = message_module.TestAllTypes()
752    proto2.optional_string = 'value'
753
754    proto2.CopyFrom(proto1)
755    self.assertEqual(1, proto2.optional_int32)
756    self.assertEqual('important-text', proto2.optional_string)
757
758  def testCopyFromRepeatedField(self, message_module):
759    # Test copy with a repeated field.
760    proto1 = message_module.TestAllTypes()
761    proto1.repeated_int32.append(1)
762    proto1.repeated_int32.append(2)
763
764    proto2 = message_module.TestAllTypes()
765    proto2.repeated_int32.append(0)
766    proto2.CopyFrom(proto1)
767
768    self.assertEqual(1, proto2.repeated_int32[0])
769    self.assertEqual(2, proto2.repeated_int32[1])
770
771  def testCopyFromAllFields(self, message_module):
772    # With all fields set.
773    proto1 = message_module.TestAllTypes()
774    test_util.SetAllFields(proto1)
775    proto2 = message_module.TestAllTypes()
776    proto2.CopyFrom(proto1)
777
778    # Messages should be equal.
779    self.assertEqual(proto2, proto1)
780
781    # Serialized string should be equal too.
782    string1 = proto1.SerializeToString()
783    string2 = proto2.SerializeToString()
784    self.assertEqual(string1, string2)
785
786  def testCopyFromSelf(self, message_module):
787    proto1 = message_module.TestAllTypes()
788    proto1.repeated_int32.append(1)
789    proto1.optional_int32 = 2
790    proto1.optional_string = 'important-text'
791
792    proto1.CopyFrom(proto1)
793    self.assertEqual(1, proto1.repeated_int32[0])
794    self.assertEqual(2, proto1.optional_int32)
795    self.assertEqual('important-text', proto1.optional_string)
796
797  def testDeepCopy(self, message_module):
798    proto1 = message_module.TestAllTypes()
799    proto1.optional_int32 = 1
800    proto2 = copy.deepcopy(proto1)
801    self.assertEqual(1, proto2.optional_int32)
802
803    proto1.repeated_int32.append(2)
804    proto1.repeated_int32.append(3)
805    container = copy.deepcopy(proto1.repeated_int32)
806    self.assertEqual([2, 3], container)
807    container.remove(container[0])
808    self.assertEqual([3], container)
809
810    message1 = proto1.repeated_nested_message.add()
811    message1.bb = 1
812    messages = copy.deepcopy(proto1.repeated_nested_message)
813    self.assertEqual(proto1.repeated_nested_message, messages)
814    message1.bb = 2
815    self.assertNotEqual(proto1.repeated_nested_message, messages)
816    messages.remove(messages[0])
817    self.assertEqual(len(messages), 0)
818
819    # TODO(user): Implement deepcopy for extension dict
820
821  def testDisconnectingBeforeClear(self, message_module):
822    proto = message_module.TestAllTypes()
823    nested = proto.optional_nested_message
824    proto.Clear()
825    self.assertIsNot(nested, proto.optional_nested_message)
826    nested.bb = 23
827    self.assertFalse(proto.HasField('optional_nested_message'))
828    self.assertEqual(0, proto.optional_nested_message.bb)
829
830    proto = message_module.TestAllTypes()
831    nested = proto.optional_nested_message
832    nested.bb = 5
833    foreign = proto.optional_foreign_message
834    foreign.c = 6
835    proto.Clear()
836    self.assertIsNot(nested, proto.optional_nested_message)
837    self.assertIsNot(foreign, proto.optional_foreign_message)
838    self.assertEqual(5, nested.bb)
839    self.assertEqual(6, foreign.c)
840    nested.bb = 15
841    foreign.c = 16
842    self.assertFalse(proto.HasField('optional_nested_message'))
843    self.assertEqual(0, proto.optional_nested_message.bb)
844    self.assertFalse(proto.HasField('optional_foreign_message'))
845    self.assertEqual(0, proto.optional_foreign_message.c)
846
847  def testStringUTF8Encoding(self, message_module):
848    proto = message_module.TestAllTypes()
849
850    # Assignment of a unicode object to a field of type 'bytes' is not allowed.
851    self.assertRaises(TypeError,
852                      setattr, proto, 'optional_bytes', u'unicode object')
853
854    # Check that the default value is of python's 'unicode' type.
855    self.assertEqual(type(proto.optional_string), six.text_type)
856
857    proto.optional_string = six.text_type('Testing')
858    self.assertEqual(proto.optional_string, str('Testing'))
859
860    # Assign a value of type 'str' which can be encoded in UTF-8.
861    proto.optional_string = str('Testing')
862    self.assertEqual(proto.optional_string, six.text_type('Testing'))
863
864    # Try to assign a 'bytes' object which contains non-UTF-8.
865    self.assertRaises(ValueError,
866                      setattr, proto, 'optional_string', b'a\x80a')
867    # No exception: Assign already encoded UTF-8 bytes to a string field.
868    utf8_bytes = u'Тест'.encode('utf-8')
869    proto.optional_string = utf8_bytes
870    # No exception: Assign the a non-ascii unicode object.
871    proto.optional_string = u'Тест'
872    # No exception thrown (normal str assignment containing ASCII).
873    proto.optional_string = 'abc'
874
875  def testBytesInTextFormat(self, message_module):
876    proto = message_module.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
877    self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n',
878                     six.text_type(proto))
879
880  def testEmptyNestedMessage(self, message_module):
881    proto = message_module.TestAllTypes()
882    proto.optional_nested_message.MergeFrom(
883        message_module.TestAllTypes.NestedMessage())
884    self.assertTrue(proto.HasField('optional_nested_message'))
885
886    proto = message_module.TestAllTypes()
887    proto.optional_nested_message.CopyFrom(
888        message_module.TestAllTypes.NestedMessage())
889    self.assertTrue(proto.HasField('optional_nested_message'))
890
891    proto = message_module.TestAllTypes()
892    bytes_read = proto.optional_nested_message.MergeFromString(b'')
893    self.assertEqual(0, bytes_read)
894    self.assertTrue(proto.HasField('optional_nested_message'))
895
896    proto = message_module.TestAllTypes()
897    proto.optional_nested_message.ParseFromString(b'')
898    self.assertTrue(proto.HasField('optional_nested_message'))
899
900    serialized = proto.SerializeToString()
901    proto2 = message_module.TestAllTypes()
902    self.assertEqual(
903        len(serialized),
904        proto2.MergeFromString(serialized))
905    self.assertTrue(proto2.HasField('optional_nested_message'))
906
907
908# Class to test proto2-only features (required, extensions, etc.)
909@testing_refleaks.TestCase
910class Proto2ReflectionTest(unittest.TestCase):
911
912  def testRepeatedCompositeConstructor(self):
913    # Constructor with only repeated composite types should succeed.
914    proto = unittest_pb2.TestAllTypes(
915        repeated_nested_message=[
916            unittest_pb2.TestAllTypes.NestedMessage(
917                bb=unittest_pb2.TestAllTypes.FOO),
918            unittest_pb2.TestAllTypes.NestedMessage(
919                bb=unittest_pb2.TestAllTypes.BAR)],
920        repeated_foreign_message=[
921            unittest_pb2.ForeignMessage(c=-43),
922            unittest_pb2.ForeignMessage(c=45324),
923            unittest_pb2.ForeignMessage(c=12)],
924        repeatedgroup=[
925            unittest_pb2.TestAllTypes.RepeatedGroup(),
926            unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
927            unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
928
929    self.assertEqual(
930        [unittest_pb2.TestAllTypes.NestedMessage(
931            bb=unittest_pb2.TestAllTypes.FOO),
932         unittest_pb2.TestAllTypes.NestedMessage(
933             bb=unittest_pb2.TestAllTypes.BAR)],
934        list(proto.repeated_nested_message))
935    self.assertEqual(
936        [unittest_pb2.ForeignMessage(c=-43),
937         unittest_pb2.ForeignMessage(c=45324),
938         unittest_pb2.ForeignMessage(c=12)],
939        list(proto.repeated_foreign_message))
940    self.assertEqual(
941        [unittest_pb2.TestAllTypes.RepeatedGroup(),
942         unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
943         unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
944        list(proto.repeatedgroup))
945
946  def assertListsEqual(self, values, others):
947    self.assertEqual(len(values), len(others))
948    for i in range(len(values)):
949      self.assertEqual(values[i], others[i])
950
951  def testSimpleHasBits(self):
952    # Test a scalar.
953    proto = unittest_pb2.TestAllTypes()
954    self.assertFalse(proto.HasField('optional_int32'))
955    self.assertEqual(0, proto.optional_int32)
956    # HasField() shouldn't be true if all we've done is
957    # read the default value.
958    self.assertFalse(proto.HasField('optional_int32'))
959    proto.optional_int32 = 1
960    # Setting a value however *should* set the "has" bit.
961    self.assertTrue(proto.HasField('optional_int32'))
962    proto.ClearField('optional_int32')
963    # And clearing that value should unset the "has" bit.
964    self.assertFalse(proto.HasField('optional_int32'))
965
966  def testHasBitsWithSinglyNestedScalar(self):
967    # Helper used to test foreign messages and groups.
968    #
969    # composite_field_name should be the name of a non-repeated
970    # composite (i.e., foreign or group) field in TestAllTypes,
971    # and scalar_field_name should be the name of an integer-valued
972    # scalar field within that composite.
973    #
974    # I never thought I'd miss C++ macros and templates so much. :(
975    # This helper is semantically just:
976    #
977    #   assert proto.composite_field.scalar_field == 0
978    #   assert not proto.composite_field.HasField('scalar_field')
979    #   assert not proto.HasField('composite_field')
980    #
981    #   proto.composite_field.scalar_field = 10
982    #   old_composite_field = proto.composite_field
983    #
984    #   assert proto.composite_field.scalar_field == 10
985    #   assert proto.composite_field.HasField('scalar_field')
986    #   assert proto.HasField('composite_field')
987    #
988    #   proto.ClearField('composite_field')
989    #
990    #   assert not proto.composite_field.HasField('scalar_field')
991    #   assert not proto.HasField('composite_field')
992    #   assert proto.composite_field.scalar_field == 0
993    #
994    #   # Now ensure that ClearField('composite_field') disconnected
995    #   # the old field object from the object tree...
996    #   assert old_composite_field is not proto.composite_field
997    #   old_composite_field.scalar_field = 20
998    #   assert not proto.composite_field.HasField('scalar_field')
999    #   assert not proto.HasField('composite_field')
1000    def TestCompositeHasBits(composite_field_name, scalar_field_name):
1001      proto = unittest_pb2.TestAllTypes()
1002      # First, check that we can get the scalar value, and see that it's the
1003      # default (0), but that proto.HasField('omposite') and
1004      # proto.composite.HasField('scalar') will still return False.
1005      composite_field = getattr(proto, composite_field_name)
1006      original_scalar_value = getattr(composite_field, scalar_field_name)
1007      self.assertEqual(0, original_scalar_value)
1008      # Assert that the composite object does not "have" the scalar.
1009      self.assertFalse(composite_field.HasField(scalar_field_name))
1010      # Assert that proto does not "have" the composite field.
1011      self.assertFalse(proto.HasField(composite_field_name))
1012
1013      # Now set the scalar within the composite field.  Ensure that the setting
1014      # is reflected, and that proto.HasField('composite') and
1015      # proto.composite.HasField('scalar') now both return True.
1016      new_val = 20
1017      setattr(composite_field, scalar_field_name, new_val)
1018      self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
1019      # Hold on to a reference to the current composite_field object.
1020      old_composite_field = composite_field
1021      # Assert that the has methods now return true.
1022      self.assertTrue(composite_field.HasField(scalar_field_name))
1023      self.assertTrue(proto.HasField(composite_field_name))
1024
1025      # Now call the clear method...
1026      proto.ClearField(composite_field_name)
1027
1028      # ...and ensure that the "has" bits are all back to False...
1029      composite_field = getattr(proto, composite_field_name)
1030      self.assertFalse(composite_field.HasField(scalar_field_name))
1031      self.assertFalse(proto.HasField(composite_field_name))
1032      # ...and ensure that the scalar field has returned to its default.
1033      self.assertEqual(0, getattr(composite_field, scalar_field_name))
1034
1035      self.assertIsNot(old_composite_field, composite_field)
1036      setattr(old_composite_field, scalar_field_name, new_val)
1037      self.assertFalse(composite_field.HasField(scalar_field_name))
1038      self.assertFalse(proto.HasField(composite_field_name))
1039      self.assertEqual(0, getattr(composite_field, scalar_field_name))
1040
1041    # Test simple, single-level nesting when we set a scalar.
1042    TestCompositeHasBits('optionalgroup', 'a')
1043    TestCompositeHasBits('optional_nested_message', 'bb')
1044    TestCompositeHasBits('optional_foreign_message', 'c')
1045    TestCompositeHasBits('optional_import_message', 'd')
1046
1047  def testHasBitsWhenModifyingRepeatedFields(self):
1048    # Test nesting when we add an element to a repeated field in a submessage.
1049    proto = unittest_pb2.TestNestedMessageHasBits()
1050    proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
1051    self.assertEqual(
1052        [5], proto.optional_nested_message.nestedmessage_repeated_int32)
1053    self.assertTrue(proto.HasField('optional_nested_message'))
1054
1055    # Do the same test, but with a repeated composite field within the
1056    # submessage.
1057    proto.ClearField('optional_nested_message')
1058    self.assertFalse(proto.HasField('optional_nested_message'))
1059    proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
1060    self.assertTrue(proto.HasField('optional_nested_message'))
1061
1062  def testHasBitsForManyLevelsOfNesting(self):
1063    # Test nesting many levels deep.
1064    recursive_proto = unittest_pb2.TestMutualRecursionA()
1065    self.assertFalse(recursive_proto.HasField('bb'))
1066    self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
1067    self.assertFalse(recursive_proto.HasField('bb'))
1068    recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
1069    self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
1070    self.assertTrue(recursive_proto.HasField('bb'))
1071    self.assertTrue(recursive_proto.bb.HasField('a'))
1072    self.assertTrue(recursive_proto.bb.a.HasField('bb'))
1073    self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
1074    self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
1075    self.assertFalse(recursive_proto.bb.a.bb.a.bb.HasField('a'))
1076    self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
1077
1078  def testSingularListExtensions(self):
1079    proto = unittest_pb2.TestAllExtensions()
1080    proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
1081    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 5
1082    proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
1083    self.assertEqual(
1084      [ (unittest_pb2.optional_int32_extension  , 5),
1085        (unittest_pb2.optional_fixed32_extension, 1),
1086        (unittest_pb2.optional_string_extension , 'foo') ],
1087      proto.ListFields())
1088    del proto.Extensions[unittest_pb2.optional_fixed32_extension]
1089    self.assertEqual(
1090        [(unittest_pb2.optional_int32_extension, 5),
1091         (unittest_pb2.optional_string_extension, 'foo')],
1092        proto.ListFields())
1093
1094  def testRepeatedListExtensions(self):
1095    proto = unittest_pb2.TestAllExtensions()
1096    proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
1097    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(5)
1098    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(11)
1099    proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
1100    proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
1101    proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
1102    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 21
1103    self.assertEqual(
1104      [ (unittest_pb2.optional_int32_extension  , 21),
1105        (unittest_pb2.repeated_int32_extension  , [5, 11]),
1106        (unittest_pb2.repeated_fixed32_extension, [1]),
1107        (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
1108      proto.ListFields())
1109    del proto.Extensions[unittest_pb2.repeated_int32_extension]
1110    del proto.Extensions[unittest_pb2.repeated_string_extension]
1111    self.assertEqual(
1112        [(unittest_pb2.optional_int32_extension, 21),
1113         (unittest_pb2.repeated_fixed32_extension, [1])],
1114        proto.ListFields())
1115
1116  def testListFieldsAndExtensions(self):
1117    proto = unittest_pb2.TestFieldOrderings()
1118    test_util.SetAllFieldsAndExtensions(proto)
1119    unittest_pb2.my_extension_int
1120    self.assertEqual(
1121      [ (proto.DESCRIPTOR.fields_by_name['my_int'   ], 1),
1122        (unittest_pb2.my_extension_int               , 23),
1123        (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
1124        (unittest_pb2.my_extension_string            , 'bar'),
1125        (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
1126      proto.ListFields())
1127
1128  def testDefaultValues(self):
1129    proto = unittest_pb2.TestAllTypes()
1130    self.assertEqual(0, proto.optional_int32)
1131    self.assertEqual(0, proto.optional_int64)
1132    self.assertEqual(0, proto.optional_uint32)
1133    self.assertEqual(0, proto.optional_uint64)
1134    self.assertEqual(0, proto.optional_sint32)
1135    self.assertEqual(0, proto.optional_sint64)
1136    self.assertEqual(0, proto.optional_fixed32)
1137    self.assertEqual(0, proto.optional_fixed64)
1138    self.assertEqual(0, proto.optional_sfixed32)
1139    self.assertEqual(0, proto.optional_sfixed64)
1140    self.assertEqual(0.0, proto.optional_float)
1141    self.assertEqual(0.0, proto.optional_double)
1142    self.assertEqual(False, proto.optional_bool)
1143    self.assertEqual('', proto.optional_string)
1144    self.assertEqual(b'', proto.optional_bytes)
1145
1146    self.assertEqual(41, proto.default_int32)
1147    self.assertEqual(42, proto.default_int64)
1148    self.assertEqual(43, proto.default_uint32)
1149    self.assertEqual(44, proto.default_uint64)
1150    self.assertEqual(-45, proto.default_sint32)
1151    self.assertEqual(46, proto.default_sint64)
1152    self.assertEqual(47, proto.default_fixed32)
1153    self.assertEqual(48, proto.default_fixed64)
1154    self.assertEqual(49, proto.default_sfixed32)
1155    self.assertEqual(-50, proto.default_sfixed64)
1156    self.assertEqual(51.5, proto.default_float)
1157    self.assertEqual(52e3, proto.default_double)
1158    self.assertEqual(True, proto.default_bool)
1159    self.assertEqual('hello', proto.default_string)
1160    self.assertEqual(b'world', proto.default_bytes)
1161    self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
1162    self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
1163    self.assertEqual(unittest_import_pb2.IMPORT_BAR,
1164                     proto.default_import_enum)
1165
1166    proto = unittest_pb2.TestExtremeDefaultValues()
1167    self.assertEqual(u'\u1234', proto.utf8_string)
1168
1169  def testHasFieldWithUnknownFieldName(self):
1170    proto = unittest_pb2.TestAllTypes()
1171    self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
1172
1173  def testClearRemovesChildren(self):
1174    # Make sure there aren't any implementation bugs that are only partially
1175    # clearing the message (which can happen in the more complex C++
1176    # implementation which has parallel message lists).
1177    proto = unittest_pb2.TestRequiredForeign()
1178    for i in range(10):
1179      proto.repeated_message.add()
1180    proto2 = unittest_pb2.TestRequiredForeign()
1181    proto.CopyFrom(proto2)
1182    self.assertRaises(IndexError, lambda: proto.repeated_message[5])
1183
1184  def testSingleScalarClearField(self):
1185    proto = unittest_pb2.TestAllTypes()
1186    # Should be allowed to clear something that's not there (a no-op).
1187    proto.ClearField('optional_int32')
1188    proto.optional_int32 = 1
1189    self.assertTrue(proto.HasField('optional_int32'))
1190    proto.ClearField('optional_int32')
1191    self.assertEqual(0, proto.optional_int32)
1192    self.assertFalse(proto.HasField('optional_int32'))
1193    # TODO(user): Test all other scalar field types.
1194
1195  def testRepeatedScalars(self):
1196    proto = unittest_pb2.TestAllTypes()
1197
1198    self.assertFalse(proto.repeated_int32)
1199    self.assertEqual(0, len(proto.repeated_int32))
1200    proto.repeated_int32.append(5)
1201    proto.repeated_int32.append(10)
1202    proto.repeated_int32.append(15)
1203    self.assertTrue(proto.repeated_int32)
1204    self.assertEqual(3, len(proto.repeated_int32))
1205
1206    self.assertEqual([5, 10, 15], proto.repeated_int32)
1207
1208    # Test single retrieval.
1209    self.assertEqual(5, proto.repeated_int32[0])
1210    self.assertEqual(15, proto.repeated_int32[-1])
1211    # Test out-of-bounds indices.
1212    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
1213    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
1214    # Test incorrect types passed to __getitem__.
1215    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
1216    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
1217
1218    # Test single assignment.
1219    proto.repeated_int32[1] = 20
1220    self.assertEqual([5, 20, 15], proto.repeated_int32)
1221
1222    # Test insertion.
1223    proto.repeated_int32.insert(1, 25)
1224    self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
1225
1226    # Test slice retrieval.
1227    proto.repeated_int32.append(30)
1228    self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
1229    self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
1230
1231    # Test slice assignment with an iterator
1232    proto.repeated_int32[1:4] = (i for i in range(3))
1233    self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
1234
1235    # Test slice assignment.
1236    proto.repeated_int32[1:4] = [35, 40, 45]
1237    self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
1238
1239    # Test that we can use the field as an iterator.
1240    result = []
1241    for i in proto.repeated_int32:
1242      result.append(i)
1243    self.assertEqual([5, 35, 40, 45, 30], result)
1244
1245    # Test single deletion.
1246    del proto.repeated_int32[2]
1247    self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
1248
1249    # Test slice deletion.
1250    del proto.repeated_int32[2:]
1251    self.assertEqual([5, 35], proto.repeated_int32)
1252
1253    # Test extending.
1254    proto.repeated_int32.extend([3, 13])
1255    self.assertEqual([5, 35, 3, 13], proto.repeated_int32)
1256
1257    # Test clearing.
1258    proto.ClearField('repeated_int32')
1259    self.assertFalse(proto.repeated_int32)
1260    self.assertEqual(0, len(proto.repeated_int32))
1261
1262    proto.repeated_int32.append(1)
1263    self.assertEqual(1, proto.repeated_int32[-1])
1264    # Test assignment to a negative index.
1265    proto.repeated_int32[-1] = 2
1266    self.assertEqual(2, proto.repeated_int32[-1])
1267
1268    # Test deletion at negative indices.
1269    proto.repeated_int32[:] = [0, 1, 2, 3]
1270    del proto.repeated_int32[-1]
1271    self.assertEqual([0, 1, 2], proto.repeated_int32)
1272
1273    del proto.repeated_int32[-2]
1274    self.assertEqual([0, 2], proto.repeated_int32)
1275
1276    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3)
1277    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300)
1278
1279    del proto.repeated_int32[-2:-1]
1280    self.assertEqual([2], proto.repeated_int32)
1281
1282    del proto.repeated_int32[100:10000]
1283    self.assertEqual([2], proto.repeated_int32)
1284
1285  def testRepeatedScalarsRemove(self):
1286    proto = unittest_pb2.TestAllTypes()
1287
1288    self.assertFalse(proto.repeated_int32)
1289    self.assertEqual(0, len(proto.repeated_int32))
1290    proto.repeated_int32.append(5)
1291    proto.repeated_int32.append(10)
1292    proto.repeated_int32.append(5)
1293    proto.repeated_int32.append(5)
1294
1295    self.assertEqual(4, len(proto.repeated_int32))
1296    proto.repeated_int32.remove(5)
1297    self.assertEqual(3, len(proto.repeated_int32))
1298    self.assertEqual(10, proto.repeated_int32[0])
1299    self.assertEqual(5, proto.repeated_int32[1])
1300    self.assertEqual(5, proto.repeated_int32[2])
1301
1302    proto.repeated_int32.remove(5)
1303    self.assertEqual(2, len(proto.repeated_int32))
1304    self.assertEqual(10, proto.repeated_int32[0])
1305    self.assertEqual(5, proto.repeated_int32[1])
1306
1307    proto.repeated_int32.remove(10)
1308    self.assertEqual(1, len(proto.repeated_int32))
1309    self.assertEqual(5, proto.repeated_int32[0])
1310
1311    # Remove a non-existent element.
1312    self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
1313
1314  def testRepeatedComposites(self):
1315    proto = unittest_pb2.TestAllTypes()
1316    self.assertFalse(proto.repeated_nested_message)
1317    self.assertEqual(0, len(proto.repeated_nested_message))
1318    m0 = proto.repeated_nested_message.add()
1319    m1 = proto.repeated_nested_message.add()
1320    self.assertTrue(proto.repeated_nested_message)
1321    self.assertEqual(2, len(proto.repeated_nested_message))
1322    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
1323    self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage)
1324
1325    # Test out-of-bounds indices.
1326    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
1327                      1234)
1328    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
1329                      -1234)
1330
1331    # Test incorrect types passed to __getitem__.
1332    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
1333                      'foo')
1334    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
1335                      None)
1336
1337    # Test slice retrieval.
1338    m2 = proto.repeated_nested_message.add()
1339    m3 = proto.repeated_nested_message.add()
1340    m4 = proto.repeated_nested_message.add()
1341    self.assertListsEqual(
1342        [m1, m2, m3], proto.repeated_nested_message[1:4])
1343    self.assertListsEqual(
1344        [m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
1345    self.assertListsEqual(
1346        [m0, m1], proto.repeated_nested_message[:2])
1347    self.assertListsEqual(
1348        [m2, m3, m4], proto.repeated_nested_message[2:])
1349    self.assertEqual(
1350        m0, proto.repeated_nested_message[0])
1351    self.assertListsEqual(
1352        [m0], proto.repeated_nested_message[:1])
1353
1354    # Test that we can use the field as an iterator.
1355    result = []
1356    for i in proto.repeated_nested_message:
1357      result.append(i)
1358    self.assertListsEqual([m0, m1, m2, m3, m4], result)
1359
1360    # Test single deletion.
1361    del proto.repeated_nested_message[2]
1362    self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message)
1363
1364    # Test slice deletion.
1365    del proto.repeated_nested_message[2:]
1366    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
1367
1368    # Test extending.
1369    n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1)
1370    n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2)
1371    proto.repeated_nested_message.extend([n1,n2])
1372    self.assertEqual(4, len(proto.repeated_nested_message))
1373    self.assertEqual(n1, proto.repeated_nested_message[2])
1374    self.assertEqual(n2, proto.repeated_nested_message[3])
1375    self.assertRaises(TypeError,
1376                      proto.repeated_nested_message.extend, n1)
1377    self.assertRaises(TypeError,
1378                      proto.repeated_nested_message.extend, [0])
1379    wrong_message_type = unittest_pb2.TestAllTypes()
1380    self.assertRaises(TypeError,
1381                      proto.repeated_nested_message.extend,
1382                      [wrong_message_type])
1383
1384    # Test clearing.
1385    proto.ClearField('repeated_nested_message')
1386    self.assertFalse(proto.repeated_nested_message)
1387    self.assertEqual(0, len(proto.repeated_nested_message))
1388
1389    # Test constructing an element while adding it.
1390    proto.repeated_nested_message.add(bb=23)
1391    self.assertEqual(1, len(proto.repeated_nested_message))
1392    self.assertEqual(23, proto.repeated_nested_message[0].bb)
1393    self.assertRaises(TypeError, proto.repeated_nested_message.add, 23)
1394    with self.assertRaises(Exception):
1395      proto.repeated_nested_message[0] = 23
1396
1397  def testRepeatedCompositeRemove(self):
1398    proto = unittest_pb2.TestAllTypes()
1399
1400    self.assertEqual(0, len(proto.repeated_nested_message))
1401    m0 = proto.repeated_nested_message.add()
1402    # Need to set some differentiating variable so m0 != m1 != m2:
1403    m0.bb = len(proto.repeated_nested_message)
1404    m1 = proto.repeated_nested_message.add()
1405    m1.bb = len(proto.repeated_nested_message)
1406    self.assertTrue(m0 != m1)
1407    m2 = proto.repeated_nested_message.add()
1408    m2.bb = len(proto.repeated_nested_message)
1409    self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
1410
1411    self.assertEqual(3, len(proto.repeated_nested_message))
1412    proto.repeated_nested_message.remove(m0)
1413    self.assertEqual(2, len(proto.repeated_nested_message))
1414    self.assertEqual(m1, proto.repeated_nested_message[0])
1415    self.assertEqual(m2, proto.repeated_nested_message[1])
1416
1417    # Removing m0 again or removing None should raise error
1418    self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0)
1419    self.assertRaises(ValueError, proto.repeated_nested_message.remove, None)
1420    self.assertEqual(2, len(proto.repeated_nested_message))
1421
1422    proto.repeated_nested_message.remove(m2)
1423    self.assertEqual(1, len(proto.repeated_nested_message))
1424    self.assertEqual(m1, proto.repeated_nested_message[0])
1425
1426  def testHandWrittenReflection(self):
1427    # Hand written extensions are only supported by the pure-Python
1428    # implementation of the API.
1429    if api_implementation.Type() != 'python':
1430      return
1431
1432    FieldDescriptor = descriptor.FieldDescriptor
1433    foo_field_descriptor = FieldDescriptor(
1434        name='foo_field', full_name='MyProto.foo_field',
1435        index=0, number=1, type=FieldDescriptor.TYPE_INT64,
1436        cpp_type=FieldDescriptor.CPPTYPE_INT64,
1437        label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
1438        containing_type=None, message_type=None, enum_type=None,
1439        is_extension=False, extension_scope=None,
1440        options=descriptor_pb2.FieldOptions(),
1441        # pylint: disable=protected-access
1442        create_key=descriptor._internal_create_key)
1443    mydescriptor = descriptor.Descriptor(
1444        name='MyProto', full_name='MyProto', filename='ignored',
1445        containing_type=None, nested_types=[], enum_types=[],
1446        fields=[foo_field_descriptor], extensions=[],
1447        options=descriptor_pb2.MessageOptions(),
1448        # pylint: disable=protected-access
1449        create_key=descriptor._internal_create_key)
1450    class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
1451      DESCRIPTOR = mydescriptor
1452    myproto_instance = MyProtoClass()
1453    self.assertEqual(0, myproto_instance.foo_field)
1454    self.assertFalse(myproto_instance.HasField('foo_field'))
1455    myproto_instance.foo_field = 23
1456    self.assertEqual(23, myproto_instance.foo_field)
1457    self.assertTrue(myproto_instance.HasField('foo_field'))
1458
1459  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
1460  def testDescriptorProtoSupport(self):
1461    # Hand written descriptors/reflection are only supported by the pure-Python
1462    # implementation of the API.
1463    if api_implementation.Type() != 'python':
1464      return
1465
1466    def AddDescriptorField(proto, field_name, field_type):
1467      AddDescriptorField.field_index += 1
1468      new_field = proto.field.add()
1469      new_field.name = field_name
1470      new_field.type = field_type
1471      new_field.number = AddDescriptorField.field_index
1472      new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
1473
1474    AddDescriptorField.field_index = 0
1475
1476    desc_proto = descriptor_pb2.DescriptorProto()
1477    desc_proto.name = 'Car'
1478    fdp = descriptor_pb2.FieldDescriptorProto
1479    AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING)
1480    AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64)
1481    AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL)
1482    AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE)
1483    # Add a repeated field
1484    AddDescriptorField.field_index += 1
1485    new_field = desc_proto.field.add()
1486    new_field.name = 'owners'
1487    new_field.type = fdp.TYPE_STRING
1488    new_field.number = AddDescriptorField.field_index
1489    new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
1490
1491    desc = descriptor.MakeDescriptor(desc_proto)
1492    self.assertTrue('name' in desc.fields_by_name)
1493    self.assertTrue('year' in desc.fields_by_name)
1494    self.assertTrue('automatic' in desc.fields_by_name)
1495    self.assertTrue('price' in desc.fields_by_name)
1496    self.assertTrue('owners' in desc.fields_by_name)
1497
1498    class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType,
1499                                        message.Message)):
1500      DESCRIPTOR = desc
1501
1502    prius = CarMessage()
1503    prius.name = 'prius'
1504    prius.year = 2010
1505    prius.automatic = True
1506    prius.price = 25134.75
1507    prius.owners.extend(['bob', 'susan'])
1508
1509    serialized_prius = prius.SerializeToString()
1510    new_prius = reflection.ParseMessage(desc, serialized_prius)
1511    self.assertIsNot(new_prius, prius)
1512    self.assertEqual(prius, new_prius)
1513
1514    # these are unnecessary assuming message equality works as advertised but
1515    # explicitly check to be safe since we're mucking about in metaclass foo
1516    self.assertEqual(prius.name, new_prius.name)
1517    self.assertEqual(prius.year, new_prius.year)
1518    self.assertEqual(prius.automatic, new_prius.automatic)
1519    self.assertEqual(prius.price, new_prius.price)
1520    self.assertEqual(prius.owners, new_prius.owners)
1521
1522  def testExtensionDelete(self):
1523    extendee_proto = more_extensions_pb2.ExtendedMessage()
1524
1525    extension_int32 = more_extensions_pb2.optional_int_extension
1526    extendee_proto.Extensions[extension_int32] = 23
1527
1528    extension_repeated = more_extensions_pb2.repeated_int_extension
1529    extendee_proto.Extensions[extension_repeated].append(11)
1530
1531    extension_msg = more_extensions_pb2.optional_message_extension
1532    extendee_proto.Extensions[extension_msg].foreign_message_int = 56
1533
1534    self.assertEqual(len(extendee_proto.Extensions), 3)
1535    del extendee_proto.Extensions[extension_msg]
1536    self.assertEqual(len(extendee_proto.Extensions), 2)
1537    del extendee_proto.Extensions[extension_repeated]
1538    self.assertEqual(len(extendee_proto.Extensions), 1)
1539    # Delete a none exist extension. It is OK to "del m.Extensions[ext]"
1540    # even if the extension is not present in the message; we don't
1541    # raise KeyError. This is consistent with "m.Extensions[ext]"
1542    # returning a default value even if we did not set anything.
1543    del extendee_proto.Extensions[extension_repeated]
1544    self.assertEqual(len(extendee_proto.Extensions), 1)
1545    del extendee_proto.Extensions[extension_int32]
1546    self.assertEqual(len(extendee_proto.Extensions), 0)
1547
1548  def testExtensionIter(self):
1549    extendee_proto = more_extensions_pb2.ExtendedMessage()
1550
1551    extension_int32 = more_extensions_pb2.optional_int_extension
1552    extendee_proto.Extensions[extension_int32] = 23
1553
1554    extension_repeated = more_extensions_pb2.repeated_int_extension
1555    extendee_proto.Extensions[extension_repeated].append(11)
1556
1557    extension_msg = more_extensions_pb2.optional_message_extension
1558    extendee_proto.Extensions[extension_msg].foreign_message_int = 56
1559
1560    # Set some normal fields.
1561    extendee_proto.optional_int32 = 1
1562    extendee_proto.repeated_string.append('hi')
1563
1564    expected = (extension_int32, extension_msg, extension_repeated)
1565    count = 0
1566    for item in extendee_proto.Extensions:
1567      self.assertEqual(item.name, expected[count].name)
1568      self.assertIn(item, extendee_proto.Extensions)
1569      count += 1
1570    self.assertEqual(count, 3)
1571
1572  def testExtensionContainsError(self):
1573    extendee_proto = more_extensions_pb2.ExtendedMessage()
1574    self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, 0)
1575
1576    field = more_extensions_pb2.ExtendedMessage.DESCRIPTOR.fields_by_name[
1577        'optional_int32']
1578    self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, field)
1579
1580  def testTopLevelExtensionsForOptionalScalar(self):
1581    extendee_proto = unittest_pb2.TestAllExtensions()
1582    extension = unittest_pb2.optional_int32_extension
1583    self.assertFalse(extendee_proto.HasExtension(extension))
1584    self.assertNotIn(extension, extendee_proto.Extensions)
1585    self.assertEqual(0, extendee_proto.Extensions[extension])
1586    # As with normal scalar fields, just doing a read doesn't actually set the
1587    # "has" bit.
1588    self.assertFalse(extendee_proto.HasExtension(extension))
1589    self.assertNotIn(extension, extendee_proto.Extensions)
1590    # Actually set the thing.
1591    extendee_proto.Extensions[extension] = 23
1592    self.assertEqual(23, extendee_proto.Extensions[extension])
1593    self.assertTrue(extendee_proto.HasExtension(extension))
1594    self.assertIn(extension, extendee_proto.Extensions)
1595    # Ensure that clearing works as well.
1596    extendee_proto.ClearExtension(extension)
1597    self.assertEqual(0, extendee_proto.Extensions[extension])
1598    self.assertFalse(extendee_proto.HasExtension(extension))
1599    self.assertNotIn(extension, extendee_proto.Extensions)
1600
1601  def testTopLevelExtensionsForRepeatedScalar(self):
1602    extendee_proto = unittest_pb2.TestAllExtensions()
1603    extension = unittest_pb2.repeated_string_extension
1604    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1605    self.assertNotIn(extension, extendee_proto.Extensions)
1606    extendee_proto.Extensions[extension].append('foo')
1607    self.assertEqual(['foo'], extendee_proto.Extensions[extension])
1608    self.assertIn(extension, extendee_proto.Extensions)
1609    string_list = extendee_proto.Extensions[extension]
1610    extendee_proto.ClearExtension(extension)
1611    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1612    self.assertNotIn(extension, extendee_proto.Extensions)
1613    self.assertIsNot(string_list, extendee_proto.Extensions[extension])
1614    # Shouldn't be allowed to do Extensions[extension] = 'a'
1615    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1616                      extension, 'a')
1617
1618  def testTopLevelExtensionsForOptionalMessage(self):
1619    extendee_proto = unittest_pb2.TestAllExtensions()
1620    extension = unittest_pb2.optional_foreign_message_extension
1621    self.assertFalse(extendee_proto.HasExtension(extension))
1622    self.assertNotIn(extension, extendee_proto.Extensions)
1623    self.assertEqual(0, extendee_proto.Extensions[extension].c)
1624    # As with normal (non-extension) fields, merely reading from the
1625    # thing shouldn't set the "has" bit.
1626    self.assertFalse(extendee_proto.HasExtension(extension))
1627    self.assertNotIn(extension, extendee_proto.Extensions)
1628    extendee_proto.Extensions[extension].c = 23
1629    self.assertEqual(23, extendee_proto.Extensions[extension].c)
1630    self.assertTrue(extendee_proto.HasExtension(extension))
1631    self.assertIn(extension, extendee_proto.Extensions)
1632    # Save a reference here.
1633    foreign_message = extendee_proto.Extensions[extension]
1634    extendee_proto.ClearExtension(extension)
1635    self.assertIsNot(foreign_message, extendee_proto.Extensions[extension])
1636    # Setting a field on foreign_message now shouldn't set
1637    # any "has" bits on extendee_proto.
1638    foreign_message.c = 42
1639    self.assertEqual(42, foreign_message.c)
1640    self.assertTrue(foreign_message.HasField('c'))
1641    self.assertFalse(extendee_proto.HasExtension(extension))
1642    self.assertNotIn(extension, extendee_proto.Extensions)
1643    # Shouldn't be allowed to do Extensions[extension] = 'a'
1644    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1645                      extension, 'a')
1646
1647  def testTopLevelExtensionsForRepeatedMessage(self):
1648    extendee_proto = unittest_pb2.TestAllExtensions()
1649    extension = unittest_pb2.repeatedgroup_extension
1650    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1651    group = extendee_proto.Extensions[extension].add()
1652    group.a = 23
1653    self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
1654    group.a = 42
1655    self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
1656    group_list = extendee_proto.Extensions[extension]
1657    extendee_proto.ClearExtension(extension)
1658    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1659    self.assertIsNot(group_list, extendee_proto.Extensions[extension])
1660    # Shouldn't be allowed to do Extensions[extension] = 'a'
1661    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1662                      extension, 'a')
1663
1664  def testNestedExtensions(self):
1665    extendee_proto = unittest_pb2.TestAllExtensions()
1666    extension = unittest_pb2.TestRequired.single
1667
1668    # We just test the non-repeated case.
1669    self.assertFalse(extendee_proto.HasExtension(extension))
1670    self.assertNotIn(extension, extendee_proto.Extensions)
1671    required = extendee_proto.Extensions[extension]
1672    self.assertEqual(0, required.a)
1673    self.assertFalse(extendee_proto.HasExtension(extension))
1674    self.assertNotIn(extension, extendee_proto.Extensions)
1675    required.a = 23
1676    self.assertEqual(23, extendee_proto.Extensions[extension].a)
1677    self.assertTrue(extendee_proto.HasExtension(extension))
1678    self.assertIn(extension, extendee_proto.Extensions)
1679    extendee_proto.ClearExtension(extension)
1680    self.assertIsNot(required, extendee_proto.Extensions[extension])
1681    self.assertFalse(extendee_proto.HasExtension(extension))
1682    self.assertNotIn(extension, extendee_proto.Extensions)
1683
1684  def testRegisteredExtensions(self):
1685    pool = unittest_pb2.DESCRIPTOR.pool
1686    self.assertTrue(
1687        pool.FindExtensionByNumber(
1688            unittest_pb2.TestAllExtensions.DESCRIPTOR, 1))
1689    self.assertIs(
1690        pool.FindExtensionByName(
1691            'protobuf_unittest.optional_int32_extension').containing_type,
1692        unittest_pb2.TestAllExtensions.DESCRIPTOR)
1693    # Make sure extensions haven't been registered into types that shouldn't
1694    # have any.
1695    self.assertEqual(0, len(
1696        pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR)))
1697
1698  # If message A directly contains message B, and
1699  # a.HasField('b') is currently False, then mutating any
1700  # extension in B should change a.HasField('b') to True
1701  # (and so on up the object tree).
1702  def testHasBitsForAncestorsOfExtendedMessage(self):
1703    # Optional scalar extension.
1704    toplevel = more_extensions_pb2.TopLevelMessage()
1705    self.assertFalse(toplevel.HasField('submessage'))
1706    self.assertEqual(0, toplevel.submessage.Extensions[
1707        more_extensions_pb2.optional_int_extension])
1708    self.assertFalse(toplevel.HasField('submessage'))
1709    toplevel.submessage.Extensions[
1710        more_extensions_pb2.optional_int_extension] = 23
1711    self.assertEqual(23, toplevel.submessage.Extensions[
1712        more_extensions_pb2.optional_int_extension])
1713    self.assertTrue(toplevel.HasField('submessage'))
1714
1715    # Repeated scalar extension.
1716    toplevel = more_extensions_pb2.TopLevelMessage()
1717    self.assertFalse(toplevel.HasField('submessage'))
1718    self.assertEqual([], toplevel.submessage.Extensions[
1719        more_extensions_pb2.repeated_int_extension])
1720    self.assertFalse(toplevel.HasField('submessage'))
1721    toplevel.submessage.Extensions[
1722        more_extensions_pb2.repeated_int_extension].append(23)
1723    self.assertEqual([23], toplevel.submessage.Extensions[
1724        more_extensions_pb2.repeated_int_extension])
1725    self.assertTrue(toplevel.HasField('submessage'))
1726
1727    # Optional message extension.
1728    toplevel = more_extensions_pb2.TopLevelMessage()
1729    self.assertFalse(toplevel.HasField('submessage'))
1730    self.assertEqual(0, toplevel.submessage.Extensions[
1731        more_extensions_pb2.optional_message_extension].foreign_message_int)
1732    self.assertFalse(toplevel.HasField('submessage'))
1733    toplevel.submessage.Extensions[
1734        more_extensions_pb2.optional_message_extension].foreign_message_int = 23
1735    self.assertEqual(23, toplevel.submessage.Extensions[
1736        more_extensions_pb2.optional_message_extension].foreign_message_int)
1737    self.assertTrue(toplevel.HasField('submessage'))
1738
1739    # Repeated message extension.
1740    toplevel = more_extensions_pb2.TopLevelMessage()
1741    self.assertFalse(toplevel.HasField('submessage'))
1742    self.assertEqual(0, len(toplevel.submessage.Extensions[
1743        more_extensions_pb2.repeated_message_extension]))
1744    self.assertFalse(toplevel.HasField('submessage'))
1745    foreign = toplevel.submessage.Extensions[
1746        more_extensions_pb2.repeated_message_extension].add()
1747    self.assertEqual(foreign, toplevel.submessage.Extensions[
1748        more_extensions_pb2.repeated_message_extension][0])
1749    self.assertTrue(toplevel.HasField('submessage'))
1750
1751  def testDisconnectionAfterClearingEmptyMessage(self):
1752    toplevel = more_extensions_pb2.TopLevelMessage()
1753    extendee_proto = toplevel.submessage
1754    extension = more_extensions_pb2.optional_message_extension
1755    extension_proto = extendee_proto.Extensions[extension]
1756    extendee_proto.ClearExtension(extension)
1757    extension_proto.foreign_message_int = 23
1758
1759    self.assertIsNot(extension_proto, extendee_proto.Extensions[extension])
1760
1761  def testExtensionFailureModes(self):
1762    extendee_proto = unittest_pb2.TestAllExtensions()
1763
1764    # Try non-extension-handle arguments to HasExtension,
1765    # ClearExtension(), and Extensions[]...
1766    self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
1767    self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
1768    self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
1769    self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
1770
1771    # Try something that *is* an extension handle, just not for
1772    # this message...
1773    for unknown_handle in (more_extensions_pb2.optional_int_extension,
1774                           more_extensions_pb2.optional_message_extension,
1775                           more_extensions_pb2.repeated_int_extension,
1776                           more_extensions_pb2.repeated_message_extension):
1777      self.assertRaises(KeyError, extendee_proto.HasExtension,
1778                        unknown_handle)
1779      self.assertRaises(KeyError, extendee_proto.ClearExtension,
1780                        unknown_handle)
1781      self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
1782                        unknown_handle)
1783      self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
1784                        unknown_handle, 5)
1785
1786    # Try call HasExtension() with a valid handle, but for a
1787    # *repeated* field.  (Just as with non-extension repeated
1788    # fields, Has*() isn't supported for extension repeated fields).
1789    self.assertRaises(KeyError, extendee_proto.HasExtension,
1790                      unittest_pb2.repeated_string_extension)
1791
1792  def testMergeFromOptionalGroup(self):
1793    # Test merge with an optional group.
1794    proto1 = unittest_pb2.TestAllTypes()
1795    proto1.optionalgroup.a = 12
1796    proto2 = unittest_pb2.TestAllTypes()
1797    proto2.MergeFrom(proto1)
1798    self.assertEqual(12, proto2.optionalgroup.a)
1799
1800  def testMergeFromExtensionsSingular(self):
1801    proto1 = unittest_pb2.TestAllExtensions()
1802    proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
1803
1804    proto2 = unittest_pb2.TestAllExtensions()
1805    proto2.MergeFrom(proto1)
1806    self.assertEqual(
1807        1, proto2.Extensions[unittest_pb2.optional_int32_extension])
1808
1809  def testMergeFromExtensionsRepeated(self):
1810    proto1 = unittest_pb2.TestAllExtensions()
1811    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
1812    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
1813
1814    proto2 = unittest_pb2.TestAllExtensions()
1815    proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
1816    proto2.MergeFrom(proto1)
1817    self.assertEqual(
1818        3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
1819    self.assertEqual(
1820        0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
1821    self.assertEqual(
1822        1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
1823    self.assertEqual(
1824        2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
1825
1826  def testMergeFromExtensionsNestedMessage(self):
1827    proto1 = unittest_pb2.TestAllExtensions()
1828    ext1 = proto1.Extensions[
1829        unittest_pb2.repeated_nested_message_extension]
1830    m = ext1.add()
1831    m.bb = 222
1832    m = ext1.add()
1833    m.bb = 333
1834
1835    proto2 = unittest_pb2.TestAllExtensions()
1836    ext2 = proto2.Extensions[
1837        unittest_pb2.repeated_nested_message_extension]
1838    m = ext2.add()
1839    m.bb = 111
1840
1841    proto2.MergeFrom(proto1)
1842    ext2 = proto2.Extensions[
1843        unittest_pb2.repeated_nested_message_extension]
1844    self.assertEqual(3, len(ext2))
1845    self.assertEqual(111, ext2[0].bb)
1846    self.assertEqual(222, ext2[1].bb)
1847    self.assertEqual(333, ext2[2].bb)
1848
1849  def testCopyFromBadType(self):
1850    # The python implementation doesn't raise an exception in this
1851    # case. In theory it should.
1852    if api_implementation.Type() == 'python':
1853      return
1854    proto1 = unittest_pb2.TestAllTypes()
1855    proto2 = unittest_pb2.TestAllExtensions()
1856    self.assertRaises(TypeError, proto1.CopyFrom, proto2)
1857
1858  def testClear(self):
1859    proto = unittest_pb2.TestAllTypes()
1860    # C++ implementation does not support lazy fields right now so leave it
1861    # out for now.
1862    if api_implementation.Type() == 'python':
1863      test_util.SetAllFields(proto)
1864    else:
1865      test_util.SetAllNonLazyFields(proto)
1866    # Clear the message.
1867    proto.Clear()
1868    self.assertEqual(proto.ByteSize(), 0)
1869    empty_proto = unittest_pb2.TestAllTypes()
1870    self.assertEqual(proto, empty_proto)
1871
1872    # Test if extensions which were set are cleared.
1873    proto = unittest_pb2.TestAllExtensions()
1874    test_util.SetAllExtensions(proto)
1875    # Clear the message.
1876    proto.Clear()
1877    self.assertEqual(proto.ByteSize(), 0)
1878    empty_proto = unittest_pb2.TestAllExtensions()
1879    self.assertEqual(proto, empty_proto)
1880
1881  def testDisconnectingInOneof(self):
1882    m = unittest_pb2.TestOneof2()  # This message has two messages in a oneof.
1883    m.foo_message.qux_int = 5
1884    sub_message = m.foo_message
1885    # Accessing another message's field does not clear the first one
1886    self.assertEqual(m.foo_lazy_message.qux_int, 0)
1887    self.assertEqual(m.foo_message.qux_int, 5)
1888    # But mutating another message in the oneof detaches the first one.
1889    m.foo_lazy_message.qux_int = 6
1890    self.assertEqual(m.foo_message.qux_int, 0)
1891    # The reference we got above was detached and is still valid.
1892    self.assertEqual(sub_message.qux_int, 5)
1893    sub_message.qux_int = 7
1894
1895  def assertInitialized(self, proto):
1896    self.assertTrue(proto.IsInitialized())
1897    # Neither method should raise an exception.
1898    proto.SerializeToString()
1899    proto.SerializePartialToString()
1900
1901  def assertNotInitialized(self, proto, error_size=None):
1902    errors = []
1903    self.assertFalse(proto.IsInitialized())
1904    self.assertFalse(proto.IsInitialized(errors))
1905    self.assertEqual(error_size, len(errors))
1906    self.assertRaises(message.EncodeError, proto.SerializeToString)
1907    # "Partial" serialization doesn't care if message is uninitialized.
1908    proto.SerializePartialToString()
1909
1910  def testIsInitialized(self):
1911    # Trivial cases - all optional fields and extensions.
1912    proto = unittest_pb2.TestAllTypes()
1913    self.assertInitialized(proto)
1914    proto = unittest_pb2.TestAllExtensions()
1915    self.assertInitialized(proto)
1916
1917    # The case of uninitialized required fields.
1918    proto = unittest_pb2.TestRequired()
1919    self.assertNotInitialized(proto, 3)
1920    proto.a = proto.b = proto.c = 2
1921    self.assertInitialized(proto)
1922
1923    # The case of uninitialized submessage.
1924    proto = unittest_pb2.TestRequiredForeign()
1925    self.assertInitialized(proto)
1926    proto.optional_message.a = 1
1927    self.assertNotInitialized(proto, 2)
1928    proto.optional_message.b = 0
1929    proto.optional_message.c = 0
1930    self.assertInitialized(proto)
1931
1932    # Uninitialized repeated submessage.
1933    message1 = proto.repeated_message.add()
1934    self.assertNotInitialized(proto, 3)
1935    message1.a = message1.b = message1.c = 0
1936    self.assertInitialized(proto)
1937
1938    # Uninitialized repeated group in an extension.
1939    proto = unittest_pb2.TestAllExtensions()
1940    extension = unittest_pb2.TestRequired.multi
1941    message1 = proto.Extensions[extension].add()
1942    message2 = proto.Extensions[extension].add()
1943    self.assertNotInitialized(proto, 6)
1944    message1.a = 1
1945    message1.b = 1
1946    message1.c = 1
1947    self.assertNotInitialized(proto, 3)
1948    message2.a = 2
1949    message2.b = 2
1950    message2.c = 2
1951    self.assertInitialized(proto)
1952
1953    # Uninitialized nonrepeated message in an extension.
1954    proto = unittest_pb2.TestAllExtensions()
1955    extension = unittest_pb2.TestRequired.single
1956    proto.Extensions[extension].a = 1
1957    self.assertNotInitialized(proto, 2)
1958    proto.Extensions[extension].b = 2
1959    proto.Extensions[extension].c = 3
1960    self.assertInitialized(proto)
1961
1962    # Try passing an errors list.
1963    errors = []
1964    proto = unittest_pb2.TestRequired()
1965    self.assertFalse(proto.IsInitialized(errors))
1966    self.assertEqual(errors, ['a', 'b', 'c'])
1967    self.assertRaises(TypeError, proto.IsInitialized, 1, 2, 3)
1968
1969  @unittest.skipIf(
1970      api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
1971      'Errors are only available from the most recent C++ implementation.')
1972  def testFileDescriptorErrors(self):
1973    file_name = 'test_file_descriptor_errors.proto'
1974    package_name = 'test_file_descriptor_errors.proto'
1975    file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
1976    file_descriptor_proto.name = file_name
1977    file_descriptor_proto.package = package_name
1978    m1 = file_descriptor_proto.message_type.add()
1979    m1.name = 'msg1'
1980    # Compiles the proto into the C++ descriptor pool
1981    descriptor.FileDescriptor(
1982        file_name,
1983        package_name,
1984        serialized_pb=file_descriptor_proto.SerializeToString())
1985    # Add a FileDescriptorProto that has duplicate symbols
1986    another_file_name = 'another_test_file_descriptor_errors.proto'
1987    file_descriptor_proto.name = another_file_name
1988    m2 = file_descriptor_proto.message_type.add()
1989    m2.name = 'msg2'
1990    with self.assertRaises(TypeError) as cm:
1991      descriptor.FileDescriptor(
1992          another_file_name,
1993          package_name,
1994          serialized_pb=file_descriptor_proto.SerializeToString())
1995      self.assertTrue(hasattr(cm, 'exception'), '%s not raised' %
1996                      getattr(cm.expected, '__name__', cm.expected))
1997      self.assertIn('test_file_descriptor_errors.proto', str(cm.exception))
1998      # Error message will say something about this definition being a
1999      # duplicate, though we don't check the message exactly to avoid a
2000      # dependency on the C++ logging code.
2001      self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
2002
2003  def testStringUTF8Serialization(self):
2004    proto = message_set_extensions_pb2.TestMessageSet()
2005    extension_message = message_set_extensions_pb2.TestMessageSetExtension2
2006    extension = extension_message.message_set_extension
2007
2008    test_utf8 = u'Тест'
2009    test_utf8_bytes = test_utf8.encode('utf-8')
2010
2011    # 'Test' in another language, using UTF-8 charset.
2012    proto.Extensions[extension].str = test_utf8
2013
2014    # Serialize using the MessageSet wire format (this is specified in the
2015    # .proto file).
2016    serialized = proto.SerializeToString()
2017
2018    # Check byte size.
2019    self.assertEqual(proto.ByteSize(), len(serialized))
2020
2021    raw = unittest_mset_pb2.RawMessageSet()
2022    bytes_read = raw.MergeFromString(serialized)
2023    self.assertEqual(len(serialized), bytes_read)
2024
2025    message2 = message_set_extensions_pb2.TestMessageSetExtension2()
2026
2027    self.assertEqual(1, len(raw.item))
2028    # Check that the type_id is the same as the tag ID in the .proto file.
2029    self.assertEqual(raw.item[0].type_id, 98418634)
2030
2031    # Check the actual bytes on the wire.
2032    self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes))
2033    bytes_read = message2.MergeFromString(raw.item[0].message)
2034    self.assertEqual(len(raw.item[0].message), bytes_read)
2035
2036    self.assertEqual(type(message2.str), six.text_type)
2037    self.assertEqual(message2.str, test_utf8)
2038
2039    # The pure Python API throws an exception on MergeFromString(),
2040    # if any of the string fields of the message can't be UTF-8 decoded.
2041    # The C++ implementation of the API has no way to check that on
2042    # MergeFromString and thus has no way to throw the exception.
2043    #
2044    # The pure Python API always returns objects of type 'unicode' (UTF-8
2045    # encoded), or 'bytes' (in 7 bit ASCII).
2046    badbytes = raw.item[0].message.replace(
2047        test_utf8_bytes, len(test_utf8_bytes) * b'\xff')
2048
2049    unicode_decode_failed = False
2050    try:
2051      message2.MergeFromString(badbytes)
2052    except UnicodeDecodeError:
2053      unicode_decode_failed = True
2054    string_field = message2.str
2055    self.assertTrue(unicode_decode_failed or type(string_field) is bytes)
2056
2057  def testSetInParent(self):
2058    proto = unittest_pb2.TestAllTypes()
2059    self.assertFalse(proto.HasField('optionalgroup'))
2060    proto.optionalgroup.SetInParent()
2061    self.assertTrue(proto.HasField('optionalgroup'))
2062
2063  def testPackageInitializationImport(self):
2064    """Test that we can import nested messages from their __init__.py.
2065
2066    Such setup is not trivial since at the time of processing of __init__.py one
2067    can't refer to its submodules by name in code, so expressions like
2068    google.protobuf.internal.import_test_package.inner_pb2
2069    don't work. They do work in imports, so we have assign an alias at import
2070    and then use that alias in generated code.
2071    """
2072    # We import here since it's the import that used to fail, and we want
2073    # the failure to have the right context.
2074    # pylint: disable=g-import-not-at-top
2075    from google.protobuf.internal import import_test_package
2076    # pylint: enable=g-import-not-at-top
2077    msg = import_test_package.myproto.Outer()
2078    # Just check the default value.
2079    self.assertEqual(57, msg.inner.value)
2080
2081#  Since we had so many tests for protocol buffer equality, we broke these out
2082#  into separate TestCase classes.
2083
2084
2085@testing_refleaks.TestCase
2086class TestAllTypesEqualityTest(unittest.TestCase):
2087
2088  def setUp(self):
2089    self.first_proto = unittest_pb2.TestAllTypes()
2090    self.second_proto = unittest_pb2.TestAllTypes()
2091
2092  def testNotHashable(self):
2093    self.assertRaises(TypeError, hash, self.first_proto)
2094
2095  def testSelfEquality(self):
2096    self.assertEqual(self.first_proto, self.first_proto)
2097
2098  def testEmptyProtosEqual(self):
2099    self.assertEqual(self.first_proto, self.second_proto)
2100
2101
2102@testing_refleaks.TestCase
2103class FullProtosEqualityTest(unittest.TestCase):
2104
2105  """Equality tests using completely-full protos as a starting point."""
2106
2107  def setUp(self):
2108    self.first_proto = unittest_pb2.TestAllTypes()
2109    self.second_proto = unittest_pb2.TestAllTypes()
2110    test_util.SetAllFields(self.first_proto)
2111    test_util.SetAllFields(self.second_proto)
2112
2113  def testNotHashable(self):
2114    self.assertRaises(TypeError, hash, self.first_proto)
2115
2116  def testNoneNotEqual(self):
2117    self.assertNotEqual(self.first_proto, None)
2118    self.assertNotEqual(None, self.second_proto)
2119
2120  def testNotEqualToOtherMessage(self):
2121    third_proto = unittest_pb2.TestRequired()
2122    self.assertNotEqual(self.first_proto, third_proto)
2123    self.assertNotEqual(third_proto, self.second_proto)
2124
2125  def testAllFieldsFilledEquality(self):
2126    self.assertEqual(self.first_proto, self.second_proto)
2127
2128  def testNonRepeatedScalar(self):
2129    # Nonrepeated scalar field change should cause inequality.
2130    self.first_proto.optional_int32 += 1
2131    self.assertNotEqual(self.first_proto, self.second_proto)
2132    # ...as should clearing a field.
2133    self.first_proto.ClearField('optional_int32')
2134    self.assertNotEqual(self.first_proto, self.second_proto)
2135
2136  def testNonRepeatedComposite(self):
2137    # Change a nonrepeated composite field.
2138    self.first_proto.optional_nested_message.bb += 1
2139    self.assertNotEqual(self.first_proto, self.second_proto)
2140    self.first_proto.optional_nested_message.bb -= 1
2141    self.assertEqual(self.first_proto, self.second_proto)
2142    # Clear a field in the nested message.
2143    self.first_proto.optional_nested_message.ClearField('bb')
2144    self.assertNotEqual(self.first_proto, self.second_proto)
2145    self.first_proto.optional_nested_message.bb = (
2146        self.second_proto.optional_nested_message.bb)
2147    self.assertEqual(self.first_proto, self.second_proto)
2148    # Remove the nested message entirely.
2149    self.first_proto.ClearField('optional_nested_message')
2150    self.assertNotEqual(self.first_proto, self.second_proto)
2151
2152  def testRepeatedScalar(self):
2153    # Change a repeated scalar field.
2154    self.first_proto.repeated_int32.append(5)
2155    self.assertNotEqual(self.first_proto, self.second_proto)
2156    self.first_proto.ClearField('repeated_int32')
2157    self.assertNotEqual(self.first_proto, self.second_proto)
2158
2159  def testRepeatedComposite(self):
2160    # Change value within a repeated composite field.
2161    self.first_proto.repeated_nested_message[0].bb += 1
2162    self.assertNotEqual(self.first_proto, self.second_proto)
2163    self.first_proto.repeated_nested_message[0].bb -= 1
2164    self.assertEqual(self.first_proto, self.second_proto)
2165    # Add a value to a repeated composite field.
2166    self.first_proto.repeated_nested_message.add()
2167    self.assertNotEqual(self.first_proto, self.second_proto)
2168    self.second_proto.repeated_nested_message.add()
2169    self.assertEqual(self.first_proto, self.second_proto)
2170
2171  def testNonRepeatedScalarHasBits(self):
2172    # Ensure that we test "has" bits as well as value for
2173    # nonrepeated scalar field.
2174    self.first_proto.ClearField('optional_int32')
2175    self.second_proto.optional_int32 = 0
2176    self.assertNotEqual(self.first_proto, self.second_proto)
2177
2178  def testNonRepeatedCompositeHasBits(self):
2179    # Ensure that we test "has" bits as well as value for
2180    # nonrepeated composite field.
2181    self.first_proto.ClearField('optional_nested_message')
2182    self.second_proto.optional_nested_message.ClearField('bb')
2183    self.assertNotEqual(self.first_proto, self.second_proto)
2184    self.first_proto.optional_nested_message.bb = 0
2185    self.first_proto.optional_nested_message.ClearField('bb')
2186    self.assertEqual(self.first_proto, self.second_proto)
2187
2188
2189@testing_refleaks.TestCase
2190class ExtensionEqualityTest(unittest.TestCase):
2191
2192  def testExtensionEquality(self):
2193    first_proto = unittest_pb2.TestAllExtensions()
2194    second_proto = unittest_pb2.TestAllExtensions()
2195    self.assertEqual(first_proto, second_proto)
2196    test_util.SetAllExtensions(first_proto)
2197    self.assertNotEqual(first_proto, second_proto)
2198    test_util.SetAllExtensions(second_proto)
2199    self.assertEqual(first_proto, second_proto)
2200
2201    # Ensure that we check value equality.
2202    first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
2203    self.assertNotEqual(first_proto, second_proto)
2204    first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
2205    self.assertEqual(first_proto, second_proto)
2206
2207    # Ensure that we also look at "has" bits.
2208    first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
2209    second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
2210    self.assertNotEqual(first_proto, second_proto)
2211    first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
2212    self.assertEqual(first_proto, second_proto)
2213
2214    # Ensure that differences in cached values
2215    # don't matter if "has" bits are both false.
2216    first_proto = unittest_pb2.TestAllExtensions()
2217    second_proto = unittest_pb2.TestAllExtensions()
2218    self.assertEqual(
2219        0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
2220    self.assertEqual(first_proto, second_proto)
2221
2222
2223@testing_refleaks.TestCase
2224class MutualRecursionEqualityTest(unittest.TestCase):
2225
2226  def testEqualityWithMutualRecursion(self):
2227    first_proto = unittest_pb2.TestMutualRecursionA()
2228    second_proto = unittest_pb2.TestMutualRecursionA()
2229    self.assertEqual(first_proto, second_proto)
2230    first_proto.bb.a.bb.optional_int32 = 23
2231    self.assertNotEqual(first_proto, second_proto)
2232    second_proto.bb.a.bb.optional_int32 = 23
2233    self.assertEqual(first_proto, second_proto)
2234
2235
2236@testing_refleaks.TestCase
2237class ByteSizeTest(unittest.TestCase):
2238
2239  def setUp(self):
2240    self.proto = unittest_pb2.TestAllTypes()
2241    self.extended_proto = more_extensions_pb2.ExtendedMessage()
2242    self.packed_proto = unittest_pb2.TestPackedTypes()
2243    self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
2244
2245  def Size(self):
2246    return self.proto.ByteSize()
2247
2248  def testEmptyMessage(self):
2249    self.assertEqual(0, self.proto.ByteSize())
2250
2251  def testSizedOnKwargs(self):
2252    # Use a separate message to ensure testing right after creation.
2253    proto = unittest_pb2.TestAllTypes()
2254    self.assertEqual(0, proto.ByteSize())
2255    proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1)
2256    # One byte for the tag, one to encode varint 1.
2257    self.assertEqual(2, proto_kwargs.ByteSize())
2258
2259  def testVarints(self):
2260    def Test(i, expected_varint_size):
2261      self.proto.Clear()
2262      self.proto.optional_int64 = i
2263      # Add one to the varint size for the tag info
2264      # for tag 1.
2265      self.assertEqual(expected_varint_size + 1, self.Size())
2266    Test(0, 1)
2267    Test(1, 1)
2268    for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
2269      Test((1 << i) - 1, num_bytes)
2270    Test(-1, 10)
2271    Test(-2, 10)
2272    Test(-(1 << 63), 10)
2273
2274  def testStrings(self):
2275    self.proto.optional_string = ''
2276    # Need one byte for tag info (tag #14), and one byte for length.
2277    self.assertEqual(2, self.Size())
2278
2279    self.proto.optional_string = 'abc'
2280    # Need one byte for tag info (tag #14), and one byte for length.
2281    self.assertEqual(2 + len(self.proto.optional_string), self.Size())
2282
2283    self.proto.optional_string = 'x' * 128
2284    # Need one byte for tag info (tag #14), and TWO bytes for length.
2285    self.assertEqual(3 + len(self.proto.optional_string), self.Size())
2286
2287  def testOtherNumerics(self):
2288    self.proto.optional_fixed32 = 1234
2289    # One byte for tag and 4 bytes for fixed32.
2290    self.assertEqual(5, self.Size())
2291    self.proto = unittest_pb2.TestAllTypes()
2292
2293    self.proto.optional_fixed64 = 1234
2294    # One byte for tag and 8 bytes for fixed64.
2295    self.assertEqual(9, self.Size())
2296    self.proto = unittest_pb2.TestAllTypes()
2297
2298    self.proto.optional_float = 1.234
2299    # One byte for tag and 4 bytes for float.
2300    self.assertEqual(5, self.Size())
2301    self.proto = unittest_pb2.TestAllTypes()
2302
2303    self.proto.optional_double = 1.234
2304    # One byte for tag and 8 bytes for float.
2305    self.assertEqual(9, self.Size())
2306    self.proto = unittest_pb2.TestAllTypes()
2307
2308    self.proto.optional_sint32 = 64
2309    # One byte for tag and 2 bytes for zig-zag-encoded 64.
2310    self.assertEqual(3, self.Size())
2311    self.proto = unittest_pb2.TestAllTypes()
2312
2313  def testComposites(self):
2314    # 3 bytes.
2315    self.proto.optional_nested_message.bb = (1 << 14)
2316    # Plus one byte for bb tag.
2317    # Plus 1 byte for optional_nested_message serialized size.
2318    # Plus two bytes for optional_nested_message tag.
2319    self.assertEqual(3 + 1 + 1 + 2, self.Size())
2320
2321  def testGroups(self):
2322    # 4 bytes.
2323    self.proto.optionalgroup.a = (1 << 21)
2324    # Plus two bytes for |a| tag.
2325    # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
2326    self.assertEqual(4 + 2 + 2*2, self.Size())
2327
2328  def testRepeatedScalars(self):
2329    self.proto.repeated_int32.append(10)  # 1 byte.
2330    self.proto.repeated_int32.append(128)  # 2 bytes.
2331    # Also need 2 bytes for each entry for tag.
2332    self.assertEqual(1 + 2 + 2*2, self.Size())
2333
2334  def testRepeatedScalarsExtend(self):
2335    self.proto.repeated_int32.extend([10, 128])  # 3 bytes.
2336    # Also need 2 bytes for each entry for tag.
2337    self.assertEqual(1 + 2 + 2*2, self.Size())
2338
2339  def testRepeatedScalarsRemove(self):
2340    self.proto.repeated_int32.append(10)  # 1 byte.
2341    self.proto.repeated_int32.append(128)  # 2 bytes.
2342    # Also need 2 bytes for each entry for tag.
2343    self.assertEqual(1 + 2 + 2*2, self.Size())
2344    self.proto.repeated_int32.remove(128)
2345    self.assertEqual(1 + 2, self.Size())
2346
2347  def testRepeatedComposites(self):
2348    # Empty message.  2 bytes tag plus 1 byte length.
2349    foreign_message_0 = self.proto.repeated_nested_message.add()
2350    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2351    foreign_message_1 = self.proto.repeated_nested_message.add()
2352    foreign_message_1.bb = 7
2353    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2354
2355  def testRepeatedCompositesDelete(self):
2356    # Empty message.  2 bytes tag plus 1 byte length.
2357    foreign_message_0 = self.proto.repeated_nested_message.add()
2358    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2359    foreign_message_1 = self.proto.repeated_nested_message.add()
2360    foreign_message_1.bb = 9
2361    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2362    repeated_nested_message = copy.deepcopy(
2363        self.proto.repeated_nested_message)
2364
2365    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2366    del self.proto.repeated_nested_message[0]
2367    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2368
2369    # Now add a new message.
2370    foreign_message_2 = self.proto.repeated_nested_message.add()
2371    foreign_message_2.bb = 12
2372
2373    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2374    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2375    self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
2376
2377    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2378    del self.proto.repeated_nested_message[1]
2379    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2380
2381    del self.proto.repeated_nested_message[0]
2382    self.assertEqual(0, self.Size())
2383
2384    self.assertEqual(2, len(repeated_nested_message))
2385    del repeated_nested_message[0:1]
2386    # TODO(user): Fix cpp extension bug when delete repeated message.
2387    if api_implementation.Type() == 'python':
2388      self.assertEqual(1, len(repeated_nested_message))
2389    del repeated_nested_message[-1]
2390    # TODO(user): Fix cpp extension bug when delete repeated message.
2391    if api_implementation.Type() == 'python':
2392      self.assertEqual(0, len(repeated_nested_message))
2393
2394  def testRepeatedGroups(self):
2395    # 2-byte START_GROUP plus 2-byte END_GROUP.
2396    group_0 = self.proto.repeatedgroup.add()
2397    # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
2398    # plus 2-byte END_GROUP.
2399    group_1 = self.proto.repeatedgroup.add()
2400    group_1.a =  7
2401    self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
2402
2403  def testExtensions(self):
2404    proto = unittest_pb2.TestAllExtensions()
2405    self.assertEqual(0, proto.ByteSize())
2406    extension = unittest_pb2.optional_int32_extension  # Field #1, 1 byte.
2407    proto.Extensions[extension] = 23
2408    # 1 byte for tag, 1 byte for value.
2409    self.assertEqual(2, proto.ByteSize())
2410    field = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[
2411        'optional_int32']
2412    with self.assertRaises(KeyError):
2413      proto.Extensions[field] = 23
2414
2415  def testCacheInvalidationForNonrepeatedScalar(self):
2416    # Test non-extension.
2417    self.proto.optional_int32 = 1
2418    self.assertEqual(2, self.proto.ByteSize())
2419    self.proto.optional_int32 = 128
2420    self.assertEqual(3, self.proto.ByteSize())
2421    self.proto.ClearField('optional_int32')
2422    self.assertEqual(0, self.proto.ByteSize())
2423
2424    # Test within extension.
2425    extension = more_extensions_pb2.optional_int_extension
2426    self.extended_proto.Extensions[extension] = 1
2427    self.assertEqual(2, self.extended_proto.ByteSize())
2428    self.extended_proto.Extensions[extension] = 128
2429    self.assertEqual(3, self.extended_proto.ByteSize())
2430    self.extended_proto.ClearExtension(extension)
2431    self.assertEqual(0, self.extended_proto.ByteSize())
2432
2433  def testCacheInvalidationForRepeatedScalar(self):
2434    # Test non-extension.
2435    self.proto.repeated_int32.append(1)
2436    self.assertEqual(3, self.proto.ByteSize())
2437    self.proto.repeated_int32.append(1)
2438    self.assertEqual(6, self.proto.ByteSize())
2439    self.proto.repeated_int32[1] = 128
2440    self.assertEqual(7, self.proto.ByteSize())
2441    self.proto.ClearField('repeated_int32')
2442    self.assertEqual(0, self.proto.ByteSize())
2443
2444    # Test within extension.
2445    extension = more_extensions_pb2.repeated_int_extension
2446    repeated = self.extended_proto.Extensions[extension]
2447    repeated.append(1)
2448    self.assertEqual(2, self.extended_proto.ByteSize())
2449    repeated.append(1)
2450    self.assertEqual(4, self.extended_proto.ByteSize())
2451    repeated[1] = 128
2452    self.assertEqual(5, self.extended_proto.ByteSize())
2453    self.extended_proto.ClearExtension(extension)
2454    self.assertEqual(0, self.extended_proto.ByteSize())
2455
2456  def testCacheInvalidationForNonrepeatedMessage(self):
2457    # Test non-extension.
2458    self.proto.optional_foreign_message.c = 1
2459    self.assertEqual(5, self.proto.ByteSize())
2460    self.proto.optional_foreign_message.c = 128
2461    self.assertEqual(6, self.proto.ByteSize())
2462    self.proto.optional_foreign_message.ClearField('c')
2463    self.assertEqual(3, self.proto.ByteSize())
2464    self.proto.ClearField('optional_foreign_message')
2465    self.assertEqual(0, self.proto.ByteSize())
2466
2467    if api_implementation.Type() == 'python':
2468      # This is only possible in pure-Python implementation of the API.
2469      child = self.proto.optional_foreign_message
2470      self.proto.ClearField('optional_foreign_message')
2471      child.c = 128
2472      self.assertEqual(0, self.proto.ByteSize())
2473
2474    # Test within extension.
2475    extension = more_extensions_pb2.optional_message_extension
2476    child = self.extended_proto.Extensions[extension]
2477    self.assertEqual(0, self.extended_proto.ByteSize())
2478    child.foreign_message_int = 1
2479    self.assertEqual(4, self.extended_proto.ByteSize())
2480    child.foreign_message_int = 128
2481    self.assertEqual(5, self.extended_proto.ByteSize())
2482    self.extended_proto.ClearExtension(extension)
2483    self.assertEqual(0, self.extended_proto.ByteSize())
2484
2485  def testCacheInvalidationForRepeatedMessage(self):
2486    # Test non-extension.
2487    child0 = self.proto.repeated_foreign_message.add()
2488    self.assertEqual(3, self.proto.ByteSize())
2489    self.proto.repeated_foreign_message.add()
2490    self.assertEqual(6, self.proto.ByteSize())
2491    child0.c = 1
2492    self.assertEqual(8, self.proto.ByteSize())
2493    self.proto.ClearField('repeated_foreign_message')
2494    self.assertEqual(0, self.proto.ByteSize())
2495
2496    # Test within extension.
2497    extension = more_extensions_pb2.repeated_message_extension
2498    child_list = self.extended_proto.Extensions[extension]
2499    child0 = child_list.add()
2500    self.assertEqual(2, self.extended_proto.ByteSize())
2501    child_list.add()
2502    self.assertEqual(4, self.extended_proto.ByteSize())
2503    child0.foreign_message_int = 1
2504    self.assertEqual(6, self.extended_proto.ByteSize())
2505    child0.ClearField('foreign_message_int')
2506    self.assertEqual(4, self.extended_proto.ByteSize())
2507    self.extended_proto.ClearExtension(extension)
2508    self.assertEqual(0, self.extended_proto.ByteSize())
2509
2510  def testPackedRepeatedScalars(self):
2511    self.assertEqual(0, self.packed_proto.ByteSize())
2512
2513    self.packed_proto.packed_int32.append(10)   # 1 byte.
2514    self.packed_proto.packed_int32.append(128)  # 2 bytes.
2515    # The tag is 2 bytes (the field number is 90), and the varint
2516    # storing the length is 1 byte.
2517    int_size = 1 + 2 + 3
2518    self.assertEqual(int_size, self.packed_proto.ByteSize())
2519
2520    self.packed_proto.packed_double.append(4.2)   # 8 bytes
2521    self.packed_proto.packed_double.append(3.25)  # 8 bytes
2522    # 2 more tag bytes, 1 more length byte.
2523    double_size = 8 + 8 + 3
2524    self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
2525
2526    self.packed_proto.ClearField('packed_int32')
2527    self.assertEqual(double_size, self.packed_proto.ByteSize())
2528
2529  def testPackedExtensions(self):
2530    self.assertEqual(0, self.packed_extended_proto.ByteSize())
2531    extension = self.packed_extended_proto.Extensions[
2532        unittest_pb2.packed_fixed32_extension]
2533    extension.extend([1, 2, 3, 4])   # 16 bytes
2534    # Tag is 3 bytes.
2535    self.assertEqual(19, self.packed_extended_proto.ByteSize())
2536
2537
2538# Issues to be sure to cover include:
2539#   * Handling of unrecognized tags ("uninterpreted_bytes").
2540#   * Handling of MessageSets.
2541#   * Consistent ordering of tags in the wire format,
2542#     including ordering between extensions and non-extension
2543#     fields.
2544#   * Consistent serialization of negative numbers, especially
2545#     negative int32s.
2546#   * Handling of empty submessages (with and without "has"
2547#     bits set).
2548
2549@testing_refleaks.TestCase
2550class SerializationTest(unittest.TestCase):
2551
2552  def testSerializeEmtpyMessage(self):
2553    first_proto = unittest_pb2.TestAllTypes()
2554    second_proto = unittest_pb2.TestAllTypes()
2555    serialized = first_proto.SerializeToString()
2556    self.assertEqual(first_proto.ByteSize(), len(serialized))
2557    self.assertEqual(
2558        len(serialized),
2559        second_proto.MergeFromString(serialized))
2560    self.assertEqual(first_proto, second_proto)
2561
2562  def testSerializeAllFields(self):
2563    first_proto = unittest_pb2.TestAllTypes()
2564    second_proto = unittest_pb2.TestAllTypes()
2565    test_util.SetAllFields(first_proto)
2566    serialized = first_proto.SerializeToString()
2567    self.assertEqual(first_proto.ByteSize(), len(serialized))
2568    self.assertEqual(
2569        len(serialized),
2570        second_proto.MergeFromString(serialized))
2571    self.assertEqual(first_proto, second_proto)
2572
2573  def testSerializeAllExtensions(self):
2574    first_proto = unittest_pb2.TestAllExtensions()
2575    second_proto = unittest_pb2.TestAllExtensions()
2576    test_util.SetAllExtensions(first_proto)
2577    serialized = first_proto.SerializeToString()
2578    self.assertEqual(
2579        len(serialized),
2580        second_proto.MergeFromString(serialized))
2581    self.assertEqual(first_proto, second_proto)
2582
2583  def testSerializeWithOptionalGroup(self):
2584    first_proto = unittest_pb2.TestAllTypes()
2585    second_proto = unittest_pb2.TestAllTypes()
2586    first_proto.optionalgroup.a = 242
2587    serialized = first_proto.SerializeToString()
2588    self.assertEqual(
2589        len(serialized),
2590        second_proto.MergeFromString(serialized))
2591    self.assertEqual(first_proto, second_proto)
2592
2593  def testSerializeNegativeValues(self):
2594    first_proto = unittest_pb2.TestAllTypes()
2595
2596    first_proto.optional_int32 = -1
2597    first_proto.optional_int64 = -(2 << 40)
2598    first_proto.optional_sint32 = -3
2599    first_proto.optional_sint64 = -(4 << 40)
2600    first_proto.optional_sfixed32 = -5
2601    first_proto.optional_sfixed64 = -(6 << 40)
2602
2603    second_proto = unittest_pb2.TestAllTypes.FromString(
2604        first_proto.SerializeToString())
2605
2606    self.assertEqual(first_proto, second_proto)
2607
2608  def testParseTruncated(self):
2609    # This test is only applicable for the Python implementation of the API.
2610    if api_implementation.Type() != 'python':
2611      return
2612
2613    first_proto = unittest_pb2.TestAllTypes()
2614    test_util.SetAllFields(first_proto)
2615    serialized = memoryview(first_proto.SerializeToString())
2616
2617    for truncation_point in range(len(serialized) + 1):
2618      try:
2619        second_proto = unittest_pb2.TestAllTypes()
2620        unknown_fields = unittest_pb2.TestEmptyMessage()
2621        pos = second_proto._InternalParse(serialized, 0, truncation_point)
2622        # If we didn't raise an error then we read exactly the amount expected.
2623        self.assertEqual(truncation_point, pos)
2624
2625        # Parsing to unknown fields should not throw if parsing to known fields
2626        # did not.
2627        try:
2628          pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
2629          self.assertEqual(truncation_point, pos2)
2630        except message.DecodeError:
2631          self.fail('Parsing unknown fields failed when parsing known fields '
2632                    'did not.')
2633      except message.DecodeError:
2634        # Parsing unknown fields should also fail.
2635        self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
2636                          serialized, 0, truncation_point)
2637
2638  def testCanonicalSerializationOrder(self):
2639    proto = more_messages_pb2.OutOfOrderFields()
2640    # These are also their tag numbers.  Even though we're setting these in
2641    # reverse-tag order AND they're listed in reverse tag-order in the .proto
2642    # file, they should nonetheless be serialized in tag order.
2643    proto.optional_sint32 = 5
2644    proto.Extensions[more_messages_pb2.optional_uint64] = 4
2645    proto.optional_uint32 = 3
2646    proto.Extensions[more_messages_pb2.optional_int64] = 2
2647    proto.optional_int32 = 1
2648    serialized = proto.SerializeToString()
2649    self.assertEqual(proto.ByteSize(), len(serialized))
2650    d = _MiniDecoder(serialized)
2651    ReadTag = d.ReadFieldNumberAndWireType
2652    self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
2653    self.assertEqual(1, d.ReadInt32())
2654    self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
2655    self.assertEqual(2, d.ReadInt64())
2656    self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
2657    self.assertEqual(3, d.ReadUInt32())
2658    self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
2659    self.assertEqual(4, d.ReadUInt64())
2660    self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
2661    self.assertEqual(5, d.ReadSInt32())
2662
2663  def testCanonicalSerializationOrderSameAsCpp(self):
2664    # Copy of the same test we use for C++.
2665    proto = unittest_pb2.TestFieldOrderings()
2666    test_util.SetAllFieldsAndExtensions(proto)
2667    serialized = proto.SerializeToString()
2668    test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
2669
2670  def testMergeFromStringWhenFieldsAlreadySet(self):
2671    first_proto = unittest_pb2.TestAllTypes()
2672    first_proto.repeated_string.append('foobar')
2673    first_proto.optional_int32 = 23
2674    first_proto.optional_nested_message.bb = 42
2675    serialized = first_proto.SerializeToString()
2676
2677    second_proto = unittest_pb2.TestAllTypes()
2678    second_proto.repeated_string.append('baz')
2679    second_proto.optional_int32 = 100
2680    second_proto.optional_nested_message.bb = 999
2681
2682    bytes_parsed = second_proto.MergeFromString(serialized)
2683    self.assertEqual(len(serialized), bytes_parsed)
2684
2685    # Ensure that we append to repeated fields.
2686    self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
2687    # Ensure that we overwrite nonrepeatd scalars.
2688    self.assertEqual(23, second_proto.optional_int32)
2689    # Ensure that we recursively call MergeFromString() on
2690    # submessages.
2691    self.assertEqual(42, second_proto.optional_nested_message.bb)
2692
2693  def testMessageSetWireFormat(self):
2694    proto = message_set_extensions_pb2.TestMessageSet()
2695    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2696    extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2
2697    extension1 = extension_message1.message_set_extension
2698    extension2 = extension_message2.message_set_extension
2699    extension3 = message_set_extensions_pb2.message_set_extension3
2700    proto.Extensions[extension1].i = 123
2701    proto.Extensions[extension2].str = 'foo'
2702    proto.Extensions[extension3].text = 'bar'
2703
2704    # Serialize using the MessageSet wire format (this is specified in the
2705    # .proto file).
2706    serialized = proto.SerializeToString()
2707
2708    raw = unittest_mset_pb2.RawMessageSet()
2709    self.assertEqual(False,
2710                     raw.DESCRIPTOR.GetOptions().message_set_wire_format)
2711    self.assertEqual(
2712        len(serialized),
2713        raw.MergeFromString(serialized))
2714    self.assertEqual(3, len(raw.item))
2715
2716    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2717    self.assertEqual(
2718        len(raw.item[0].message),
2719        message1.MergeFromString(raw.item[0].message))
2720    self.assertEqual(123, message1.i)
2721
2722    message2 = message_set_extensions_pb2.TestMessageSetExtension2()
2723    self.assertEqual(
2724        len(raw.item[1].message),
2725        message2.MergeFromString(raw.item[1].message))
2726    self.assertEqual('foo', message2.str)
2727
2728    message3 = message_set_extensions_pb2.TestMessageSetExtension3()
2729    self.assertEqual(
2730        len(raw.item[2].message),
2731        message3.MergeFromString(raw.item[2].message))
2732    self.assertEqual('bar', message3.text)
2733
2734    # Deserialize using the MessageSet wire format.
2735    proto2 = message_set_extensions_pb2.TestMessageSet()
2736    self.assertEqual(
2737        len(serialized),
2738        proto2.MergeFromString(serialized))
2739    self.assertEqual(123, proto2.Extensions[extension1].i)
2740    self.assertEqual('foo', proto2.Extensions[extension2].str)
2741    self.assertEqual('bar', proto2.Extensions[extension3].text)
2742
2743    # Check byte size.
2744    self.assertEqual(proto2.ByteSize(), len(serialized))
2745    self.assertEqual(proto.ByteSize(), len(serialized))
2746
2747  def testMessageSetWireFormatUnknownExtension(self):
2748    # Create a message using the message set wire format with an unknown
2749    # message.
2750    raw = unittest_mset_pb2.RawMessageSet()
2751
2752    # Add an item.
2753    item = raw.item.add()
2754    item.type_id = 98418603
2755    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2756    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2757    message1.i = 12345
2758    item.message = message1.SerializeToString()
2759
2760    # Add a second, unknown extension.
2761    item = raw.item.add()
2762    item.type_id = 98418604
2763    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2764    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2765    message1.i = 12346
2766    item.message = message1.SerializeToString()
2767
2768    # Add another unknown extension.
2769    item = raw.item.add()
2770    item.type_id = 98418605
2771    message1 = message_set_extensions_pb2.TestMessageSetExtension2()
2772    message1.str = 'foo'
2773    item.message = message1.SerializeToString()
2774
2775    serialized = raw.SerializeToString()
2776
2777    # Parse message using the message set wire format.
2778    proto = message_set_extensions_pb2.TestMessageSet()
2779    self.assertEqual(
2780        len(serialized),
2781        proto.MergeFromString(serialized))
2782
2783    # Check that the message parsed well.
2784    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2785    extension1 = extension_message1.message_set_extension
2786    self.assertEqual(12345, proto.Extensions[extension1].i)
2787
2788  def testUnknownFields(self):
2789    proto = unittest_pb2.TestAllTypes()
2790    test_util.SetAllFields(proto)
2791
2792    serialized = proto.SerializeToString()
2793
2794    # The empty message should be parsable with all of the fields
2795    # unknown.
2796    proto2 = unittest_pb2.TestEmptyMessage()
2797
2798    # Parsing this message should succeed.
2799    self.assertEqual(
2800        len(serialized),
2801        proto2.MergeFromString(serialized))
2802
2803    # Now test with a int64 field set.
2804    proto = unittest_pb2.TestAllTypes()
2805    proto.optional_int64 = 0x0fffffffffffffff
2806    serialized = proto.SerializeToString()
2807    # The empty message should be parsable with all of the fields
2808    # unknown.
2809    proto2 = unittest_pb2.TestEmptyMessage()
2810    # Parsing this message should succeed.
2811    self.assertEqual(
2812        len(serialized),
2813        proto2.MergeFromString(serialized))
2814
2815  def _CheckRaises(self, exc_class, callable_obj, exception):
2816    """This method checks if the excpetion type and message are as expected."""
2817    try:
2818      callable_obj()
2819    except exc_class as ex:
2820      # Check if the exception message is the right one.
2821      self.assertEqual(exception, str(ex))
2822      return
2823    else:
2824      raise self.failureException('%s not raised' % str(exc_class))
2825
2826  def testSerializeUninitialized(self):
2827    proto = unittest_pb2.TestRequired()
2828    self._CheckRaises(
2829        message.EncodeError,
2830        proto.SerializeToString,
2831        'Message protobuf_unittest.TestRequired is missing required fields: '
2832        'a,b,c')
2833    # Shouldn't raise exceptions.
2834    partial = proto.SerializePartialToString()
2835
2836    proto2 = unittest_pb2.TestRequired()
2837    self.assertFalse(proto2.HasField('a'))
2838    # proto2 ParseFromString does not check that required fields are set.
2839    proto2.ParseFromString(partial)
2840    self.assertFalse(proto2.HasField('a'))
2841
2842    proto.a = 1
2843    self._CheckRaises(
2844        message.EncodeError,
2845        proto.SerializeToString,
2846        'Message protobuf_unittest.TestRequired is missing required fields: b,c')
2847    # Shouldn't raise exceptions.
2848    partial = proto.SerializePartialToString()
2849
2850    proto.b = 2
2851    self._CheckRaises(
2852        message.EncodeError,
2853        proto.SerializeToString,
2854        'Message protobuf_unittest.TestRequired is missing required fields: c')
2855    # Shouldn't raise exceptions.
2856    partial = proto.SerializePartialToString()
2857
2858    proto.c = 3
2859    serialized = proto.SerializeToString()
2860    # Shouldn't raise exceptions.
2861    partial = proto.SerializePartialToString()
2862
2863    proto2 = unittest_pb2.TestRequired()
2864    self.assertEqual(
2865        len(serialized),
2866        proto2.MergeFromString(serialized))
2867    self.assertEqual(1, proto2.a)
2868    self.assertEqual(2, proto2.b)
2869    self.assertEqual(3, proto2.c)
2870    self.assertEqual(
2871        len(partial),
2872        proto2.MergeFromString(partial))
2873    self.assertEqual(1, proto2.a)
2874    self.assertEqual(2, proto2.b)
2875    self.assertEqual(3, proto2.c)
2876
2877  def testSerializeUninitializedSubMessage(self):
2878    proto = unittest_pb2.TestRequiredForeign()
2879
2880    # Sub-message doesn't exist yet, so this succeeds.
2881    proto.SerializeToString()
2882
2883    proto.optional_message.a = 1
2884    self._CheckRaises(
2885        message.EncodeError,
2886        proto.SerializeToString,
2887        'Message protobuf_unittest.TestRequiredForeign '
2888        'is missing required fields: '
2889        'optional_message.b,optional_message.c')
2890
2891    proto.optional_message.b = 2
2892    proto.optional_message.c = 3
2893    proto.SerializeToString()
2894
2895    proto.repeated_message.add().a = 1
2896    proto.repeated_message.add().b = 2
2897    self._CheckRaises(
2898        message.EncodeError,
2899        proto.SerializeToString,
2900        'Message protobuf_unittest.TestRequiredForeign is missing required fields: '
2901        'repeated_message[0].b,repeated_message[0].c,'
2902        'repeated_message[1].a,repeated_message[1].c')
2903
2904    proto.repeated_message[0].b = 2
2905    proto.repeated_message[0].c = 3
2906    proto.repeated_message[1].a = 1
2907    proto.repeated_message[1].c = 3
2908    proto.SerializeToString()
2909
2910  def testSerializeAllPackedFields(self):
2911    first_proto = unittest_pb2.TestPackedTypes()
2912    second_proto = unittest_pb2.TestPackedTypes()
2913    test_util.SetAllPackedFields(first_proto)
2914    serialized = first_proto.SerializeToString()
2915    self.assertEqual(first_proto.ByteSize(), len(serialized))
2916    bytes_read = second_proto.MergeFromString(serialized)
2917    self.assertEqual(second_proto.ByteSize(), bytes_read)
2918    self.assertEqual(first_proto, second_proto)
2919
2920  def testSerializeAllPackedExtensions(self):
2921    first_proto = unittest_pb2.TestPackedExtensions()
2922    second_proto = unittest_pb2.TestPackedExtensions()
2923    test_util.SetAllPackedExtensions(first_proto)
2924    serialized = first_proto.SerializeToString()
2925    bytes_read = second_proto.MergeFromString(serialized)
2926    self.assertEqual(second_proto.ByteSize(), bytes_read)
2927    self.assertEqual(first_proto, second_proto)
2928
2929  def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
2930    first_proto = unittest_pb2.TestPackedTypes()
2931    first_proto.packed_int32.extend([1, 2])
2932    first_proto.packed_double.append(3.0)
2933    serialized = first_proto.SerializeToString()
2934
2935    second_proto = unittest_pb2.TestPackedTypes()
2936    second_proto.packed_int32.append(3)
2937    second_proto.packed_double.extend([1.0, 2.0])
2938    second_proto.packed_sint32.append(4)
2939
2940    self.assertEqual(
2941        len(serialized),
2942        second_proto.MergeFromString(serialized))
2943    self.assertEqual([3, 1, 2], second_proto.packed_int32)
2944    self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
2945    self.assertEqual([4], second_proto.packed_sint32)
2946
2947  def testPackedFieldsWireFormat(self):
2948    proto = unittest_pb2.TestPackedTypes()
2949    proto.packed_int32.extend([1, 2, 150, 3])  # 1 + 1 + 2 + 1 bytes
2950    proto.packed_double.extend([1.0, 1000.0])  # 8 + 8 bytes
2951    proto.packed_float.append(2.0)             # 4 bytes, will be before double
2952    serialized = proto.SerializeToString()
2953    self.assertEqual(proto.ByteSize(), len(serialized))
2954    d = _MiniDecoder(serialized)
2955    ReadTag = d.ReadFieldNumberAndWireType
2956    self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2957    self.assertEqual(1+1+1+2, d.ReadInt32())
2958    self.assertEqual(1, d.ReadInt32())
2959    self.assertEqual(2, d.ReadInt32())
2960    self.assertEqual(150, d.ReadInt32())
2961    self.assertEqual(3, d.ReadInt32())
2962    self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2963    self.assertEqual(4, d.ReadInt32())
2964    self.assertEqual(2.0, d.ReadFloat())
2965    self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2966    self.assertEqual(8+8, d.ReadInt32())
2967    self.assertEqual(1.0, d.ReadDouble())
2968    self.assertEqual(1000.0, d.ReadDouble())
2969    self.assertTrue(d.EndOfStream())
2970
2971  def testParsePackedFromUnpacked(self):
2972    unpacked = unittest_pb2.TestUnpackedTypes()
2973    test_util.SetAllUnpackedFields(unpacked)
2974    packed = unittest_pb2.TestPackedTypes()
2975    serialized = unpacked.SerializeToString()
2976    self.assertEqual(
2977        len(serialized),
2978        packed.MergeFromString(serialized))
2979    expected = unittest_pb2.TestPackedTypes()
2980    test_util.SetAllPackedFields(expected)
2981    self.assertEqual(expected, packed)
2982
2983  def testParseUnpackedFromPacked(self):
2984    packed = unittest_pb2.TestPackedTypes()
2985    test_util.SetAllPackedFields(packed)
2986    unpacked = unittest_pb2.TestUnpackedTypes()
2987    serialized = packed.SerializeToString()
2988    self.assertEqual(
2989        len(serialized),
2990        unpacked.MergeFromString(serialized))
2991    expected = unittest_pb2.TestUnpackedTypes()
2992    test_util.SetAllUnpackedFields(expected)
2993    self.assertEqual(expected, unpacked)
2994
2995  def testFieldNumbers(self):
2996    proto = unittest_pb2.TestAllTypes()
2997    self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
2998    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1)
2999    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16)
3000    self.assertEqual(
3001      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18)
3002    self.assertEqual(
3003      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21)
3004    self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31)
3005    self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46)
3006    self.assertEqual(
3007      unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48)
3008    self.assertEqual(
3009      unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51)
3010
3011  def testExtensionFieldNumbers(self):
3012    self.assertEqual(unittest_pb2.TestRequired.single.number, 1000)
3013    self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000)
3014    self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001)
3015    self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001)
3016    self.assertEqual(unittest_pb2.optional_int32_extension.number, 1)
3017    self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1)
3018    self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16)
3019    self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16)
3020    self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18)
3021    self.assertEqual(
3022      unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18)
3023    self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21)
3024    self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
3025      21)
3026    self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31)
3027    self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31)
3028    self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46)
3029    self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46)
3030    self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48)
3031    self.assertEqual(
3032      unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48)
3033    self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51)
3034    self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
3035      51)
3036
3037  def testFieldProperties(self):
3038    cls = unittest_pb2.TestAllTypes
3039    self.assertIs(cls.optional_int32.DESCRIPTOR,
3040                  cls.DESCRIPTOR.fields_by_name['optional_int32'])
3041    self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER,
3042                     cls.optional_int32.DESCRIPTOR.number)
3043    self.assertIs(cls.optional_nested_message.DESCRIPTOR,
3044                  cls.DESCRIPTOR.fields_by_name['optional_nested_message'])
3045    self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER,
3046                     cls.optional_nested_message.DESCRIPTOR.number)
3047    self.assertIs(cls.repeated_int32.DESCRIPTOR,
3048                  cls.DESCRIPTOR.fields_by_name['repeated_int32'])
3049    self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER,
3050                     cls.repeated_int32.DESCRIPTOR.number)
3051
3052  def testFieldDataDescriptor(self):
3053    msg = unittest_pb2.TestAllTypes()
3054    msg.optional_int32 = 42
3055    self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42)
3056    unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25)
3057    self.assertEqual(msg.optional_int32, 25)
3058    with self.assertRaises(AttributeError):
3059      del msg.optional_int32
3060    try:
3061      unittest_pb2.ForeignMessage.c.__get__(msg)
3062    except TypeError:
3063      pass  # The cpp implementation cannot mix fields from other messages.
3064            # This test exercises a specific check that avoids a crash.
3065    else:
3066      pass  # The python implementation allows fields from other messages.
3067            # This is useless, but works.
3068
3069  def testInitKwargs(self):
3070    proto = unittest_pb2.TestAllTypes(
3071        optional_int32=1,
3072        optional_string='foo',
3073        optional_bool=True,
3074        optional_bytes=b'bar',
3075        optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
3076        optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
3077        optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
3078        optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
3079        repeated_int32=[1, 2, 3])
3080    self.assertTrue(proto.IsInitialized())
3081    self.assertTrue(proto.HasField('optional_int32'))
3082    self.assertTrue(proto.HasField('optional_string'))
3083    self.assertTrue(proto.HasField('optional_bool'))
3084    self.assertTrue(proto.HasField('optional_bytes'))
3085    self.assertTrue(proto.HasField('optional_nested_message'))
3086    self.assertTrue(proto.HasField('optional_foreign_message'))
3087    self.assertTrue(proto.HasField('optional_nested_enum'))
3088    self.assertTrue(proto.HasField('optional_foreign_enum'))
3089    self.assertEqual(1, proto.optional_int32)
3090    self.assertEqual('foo', proto.optional_string)
3091    self.assertEqual(True, proto.optional_bool)
3092    self.assertEqual(b'bar', proto.optional_bytes)
3093    self.assertEqual(1, proto.optional_nested_message.bb)
3094    self.assertEqual(1, proto.optional_foreign_message.c)
3095    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
3096                     proto.optional_nested_enum)
3097    self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
3098    self.assertEqual([1, 2, 3], proto.repeated_int32)
3099
3100  def testInitArgsUnknownFieldName(self):
3101    def InitalizeEmptyMessageWithExtraKeywordArg():
3102      unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
3103    self._CheckRaises(
3104        ValueError,
3105        InitalizeEmptyMessageWithExtraKeywordArg,
3106        'Protocol message TestEmptyMessage has no "unknown" field.')
3107
3108  def testInitRequiredKwargs(self):
3109    proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
3110    self.assertTrue(proto.IsInitialized())
3111    self.assertTrue(proto.HasField('a'))
3112    self.assertTrue(proto.HasField('b'))
3113    self.assertTrue(proto.HasField('c'))
3114    self.assertFalse(proto.HasField('dummy2'))
3115    self.assertEqual(1, proto.a)
3116    self.assertEqual(1, proto.b)
3117    self.assertEqual(1, proto.c)
3118
3119  def testInitRequiredForeignKwargs(self):
3120    proto = unittest_pb2.TestRequiredForeign(
3121        optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
3122    self.assertTrue(proto.IsInitialized())
3123    self.assertTrue(proto.HasField('optional_message'))
3124    self.assertTrue(proto.optional_message.IsInitialized())
3125    self.assertTrue(proto.optional_message.HasField('a'))
3126    self.assertTrue(proto.optional_message.HasField('b'))
3127    self.assertTrue(proto.optional_message.HasField('c'))
3128    self.assertFalse(proto.optional_message.HasField('dummy2'))
3129    self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
3130                     proto.optional_message)
3131    self.assertEqual(1, proto.optional_message.a)
3132    self.assertEqual(1, proto.optional_message.b)
3133    self.assertEqual(1, proto.optional_message.c)
3134
3135  def testInitRepeatedKwargs(self):
3136    proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
3137    self.assertTrue(proto.IsInitialized())
3138    self.assertEqual(1, proto.repeated_int32[0])
3139    self.assertEqual(2, proto.repeated_int32[1])
3140    self.assertEqual(3, proto.repeated_int32[2])
3141
3142
3143@testing_refleaks.TestCase
3144class OptionsTest(unittest.TestCase):
3145
3146  def testMessageOptions(self):
3147    proto = message_set_extensions_pb2.TestMessageSet()
3148    self.assertEqual(True,
3149                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
3150    proto = unittest_pb2.TestAllTypes()
3151    self.assertEqual(False,
3152                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
3153
3154  def testPackedOptions(self):
3155    proto = unittest_pb2.TestAllTypes()
3156    proto.optional_int32 = 1
3157    proto.optional_double = 3.0
3158    for field_descriptor, _ in proto.ListFields():
3159      self.assertEqual(False, field_descriptor.GetOptions().packed)
3160
3161    proto = unittest_pb2.TestPackedTypes()
3162    proto.packed_int32.append(1)
3163    proto.packed_double.append(3.0)
3164    for field_descriptor, _ in proto.ListFields():
3165      self.assertEqual(True, field_descriptor.GetOptions().packed)
3166      self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED,
3167                       field_descriptor.label)
3168
3169
3170
3171@testing_refleaks.TestCase
3172class ClassAPITest(unittest.TestCase):
3173
3174  @unittest.skipIf(
3175      api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
3176      'C++ implementation requires a call to MakeDescriptor()')
3177  @testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable')
3178  def testMakeClassWithNestedDescriptor(self):
3179    leaf_desc = descriptor.Descriptor(
3180        'leaf', 'package.parent.child.leaf', '',
3181        containing_type=None, fields=[],
3182        nested_types=[], enum_types=[],
3183        extensions=[],
3184        # pylint: disable=protected-access
3185        create_key=descriptor._internal_create_key)
3186    child_desc = descriptor.Descriptor(
3187        'child', 'package.parent.child', '',
3188        containing_type=None, fields=[],
3189        nested_types=[leaf_desc], enum_types=[],
3190        extensions=[],
3191        # pylint: disable=protected-access
3192        create_key=descriptor._internal_create_key)
3193    sibling_desc = descriptor.Descriptor(
3194        'sibling', 'package.parent.sibling',
3195        '', containing_type=None, fields=[],
3196        nested_types=[], enum_types=[],
3197        extensions=[],
3198        # pylint: disable=protected-access
3199        create_key=descriptor._internal_create_key)
3200    parent_desc = descriptor.Descriptor(
3201        'parent', 'package.parent', '',
3202        containing_type=None, fields=[],
3203        nested_types=[child_desc, sibling_desc],
3204        enum_types=[], extensions=[],
3205        # pylint: disable=protected-access
3206        create_key=descriptor._internal_create_key)
3207    reflection.MakeClass(parent_desc)
3208
3209  def _GetSerializedFileDescriptor(self, name):
3210    """Get a serialized representation of a test FileDescriptorProto.
3211
3212    Args:
3213      name: All calls to this must use a unique message name, to avoid
3214          collisions in the cpp descriptor pool.
3215    Returns:
3216      A string containing the serialized form of a test FileDescriptorProto.
3217    """
3218    file_descriptor_str = (
3219        'message_type {'
3220        '  name: "' + name + '"'
3221        '  field {'
3222        '    name: "flat"'
3223        '    number: 1'
3224        '    label: LABEL_REPEATED'
3225        '    type: TYPE_UINT32'
3226        '  }'
3227        '  field {'
3228        '    name: "bar"'
3229        '    number: 2'
3230        '    label: LABEL_OPTIONAL'
3231        '    type: TYPE_MESSAGE'
3232        '    type_name: "Bar"'
3233        '  }'
3234        '  nested_type {'
3235        '    name: "Bar"'
3236        '    field {'
3237        '      name: "baz"'
3238        '      number: 3'
3239        '      label: LABEL_OPTIONAL'
3240        '      type: TYPE_MESSAGE'
3241        '      type_name: "Baz"'
3242        '    }'
3243        '    nested_type {'
3244        '      name: "Baz"'
3245        '      enum_type {'
3246        '        name: "deep_enum"'
3247        '        value {'
3248        '          name: "VALUE_A"'
3249        '          number: 0'
3250        '        }'
3251        '      }'
3252        '      field {'
3253        '        name: "deep"'
3254        '        number: 4'
3255        '        label: LABEL_OPTIONAL'
3256        '        type: TYPE_UINT32'
3257        '      }'
3258        '    }'
3259        '  }'
3260        '}')
3261    file_descriptor = descriptor_pb2.FileDescriptorProto()
3262    text_format.Merge(file_descriptor_str, file_descriptor)
3263    return file_descriptor.SerializeToString()
3264
3265  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
3266  # This test can only run once; the second time, it raises errors about
3267  # conflicting message descriptors.
3268  def testParsingFlatClassWithExplicitClassDeclaration(self):
3269    """Test that the generated class can parse a flat message."""
3270    # TODO(user): This test fails with cpp implemetnation in the call
3271    # of six.with_metaclass(). The other two callsites of with_metaclass
3272    # in this file are both excluded from cpp test, so it might be expected
3273    # to fail. Need someone more familiar with the python code to take a
3274    # look at this.
3275    if api_implementation.Type() != 'python':
3276      return
3277    file_descriptor = descriptor_pb2.FileDescriptorProto()
3278    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
3279    msg_descriptor = descriptor.MakeDescriptor(
3280        file_descriptor.message_type[0])
3281
3282    class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
3283      DESCRIPTOR = msg_descriptor
3284    msg = MessageClass()
3285    msg_str = (
3286        'flat: 0 '
3287        'flat: 1 '
3288        'flat: 2 ')
3289    text_format.Merge(msg_str, msg)
3290    self.assertEqual(msg.flat, [0, 1, 2])
3291
3292  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
3293  def testParsingFlatClass(self):
3294    """Test that the generated class can parse a flat message."""
3295    file_descriptor = descriptor_pb2.FileDescriptorProto()
3296    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B'))
3297    msg_descriptor = descriptor.MakeDescriptor(
3298        file_descriptor.message_type[0])
3299    msg_class = reflection.MakeClass(msg_descriptor)
3300    msg = msg_class()
3301    msg_str = (
3302        'flat: 0 '
3303        'flat: 1 '
3304        'flat: 2 ')
3305    text_format.Merge(msg_str, msg)
3306    self.assertEqual(msg.flat, [0, 1, 2])
3307
3308  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
3309  def testParsingNestedClass(self):
3310    """Test that the generated class can parse a nested message."""
3311    file_descriptor = descriptor_pb2.FileDescriptorProto()
3312    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
3313    msg_descriptor = descriptor.MakeDescriptor(
3314        file_descriptor.message_type[0])
3315    msg_class = reflection.MakeClass(msg_descriptor)
3316    msg = msg_class()
3317    msg_str = (
3318        'bar {'
3319        '  baz {'
3320        '    deep: 4'
3321        '  }'
3322        '}')
3323    text_format.Merge(msg_str, msg)
3324    self.assertEqual(msg.bar.baz.deep, 4)
3325
3326if __name__ == '__main__':
3327  unittest.main()
3328