1#! /usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4# Protocol Buffers - Google's data interchange format
5# Copyright 2008 Google Inc.  All rights reserved.
6# https://developers.google.com/protocol-buffers/
7#
8# Redistribution and use in source and binary forms, with or without
9# modification, are permitted provided that the following conditions are
10# met:
11#
12#     * Redistributions of source code must retain the above copyright
13# notice, this list of conditions and the following disclaimer.
14#     * Redistributions in binary form must reproduce the above
15# copyright notice, this list of conditions and the following disclaimer
16# in the documentation and/or other materials provided with the
17# distribution.
18#     * Neither the name of Google Inc. nor the names of its
19# contributors may be used to endorse or promote products derived from
20# this software without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34"""Unittest for reflection.py, which also indirectly tests the output of the
35pure-Python protocol compiler.
36"""
37
38import copy
39import gc
40import operator
41import six
42import struct
43
44try:
45  import unittest2 as unittest  #PY26
46except ImportError:
47  import unittest
48
49from google.protobuf import unittest_import_pb2
50from google.protobuf import unittest_mset_pb2
51from google.protobuf import unittest_pb2
52from google.protobuf import descriptor_pb2
53from google.protobuf import descriptor
54from google.protobuf import message
55from google.protobuf import reflection
56from google.protobuf import text_format
57from google.protobuf.internal import api_implementation
58from google.protobuf.internal import more_extensions_pb2
59from google.protobuf.internal import more_messages_pb2
60from google.protobuf.internal import message_set_extensions_pb2
61from google.protobuf.internal import wire_format
62from google.protobuf.internal import test_util
63from google.protobuf.internal import testing_refleaks
64from google.protobuf.internal import decoder
65
66
67if six.PY3:
68  long = int  # pylint: disable=redefined-builtin,invalid-name
69
70
71class _MiniDecoder(object):
72  """Decodes a stream of values from a string.
73
74  Once upon a time we actually had a class called decoder.Decoder.  Then we
75  got rid of it during a redesign that made decoding much, much faster overall.
76  But a couple tests in this file used it to check that the serialized form of
77  a message was correct.  So, this class implements just the methods that were
78  used by said tests, so that we don't have to rewrite the tests.
79  """
80
81  def __init__(self, bytes):
82    self._bytes = bytes
83    self._pos = 0
84
85  def ReadVarint(self):
86    result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
87    return result
88
89  ReadInt32 = ReadVarint
90  ReadInt64 = ReadVarint
91  ReadUInt32 = ReadVarint
92  ReadUInt64 = ReadVarint
93
94  def ReadSInt64(self):
95    return wire_format.ZigZagDecode(self.ReadVarint())
96
97  ReadSInt32 = ReadSInt64
98
99  def ReadFieldNumberAndWireType(self):
100    return wire_format.UnpackTag(self.ReadVarint())
101
102  def ReadFloat(self):
103    result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0]
104    self._pos += 4
105    return result
106
107  def ReadDouble(self):
108    result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0]
109    self._pos += 8
110    return result
111
112  def EndOfStream(self):
113    return self._pos == len(self._bytes)
114
115
116@testing_refleaks.TestCase
117class ReflectionTest(unittest.TestCase):
118
119  def assertListsEqual(self, values, others):
120    self.assertEqual(len(values), len(others))
121    for i in range(len(values)):
122      self.assertEqual(values[i], others[i])
123
124  def testScalarConstructor(self):
125    # Constructor with only scalar types should succeed.
126    proto = unittest_pb2.TestAllTypes(
127        optional_int32=24,
128        optional_double=54.321,
129        optional_string='optional_string',
130        optional_float=None)
131
132    self.assertEqual(24, proto.optional_int32)
133    self.assertEqual(54.321, proto.optional_double)
134    self.assertEqual('optional_string', proto.optional_string)
135    self.assertFalse(proto.HasField("optional_float"))
136
137  def testRepeatedScalarConstructor(self):
138    # Constructor with only repeated scalar types should succeed.
139    proto = unittest_pb2.TestAllTypes(
140        repeated_int32=[1, 2, 3, 4],
141        repeated_double=[1.23, 54.321],
142        repeated_bool=[True, False, False],
143        repeated_string=["optional_string"],
144        repeated_float=None)
145
146    self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32))
147    self.assertEqual([1.23, 54.321], list(proto.repeated_double))
148    self.assertEqual([True, False, False], list(proto.repeated_bool))
149    self.assertEqual(["optional_string"], list(proto.repeated_string))
150    self.assertEqual([], list(proto.repeated_float))
151
152  def testRepeatedCompositeConstructor(self):
153    # Constructor with only repeated composite types should succeed.
154    proto = unittest_pb2.TestAllTypes(
155        repeated_nested_message=[
156            unittest_pb2.TestAllTypes.NestedMessage(
157                bb=unittest_pb2.TestAllTypes.FOO),
158            unittest_pb2.TestAllTypes.NestedMessage(
159                bb=unittest_pb2.TestAllTypes.BAR)],
160        repeated_foreign_message=[
161            unittest_pb2.ForeignMessage(c=-43),
162            unittest_pb2.ForeignMessage(c=45324),
163            unittest_pb2.ForeignMessage(c=12)],
164        repeatedgroup=[
165            unittest_pb2.TestAllTypes.RepeatedGroup(),
166            unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
167            unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
168
169    self.assertEqual(
170        [unittest_pb2.TestAllTypes.NestedMessage(
171            bb=unittest_pb2.TestAllTypes.FOO),
172         unittest_pb2.TestAllTypes.NestedMessage(
173             bb=unittest_pb2.TestAllTypes.BAR)],
174        list(proto.repeated_nested_message))
175    self.assertEqual(
176        [unittest_pb2.ForeignMessage(c=-43),
177         unittest_pb2.ForeignMessage(c=45324),
178         unittest_pb2.ForeignMessage(c=12)],
179        list(proto.repeated_foreign_message))
180    self.assertEqual(
181        [unittest_pb2.TestAllTypes.RepeatedGroup(),
182         unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
183         unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
184        list(proto.repeatedgroup))
185
186  def testMixedConstructor(self):
187    # Constructor with only mixed types should succeed.
188    proto = unittest_pb2.TestAllTypes(
189        optional_int32=24,
190        optional_string='optional_string',
191        repeated_double=[1.23, 54.321],
192        repeated_bool=[True, False, False],
193        repeated_nested_message=[
194            unittest_pb2.TestAllTypes.NestedMessage(
195                bb=unittest_pb2.TestAllTypes.FOO),
196            unittest_pb2.TestAllTypes.NestedMessage(
197                bb=unittest_pb2.TestAllTypes.BAR)],
198        repeated_foreign_message=[
199            unittest_pb2.ForeignMessage(c=-43),
200            unittest_pb2.ForeignMessage(c=45324),
201            unittest_pb2.ForeignMessage(c=12)],
202        optional_nested_message=None)
203
204    self.assertEqual(24, proto.optional_int32)
205    self.assertEqual('optional_string', proto.optional_string)
206    self.assertEqual([1.23, 54.321], list(proto.repeated_double))
207    self.assertEqual([True, False, False], list(proto.repeated_bool))
208    self.assertEqual(
209        [unittest_pb2.TestAllTypes.NestedMessage(
210            bb=unittest_pb2.TestAllTypes.FOO),
211         unittest_pb2.TestAllTypes.NestedMessage(
212             bb=unittest_pb2.TestAllTypes.BAR)],
213        list(proto.repeated_nested_message))
214    self.assertEqual(
215        [unittest_pb2.ForeignMessage(c=-43),
216         unittest_pb2.ForeignMessage(c=45324),
217         unittest_pb2.ForeignMessage(c=12)],
218        list(proto.repeated_foreign_message))
219    self.assertFalse(proto.HasField("optional_nested_message"))
220
221  def testConstructorTypeError(self):
222    self.assertRaises(
223        TypeError, unittest_pb2.TestAllTypes, optional_int32="foo")
224    self.assertRaises(
225        TypeError, unittest_pb2.TestAllTypes, optional_string=1234)
226    self.assertRaises(
227        TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234)
228    self.assertRaises(
229        TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234)
230    self.assertRaises(
231        TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"])
232    self.assertRaises(
233        TypeError, unittest_pb2.TestAllTypes, repeated_string=1234)
234    self.assertRaises(
235        TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234])
236    self.assertRaises(
237        TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234)
238    self.assertRaises(
239        TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234])
240
241  def testConstructorInvalidatesCachedByteSize(self):
242    message = unittest_pb2.TestAllTypes(optional_int32 = 12)
243    self.assertEqual(2, message.ByteSize())
244
245    message = unittest_pb2.TestAllTypes(
246        optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage())
247    self.assertEqual(3, message.ByteSize())
248
249    message = unittest_pb2.TestAllTypes(repeated_int32 = [12])
250    self.assertEqual(3, message.ByteSize())
251
252    message = unittest_pb2.TestAllTypes(
253        repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()])
254    self.assertEqual(3, message.ByteSize())
255
256  def testSimpleHasBits(self):
257    # Test a scalar.
258    proto = unittest_pb2.TestAllTypes()
259    self.assertTrue(not proto.HasField('optional_int32'))
260    self.assertEqual(0, proto.optional_int32)
261    # HasField() shouldn't be true if all we've done is
262    # read the default value.
263    self.assertTrue(not proto.HasField('optional_int32'))
264    proto.optional_int32 = 1
265    # Setting a value however *should* set the "has" bit.
266    self.assertTrue(proto.HasField('optional_int32'))
267    proto.ClearField('optional_int32')
268    # And clearing that value should unset the "has" bit.
269    self.assertTrue(not proto.HasField('optional_int32'))
270
271  def testHasBitsWithSinglyNestedScalar(self):
272    # Helper used to test foreign messages and groups.
273    #
274    # composite_field_name should be the name of a non-repeated
275    # composite (i.e., foreign or group) field in TestAllTypes,
276    # and scalar_field_name should be the name of an integer-valued
277    # scalar field within that composite.
278    #
279    # I never thought I'd miss C++ macros and templates so much. :(
280    # This helper is semantically just:
281    #
282    #   assert proto.composite_field.scalar_field == 0
283    #   assert not proto.composite_field.HasField('scalar_field')
284    #   assert not proto.HasField('composite_field')
285    #
286    #   proto.composite_field.scalar_field = 10
287    #   old_composite_field = proto.composite_field
288    #
289    #   assert proto.composite_field.scalar_field == 10
290    #   assert proto.composite_field.HasField('scalar_field')
291    #   assert proto.HasField('composite_field')
292    #
293    #   proto.ClearField('composite_field')
294    #
295    #   assert not proto.composite_field.HasField('scalar_field')
296    #   assert not proto.HasField('composite_field')
297    #   assert proto.composite_field.scalar_field == 0
298    #
299    #   # Now ensure that ClearField('composite_field') disconnected
300    #   # the old field object from the object tree...
301    #   assert old_composite_field is not proto.composite_field
302    #   old_composite_field.scalar_field = 20
303    #   assert not proto.composite_field.HasField('scalar_field')
304    #   assert not proto.HasField('composite_field')
305    def TestCompositeHasBits(composite_field_name, scalar_field_name):
306      proto = unittest_pb2.TestAllTypes()
307      # First, check that we can get the scalar value, and see that it's the
308      # default (0), but that proto.HasField('omposite') and
309      # proto.composite.HasField('scalar') will still return False.
310      composite_field = getattr(proto, composite_field_name)
311      original_scalar_value = getattr(composite_field, scalar_field_name)
312      self.assertEqual(0, original_scalar_value)
313      # Assert that the composite object does not "have" the scalar.
314      self.assertTrue(not composite_field.HasField(scalar_field_name))
315      # Assert that proto does not "have" the composite field.
316      self.assertTrue(not proto.HasField(composite_field_name))
317
318      # Now set the scalar within the composite field.  Ensure that the setting
319      # is reflected, and that proto.HasField('composite') and
320      # proto.composite.HasField('scalar') now both return True.
321      new_val = 20
322      setattr(composite_field, scalar_field_name, new_val)
323      self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
324      # Hold on to a reference to the current composite_field object.
325      old_composite_field = composite_field
326      # Assert that the has methods now return true.
327      self.assertTrue(composite_field.HasField(scalar_field_name))
328      self.assertTrue(proto.HasField(composite_field_name))
329
330      # Now call the clear method...
331      proto.ClearField(composite_field_name)
332
333      # ...and ensure that the "has" bits are all back to False...
334      composite_field = getattr(proto, composite_field_name)
335      self.assertTrue(not composite_field.HasField(scalar_field_name))
336      self.assertTrue(not proto.HasField(composite_field_name))
337      # ...and ensure that the scalar field has returned to its default.
338      self.assertEqual(0, getattr(composite_field, scalar_field_name))
339
340      self.assertTrue(old_composite_field is not composite_field)
341      setattr(old_composite_field, scalar_field_name, new_val)
342      self.assertTrue(not composite_field.HasField(scalar_field_name))
343      self.assertTrue(not proto.HasField(composite_field_name))
344      self.assertEqual(0, getattr(composite_field, scalar_field_name))
345
346    # Test simple, single-level nesting when we set a scalar.
347    TestCompositeHasBits('optionalgroup', 'a')
348    TestCompositeHasBits('optional_nested_message', 'bb')
349    TestCompositeHasBits('optional_foreign_message', 'c')
350    TestCompositeHasBits('optional_import_message', 'd')
351
352  def testReferencesToNestedMessage(self):
353    proto = unittest_pb2.TestAllTypes()
354    nested = proto.optional_nested_message
355    del proto
356    # A previous version had a bug where this would raise an exception when
357    # hitting a now-dead weak reference.
358    nested.bb = 23
359
360  def testDisconnectingNestedMessageBeforeSettingField(self):
361    proto = unittest_pb2.TestAllTypes()
362    nested = proto.optional_nested_message
363    proto.ClearField('optional_nested_message')  # Should disconnect from parent
364    self.assertTrue(nested is not proto.optional_nested_message)
365    nested.bb = 23
366    self.assertTrue(not proto.HasField('optional_nested_message'))
367    self.assertEqual(0, proto.optional_nested_message.bb)
368
369  def testGetDefaultMessageAfterDisconnectingDefaultMessage(self):
370    proto = unittest_pb2.TestAllTypes()
371    nested = proto.optional_nested_message
372    proto.ClearField('optional_nested_message')
373    del proto
374    del nested
375    # Force a garbage collect so that the underlying CMessages are freed along
376    # with the Messages they point to. This is to make sure we're not deleting
377    # default message instances.
378    gc.collect()
379    proto = unittest_pb2.TestAllTypes()
380    nested = proto.optional_nested_message
381
382  def testDisconnectingNestedMessageAfterSettingField(self):
383    proto = unittest_pb2.TestAllTypes()
384    nested = proto.optional_nested_message
385    nested.bb = 5
386    self.assertTrue(proto.HasField('optional_nested_message'))
387    proto.ClearField('optional_nested_message')  # Should disconnect from parent
388    self.assertEqual(5, nested.bb)
389    self.assertEqual(0, proto.optional_nested_message.bb)
390    self.assertTrue(nested is not proto.optional_nested_message)
391    nested.bb = 23
392    self.assertTrue(not proto.HasField('optional_nested_message'))
393    self.assertEqual(0, proto.optional_nested_message.bb)
394
395  def testDisconnectingNestedMessageBeforeGettingField(self):
396    proto = unittest_pb2.TestAllTypes()
397    self.assertTrue(not proto.HasField('optional_nested_message'))
398    proto.ClearField('optional_nested_message')
399    self.assertTrue(not proto.HasField('optional_nested_message'))
400
401  def testDisconnectingNestedMessageAfterMerge(self):
402    # This test exercises the code path that does not use ReleaseMessage().
403    # The underlying fear is that if we use ReleaseMessage() incorrectly,
404    # we will have memory leaks.  It's hard to check that that doesn't happen,
405    # but at least we can exercise that code path to make sure it works.
406    proto1 = unittest_pb2.TestAllTypes()
407    proto2 = unittest_pb2.TestAllTypes()
408    proto2.optional_nested_message.bb = 5
409    proto1.MergeFrom(proto2)
410    self.assertTrue(proto1.HasField('optional_nested_message'))
411    proto1.ClearField('optional_nested_message')
412    self.assertTrue(not proto1.HasField('optional_nested_message'))
413
414  def testDisconnectingLazyNestedMessage(self):
415    # This test exercises releasing a nested message that is lazy. This test
416    # only exercises real code in the C++ implementation as Python does not
417    # support lazy parsing, but the current C++ implementation results in
418    # memory corruption and a crash.
419    if api_implementation.Type() != 'python':
420      return
421    proto = unittest_pb2.TestAllTypes()
422    proto.optional_lazy_message.bb = 5
423    proto.ClearField('optional_lazy_message')
424    del proto
425    gc.collect()
426
427  def testHasBitsWhenModifyingRepeatedFields(self):
428    # Test nesting when we add an element to a repeated field in a submessage.
429    proto = unittest_pb2.TestNestedMessageHasBits()
430    proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
431    self.assertEqual(
432        [5], proto.optional_nested_message.nestedmessage_repeated_int32)
433    self.assertTrue(proto.HasField('optional_nested_message'))
434
435    # Do the same test, but with a repeated composite field within the
436    # submessage.
437    proto.ClearField('optional_nested_message')
438    self.assertTrue(not proto.HasField('optional_nested_message'))
439    proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
440    self.assertTrue(proto.HasField('optional_nested_message'))
441
442  def testHasBitsForManyLevelsOfNesting(self):
443    # Test nesting many levels deep.
444    recursive_proto = unittest_pb2.TestMutualRecursionA()
445    self.assertTrue(not recursive_proto.HasField('bb'))
446    self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
447    self.assertTrue(not recursive_proto.HasField('bb'))
448    recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
449    self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
450    self.assertTrue(recursive_proto.HasField('bb'))
451    self.assertTrue(recursive_proto.bb.HasField('a'))
452    self.assertTrue(recursive_proto.bb.a.HasField('bb'))
453    self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
454    self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
455    self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a'))
456    self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
457
458  def testSingularListFields(self):
459    proto = unittest_pb2.TestAllTypes()
460    proto.optional_fixed32 = 1
461    proto.optional_int32 = 5
462    proto.optional_string = 'foo'
463    # Access sub-message but don't set it yet.
464    nested_message = proto.optional_nested_message
465    self.assertEqual(
466      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
467        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
468        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
469      proto.ListFields())
470
471    proto.optional_nested_message.bb = 123
472    self.assertEqual(
473      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
474        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
475        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
476        (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
477             nested_message) ],
478      proto.ListFields())
479
480  def testRepeatedListFields(self):
481    proto = unittest_pb2.TestAllTypes()
482    proto.repeated_fixed32.append(1)
483    proto.repeated_int32.append(5)
484    proto.repeated_int32.append(11)
485    proto.repeated_string.extend(['foo', 'bar'])
486    proto.repeated_string.extend([])
487    proto.repeated_string.append('baz')
488    proto.repeated_string.extend(str(x) for x in range(2))
489    proto.optional_int32 = 21
490    proto.repeated_bool  # Access but don't set anything; should not be listed.
491    self.assertEqual(
492      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 21),
493        (proto.DESCRIPTOR.fields_by_name['repeated_int32'  ], [5, 11]),
494        (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
495        (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
496          ['foo', 'bar', 'baz', '0', '1']) ],
497      proto.ListFields())
498
499  def testSingularListExtensions(self):
500    proto = unittest_pb2.TestAllExtensions()
501    proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
502    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 5
503    proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
504    self.assertEqual(
505      [ (unittest_pb2.optional_int32_extension  , 5),
506        (unittest_pb2.optional_fixed32_extension, 1),
507        (unittest_pb2.optional_string_extension , 'foo') ],
508      proto.ListFields())
509
510  def testRepeatedListExtensions(self):
511    proto = unittest_pb2.TestAllExtensions()
512    proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
513    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(5)
514    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(11)
515    proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
516    proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
517    proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
518    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 21
519    self.assertEqual(
520      [ (unittest_pb2.optional_int32_extension  , 21),
521        (unittest_pb2.repeated_int32_extension  , [5, 11]),
522        (unittest_pb2.repeated_fixed32_extension, [1]),
523        (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
524      proto.ListFields())
525
526  def testListFieldsAndExtensions(self):
527    proto = unittest_pb2.TestFieldOrderings()
528    test_util.SetAllFieldsAndExtensions(proto)
529    unittest_pb2.my_extension_int
530    self.assertEqual(
531      [ (proto.DESCRIPTOR.fields_by_name['my_int'   ], 1),
532        (unittest_pb2.my_extension_int               , 23),
533        (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
534        (unittest_pb2.my_extension_string            , 'bar'),
535        (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
536      proto.ListFields())
537
538  def testDefaultValues(self):
539    proto = unittest_pb2.TestAllTypes()
540    self.assertEqual(0, proto.optional_int32)
541    self.assertEqual(0, proto.optional_int64)
542    self.assertEqual(0, proto.optional_uint32)
543    self.assertEqual(0, proto.optional_uint64)
544    self.assertEqual(0, proto.optional_sint32)
545    self.assertEqual(0, proto.optional_sint64)
546    self.assertEqual(0, proto.optional_fixed32)
547    self.assertEqual(0, proto.optional_fixed64)
548    self.assertEqual(0, proto.optional_sfixed32)
549    self.assertEqual(0, proto.optional_sfixed64)
550    self.assertEqual(0.0, proto.optional_float)
551    self.assertEqual(0.0, proto.optional_double)
552    self.assertEqual(False, proto.optional_bool)
553    self.assertEqual('', proto.optional_string)
554    self.assertEqual(b'', proto.optional_bytes)
555
556    self.assertEqual(41, proto.default_int32)
557    self.assertEqual(42, proto.default_int64)
558    self.assertEqual(43, proto.default_uint32)
559    self.assertEqual(44, proto.default_uint64)
560    self.assertEqual(-45, proto.default_sint32)
561    self.assertEqual(46, proto.default_sint64)
562    self.assertEqual(47, proto.default_fixed32)
563    self.assertEqual(48, proto.default_fixed64)
564    self.assertEqual(49, proto.default_sfixed32)
565    self.assertEqual(-50, proto.default_sfixed64)
566    self.assertEqual(51.5, proto.default_float)
567    self.assertEqual(52e3, proto.default_double)
568    self.assertEqual(True, proto.default_bool)
569    self.assertEqual('hello', proto.default_string)
570    self.assertEqual(b'world', proto.default_bytes)
571    self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
572    self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
573    self.assertEqual(unittest_import_pb2.IMPORT_BAR,
574                     proto.default_import_enum)
575
576    proto = unittest_pb2.TestExtremeDefaultValues()
577    self.assertEqual(u'\u1234', proto.utf8_string)
578
579  def testHasFieldWithUnknownFieldName(self):
580    proto = unittest_pb2.TestAllTypes()
581    self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
582
583  def testClearFieldWithUnknownFieldName(self):
584    proto = unittest_pb2.TestAllTypes()
585    self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
586
587  def testClearRemovesChildren(self):
588    # Make sure there aren't any implementation bugs that are only partially
589    # clearing the message (which can happen in the more complex C++
590    # implementation which has parallel message lists).
591    proto = unittest_pb2.TestRequiredForeign()
592    for i in range(10):
593      proto.repeated_message.add()
594    proto2 = unittest_pb2.TestRequiredForeign()
595    proto.CopyFrom(proto2)
596    self.assertRaises(IndexError, lambda: proto.repeated_message[5])
597
598  def testDisallowedAssignments(self):
599    # It's illegal to assign values directly to repeated fields
600    # or to nonrepeated composite fields.  Ensure that this fails.
601    proto = unittest_pb2.TestAllTypes()
602    # Repeated fields.
603    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
604    # Lists shouldn't work, either.
605    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
606    # Composite fields.
607    self.assertRaises(AttributeError, setattr, proto,
608                      'optional_nested_message', 23)
609    # Assignment to a repeated nested message field without specifying
610    # the index in the array of nested messages.
611    self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
612                      'bb', 34)
613    # Assignment to an attribute of a repeated field.
614    self.assertRaises(AttributeError, setattr, proto.repeated_float,
615                      'some_attribute', 34)
616    # proto.nonexistent_field = 23 should fail as well.
617    self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
618
619  def testSingleScalarTypeSafety(self):
620    proto = unittest_pb2.TestAllTypes()
621    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
622    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
623    self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
624    self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
625    self.assertRaises(TypeError, setattr, proto, 'optional_bool', 'foo')
626    self.assertRaises(TypeError, setattr, proto, 'optional_float', 'foo')
627    self.assertRaises(TypeError, setattr, proto, 'optional_double', 'foo')
628    # TODO(jieluo): Fix type checking difference for python and c extension
629    if api_implementation.Type() == 'python':
630      self.assertRaises(TypeError, setattr, proto, 'optional_bool', 1.1)
631    else:
632      proto.optional_bool = 1.1
633
634  def assertIntegerTypes(self, integer_fn):
635    """Verifies setting of scalar integers.
636
637    Args:
638      integer_fn: A function to wrap the integers that will be assigned.
639    """
640    def TestGetAndDeserialize(field_name, value, expected_type):
641      proto = unittest_pb2.TestAllTypes()
642      value = integer_fn(value)
643      setattr(proto, field_name, value)
644      self.assertIsInstance(getattr(proto, field_name), expected_type)
645      proto2 = unittest_pb2.TestAllTypes()
646      proto2.ParseFromString(proto.SerializeToString())
647      self.assertIsInstance(getattr(proto2, field_name), expected_type)
648
649    TestGetAndDeserialize('optional_int32', 1, int)
650    TestGetAndDeserialize('optional_int32', 1 << 30, int)
651    TestGetAndDeserialize('optional_uint32', 1 << 30, int)
652    integer_64 = long
653    if struct.calcsize('L') == 4:
654      # Python only has signed ints, so 32-bit python can't fit an uint32
655      # in an int.
656      TestGetAndDeserialize('optional_uint32', 1 << 31, integer_64)
657    else:
658      # 64-bit python can fit uint32 inside an int
659      TestGetAndDeserialize('optional_uint32', 1 << 31, int)
660    TestGetAndDeserialize('optional_int64', 1 << 30, integer_64)
661    TestGetAndDeserialize('optional_int64', 1 << 60, integer_64)
662    TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
663    TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
664
665  def testIntegerTypes(self):
666    self.assertIntegerTypes(lambda x: x)
667
668  def testNonStandardIntegerTypes(self):
669    self.assertIntegerTypes(test_util.NonStandardInteger)
670
671  def testIllegalValuesForIntegers(self):
672    pb = unittest_pb2.TestAllTypes()
673
674    # Strings are illegal, even when the represent an integer.
675    with self.assertRaises(TypeError):
676      pb.optional_uint64 = '2'
677
678    # The exact error should propagate with a poorly written custom integer.
679    with self.assertRaisesRegexp(RuntimeError, 'my_error'):
680      pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error')
681
682  def assetIntegerBoundsChecking(self, integer_fn):
683    """Verifies bounds checking for scalar integer fields.
684
685    Args:
686      integer_fn: A function to wrap the integers that will be assigned.
687    """
688    def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
689      pb = unittest_pb2.TestAllTypes()
690      expected_min = integer_fn(expected_min)
691      expected_max = integer_fn(expected_max)
692      setattr(pb, field_name, expected_min)
693      self.assertEqual(expected_min, getattr(pb, field_name))
694      setattr(pb, field_name, expected_max)
695      self.assertEqual(expected_max, getattr(pb, field_name))
696      self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
697                        expected_min - 1)
698      self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
699                        expected_max + 1)
700
701    TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
702    TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
703    TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
704    TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
705    # A bit of white-box testing since -1 is an int and not a long in C++ and
706    # so goes down a different path.
707    pb = unittest_pb2.TestAllTypes()
708    with self.assertRaises((ValueError, TypeError)):
709      pb.optional_uint64 = integer_fn(-(1 << 63))
710
711    pb = unittest_pb2.TestAllTypes()
712    pb.optional_nested_enum = integer_fn(1)
713    self.assertEqual(1, pb.optional_nested_enum)
714
715  def testSingleScalarBoundsChecking(self):
716    self.assetIntegerBoundsChecking(lambda x: x)
717
718  def testNonStandardSingleScalarBoundsChecking(self):
719    self.assetIntegerBoundsChecking(test_util.NonStandardInteger)
720
721  def testRepeatedScalarTypeSafety(self):
722    proto = unittest_pb2.TestAllTypes()
723    self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
724    self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
725    self.assertRaises(TypeError, proto.repeated_string, 10)
726    self.assertRaises(TypeError, proto.repeated_bytes, 10)
727
728    proto.repeated_int32.append(10)
729    proto.repeated_int32[0] = 23
730    self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
731    self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
732    self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, [])
733    self.assertRaises(TypeError, proto.repeated_int32.__setitem__,
734                      'index', 23)
735
736    proto.repeated_string.append('2')
737    self.assertRaises(TypeError, proto.repeated_string.__setitem__, 0, 10)
738
739    # Repeated enums tests.
740    #proto.repeated_nested_enum.append(0)
741
742  def testSingleScalarGettersAndSetters(self):
743    proto = unittest_pb2.TestAllTypes()
744    self.assertEqual(0, proto.optional_int32)
745    proto.optional_int32 = 1
746    self.assertEqual(1, proto.optional_int32)
747
748    proto.optional_uint64 = 0xffffffffffff
749    self.assertEqual(0xffffffffffff, proto.optional_uint64)
750    proto.optional_uint64 = 0xffffffffffffffff
751    self.assertEqual(0xffffffffffffffff, proto.optional_uint64)
752    # TODO(robinson): Test all other scalar field types.
753
754  def testSingleScalarClearField(self):
755    proto = unittest_pb2.TestAllTypes()
756    # Should be allowed to clear something that's not there (a no-op).
757    proto.ClearField('optional_int32')
758    proto.optional_int32 = 1
759    self.assertTrue(proto.HasField('optional_int32'))
760    proto.ClearField('optional_int32')
761    self.assertEqual(0, proto.optional_int32)
762    self.assertTrue(not proto.HasField('optional_int32'))
763    # TODO(robinson): Test all other scalar field types.
764
765  def testEnums(self):
766    proto = unittest_pb2.TestAllTypes()
767    self.assertEqual(1, proto.FOO)
768    self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
769    self.assertEqual(2, proto.BAR)
770    self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
771    self.assertEqual(3, proto.BAZ)
772    self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
773
774  def testEnum_Name(self):
775    self.assertEqual('FOREIGN_FOO',
776                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO))
777    self.assertEqual('FOREIGN_BAR',
778                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR))
779    self.assertEqual('FOREIGN_BAZ',
780                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ))
781    self.assertRaises(ValueError,
782                      unittest_pb2.ForeignEnum.Name, 11312)
783
784    proto = unittest_pb2.TestAllTypes()
785    self.assertEqual('FOO',
786                     proto.NestedEnum.Name(proto.FOO))
787    self.assertEqual('FOO',
788                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO))
789    self.assertEqual('BAR',
790                     proto.NestedEnum.Name(proto.BAR))
791    self.assertEqual('BAR',
792                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR))
793    self.assertEqual('BAZ',
794                     proto.NestedEnum.Name(proto.BAZ))
795    self.assertEqual('BAZ',
796                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ))
797    self.assertRaises(ValueError,
798                      proto.NestedEnum.Name, 11312)
799    self.assertRaises(ValueError,
800                      unittest_pb2.TestAllTypes.NestedEnum.Name, 11312)
801
802  def testEnum_Value(self):
803    self.assertEqual(unittest_pb2.FOREIGN_FOO,
804                     unittest_pb2.ForeignEnum.Value('FOREIGN_FOO'))
805    self.assertEqual(unittest_pb2.FOREIGN_FOO,
806                     unittest_pb2.ForeignEnum.FOREIGN_FOO)
807
808    self.assertEqual(unittest_pb2.FOREIGN_BAR,
809                     unittest_pb2.ForeignEnum.Value('FOREIGN_BAR'))
810    self.assertEqual(unittest_pb2.FOREIGN_BAR,
811                     unittest_pb2.ForeignEnum.FOREIGN_BAR)
812
813    self.assertEqual(unittest_pb2.FOREIGN_BAZ,
814                     unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ'))
815    self.assertEqual(unittest_pb2.FOREIGN_BAZ,
816                     unittest_pb2.ForeignEnum.FOREIGN_BAZ)
817
818    self.assertRaises(ValueError,
819                      unittest_pb2.ForeignEnum.Value, 'FO')
820    with self.assertRaises(AttributeError):
821      unittest_pb2.ForeignEnum.FO
822
823    proto = unittest_pb2.TestAllTypes()
824    self.assertEqual(proto.FOO,
825                     proto.NestedEnum.Value('FOO'))
826    self.assertEqual(proto.FOO,
827                     proto.NestedEnum.FOO)
828
829    self.assertEqual(proto.FOO,
830                     unittest_pb2.TestAllTypes.NestedEnum.Value('FOO'))
831    self.assertEqual(proto.FOO,
832                     unittest_pb2.TestAllTypes.NestedEnum.FOO)
833
834    self.assertEqual(proto.BAR,
835                     proto.NestedEnum.Value('BAR'))
836    self.assertEqual(proto.BAR,
837                     proto.NestedEnum.BAR)
838
839    self.assertEqual(proto.BAR,
840                     unittest_pb2.TestAllTypes.NestedEnum.Value('BAR'))
841    self.assertEqual(proto.BAR,
842                     unittest_pb2.TestAllTypes.NestedEnum.BAR)
843
844    self.assertEqual(proto.BAZ,
845                     proto.NestedEnum.Value('BAZ'))
846    self.assertEqual(proto.BAZ,
847                     proto.NestedEnum.BAZ)
848
849    self.assertEqual(proto.BAZ,
850                     unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ'))
851    self.assertEqual(proto.BAZ,
852                     unittest_pb2.TestAllTypes.NestedEnum.BAZ)
853
854    self.assertRaises(ValueError,
855                      proto.NestedEnum.Value, 'Foo')
856    with self.assertRaises(AttributeError):
857      proto.NestedEnum.Value.Foo
858
859    self.assertRaises(ValueError,
860                      unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo')
861    with self.assertRaises(AttributeError):
862      unittest_pb2.TestAllTypes.NestedEnum.Value.Foo
863
864  def testEnum_KeysAndValues(self):
865    self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'],
866                     list(unittest_pb2.ForeignEnum.keys()))
867    self.assertEqual([4, 5, 6],
868                     list(unittest_pb2.ForeignEnum.values()))
869    self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5),
870                      ('FOREIGN_BAZ', 6)],
871                     list(unittest_pb2.ForeignEnum.items()))
872
873    proto = unittest_pb2.TestAllTypes()
874    self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], list(proto.NestedEnum.keys()))
875    self.assertEqual([1, 2, 3, -1], list(proto.NestedEnum.values()))
876    self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)],
877                     list(proto.NestedEnum.items()))
878
879  def testRepeatedScalars(self):
880    proto = unittest_pb2.TestAllTypes()
881
882    self.assertTrue(not proto.repeated_int32)
883    self.assertEqual(0, len(proto.repeated_int32))
884    proto.repeated_int32.append(5)
885    proto.repeated_int32.append(10)
886    proto.repeated_int32.append(15)
887    self.assertTrue(proto.repeated_int32)
888    self.assertEqual(3, len(proto.repeated_int32))
889
890    self.assertEqual([5, 10, 15], proto.repeated_int32)
891
892    # Test single retrieval.
893    self.assertEqual(5, proto.repeated_int32[0])
894    self.assertEqual(15, proto.repeated_int32[-1])
895    # Test out-of-bounds indices.
896    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
897    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
898    # Test incorrect types passed to __getitem__.
899    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
900    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
901
902    # Test single assignment.
903    proto.repeated_int32[1] = 20
904    self.assertEqual([5, 20, 15], proto.repeated_int32)
905
906    # Test insertion.
907    proto.repeated_int32.insert(1, 25)
908    self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
909
910    # Test slice retrieval.
911    proto.repeated_int32.append(30)
912    self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
913    self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
914
915    # Test slice assignment with an iterator
916    proto.repeated_int32[1:4] = (i for i in range(3))
917    self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
918
919    # Test slice assignment.
920    proto.repeated_int32[1:4] = [35, 40, 45]
921    self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
922
923    # Test that we can use the field as an iterator.
924    result = []
925    for i in proto.repeated_int32:
926      result.append(i)
927    self.assertEqual([5, 35, 40, 45, 30], result)
928
929    # Test single deletion.
930    del proto.repeated_int32[2]
931    self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
932
933    # Test slice deletion.
934    del proto.repeated_int32[2:]
935    self.assertEqual([5, 35], proto.repeated_int32)
936
937    # Test extending.
938    proto.repeated_int32.extend([3, 13])
939    self.assertEqual([5, 35, 3, 13], proto.repeated_int32)
940
941    # Test clearing.
942    proto.ClearField('repeated_int32')
943    self.assertTrue(not proto.repeated_int32)
944    self.assertEqual(0, len(proto.repeated_int32))
945
946    proto.repeated_int32.append(1)
947    self.assertEqual(1, proto.repeated_int32[-1])
948    # Test assignment to a negative index.
949    proto.repeated_int32[-1] = 2
950    self.assertEqual(2, proto.repeated_int32[-1])
951
952    # Test deletion at negative indices.
953    proto.repeated_int32[:] = [0, 1, 2, 3]
954    del proto.repeated_int32[-1]
955    self.assertEqual([0, 1, 2], proto.repeated_int32)
956
957    del proto.repeated_int32[-2]
958    self.assertEqual([0, 2], proto.repeated_int32)
959
960    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3)
961    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300)
962
963    del proto.repeated_int32[-2:-1]
964    self.assertEqual([2], proto.repeated_int32)
965
966    del proto.repeated_int32[100:10000]
967    self.assertEqual([2], proto.repeated_int32)
968
969  def testRepeatedScalarsRemove(self):
970    proto = unittest_pb2.TestAllTypes()
971
972    self.assertTrue(not proto.repeated_int32)
973    self.assertEqual(0, len(proto.repeated_int32))
974    proto.repeated_int32.append(5)
975    proto.repeated_int32.append(10)
976    proto.repeated_int32.append(5)
977    proto.repeated_int32.append(5)
978
979    self.assertEqual(4, len(proto.repeated_int32))
980    proto.repeated_int32.remove(5)
981    self.assertEqual(3, len(proto.repeated_int32))
982    self.assertEqual(10, proto.repeated_int32[0])
983    self.assertEqual(5, proto.repeated_int32[1])
984    self.assertEqual(5, proto.repeated_int32[2])
985
986    proto.repeated_int32.remove(5)
987    self.assertEqual(2, len(proto.repeated_int32))
988    self.assertEqual(10, proto.repeated_int32[0])
989    self.assertEqual(5, proto.repeated_int32[1])
990
991    proto.repeated_int32.remove(10)
992    self.assertEqual(1, len(proto.repeated_int32))
993    self.assertEqual(5, proto.repeated_int32[0])
994
995    # Remove a non-existent element.
996    self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
997
998  def testRepeatedComposites(self):
999    proto = unittest_pb2.TestAllTypes()
1000    self.assertTrue(not proto.repeated_nested_message)
1001    self.assertEqual(0, len(proto.repeated_nested_message))
1002    m0 = proto.repeated_nested_message.add()
1003    m1 = proto.repeated_nested_message.add()
1004    self.assertTrue(proto.repeated_nested_message)
1005    self.assertEqual(2, len(proto.repeated_nested_message))
1006    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
1007    self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage)
1008
1009    # Test out-of-bounds indices.
1010    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
1011                      1234)
1012    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
1013                      -1234)
1014
1015    # Test incorrect types passed to __getitem__.
1016    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
1017                      'foo')
1018    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
1019                      None)
1020
1021    # Test slice retrieval.
1022    m2 = proto.repeated_nested_message.add()
1023    m3 = proto.repeated_nested_message.add()
1024    m4 = proto.repeated_nested_message.add()
1025    self.assertListsEqual(
1026        [m1, m2, m3], proto.repeated_nested_message[1:4])
1027    self.assertListsEqual(
1028        [m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
1029    self.assertListsEqual(
1030        [m0, m1], proto.repeated_nested_message[:2])
1031    self.assertListsEqual(
1032        [m2, m3, m4], proto.repeated_nested_message[2:])
1033    self.assertEqual(
1034        m0, proto.repeated_nested_message[0])
1035    self.assertListsEqual(
1036        [m0], proto.repeated_nested_message[:1])
1037
1038    # Test that we can use the field as an iterator.
1039    result = []
1040    for i in proto.repeated_nested_message:
1041      result.append(i)
1042    self.assertListsEqual([m0, m1, m2, m3, m4], result)
1043
1044    # Test single deletion.
1045    del proto.repeated_nested_message[2]
1046    self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message)
1047
1048    # Test slice deletion.
1049    del proto.repeated_nested_message[2:]
1050    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
1051
1052    # Test extending.
1053    n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1)
1054    n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2)
1055    proto.repeated_nested_message.extend([n1,n2])
1056    self.assertEqual(4, len(proto.repeated_nested_message))
1057    self.assertEqual(n1, proto.repeated_nested_message[2])
1058    self.assertEqual(n2, proto.repeated_nested_message[3])
1059    self.assertRaises(TypeError,
1060                      proto.repeated_nested_message.extend, n1)
1061    self.assertRaises(TypeError,
1062                      proto.repeated_nested_message.extend, [0])
1063    wrong_message_type = unittest_pb2.TestAllTypes()
1064    self.assertRaises(TypeError,
1065                      proto.repeated_nested_message.extend,
1066                      [wrong_message_type])
1067
1068    # Test clearing.
1069    proto.ClearField('repeated_nested_message')
1070    self.assertTrue(not proto.repeated_nested_message)
1071    self.assertEqual(0, len(proto.repeated_nested_message))
1072
1073    # Test constructing an element while adding it.
1074    proto.repeated_nested_message.add(bb=23)
1075    self.assertEqual(1, len(proto.repeated_nested_message))
1076    self.assertEqual(23, proto.repeated_nested_message[0].bb)
1077    self.assertRaises(TypeError, proto.repeated_nested_message.add, 23)
1078    with self.assertRaises(Exception):
1079      proto.repeated_nested_message[0] = 23
1080
1081  def testRepeatedCompositeRemove(self):
1082    proto = unittest_pb2.TestAllTypes()
1083
1084    self.assertEqual(0, len(proto.repeated_nested_message))
1085    m0 = proto.repeated_nested_message.add()
1086    # Need to set some differentiating variable so m0 != m1 != m2:
1087    m0.bb = len(proto.repeated_nested_message)
1088    m1 = proto.repeated_nested_message.add()
1089    m1.bb = len(proto.repeated_nested_message)
1090    self.assertTrue(m0 != m1)
1091    m2 = proto.repeated_nested_message.add()
1092    m2.bb = len(proto.repeated_nested_message)
1093    self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
1094
1095    self.assertEqual(3, len(proto.repeated_nested_message))
1096    proto.repeated_nested_message.remove(m0)
1097    self.assertEqual(2, len(proto.repeated_nested_message))
1098    self.assertEqual(m1, proto.repeated_nested_message[0])
1099    self.assertEqual(m2, proto.repeated_nested_message[1])
1100
1101    # Removing m0 again or removing None should raise error
1102    self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0)
1103    self.assertRaises(ValueError, proto.repeated_nested_message.remove, None)
1104    self.assertEqual(2, len(proto.repeated_nested_message))
1105
1106    proto.repeated_nested_message.remove(m2)
1107    self.assertEqual(1, len(proto.repeated_nested_message))
1108    self.assertEqual(m1, proto.repeated_nested_message[0])
1109
1110  def testHandWrittenReflection(self):
1111    # Hand written extensions are only supported by the pure-Python
1112    # implementation of the API.
1113    if api_implementation.Type() != 'python':
1114      return
1115
1116    FieldDescriptor = descriptor.FieldDescriptor
1117    foo_field_descriptor = FieldDescriptor(
1118        name='foo_field', full_name='MyProto.foo_field',
1119        index=0, number=1, type=FieldDescriptor.TYPE_INT64,
1120        cpp_type=FieldDescriptor.CPPTYPE_INT64,
1121        label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
1122        containing_type=None, message_type=None, enum_type=None,
1123        is_extension=False, extension_scope=None,
1124        options=descriptor_pb2.FieldOptions())
1125    mydescriptor = descriptor.Descriptor(
1126        name='MyProto', full_name='MyProto', filename='ignored',
1127        containing_type=None, nested_types=[], enum_types=[],
1128        fields=[foo_field_descriptor], extensions=[],
1129        options=descriptor_pb2.MessageOptions())
1130    class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
1131      DESCRIPTOR = mydescriptor
1132    myproto_instance = MyProtoClass()
1133    self.assertEqual(0, myproto_instance.foo_field)
1134    self.assertTrue(not myproto_instance.HasField('foo_field'))
1135    myproto_instance.foo_field = 23
1136    self.assertEqual(23, myproto_instance.foo_field)
1137    self.assertTrue(myproto_instance.HasField('foo_field'))
1138
1139  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
1140  def testDescriptorProtoSupport(self):
1141    # Hand written descriptors/reflection are only supported by the pure-Python
1142    # implementation of the API.
1143    if api_implementation.Type() != 'python':
1144      return
1145
1146    def AddDescriptorField(proto, field_name, field_type):
1147      AddDescriptorField.field_index += 1
1148      new_field = proto.field.add()
1149      new_field.name = field_name
1150      new_field.type = field_type
1151      new_field.number = AddDescriptorField.field_index
1152      new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
1153
1154    AddDescriptorField.field_index = 0
1155
1156    desc_proto = descriptor_pb2.DescriptorProto()
1157    desc_proto.name = 'Car'
1158    fdp = descriptor_pb2.FieldDescriptorProto
1159    AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING)
1160    AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64)
1161    AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL)
1162    AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE)
1163    # Add a repeated field
1164    AddDescriptorField.field_index += 1
1165    new_field = desc_proto.field.add()
1166    new_field.name = 'owners'
1167    new_field.type = fdp.TYPE_STRING
1168    new_field.number = AddDescriptorField.field_index
1169    new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
1170
1171    desc = descriptor.MakeDescriptor(desc_proto)
1172    self.assertTrue('name' in desc.fields_by_name)
1173    self.assertTrue('year' in desc.fields_by_name)
1174    self.assertTrue('automatic' in desc.fields_by_name)
1175    self.assertTrue('price' in desc.fields_by_name)
1176    self.assertTrue('owners' in desc.fields_by_name)
1177
1178    class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType,
1179                                        message.Message)):
1180      DESCRIPTOR = desc
1181
1182    prius = CarMessage()
1183    prius.name = 'prius'
1184    prius.year = 2010
1185    prius.automatic = True
1186    prius.price = 25134.75
1187    prius.owners.extend(['bob', 'susan'])
1188
1189    serialized_prius = prius.SerializeToString()
1190    new_prius = reflection.ParseMessage(desc, serialized_prius)
1191    self.assertTrue(new_prius is not prius)
1192    self.assertEqual(prius, new_prius)
1193
1194    # these are unnecessary assuming message equality works as advertised but
1195    # explicitly check to be safe since we're mucking about in metaclass foo
1196    self.assertEqual(prius.name, new_prius.name)
1197    self.assertEqual(prius.year, new_prius.year)
1198    self.assertEqual(prius.automatic, new_prius.automatic)
1199    self.assertEqual(prius.price, new_prius.price)
1200    self.assertEqual(prius.owners, new_prius.owners)
1201
1202  def testExtensionIter(self):
1203    extendee_proto = more_extensions_pb2.ExtendedMessage()
1204
1205    extension_int32 = more_extensions_pb2.optional_int_extension
1206    extendee_proto.Extensions[extension_int32] = 23
1207
1208    extension_repeated = more_extensions_pb2.repeated_int_extension
1209    extendee_proto.Extensions[extension_repeated].append(11)
1210
1211    extension_msg = more_extensions_pb2.optional_message_extension
1212    extendee_proto.Extensions[extension_msg].foreign_message_int = 56
1213
1214    # Set some normal fields.
1215    extendee_proto.optional_int32 = 1
1216    extendee_proto.repeated_string.append('hi')
1217
1218    expected = (extension_int32, extension_msg, extension_repeated)
1219    count = 0
1220    for item in extendee_proto.Extensions:
1221      self.assertEqual(item.name, expected[count].name)
1222      self.assertIn(item, extendee_proto.Extensions)
1223      count += 1
1224    self.assertEqual(count, 3)
1225
1226  def testExtensionContainsError(self):
1227    extendee_proto = more_extensions_pb2.ExtendedMessage()
1228    self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, 0)
1229
1230    field = more_extensions_pb2.ExtendedMessage.DESCRIPTOR.fields_by_name[
1231        'optional_int32']
1232    self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, field)
1233
1234  def testTopLevelExtensionsForOptionalScalar(self):
1235    extendee_proto = unittest_pb2.TestAllExtensions()
1236    extension = unittest_pb2.optional_int32_extension
1237    self.assertTrue(not extendee_proto.HasExtension(extension))
1238    self.assertNotIn(extension, extendee_proto.Extensions)
1239    self.assertEqual(0, extendee_proto.Extensions[extension])
1240    # As with normal scalar fields, just doing a read doesn't actually set the
1241    # "has" bit.
1242    self.assertTrue(not extendee_proto.HasExtension(extension))
1243    self.assertNotIn(extension, extendee_proto.Extensions)
1244    # Actually set the thing.
1245    extendee_proto.Extensions[extension] = 23
1246    self.assertEqual(23, extendee_proto.Extensions[extension])
1247    self.assertTrue(extendee_proto.HasExtension(extension))
1248    self.assertIn(extension, extendee_proto.Extensions)
1249    # Ensure that clearing works as well.
1250    extendee_proto.ClearExtension(extension)
1251    self.assertEqual(0, extendee_proto.Extensions[extension])
1252    self.assertTrue(not extendee_proto.HasExtension(extension))
1253    self.assertNotIn(extension, extendee_proto.Extensions)
1254
1255  def testTopLevelExtensionsForRepeatedScalar(self):
1256    extendee_proto = unittest_pb2.TestAllExtensions()
1257    extension = unittest_pb2.repeated_string_extension
1258    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1259    self.assertNotIn(extension, extendee_proto.Extensions)
1260    extendee_proto.Extensions[extension].append('foo')
1261    self.assertEqual(['foo'], extendee_proto.Extensions[extension])
1262    self.assertIn(extension, extendee_proto.Extensions)
1263    string_list = extendee_proto.Extensions[extension]
1264    extendee_proto.ClearExtension(extension)
1265    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1266    self.assertNotIn(extension, extendee_proto.Extensions)
1267    self.assertTrue(string_list is not extendee_proto.Extensions[extension])
1268    # Shouldn't be allowed to do Extensions[extension] = 'a'
1269    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1270                      extension, 'a')
1271
1272  def testTopLevelExtensionsForOptionalMessage(self):
1273    extendee_proto = unittest_pb2.TestAllExtensions()
1274    extension = unittest_pb2.optional_foreign_message_extension
1275    self.assertTrue(not extendee_proto.HasExtension(extension))
1276    self.assertNotIn(extension, extendee_proto.Extensions)
1277    self.assertEqual(0, extendee_proto.Extensions[extension].c)
1278    # As with normal (non-extension) fields, merely reading from the
1279    # thing shouldn't set the "has" bit.
1280    self.assertTrue(not extendee_proto.HasExtension(extension))
1281    self.assertNotIn(extension, extendee_proto.Extensions)
1282    extendee_proto.Extensions[extension].c = 23
1283    self.assertEqual(23, extendee_proto.Extensions[extension].c)
1284    self.assertTrue(extendee_proto.HasExtension(extension))
1285    self.assertIn(extension, extendee_proto.Extensions)
1286    # Save a reference here.
1287    foreign_message = extendee_proto.Extensions[extension]
1288    extendee_proto.ClearExtension(extension)
1289    self.assertTrue(foreign_message is not extendee_proto.Extensions[extension])
1290    # Setting a field on foreign_message now shouldn't set
1291    # any "has" bits on extendee_proto.
1292    foreign_message.c = 42
1293    self.assertEqual(42, foreign_message.c)
1294    self.assertTrue(foreign_message.HasField('c'))
1295    self.assertTrue(not extendee_proto.HasExtension(extension))
1296    self.assertNotIn(extension, extendee_proto.Extensions)
1297    # Shouldn't be allowed to do Extensions[extension] = 'a'
1298    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1299                      extension, 'a')
1300
1301  def testTopLevelExtensionsForRepeatedMessage(self):
1302    extendee_proto = unittest_pb2.TestAllExtensions()
1303    extension = unittest_pb2.repeatedgroup_extension
1304    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1305    group = extendee_proto.Extensions[extension].add()
1306    group.a = 23
1307    self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
1308    group.a = 42
1309    self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
1310    group_list = extendee_proto.Extensions[extension]
1311    extendee_proto.ClearExtension(extension)
1312    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1313    self.assertTrue(group_list is not extendee_proto.Extensions[extension])
1314    # Shouldn't be allowed to do Extensions[extension] = 'a'
1315    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1316                      extension, 'a')
1317
1318  def testNestedExtensions(self):
1319    extendee_proto = unittest_pb2.TestAllExtensions()
1320    extension = unittest_pb2.TestRequired.single
1321
1322    # We just test the non-repeated case.
1323    self.assertTrue(not extendee_proto.HasExtension(extension))
1324    self.assertNotIn(extension, extendee_proto.Extensions)
1325    required = extendee_proto.Extensions[extension]
1326    self.assertEqual(0, required.a)
1327    self.assertTrue(not extendee_proto.HasExtension(extension))
1328    self.assertNotIn(extension, extendee_proto.Extensions)
1329    required.a = 23
1330    self.assertEqual(23, extendee_proto.Extensions[extension].a)
1331    self.assertTrue(extendee_proto.HasExtension(extension))
1332    self.assertIn(extension, extendee_proto.Extensions)
1333    extendee_proto.ClearExtension(extension)
1334    self.assertTrue(required is not extendee_proto.Extensions[extension])
1335    self.assertTrue(not extendee_proto.HasExtension(extension))
1336    self.assertNotIn(extension, extendee_proto.Extensions)
1337
1338  def testRegisteredExtensions(self):
1339    pool = unittest_pb2.DESCRIPTOR.pool
1340    self.assertTrue(
1341        pool.FindExtensionByNumber(
1342            unittest_pb2.TestAllExtensions.DESCRIPTOR, 1))
1343    self.assertIs(
1344        pool.FindExtensionByName(
1345            'protobuf_unittest.optional_int32_extension').containing_type,
1346        unittest_pb2.TestAllExtensions.DESCRIPTOR)
1347    # Make sure extensions haven't been registered into types that shouldn't
1348    # have any.
1349    self.assertEqual(0, len(
1350        pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR)))
1351
1352  # If message A directly contains message B, and
1353  # a.HasField('b') is currently False, then mutating any
1354  # extension in B should change a.HasField('b') to True
1355  # (and so on up the object tree).
1356  def testHasBitsForAncestorsOfExtendedMessage(self):
1357    # Optional scalar extension.
1358    toplevel = more_extensions_pb2.TopLevelMessage()
1359    self.assertTrue(not toplevel.HasField('submessage'))
1360    self.assertEqual(0, toplevel.submessage.Extensions[
1361        more_extensions_pb2.optional_int_extension])
1362    self.assertTrue(not toplevel.HasField('submessage'))
1363    toplevel.submessage.Extensions[
1364        more_extensions_pb2.optional_int_extension] = 23
1365    self.assertEqual(23, toplevel.submessage.Extensions[
1366        more_extensions_pb2.optional_int_extension])
1367    self.assertTrue(toplevel.HasField('submessage'))
1368
1369    # Repeated scalar extension.
1370    toplevel = more_extensions_pb2.TopLevelMessage()
1371    self.assertTrue(not toplevel.HasField('submessage'))
1372    self.assertEqual([], toplevel.submessage.Extensions[
1373        more_extensions_pb2.repeated_int_extension])
1374    self.assertTrue(not toplevel.HasField('submessage'))
1375    toplevel.submessage.Extensions[
1376        more_extensions_pb2.repeated_int_extension].append(23)
1377    self.assertEqual([23], toplevel.submessage.Extensions[
1378        more_extensions_pb2.repeated_int_extension])
1379    self.assertTrue(toplevel.HasField('submessage'))
1380
1381    # Optional message extension.
1382    toplevel = more_extensions_pb2.TopLevelMessage()
1383    self.assertTrue(not toplevel.HasField('submessage'))
1384    self.assertEqual(0, toplevel.submessage.Extensions[
1385        more_extensions_pb2.optional_message_extension].foreign_message_int)
1386    self.assertTrue(not toplevel.HasField('submessage'))
1387    toplevel.submessage.Extensions[
1388        more_extensions_pb2.optional_message_extension].foreign_message_int = 23
1389    self.assertEqual(23, toplevel.submessage.Extensions[
1390        more_extensions_pb2.optional_message_extension].foreign_message_int)
1391    self.assertTrue(toplevel.HasField('submessage'))
1392
1393    # Repeated message extension.
1394    toplevel = more_extensions_pb2.TopLevelMessage()
1395    self.assertTrue(not toplevel.HasField('submessage'))
1396    self.assertEqual(0, len(toplevel.submessage.Extensions[
1397        more_extensions_pb2.repeated_message_extension]))
1398    self.assertTrue(not toplevel.HasField('submessage'))
1399    foreign = toplevel.submessage.Extensions[
1400        more_extensions_pb2.repeated_message_extension].add()
1401    self.assertEqual(foreign, toplevel.submessage.Extensions[
1402        more_extensions_pb2.repeated_message_extension][0])
1403    self.assertTrue(toplevel.HasField('submessage'))
1404
1405  def testDisconnectionAfterClearingEmptyMessage(self):
1406    toplevel = more_extensions_pb2.TopLevelMessage()
1407    extendee_proto = toplevel.submessage
1408    extension = more_extensions_pb2.optional_message_extension
1409    extension_proto = extendee_proto.Extensions[extension]
1410    extendee_proto.ClearExtension(extension)
1411    extension_proto.foreign_message_int = 23
1412
1413    self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
1414
1415  def testExtensionFailureModes(self):
1416    extendee_proto = unittest_pb2.TestAllExtensions()
1417
1418    # Try non-extension-handle arguments to HasExtension,
1419    # ClearExtension(), and Extensions[]...
1420    self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
1421    self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
1422    self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
1423    self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
1424
1425    # Try something that *is* an extension handle, just not for
1426    # this message...
1427    for unknown_handle in (more_extensions_pb2.optional_int_extension,
1428                           more_extensions_pb2.optional_message_extension,
1429                           more_extensions_pb2.repeated_int_extension,
1430                           more_extensions_pb2.repeated_message_extension):
1431      self.assertRaises(KeyError, extendee_proto.HasExtension,
1432                        unknown_handle)
1433      self.assertRaises(KeyError, extendee_proto.ClearExtension,
1434                        unknown_handle)
1435      self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
1436                        unknown_handle)
1437      self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
1438                        unknown_handle, 5)
1439
1440    # Try call HasExtension() with a valid handle, but for a
1441    # *repeated* field.  (Just as with non-extension repeated
1442    # fields, Has*() isn't supported for extension repeated fields).
1443    self.assertRaises(KeyError, extendee_proto.HasExtension,
1444                      unittest_pb2.repeated_string_extension)
1445
1446  def testStaticParseFrom(self):
1447    proto1 = unittest_pb2.TestAllTypes()
1448    test_util.SetAllFields(proto1)
1449
1450    string1 = proto1.SerializeToString()
1451    proto2 = unittest_pb2.TestAllTypes.FromString(string1)
1452
1453    # Messages should be equal.
1454    self.assertEqual(proto2, proto1)
1455
1456  def testMergeFromSingularField(self):
1457    # Test merge with just a singular field.
1458    proto1 = unittest_pb2.TestAllTypes()
1459    proto1.optional_int32 = 1
1460
1461    proto2 = unittest_pb2.TestAllTypes()
1462    # This shouldn't get overwritten.
1463    proto2.optional_string = 'value'
1464
1465    proto2.MergeFrom(proto1)
1466    self.assertEqual(1, proto2.optional_int32)
1467    self.assertEqual('value', proto2.optional_string)
1468
1469  def testMergeFromRepeatedField(self):
1470    # Test merge with just a repeated field.
1471    proto1 = unittest_pb2.TestAllTypes()
1472    proto1.repeated_int32.append(1)
1473    proto1.repeated_int32.append(2)
1474
1475    proto2 = unittest_pb2.TestAllTypes()
1476    proto2.repeated_int32.append(0)
1477    proto2.MergeFrom(proto1)
1478
1479    self.assertEqual(0, proto2.repeated_int32[0])
1480    self.assertEqual(1, proto2.repeated_int32[1])
1481    self.assertEqual(2, proto2.repeated_int32[2])
1482
1483  def testMergeFromOptionalGroup(self):
1484    # Test merge with an optional group.
1485    proto1 = unittest_pb2.TestAllTypes()
1486    proto1.optionalgroup.a = 12
1487    proto2 = unittest_pb2.TestAllTypes()
1488    proto2.MergeFrom(proto1)
1489    self.assertEqual(12, proto2.optionalgroup.a)
1490
1491  def testMergeFromRepeatedNestedMessage(self):
1492    # Test merge with a repeated nested message.
1493    proto1 = unittest_pb2.TestAllTypes()
1494    m = proto1.repeated_nested_message.add()
1495    m.bb = 123
1496    m = proto1.repeated_nested_message.add()
1497    m.bb = 321
1498
1499    proto2 = unittest_pb2.TestAllTypes()
1500    m = proto2.repeated_nested_message.add()
1501    m.bb = 999
1502    proto2.MergeFrom(proto1)
1503    self.assertEqual(999, proto2.repeated_nested_message[0].bb)
1504    self.assertEqual(123, proto2.repeated_nested_message[1].bb)
1505    self.assertEqual(321, proto2.repeated_nested_message[2].bb)
1506
1507    proto3 = unittest_pb2.TestAllTypes()
1508    proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message)
1509    self.assertEqual(999, proto3.repeated_nested_message[0].bb)
1510    self.assertEqual(123, proto3.repeated_nested_message[1].bb)
1511    self.assertEqual(321, proto3.repeated_nested_message[2].bb)
1512
1513  def testMergeFromAllFields(self):
1514    # With all fields set.
1515    proto1 = unittest_pb2.TestAllTypes()
1516    test_util.SetAllFields(proto1)
1517    proto2 = unittest_pb2.TestAllTypes()
1518    proto2.MergeFrom(proto1)
1519
1520    # Messages should be equal.
1521    self.assertEqual(proto2, proto1)
1522
1523    # Serialized string should be equal too.
1524    string1 = proto1.SerializeToString()
1525    string2 = proto2.SerializeToString()
1526    self.assertEqual(string1, string2)
1527
1528  def testMergeFromExtensionsSingular(self):
1529    proto1 = unittest_pb2.TestAllExtensions()
1530    proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
1531
1532    proto2 = unittest_pb2.TestAllExtensions()
1533    proto2.MergeFrom(proto1)
1534    self.assertEqual(
1535        1, proto2.Extensions[unittest_pb2.optional_int32_extension])
1536
1537  def testMergeFromExtensionsRepeated(self):
1538    proto1 = unittest_pb2.TestAllExtensions()
1539    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
1540    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
1541
1542    proto2 = unittest_pb2.TestAllExtensions()
1543    proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
1544    proto2.MergeFrom(proto1)
1545    self.assertEqual(
1546        3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
1547    self.assertEqual(
1548        0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
1549    self.assertEqual(
1550        1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
1551    self.assertEqual(
1552        2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
1553
1554  def testMergeFromExtensionsNestedMessage(self):
1555    proto1 = unittest_pb2.TestAllExtensions()
1556    ext1 = proto1.Extensions[
1557        unittest_pb2.repeated_nested_message_extension]
1558    m = ext1.add()
1559    m.bb = 222
1560    m = ext1.add()
1561    m.bb = 333
1562
1563    proto2 = unittest_pb2.TestAllExtensions()
1564    ext2 = proto2.Extensions[
1565        unittest_pb2.repeated_nested_message_extension]
1566    m = ext2.add()
1567    m.bb = 111
1568
1569    proto2.MergeFrom(proto1)
1570    ext2 = proto2.Extensions[
1571        unittest_pb2.repeated_nested_message_extension]
1572    self.assertEqual(3, len(ext2))
1573    self.assertEqual(111, ext2[0].bb)
1574    self.assertEqual(222, ext2[1].bb)
1575    self.assertEqual(333, ext2[2].bb)
1576
1577  def testMergeFromBug(self):
1578    message1 = unittest_pb2.TestAllTypes()
1579    message2 = unittest_pb2.TestAllTypes()
1580
1581    # Cause optional_nested_message to be instantiated within message1, even
1582    # though it is not considered to be "present".
1583    message1.optional_nested_message
1584    self.assertFalse(message1.HasField('optional_nested_message'))
1585
1586    # Merge into message2.  This should not instantiate the field is message2.
1587    message2.MergeFrom(message1)
1588    self.assertFalse(message2.HasField('optional_nested_message'))
1589
1590  def testCopyFromSingularField(self):
1591    # Test copy with just a singular field.
1592    proto1 = unittest_pb2.TestAllTypes()
1593    proto1.optional_int32 = 1
1594    proto1.optional_string = 'important-text'
1595
1596    proto2 = unittest_pb2.TestAllTypes()
1597    proto2.optional_string = 'value'
1598
1599    proto2.CopyFrom(proto1)
1600    self.assertEqual(1, proto2.optional_int32)
1601    self.assertEqual('important-text', proto2.optional_string)
1602
1603  def testCopyFromRepeatedField(self):
1604    # Test copy with a repeated field.
1605    proto1 = unittest_pb2.TestAllTypes()
1606    proto1.repeated_int32.append(1)
1607    proto1.repeated_int32.append(2)
1608
1609    proto2 = unittest_pb2.TestAllTypes()
1610    proto2.repeated_int32.append(0)
1611    proto2.CopyFrom(proto1)
1612
1613    self.assertEqual(1, proto2.repeated_int32[0])
1614    self.assertEqual(2, proto2.repeated_int32[1])
1615
1616  def testCopyFromAllFields(self):
1617    # With all fields set.
1618    proto1 = unittest_pb2.TestAllTypes()
1619    test_util.SetAllFields(proto1)
1620    proto2 = unittest_pb2.TestAllTypes()
1621    proto2.CopyFrom(proto1)
1622
1623    # Messages should be equal.
1624    self.assertEqual(proto2, proto1)
1625
1626    # Serialized string should be equal too.
1627    string1 = proto1.SerializeToString()
1628    string2 = proto2.SerializeToString()
1629    self.assertEqual(string1, string2)
1630
1631  def testCopyFromSelf(self):
1632    proto1 = unittest_pb2.TestAllTypes()
1633    proto1.repeated_int32.append(1)
1634    proto1.optional_int32 = 2
1635    proto1.optional_string = 'important-text'
1636
1637    proto1.CopyFrom(proto1)
1638    self.assertEqual(1, proto1.repeated_int32[0])
1639    self.assertEqual(2, proto1.optional_int32)
1640    self.assertEqual('important-text', proto1.optional_string)
1641
1642  def testCopyFromBadType(self):
1643    # The python implementation doesn't raise an exception in this
1644    # case. In theory it should.
1645    if api_implementation.Type() == 'python':
1646      return
1647    proto1 = unittest_pb2.TestAllTypes()
1648    proto2 = unittest_pb2.TestAllExtensions()
1649    self.assertRaises(TypeError, proto1.CopyFrom, proto2)
1650
1651  def testDeepCopy(self):
1652    proto1 = unittest_pb2.TestAllTypes()
1653    proto1.optional_int32 = 1
1654    proto2 = copy.deepcopy(proto1)
1655    self.assertEqual(1, proto2.optional_int32)
1656
1657    proto1.repeated_int32.append(2)
1658    proto1.repeated_int32.append(3)
1659    container = copy.deepcopy(proto1.repeated_int32)
1660    self.assertEqual([2, 3], container)
1661    container.remove(container[0])
1662    self.assertEqual([3], container)
1663
1664    message1 = proto1.repeated_nested_message.add()
1665    message1.bb = 1
1666    messages = copy.deepcopy(proto1.repeated_nested_message)
1667    self.assertEqual(proto1.repeated_nested_message, messages)
1668    message1.bb = 2
1669    self.assertNotEqual(proto1.repeated_nested_message, messages)
1670    messages.remove(messages[0])
1671    self.assertEqual(len(messages), 0)
1672
1673    # TODO(anuraag): Implement deepcopy for extension dict
1674
1675  def testClear(self):
1676    proto = unittest_pb2.TestAllTypes()
1677    # C++ implementation does not support lazy fields right now so leave it
1678    # out for now.
1679    if api_implementation.Type() == 'python':
1680      test_util.SetAllFields(proto)
1681    else:
1682      test_util.SetAllNonLazyFields(proto)
1683    # Clear the message.
1684    proto.Clear()
1685    self.assertEqual(proto.ByteSize(), 0)
1686    empty_proto = unittest_pb2.TestAllTypes()
1687    self.assertEqual(proto, empty_proto)
1688
1689    # Test if extensions which were set are cleared.
1690    proto = unittest_pb2.TestAllExtensions()
1691    test_util.SetAllExtensions(proto)
1692    # Clear the message.
1693    proto.Clear()
1694    self.assertEqual(proto.ByteSize(), 0)
1695    empty_proto = unittest_pb2.TestAllExtensions()
1696    self.assertEqual(proto, empty_proto)
1697
1698  def testDisconnectingBeforeClear(self):
1699    proto = unittest_pb2.TestAllTypes()
1700    nested = proto.optional_nested_message
1701    proto.Clear()
1702    self.assertTrue(nested is not proto.optional_nested_message)
1703    nested.bb = 23
1704    self.assertTrue(not proto.HasField('optional_nested_message'))
1705    self.assertEqual(0, proto.optional_nested_message.bb)
1706
1707    proto = unittest_pb2.TestAllTypes()
1708    nested = proto.optional_nested_message
1709    nested.bb = 5
1710    foreign = proto.optional_foreign_message
1711    foreign.c = 6
1712
1713    proto.Clear()
1714    self.assertTrue(nested is not proto.optional_nested_message)
1715    self.assertTrue(foreign is not proto.optional_foreign_message)
1716    self.assertEqual(5, nested.bb)
1717    self.assertEqual(6, foreign.c)
1718    nested.bb = 15
1719    foreign.c = 16
1720    self.assertFalse(proto.HasField('optional_nested_message'))
1721    self.assertEqual(0, proto.optional_nested_message.bb)
1722    self.assertFalse(proto.HasField('optional_foreign_message'))
1723    self.assertEqual(0, proto.optional_foreign_message.c)
1724
1725  def testDisconnectingInOneof(self):
1726    m = unittest_pb2.TestOneof2()  # This message has two messages in a oneof.
1727    m.foo_message.qux_int = 5
1728    sub_message = m.foo_message
1729    # Accessing another message's field does not clear the first one
1730    self.assertEqual(m.foo_lazy_message.qux_int, 0)
1731    self.assertEqual(m.foo_message.qux_int, 5)
1732    # But mutating another message in the oneof detaches the first one.
1733    m.foo_lazy_message.qux_int = 6
1734    self.assertEqual(m.foo_message.qux_int, 0)
1735    # The reference we got above was detached and is still valid.
1736    self.assertEqual(sub_message.qux_int, 5)
1737    sub_message.qux_int = 7
1738
1739  def testOneOf(self):
1740    proto = unittest_pb2.TestAllTypes()
1741    proto.oneof_uint32 = 10
1742    proto.oneof_nested_message.bb = 11
1743    self.assertEqual(11, proto.oneof_nested_message.bb)
1744    self.assertFalse(proto.HasField('oneof_uint32'))
1745    nested = proto.oneof_nested_message
1746    proto.oneof_string = 'abc'
1747    self.assertEqual('abc', proto.oneof_string)
1748    self.assertEqual(11, nested.bb)
1749    self.assertFalse(proto.HasField('oneof_nested_message'))
1750
1751  def assertInitialized(self, proto):
1752    self.assertTrue(proto.IsInitialized())
1753    # Neither method should raise an exception.
1754    proto.SerializeToString()
1755    proto.SerializePartialToString()
1756
1757  def assertNotInitialized(self, proto, error_size=None):
1758    errors = []
1759    self.assertFalse(proto.IsInitialized())
1760    self.assertFalse(proto.IsInitialized(errors))
1761    self.assertEqual(error_size, len(errors))
1762    self.assertRaises(message.EncodeError, proto.SerializeToString)
1763    # "Partial" serialization doesn't care if message is uninitialized.
1764    proto.SerializePartialToString()
1765
1766  def testIsInitialized(self):
1767    # Trivial cases - all optional fields and extensions.
1768    proto = unittest_pb2.TestAllTypes()
1769    self.assertInitialized(proto)
1770    proto = unittest_pb2.TestAllExtensions()
1771    self.assertInitialized(proto)
1772
1773    # The case of uninitialized required fields.
1774    proto = unittest_pb2.TestRequired()
1775    self.assertNotInitialized(proto, 3)
1776    proto.a = proto.b = proto.c = 2
1777    self.assertInitialized(proto)
1778
1779    # The case of uninitialized submessage.
1780    proto = unittest_pb2.TestRequiredForeign()
1781    self.assertInitialized(proto)
1782    proto.optional_message.a = 1
1783    self.assertNotInitialized(proto, 2)
1784    proto.optional_message.b = 0
1785    proto.optional_message.c = 0
1786    self.assertInitialized(proto)
1787
1788    # Uninitialized repeated submessage.
1789    message1 = proto.repeated_message.add()
1790    self.assertNotInitialized(proto, 3)
1791    message1.a = message1.b = message1.c = 0
1792    self.assertInitialized(proto)
1793
1794    # Uninitialized repeated group in an extension.
1795    proto = unittest_pb2.TestAllExtensions()
1796    extension = unittest_pb2.TestRequired.multi
1797    message1 = proto.Extensions[extension].add()
1798    message2 = proto.Extensions[extension].add()
1799    self.assertNotInitialized(proto, 6)
1800    message1.a = 1
1801    message1.b = 1
1802    message1.c = 1
1803    self.assertNotInitialized(proto, 3)
1804    message2.a = 2
1805    message2.b = 2
1806    message2.c = 2
1807    self.assertInitialized(proto)
1808
1809    # Uninitialized nonrepeated message in an extension.
1810    proto = unittest_pb2.TestAllExtensions()
1811    extension = unittest_pb2.TestRequired.single
1812    proto.Extensions[extension].a = 1
1813    self.assertNotInitialized(proto, 2)
1814    proto.Extensions[extension].b = 2
1815    proto.Extensions[extension].c = 3
1816    self.assertInitialized(proto)
1817
1818    # Try passing an errors list.
1819    errors = []
1820    proto = unittest_pb2.TestRequired()
1821    self.assertFalse(proto.IsInitialized(errors))
1822    self.assertEqual(errors, ['a', 'b', 'c'])
1823
1824  @unittest.skipIf(
1825      api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
1826      'Errors are only available from the most recent C++ implementation.')
1827  def testFileDescriptorErrors(self):
1828    file_name = 'test_file_descriptor_errors.proto'
1829    package_name = 'test_file_descriptor_errors.proto'
1830    file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
1831    file_descriptor_proto.name = file_name
1832    file_descriptor_proto.package = package_name
1833    m1 = file_descriptor_proto.message_type.add()
1834    m1.name = 'msg1'
1835    # Compiles the proto into the C++ descriptor pool
1836    descriptor.FileDescriptor(
1837        file_name,
1838        package_name,
1839        serialized_pb=file_descriptor_proto.SerializeToString())
1840    # Add a FileDescriptorProto that has duplicate symbols
1841    another_file_name = 'another_test_file_descriptor_errors.proto'
1842    file_descriptor_proto.name = another_file_name
1843    m2 = file_descriptor_proto.message_type.add()
1844    m2.name = 'msg2'
1845    with self.assertRaises(TypeError) as cm:
1846      descriptor.FileDescriptor(
1847          another_file_name,
1848          package_name,
1849          serialized_pb=file_descriptor_proto.SerializeToString())
1850      self.assertTrue(hasattr(cm, 'exception'), '%s not raised' %
1851                      getattr(cm.expected, '__name__', cm.expected))
1852      self.assertIn('test_file_descriptor_errors.proto', str(cm.exception))
1853      # Error message will say something about this definition being a
1854      # duplicate, though we don't check the message exactly to avoid a
1855      # dependency on the C++ logging code.
1856      self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
1857
1858  def testStringUTF8Encoding(self):
1859    proto = unittest_pb2.TestAllTypes()
1860
1861    # Assignment of a unicode object to a field of type 'bytes' is not allowed.
1862    self.assertRaises(TypeError,
1863                      setattr, proto, 'optional_bytes', u'unicode object')
1864
1865    # Check that the default value is of python's 'unicode' type.
1866    self.assertEqual(type(proto.optional_string), six.text_type)
1867
1868    proto.optional_string = six.text_type('Testing')
1869    self.assertEqual(proto.optional_string, str('Testing'))
1870
1871    # Assign a value of type 'str' which can be encoded in UTF-8.
1872    proto.optional_string = str('Testing')
1873    self.assertEqual(proto.optional_string, six.text_type('Testing'))
1874
1875    # Try to assign a 'bytes' object which contains non-UTF-8.
1876    self.assertRaises(ValueError,
1877                      setattr, proto, 'optional_string', b'a\x80a')
1878    # No exception: Assign already encoded UTF-8 bytes to a string field.
1879    utf8_bytes = u'Тест'.encode('utf-8')
1880    proto.optional_string = utf8_bytes
1881    # No exception: Assign the a non-ascii unicode object.
1882    proto.optional_string = u'Тест'
1883    # No exception thrown (normal str assignment containing ASCII).
1884    proto.optional_string = 'abc'
1885
1886  def testStringUTF8Serialization(self):
1887    proto = message_set_extensions_pb2.TestMessageSet()
1888    extension_message = message_set_extensions_pb2.TestMessageSetExtension2
1889    extension = extension_message.message_set_extension
1890
1891    test_utf8 = u'Тест'
1892    test_utf8_bytes = test_utf8.encode('utf-8')
1893
1894    # 'Test' in another language, using UTF-8 charset.
1895    proto.Extensions[extension].str = test_utf8
1896
1897    # Serialize using the MessageSet wire format (this is specified in the
1898    # .proto file).
1899    serialized = proto.SerializeToString()
1900
1901    # Check byte size.
1902    self.assertEqual(proto.ByteSize(), len(serialized))
1903
1904    raw = unittest_mset_pb2.RawMessageSet()
1905    bytes_read = raw.MergeFromString(serialized)
1906    self.assertEqual(len(serialized), bytes_read)
1907
1908    message2 = message_set_extensions_pb2.TestMessageSetExtension2()
1909
1910    self.assertEqual(1, len(raw.item))
1911    # Check that the type_id is the same as the tag ID in the .proto file.
1912    self.assertEqual(raw.item[0].type_id, 98418634)
1913
1914    # Check the actual bytes on the wire.
1915    self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes))
1916    bytes_read = message2.MergeFromString(raw.item[0].message)
1917    self.assertEqual(len(raw.item[0].message), bytes_read)
1918
1919    self.assertEqual(type(message2.str), six.text_type)
1920    self.assertEqual(message2.str, test_utf8)
1921
1922    # The pure Python API throws an exception on MergeFromString(),
1923    # if any of the string fields of the message can't be UTF-8 decoded.
1924    # The C++ implementation of the API has no way to check that on
1925    # MergeFromString and thus has no way to throw the exception.
1926    #
1927    # The pure Python API always returns objects of type 'unicode' (UTF-8
1928    # encoded), or 'bytes' (in 7 bit ASCII).
1929    badbytes = raw.item[0].message.replace(
1930        test_utf8_bytes, len(test_utf8_bytes) * b'\xff')
1931
1932    unicode_decode_failed = False
1933    try:
1934      message2.MergeFromString(badbytes)
1935    except UnicodeDecodeError:
1936      unicode_decode_failed = True
1937    string_field = message2.str
1938    self.assertTrue(unicode_decode_failed or type(string_field) is bytes)
1939
1940  def testBytesInTextFormat(self):
1941    proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
1942    self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n',
1943                     six.text_type(proto))
1944
1945  def testEmptyNestedMessage(self):
1946    proto = unittest_pb2.TestAllTypes()
1947    proto.optional_nested_message.MergeFrom(
1948        unittest_pb2.TestAllTypes.NestedMessage())
1949    self.assertTrue(proto.HasField('optional_nested_message'))
1950
1951    proto = unittest_pb2.TestAllTypes()
1952    proto.optional_nested_message.CopyFrom(
1953        unittest_pb2.TestAllTypes.NestedMessage())
1954    self.assertTrue(proto.HasField('optional_nested_message'))
1955
1956    proto = unittest_pb2.TestAllTypes()
1957    bytes_read = proto.optional_nested_message.MergeFromString(b'')
1958    self.assertEqual(0, bytes_read)
1959    self.assertTrue(proto.HasField('optional_nested_message'))
1960
1961    proto = unittest_pb2.TestAllTypes()
1962    proto.optional_nested_message.ParseFromString(b'')
1963    self.assertTrue(proto.HasField('optional_nested_message'))
1964
1965    serialized = proto.SerializeToString()
1966    proto2 = unittest_pb2.TestAllTypes()
1967    self.assertEqual(
1968        len(serialized),
1969        proto2.MergeFromString(serialized))
1970    self.assertTrue(proto2.HasField('optional_nested_message'))
1971
1972  def testSetInParent(self):
1973    proto = unittest_pb2.TestAllTypes()
1974    self.assertFalse(proto.HasField('optionalgroup'))
1975    proto.optionalgroup.SetInParent()
1976    self.assertTrue(proto.HasField('optionalgroup'))
1977
1978  def testPackageInitializationImport(self):
1979    """Test that we can import nested messages from their __init__.py.
1980
1981    Such setup is not trivial since at the time of processing of __init__.py one
1982    can't refer to its submodules by name in code, so expressions like
1983    google.protobuf.internal.import_test_package.inner_pb2
1984    don't work. They do work in imports, so we have assign an alias at import
1985    and then use that alias in generated code.
1986    """
1987    # We import here since it's the import that used to fail, and we want
1988    # the failure to have the right context.
1989    # pylint: disable=g-import-not-at-top
1990    from google.protobuf.internal import import_test_package
1991    # pylint: enable=g-import-not-at-top
1992    msg = import_test_package.myproto.Outer()
1993    # Just check the default value.
1994    self.assertEqual(57, msg.inner.value)
1995
1996#  Since we had so many tests for protocol buffer equality, we broke these out
1997#  into separate TestCase classes.
1998
1999
2000@testing_refleaks.TestCase
2001class TestAllTypesEqualityTest(unittest.TestCase):
2002
2003  def setUp(self):
2004    self.first_proto = unittest_pb2.TestAllTypes()
2005    self.second_proto = unittest_pb2.TestAllTypes()
2006
2007  def testNotHashable(self):
2008    self.assertRaises(TypeError, hash, self.first_proto)
2009
2010  def testSelfEquality(self):
2011    self.assertEqual(self.first_proto, self.first_proto)
2012
2013  def testEmptyProtosEqual(self):
2014    self.assertEqual(self.first_proto, self.second_proto)
2015
2016
2017@testing_refleaks.TestCase
2018class FullProtosEqualityTest(unittest.TestCase):
2019
2020  """Equality tests using completely-full protos as a starting point."""
2021
2022  def setUp(self):
2023    self.first_proto = unittest_pb2.TestAllTypes()
2024    self.second_proto = unittest_pb2.TestAllTypes()
2025    test_util.SetAllFields(self.first_proto)
2026    test_util.SetAllFields(self.second_proto)
2027
2028  def testNotHashable(self):
2029    self.assertRaises(TypeError, hash, self.first_proto)
2030
2031  def testNoneNotEqual(self):
2032    self.assertNotEqual(self.first_proto, None)
2033    self.assertNotEqual(None, self.second_proto)
2034
2035  def testNotEqualToOtherMessage(self):
2036    third_proto = unittest_pb2.TestRequired()
2037    self.assertNotEqual(self.first_proto, third_proto)
2038    self.assertNotEqual(third_proto, self.second_proto)
2039
2040  def testAllFieldsFilledEquality(self):
2041    self.assertEqual(self.first_proto, self.second_proto)
2042
2043  def testNonRepeatedScalar(self):
2044    # Nonrepeated scalar field change should cause inequality.
2045    self.first_proto.optional_int32 += 1
2046    self.assertNotEqual(self.first_proto, self.second_proto)
2047    # ...as should clearing a field.
2048    self.first_proto.ClearField('optional_int32')
2049    self.assertNotEqual(self.first_proto, self.second_proto)
2050
2051  def testNonRepeatedComposite(self):
2052    # Change a nonrepeated composite field.
2053    self.first_proto.optional_nested_message.bb += 1
2054    self.assertNotEqual(self.first_proto, self.second_proto)
2055    self.first_proto.optional_nested_message.bb -= 1
2056    self.assertEqual(self.first_proto, self.second_proto)
2057    # Clear a field in the nested message.
2058    self.first_proto.optional_nested_message.ClearField('bb')
2059    self.assertNotEqual(self.first_proto, self.second_proto)
2060    self.first_proto.optional_nested_message.bb = (
2061        self.second_proto.optional_nested_message.bb)
2062    self.assertEqual(self.first_proto, self.second_proto)
2063    # Remove the nested message entirely.
2064    self.first_proto.ClearField('optional_nested_message')
2065    self.assertNotEqual(self.first_proto, self.second_proto)
2066
2067  def testRepeatedScalar(self):
2068    # Change a repeated scalar field.
2069    self.first_proto.repeated_int32.append(5)
2070    self.assertNotEqual(self.first_proto, self.second_proto)
2071    self.first_proto.ClearField('repeated_int32')
2072    self.assertNotEqual(self.first_proto, self.second_proto)
2073
2074  def testRepeatedComposite(self):
2075    # Change value within a repeated composite field.
2076    self.first_proto.repeated_nested_message[0].bb += 1
2077    self.assertNotEqual(self.first_proto, self.second_proto)
2078    self.first_proto.repeated_nested_message[0].bb -= 1
2079    self.assertEqual(self.first_proto, self.second_proto)
2080    # Add a value to a repeated composite field.
2081    self.first_proto.repeated_nested_message.add()
2082    self.assertNotEqual(self.first_proto, self.second_proto)
2083    self.second_proto.repeated_nested_message.add()
2084    self.assertEqual(self.first_proto, self.second_proto)
2085
2086  def testNonRepeatedScalarHasBits(self):
2087    # Ensure that we test "has" bits as well as value for
2088    # nonrepeated scalar field.
2089    self.first_proto.ClearField('optional_int32')
2090    self.second_proto.optional_int32 = 0
2091    self.assertNotEqual(self.first_proto, self.second_proto)
2092
2093  def testNonRepeatedCompositeHasBits(self):
2094    # Ensure that we test "has" bits as well as value for
2095    # nonrepeated composite field.
2096    self.first_proto.ClearField('optional_nested_message')
2097    self.second_proto.optional_nested_message.ClearField('bb')
2098    self.assertNotEqual(self.first_proto, self.second_proto)
2099    self.first_proto.optional_nested_message.bb = 0
2100    self.first_proto.optional_nested_message.ClearField('bb')
2101    self.assertEqual(self.first_proto, self.second_proto)
2102
2103
2104@testing_refleaks.TestCase
2105class ExtensionEqualityTest(unittest.TestCase):
2106
2107  def testExtensionEquality(self):
2108    first_proto = unittest_pb2.TestAllExtensions()
2109    second_proto = unittest_pb2.TestAllExtensions()
2110    self.assertEqual(first_proto, second_proto)
2111    test_util.SetAllExtensions(first_proto)
2112    self.assertNotEqual(first_proto, second_proto)
2113    test_util.SetAllExtensions(second_proto)
2114    self.assertEqual(first_proto, second_proto)
2115
2116    # Ensure that we check value equality.
2117    first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
2118    self.assertNotEqual(first_proto, second_proto)
2119    first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
2120    self.assertEqual(first_proto, second_proto)
2121
2122    # Ensure that we also look at "has" bits.
2123    first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
2124    second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
2125    self.assertNotEqual(first_proto, second_proto)
2126    first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
2127    self.assertEqual(first_proto, second_proto)
2128
2129    # Ensure that differences in cached values
2130    # don't matter if "has" bits are both false.
2131    first_proto = unittest_pb2.TestAllExtensions()
2132    second_proto = unittest_pb2.TestAllExtensions()
2133    self.assertEqual(
2134        0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
2135    self.assertEqual(first_proto, second_proto)
2136
2137
2138@testing_refleaks.TestCase
2139class MutualRecursionEqualityTest(unittest.TestCase):
2140
2141  def testEqualityWithMutualRecursion(self):
2142    first_proto = unittest_pb2.TestMutualRecursionA()
2143    second_proto = unittest_pb2.TestMutualRecursionA()
2144    self.assertEqual(first_proto, second_proto)
2145    first_proto.bb.a.bb.optional_int32 = 23
2146    self.assertNotEqual(first_proto, second_proto)
2147    second_proto.bb.a.bb.optional_int32 = 23
2148    self.assertEqual(first_proto, second_proto)
2149
2150
2151@testing_refleaks.TestCase
2152class ByteSizeTest(unittest.TestCase):
2153
2154  def setUp(self):
2155    self.proto = unittest_pb2.TestAllTypes()
2156    self.extended_proto = more_extensions_pb2.ExtendedMessage()
2157    self.packed_proto = unittest_pb2.TestPackedTypes()
2158    self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
2159
2160  def Size(self):
2161    return self.proto.ByteSize()
2162
2163  def testEmptyMessage(self):
2164    self.assertEqual(0, self.proto.ByteSize())
2165
2166  def testSizedOnKwargs(self):
2167    # Use a separate message to ensure testing right after creation.
2168    proto = unittest_pb2.TestAllTypes()
2169    self.assertEqual(0, proto.ByteSize())
2170    proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1)
2171    # One byte for the tag, one to encode varint 1.
2172    self.assertEqual(2, proto_kwargs.ByteSize())
2173
2174  def testVarints(self):
2175    def Test(i, expected_varint_size):
2176      self.proto.Clear()
2177      self.proto.optional_int64 = i
2178      # Add one to the varint size for the tag info
2179      # for tag 1.
2180      self.assertEqual(expected_varint_size + 1, self.Size())
2181    Test(0, 1)
2182    Test(1, 1)
2183    for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
2184      Test((1 << i) - 1, num_bytes)
2185    Test(-1, 10)
2186    Test(-2, 10)
2187    Test(-(1 << 63), 10)
2188
2189  def testStrings(self):
2190    self.proto.optional_string = ''
2191    # Need one byte for tag info (tag #14), and one byte for length.
2192    self.assertEqual(2, self.Size())
2193
2194    self.proto.optional_string = 'abc'
2195    # Need one byte for tag info (tag #14), and one byte for length.
2196    self.assertEqual(2 + len(self.proto.optional_string), self.Size())
2197
2198    self.proto.optional_string = 'x' * 128
2199    # Need one byte for tag info (tag #14), and TWO bytes for length.
2200    self.assertEqual(3 + len(self.proto.optional_string), self.Size())
2201
2202  def testOtherNumerics(self):
2203    self.proto.optional_fixed32 = 1234
2204    # One byte for tag and 4 bytes for fixed32.
2205    self.assertEqual(5, self.Size())
2206    self.proto = unittest_pb2.TestAllTypes()
2207
2208    self.proto.optional_fixed64 = 1234
2209    # One byte for tag and 8 bytes for fixed64.
2210    self.assertEqual(9, self.Size())
2211    self.proto = unittest_pb2.TestAllTypes()
2212
2213    self.proto.optional_float = 1.234
2214    # One byte for tag and 4 bytes for float.
2215    self.assertEqual(5, self.Size())
2216    self.proto = unittest_pb2.TestAllTypes()
2217
2218    self.proto.optional_double = 1.234
2219    # One byte for tag and 8 bytes for float.
2220    self.assertEqual(9, self.Size())
2221    self.proto = unittest_pb2.TestAllTypes()
2222
2223    self.proto.optional_sint32 = 64
2224    # One byte for tag and 2 bytes for zig-zag-encoded 64.
2225    self.assertEqual(3, self.Size())
2226    self.proto = unittest_pb2.TestAllTypes()
2227
2228  def testComposites(self):
2229    # 3 bytes.
2230    self.proto.optional_nested_message.bb = (1 << 14)
2231    # Plus one byte for bb tag.
2232    # Plus 1 byte for optional_nested_message serialized size.
2233    # Plus two bytes for optional_nested_message tag.
2234    self.assertEqual(3 + 1 + 1 + 2, self.Size())
2235
2236  def testGroups(self):
2237    # 4 bytes.
2238    self.proto.optionalgroup.a = (1 << 21)
2239    # Plus two bytes for |a| tag.
2240    # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
2241    self.assertEqual(4 + 2 + 2*2, self.Size())
2242
2243  def testRepeatedScalars(self):
2244    self.proto.repeated_int32.append(10)  # 1 byte.
2245    self.proto.repeated_int32.append(128)  # 2 bytes.
2246    # Also need 2 bytes for each entry for tag.
2247    self.assertEqual(1 + 2 + 2*2, self.Size())
2248
2249  def testRepeatedScalarsExtend(self):
2250    self.proto.repeated_int32.extend([10, 128])  # 3 bytes.
2251    # Also need 2 bytes for each entry for tag.
2252    self.assertEqual(1 + 2 + 2*2, self.Size())
2253
2254  def testRepeatedScalarsRemove(self):
2255    self.proto.repeated_int32.append(10)  # 1 byte.
2256    self.proto.repeated_int32.append(128)  # 2 bytes.
2257    # Also need 2 bytes for each entry for tag.
2258    self.assertEqual(1 + 2 + 2*2, self.Size())
2259    self.proto.repeated_int32.remove(128)
2260    self.assertEqual(1 + 2, self.Size())
2261
2262  def testRepeatedComposites(self):
2263    # Empty message.  2 bytes tag plus 1 byte length.
2264    foreign_message_0 = self.proto.repeated_nested_message.add()
2265    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2266    foreign_message_1 = self.proto.repeated_nested_message.add()
2267    foreign_message_1.bb = 7
2268    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2269
2270  def testRepeatedCompositesDelete(self):
2271    # Empty message.  2 bytes tag plus 1 byte length.
2272    foreign_message_0 = self.proto.repeated_nested_message.add()
2273    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2274    foreign_message_1 = self.proto.repeated_nested_message.add()
2275    foreign_message_1.bb = 9
2276    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2277    repeated_nested_message = copy.deepcopy(
2278        self.proto.repeated_nested_message)
2279
2280    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2281    del self.proto.repeated_nested_message[0]
2282    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2283
2284    # Now add a new message.
2285    foreign_message_2 = self.proto.repeated_nested_message.add()
2286    foreign_message_2.bb = 12
2287
2288    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2289    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2290    self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
2291
2292    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2293    del self.proto.repeated_nested_message[1]
2294    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2295
2296    del self.proto.repeated_nested_message[0]
2297    self.assertEqual(0, self.Size())
2298
2299    self.assertEqual(2, len(repeated_nested_message))
2300    del repeated_nested_message[0:1]
2301    # TODO(jieluo): Fix cpp extension bug when delete repeated message.
2302    if api_implementation.Type() == 'python':
2303      self.assertEqual(1, len(repeated_nested_message))
2304    del repeated_nested_message[-1]
2305    # TODO(jieluo): Fix cpp extension bug when delete repeated message.
2306    if api_implementation.Type() == 'python':
2307      self.assertEqual(0, len(repeated_nested_message))
2308
2309  def testRepeatedGroups(self):
2310    # 2-byte START_GROUP plus 2-byte END_GROUP.
2311    group_0 = self.proto.repeatedgroup.add()
2312    # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
2313    # plus 2-byte END_GROUP.
2314    group_1 = self.proto.repeatedgroup.add()
2315    group_1.a =  7
2316    self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
2317
2318  def testExtensions(self):
2319    proto = unittest_pb2.TestAllExtensions()
2320    self.assertEqual(0, proto.ByteSize())
2321    extension = unittest_pb2.optional_int32_extension  # Field #1, 1 byte.
2322    proto.Extensions[extension] = 23
2323    # 1 byte for tag, 1 byte for value.
2324    self.assertEqual(2, proto.ByteSize())
2325    field = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[
2326        'optional_int32']
2327    with self.assertRaises(KeyError):
2328      proto.Extensions[field] = 23
2329
2330  def testCacheInvalidationForNonrepeatedScalar(self):
2331    # Test non-extension.
2332    self.proto.optional_int32 = 1
2333    self.assertEqual(2, self.proto.ByteSize())
2334    self.proto.optional_int32 = 128
2335    self.assertEqual(3, self.proto.ByteSize())
2336    self.proto.ClearField('optional_int32')
2337    self.assertEqual(0, self.proto.ByteSize())
2338
2339    # Test within extension.
2340    extension = more_extensions_pb2.optional_int_extension
2341    self.extended_proto.Extensions[extension] = 1
2342    self.assertEqual(2, self.extended_proto.ByteSize())
2343    self.extended_proto.Extensions[extension] = 128
2344    self.assertEqual(3, self.extended_proto.ByteSize())
2345    self.extended_proto.ClearExtension(extension)
2346    self.assertEqual(0, self.extended_proto.ByteSize())
2347
2348  def testCacheInvalidationForRepeatedScalar(self):
2349    # Test non-extension.
2350    self.proto.repeated_int32.append(1)
2351    self.assertEqual(3, self.proto.ByteSize())
2352    self.proto.repeated_int32.append(1)
2353    self.assertEqual(6, self.proto.ByteSize())
2354    self.proto.repeated_int32[1] = 128
2355    self.assertEqual(7, self.proto.ByteSize())
2356    self.proto.ClearField('repeated_int32')
2357    self.assertEqual(0, self.proto.ByteSize())
2358
2359    # Test within extension.
2360    extension = more_extensions_pb2.repeated_int_extension
2361    repeated = self.extended_proto.Extensions[extension]
2362    repeated.append(1)
2363    self.assertEqual(2, self.extended_proto.ByteSize())
2364    repeated.append(1)
2365    self.assertEqual(4, self.extended_proto.ByteSize())
2366    repeated[1] = 128
2367    self.assertEqual(5, self.extended_proto.ByteSize())
2368    self.extended_proto.ClearExtension(extension)
2369    self.assertEqual(0, self.extended_proto.ByteSize())
2370
2371  def testCacheInvalidationForNonrepeatedMessage(self):
2372    # Test non-extension.
2373    self.proto.optional_foreign_message.c = 1
2374    self.assertEqual(5, self.proto.ByteSize())
2375    self.proto.optional_foreign_message.c = 128
2376    self.assertEqual(6, self.proto.ByteSize())
2377    self.proto.optional_foreign_message.ClearField('c')
2378    self.assertEqual(3, self.proto.ByteSize())
2379    self.proto.ClearField('optional_foreign_message')
2380    self.assertEqual(0, self.proto.ByteSize())
2381
2382    if api_implementation.Type() == 'python':
2383      # This is only possible in pure-Python implementation of the API.
2384      child = self.proto.optional_foreign_message
2385      self.proto.ClearField('optional_foreign_message')
2386      child.c = 128
2387      self.assertEqual(0, self.proto.ByteSize())
2388
2389    # Test within extension.
2390    extension = more_extensions_pb2.optional_message_extension
2391    child = self.extended_proto.Extensions[extension]
2392    self.assertEqual(0, self.extended_proto.ByteSize())
2393    child.foreign_message_int = 1
2394    self.assertEqual(4, self.extended_proto.ByteSize())
2395    child.foreign_message_int = 128
2396    self.assertEqual(5, self.extended_proto.ByteSize())
2397    self.extended_proto.ClearExtension(extension)
2398    self.assertEqual(0, self.extended_proto.ByteSize())
2399
2400  def testCacheInvalidationForRepeatedMessage(self):
2401    # Test non-extension.
2402    child0 = self.proto.repeated_foreign_message.add()
2403    self.assertEqual(3, self.proto.ByteSize())
2404    self.proto.repeated_foreign_message.add()
2405    self.assertEqual(6, self.proto.ByteSize())
2406    child0.c = 1
2407    self.assertEqual(8, self.proto.ByteSize())
2408    self.proto.ClearField('repeated_foreign_message')
2409    self.assertEqual(0, self.proto.ByteSize())
2410
2411    # Test within extension.
2412    extension = more_extensions_pb2.repeated_message_extension
2413    child_list = self.extended_proto.Extensions[extension]
2414    child0 = child_list.add()
2415    self.assertEqual(2, self.extended_proto.ByteSize())
2416    child_list.add()
2417    self.assertEqual(4, self.extended_proto.ByteSize())
2418    child0.foreign_message_int = 1
2419    self.assertEqual(6, self.extended_proto.ByteSize())
2420    child0.ClearField('foreign_message_int')
2421    self.assertEqual(4, self.extended_proto.ByteSize())
2422    self.extended_proto.ClearExtension(extension)
2423    self.assertEqual(0, self.extended_proto.ByteSize())
2424
2425  def testPackedRepeatedScalars(self):
2426    self.assertEqual(0, self.packed_proto.ByteSize())
2427
2428    self.packed_proto.packed_int32.append(10)   # 1 byte.
2429    self.packed_proto.packed_int32.append(128)  # 2 bytes.
2430    # The tag is 2 bytes (the field number is 90), and the varint
2431    # storing the length is 1 byte.
2432    int_size = 1 + 2 + 3
2433    self.assertEqual(int_size, self.packed_proto.ByteSize())
2434
2435    self.packed_proto.packed_double.append(4.2)   # 8 bytes
2436    self.packed_proto.packed_double.append(3.25)  # 8 bytes
2437    # 2 more tag bytes, 1 more length byte.
2438    double_size = 8 + 8 + 3
2439    self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
2440
2441    self.packed_proto.ClearField('packed_int32')
2442    self.assertEqual(double_size, self.packed_proto.ByteSize())
2443
2444  def testPackedExtensions(self):
2445    self.assertEqual(0, self.packed_extended_proto.ByteSize())
2446    extension = self.packed_extended_proto.Extensions[
2447        unittest_pb2.packed_fixed32_extension]
2448    extension.extend([1, 2, 3, 4])   # 16 bytes
2449    # Tag is 3 bytes.
2450    self.assertEqual(19, self.packed_extended_proto.ByteSize())
2451
2452
2453# Issues to be sure to cover include:
2454#   * Handling of unrecognized tags ("uninterpreted_bytes").
2455#   * Handling of MessageSets.
2456#   * Consistent ordering of tags in the wire format,
2457#     including ordering between extensions and non-extension
2458#     fields.
2459#   * Consistent serialization of negative numbers, especially
2460#     negative int32s.
2461#   * Handling of empty submessages (with and without "has"
2462#     bits set).
2463
2464@testing_refleaks.TestCase
2465class SerializationTest(unittest.TestCase):
2466
2467  def testSerializeEmtpyMessage(self):
2468    first_proto = unittest_pb2.TestAllTypes()
2469    second_proto = unittest_pb2.TestAllTypes()
2470    serialized = first_proto.SerializeToString()
2471    self.assertEqual(first_proto.ByteSize(), len(serialized))
2472    self.assertEqual(
2473        len(serialized),
2474        second_proto.MergeFromString(serialized))
2475    self.assertEqual(first_proto, second_proto)
2476
2477  def testSerializeAllFields(self):
2478    first_proto = unittest_pb2.TestAllTypes()
2479    second_proto = unittest_pb2.TestAllTypes()
2480    test_util.SetAllFields(first_proto)
2481    serialized = first_proto.SerializeToString()
2482    self.assertEqual(first_proto.ByteSize(), len(serialized))
2483    self.assertEqual(
2484        len(serialized),
2485        second_proto.MergeFromString(serialized))
2486    self.assertEqual(first_proto, second_proto)
2487
2488  def testSerializeAllExtensions(self):
2489    first_proto = unittest_pb2.TestAllExtensions()
2490    second_proto = unittest_pb2.TestAllExtensions()
2491    test_util.SetAllExtensions(first_proto)
2492    serialized = first_proto.SerializeToString()
2493    self.assertEqual(
2494        len(serialized),
2495        second_proto.MergeFromString(serialized))
2496    self.assertEqual(first_proto, second_proto)
2497
2498  def testSerializeWithOptionalGroup(self):
2499    first_proto = unittest_pb2.TestAllTypes()
2500    second_proto = unittest_pb2.TestAllTypes()
2501    first_proto.optionalgroup.a = 242
2502    serialized = first_proto.SerializeToString()
2503    self.assertEqual(
2504        len(serialized),
2505        second_proto.MergeFromString(serialized))
2506    self.assertEqual(first_proto, second_proto)
2507
2508  def testSerializeNegativeValues(self):
2509    first_proto = unittest_pb2.TestAllTypes()
2510
2511    first_proto.optional_int32 = -1
2512    first_proto.optional_int64 = -(2 << 40)
2513    first_proto.optional_sint32 = -3
2514    first_proto.optional_sint64 = -(4 << 40)
2515    first_proto.optional_sfixed32 = -5
2516    first_proto.optional_sfixed64 = -(6 << 40)
2517
2518    second_proto = unittest_pb2.TestAllTypes.FromString(
2519        first_proto.SerializeToString())
2520
2521    self.assertEqual(first_proto, second_proto)
2522
2523  def testParseTruncated(self):
2524    # This test is only applicable for the Python implementation of the API.
2525    if api_implementation.Type() != 'python':
2526      return
2527
2528    first_proto = unittest_pb2.TestAllTypes()
2529    test_util.SetAllFields(first_proto)
2530    serialized = memoryview(first_proto.SerializeToString())
2531
2532    for truncation_point in range(len(serialized) + 1):
2533      try:
2534        second_proto = unittest_pb2.TestAllTypes()
2535        unknown_fields = unittest_pb2.TestEmptyMessage()
2536        pos = second_proto._InternalParse(serialized, 0, truncation_point)
2537        # If we didn't raise an error then we read exactly the amount expected.
2538        self.assertEqual(truncation_point, pos)
2539
2540        # Parsing to unknown fields should not throw if parsing to known fields
2541        # did not.
2542        try:
2543          pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
2544          self.assertEqual(truncation_point, pos2)
2545        except message.DecodeError:
2546          self.fail('Parsing unknown fields failed when parsing known fields '
2547                    'did not.')
2548      except message.DecodeError:
2549        # Parsing unknown fields should also fail.
2550        self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
2551                          serialized, 0, truncation_point)
2552
2553  def testCanonicalSerializationOrder(self):
2554    proto = more_messages_pb2.OutOfOrderFields()
2555    # These are also their tag numbers.  Even though we're setting these in
2556    # reverse-tag order AND they're listed in reverse tag-order in the .proto
2557    # file, they should nonetheless be serialized in tag order.
2558    proto.optional_sint32 = 5
2559    proto.Extensions[more_messages_pb2.optional_uint64] = 4
2560    proto.optional_uint32 = 3
2561    proto.Extensions[more_messages_pb2.optional_int64] = 2
2562    proto.optional_int32 = 1
2563    serialized = proto.SerializeToString()
2564    self.assertEqual(proto.ByteSize(), len(serialized))
2565    d = _MiniDecoder(serialized)
2566    ReadTag = d.ReadFieldNumberAndWireType
2567    self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
2568    self.assertEqual(1, d.ReadInt32())
2569    self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
2570    self.assertEqual(2, d.ReadInt64())
2571    self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
2572    self.assertEqual(3, d.ReadUInt32())
2573    self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
2574    self.assertEqual(4, d.ReadUInt64())
2575    self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
2576    self.assertEqual(5, d.ReadSInt32())
2577
2578  def testCanonicalSerializationOrderSameAsCpp(self):
2579    # Copy of the same test we use for C++.
2580    proto = unittest_pb2.TestFieldOrderings()
2581    test_util.SetAllFieldsAndExtensions(proto)
2582    serialized = proto.SerializeToString()
2583    test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
2584
2585  def testMergeFromStringWhenFieldsAlreadySet(self):
2586    first_proto = unittest_pb2.TestAllTypes()
2587    first_proto.repeated_string.append('foobar')
2588    first_proto.optional_int32 = 23
2589    first_proto.optional_nested_message.bb = 42
2590    serialized = first_proto.SerializeToString()
2591
2592    second_proto = unittest_pb2.TestAllTypes()
2593    second_proto.repeated_string.append('baz')
2594    second_proto.optional_int32 = 100
2595    second_proto.optional_nested_message.bb = 999
2596
2597    bytes_parsed = second_proto.MergeFromString(serialized)
2598    self.assertEqual(len(serialized), bytes_parsed)
2599
2600    # Ensure that we append to repeated fields.
2601    self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
2602    # Ensure that we overwrite nonrepeatd scalars.
2603    self.assertEqual(23, second_proto.optional_int32)
2604    # Ensure that we recursively call MergeFromString() on
2605    # submessages.
2606    self.assertEqual(42, second_proto.optional_nested_message.bb)
2607
2608  def testMessageSetWireFormat(self):
2609    proto = message_set_extensions_pb2.TestMessageSet()
2610    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2611    extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2
2612    extension1 = extension_message1.message_set_extension
2613    extension2 = extension_message2.message_set_extension
2614    extension3 = message_set_extensions_pb2.message_set_extension3
2615    proto.Extensions[extension1].i = 123
2616    proto.Extensions[extension2].str = 'foo'
2617    proto.Extensions[extension3].text = 'bar'
2618
2619    # Serialize using the MessageSet wire format (this is specified in the
2620    # .proto file).
2621    serialized = proto.SerializeToString()
2622
2623    raw = unittest_mset_pb2.RawMessageSet()
2624    self.assertEqual(False,
2625                     raw.DESCRIPTOR.GetOptions().message_set_wire_format)
2626    self.assertEqual(
2627        len(serialized),
2628        raw.MergeFromString(serialized))
2629    self.assertEqual(3, len(raw.item))
2630
2631    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2632    self.assertEqual(
2633        len(raw.item[0].message),
2634        message1.MergeFromString(raw.item[0].message))
2635    self.assertEqual(123, message1.i)
2636
2637    message2 = message_set_extensions_pb2.TestMessageSetExtension2()
2638    self.assertEqual(
2639        len(raw.item[1].message),
2640        message2.MergeFromString(raw.item[1].message))
2641    self.assertEqual('foo', message2.str)
2642
2643    message3 = message_set_extensions_pb2.TestMessageSetExtension3()
2644    self.assertEqual(
2645        len(raw.item[2].message),
2646        message3.MergeFromString(raw.item[2].message))
2647    self.assertEqual('bar', message3.text)
2648
2649    # Deserialize using the MessageSet wire format.
2650    proto2 = message_set_extensions_pb2.TestMessageSet()
2651    self.assertEqual(
2652        len(serialized),
2653        proto2.MergeFromString(serialized))
2654    self.assertEqual(123, proto2.Extensions[extension1].i)
2655    self.assertEqual('foo', proto2.Extensions[extension2].str)
2656    self.assertEqual('bar', proto2.Extensions[extension3].text)
2657
2658    # Check byte size.
2659    self.assertEqual(proto2.ByteSize(), len(serialized))
2660    self.assertEqual(proto.ByteSize(), len(serialized))
2661
2662  def testMessageSetWireFormatUnknownExtension(self):
2663    # Create a message using the message set wire format with an unknown
2664    # message.
2665    raw = unittest_mset_pb2.RawMessageSet()
2666
2667    # Add an item.
2668    item = raw.item.add()
2669    item.type_id = 98418603
2670    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2671    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2672    message1.i = 12345
2673    item.message = message1.SerializeToString()
2674
2675    # Add a second, unknown extension.
2676    item = raw.item.add()
2677    item.type_id = 98418604
2678    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2679    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2680    message1.i = 12346
2681    item.message = message1.SerializeToString()
2682
2683    # Add another unknown extension.
2684    item = raw.item.add()
2685    item.type_id = 98418605
2686    message1 = message_set_extensions_pb2.TestMessageSetExtension2()
2687    message1.str = 'foo'
2688    item.message = message1.SerializeToString()
2689
2690    serialized = raw.SerializeToString()
2691
2692    # Parse message using the message set wire format.
2693    proto = message_set_extensions_pb2.TestMessageSet()
2694    self.assertEqual(
2695        len(serialized),
2696        proto.MergeFromString(serialized))
2697
2698    # Check that the message parsed well.
2699    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2700    extension1 = extension_message1.message_set_extension
2701    self.assertEqual(12345, proto.Extensions[extension1].i)
2702
2703  def testUnknownFields(self):
2704    proto = unittest_pb2.TestAllTypes()
2705    test_util.SetAllFields(proto)
2706
2707    serialized = proto.SerializeToString()
2708
2709    # The empty message should be parsable with all of the fields
2710    # unknown.
2711    proto2 = unittest_pb2.TestEmptyMessage()
2712
2713    # Parsing this message should succeed.
2714    self.assertEqual(
2715        len(serialized),
2716        proto2.MergeFromString(serialized))
2717
2718    # Now test with a int64 field set.
2719    proto = unittest_pb2.TestAllTypes()
2720    proto.optional_int64 = 0x0fffffffffffffff
2721    serialized = proto.SerializeToString()
2722    # The empty message should be parsable with all of the fields
2723    # unknown.
2724    proto2 = unittest_pb2.TestEmptyMessage()
2725    # Parsing this message should succeed.
2726    self.assertEqual(
2727        len(serialized),
2728        proto2.MergeFromString(serialized))
2729
2730  def _CheckRaises(self, exc_class, callable_obj, exception):
2731    """This method checks if the excpetion type and message are as expected."""
2732    try:
2733      callable_obj()
2734    except exc_class as ex:
2735      # Check if the exception message is the right one.
2736      self.assertEqual(exception, str(ex))
2737      return
2738    else:
2739      raise self.failureException('%s not raised' % str(exc_class))
2740
2741  def testSerializeUninitialized(self):
2742    proto = unittest_pb2.TestRequired()
2743    self._CheckRaises(
2744        message.EncodeError,
2745        proto.SerializeToString,
2746        'Message protobuf_unittest.TestRequired is missing required fields: '
2747        'a,b,c')
2748    # Shouldn't raise exceptions.
2749    partial = proto.SerializePartialToString()
2750
2751    proto2 = unittest_pb2.TestRequired()
2752    self.assertFalse(proto2.HasField('a'))
2753    # proto2 ParseFromString does not check that required fields are set.
2754    proto2.ParseFromString(partial)
2755    self.assertFalse(proto2.HasField('a'))
2756
2757    proto.a = 1
2758    self._CheckRaises(
2759        message.EncodeError,
2760        proto.SerializeToString,
2761        'Message protobuf_unittest.TestRequired is missing required fields: b,c')
2762    # Shouldn't raise exceptions.
2763    partial = proto.SerializePartialToString()
2764
2765    proto.b = 2
2766    self._CheckRaises(
2767        message.EncodeError,
2768        proto.SerializeToString,
2769        'Message protobuf_unittest.TestRequired is missing required fields: c')
2770    # Shouldn't raise exceptions.
2771    partial = proto.SerializePartialToString()
2772
2773    proto.c = 3
2774    serialized = proto.SerializeToString()
2775    # Shouldn't raise exceptions.
2776    partial = proto.SerializePartialToString()
2777
2778    proto2 = unittest_pb2.TestRequired()
2779    self.assertEqual(
2780        len(serialized),
2781        proto2.MergeFromString(serialized))
2782    self.assertEqual(1, proto2.a)
2783    self.assertEqual(2, proto2.b)
2784    self.assertEqual(3, proto2.c)
2785    self.assertEqual(
2786        len(partial),
2787        proto2.MergeFromString(partial))
2788    self.assertEqual(1, proto2.a)
2789    self.assertEqual(2, proto2.b)
2790    self.assertEqual(3, proto2.c)
2791
2792  def testSerializeUninitializedSubMessage(self):
2793    proto = unittest_pb2.TestRequiredForeign()
2794
2795    # Sub-message doesn't exist yet, so this succeeds.
2796    proto.SerializeToString()
2797
2798    proto.optional_message.a = 1
2799    self._CheckRaises(
2800        message.EncodeError,
2801        proto.SerializeToString,
2802        'Message protobuf_unittest.TestRequiredForeign '
2803        'is missing required fields: '
2804        'optional_message.b,optional_message.c')
2805
2806    proto.optional_message.b = 2
2807    proto.optional_message.c = 3
2808    proto.SerializeToString()
2809
2810    proto.repeated_message.add().a = 1
2811    proto.repeated_message.add().b = 2
2812    self._CheckRaises(
2813        message.EncodeError,
2814        proto.SerializeToString,
2815        'Message protobuf_unittest.TestRequiredForeign is missing required fields: '
2816        'repeated_message[0].b,repeated_message[0].c,'
2817        'repeated_message[1].a,repeated_message[1].c')
2818
2819    proto.repeated_message[0].b = 2
2820    proto.repeated_message[0].c = 3
2821    proto.repeated_message[1].a = 1
2822    proto.repeated_message[1].c = 3
2823    proto.SerializeToString()
2824
2825  def testSerializeAllPackedFields(self):
2826    first_proto = unittest_pb2.TestPackedTypes()
2827    second_proto = unittest_pb2.TestPackedTypes()
2828    test_util.SetAllPackedFields(first_proto)
2829    serialized = first_proto.SerializeToString()
2830    self.assertEqual(first_proto.ByteSize(), len(serialized))
2831    bytes_read = second_proto.MergeFromString(serialized)
2832    self.assertEqual(second_proto.ByteSize(), bytes_read)
2833    self.assertEqual(first_proto, second_proto)
2834
2835  def testSerializeAllPackedExtensions(self):
2836    first_proto = unittest_pb2.TestPackedExtensions()
2837    second_proto = unittest_pb2.TestPackedExtensions()
2838    test_util.SetAllPackedExtensions(first_proto)
2839    serialized = first_proto.SerializeToString()
2840    bytes_read = second_proto.MergeFromString(serialized)
2841    self.assertEqual(second_proto.ByteSize(), bytes_read)
2842    self.assertEqual(first_proto, second_proto)
2843
2844  def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
2845    first_proto = unittest_pb2.TestPackedTypes()
2846    first_proto.packed_int32.extend([1, 2])
2847    first_proto.packed_double.append(3.0)
2848    serialized = first_proto.SerializeToString()
2849
2850    second_proto = unittest_pb2.TestPackedTypes()
2851    second_proto.packed_int32.append(3)
2852    second_proto.packed_double.extend([1.0, 2.0])
2853    second_proto.packed_sint32.append(4)
2854
2855    self.assertEqual(
2856        len(serialized),
2857        second_proto.MergeFromString(serialized))
2858    self.assertEqual([3, 1, 2], second_proto.packed_int32)
2859    self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
2860    self.assertEqual([4], second_proto.packed_sint32)
2861
2862  def testPackedFieldsWireFormat(self):
2863    proto = unittest_pb2.TestPackedTypes()
2864    proto.packed_int32.extend([1, 2, 150, 3])  # 1 + 1 + 2 + 1 bytes
2865    proto.packed_double.extend([1.0, 1000.0])  # 8 + 8 bytes
2866    proto.packed_float.append(2.0)             # 4 bytes, will be before double
2867    serialized = proto.SerializeToString()
2868    self.assertEqual(proto.ByteSize(), len(serialized))
2869    d = _MiniDecoder(serialized)
2870    ReadTag = d.ReadFieldNumberAndWireType
2871    self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2872    self.assertEqual(1+1+1+2, d.ReadInt32())
2873    self.assertEqual(1, d.ReadInt32())
2874    self.assertEqual(2, d.ReadInt32())
2875    self.assertEqual(150, d.ReadInt32())
2876    self.assertEqual(3, d.ReadInt32())
2877    self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2878    self.assertEqual(4, d.ReadInt32())
2879    self.assertEqual(2.0, d.ReadFloat())
2880    self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2881    self.assertEqual(8+8, d.ReadInt32())
2882    self.assertEqual(1.0, d.ReadDouble())
2883    self.assertEqual(1000.0, d.ReadDouble())
2884    self.assertTrue(d.EndOfStream())
2885
2886  def testParsePackedFromUnpacked(self):
2887    unpacked = unittest_pb2.TestUnpackedTypes()
2888    test_util.SetAllUnpackedFields(unpacked)
2889    packed = unittest_pb2.TestPackedTypes()
2890    serialized = unpacked.SerializeToString()
2891    self.assertEqual(
2892        len(serialized),
2893        packed.MergeFromString(serialized))
2894    expected = unittest_pb2.TestPackedTypes()
2895    test_util.SetAllPackedFields(expected)
2896    self.assertEqual(expected, packed)
2897
2898  def testParseUnpackedFromPacked(self):
2899    packed = unittest_pb2.TestPackedTypes()
2900    test_util.SetAllPackedFields(packed)
2901    unpacked = unittest_pb2.TestUnpackedTypes()
2902    serialized = packed.SerializeToString()
2903    self.assertEqual(
2904        len(serialized),
2905        unpacked.MergeFromString(serialized))
2906    expected = unittest_pb2.TestUnpackedTypes()
2907    test_util.SetAllUnpackedFields(expected)
2908    self.assertEqual(expected, unpacked)
2909
2910  def testFieldNumbers(self):
2911    proto = unittest_pb2.TestAllTypes()
2912    self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
2913    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1)
2914    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16)
2915    self.assertEqual(
2916      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18)
2917    self.assertEqual(
2918      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21)
2919    self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31)
2920    self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46)
2921    self.assertEqual(
2922      unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48)
2923    self.assertEqual(
2924      unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51)
2925
2926  def testExtensionFieldNumbers(self):
2927    self.assertEqual(unittest_pb2.TestRequired.single.number, 1000)
2928    self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000)
2929    self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001)
2930    self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001)
2931    self.assertEqual(unittest_pb2.optional_int32_extension.number, 1)
2932    self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1)
2933    self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16)
2934    self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16)
2935    self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18)
2936    self.assertEqual(
2937      unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18)
2938    self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21)
2939    self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
2940      21)
2941    self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31)
2942    self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31)
2943    self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46)
2944    self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46)
2945    self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48)
2946    self.assertEqual(
2947      unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48)
2948    self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51)
2949    self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
2950      51)
2951
2952  def testFieldProperties(self):
2953    cls = unittest_pb2.TestAllTypes
2954    self.assertIs(cls.optional_int32.DESCRIPTOR,
2955                  cls.DESCRIPTOR.fields_by_name['optional_int32'])
2956    self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER,
2957                     cls.optional_int32.DESCRIPTOR.number)
2958    self.assertIs(cls.optional_nested_message.DESCRIPTOR,
2959                  cls.DESCRIPTOR.fields_by_name['optional_nested_message'])
2960    self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER,
2961                     cls.optional_nested_message.DESCRIPTOR.number)
2962    self.assertIs(cls.repeated_int32.DESCRIPTOR,
2963                  cls.DESCRIPTOR.fields_by_name['repeated_int32'])
2964    self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER,
2965                     cls.repeated_int32.DESCRIPTOR.number)
2966
2967  def testFieldDataDescriptor(self):
2968    msg = unittest_pb2.TestAllTypes()
2969    msg.optional_int32 = 42
2970    self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42)
2971    unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25)
2972    self.assertEqual(msg.optional_int32, 25)
2973    with self.assertRaises(AttributeError):
2974      del msg.optional_int32
2975    try:
2976      unittest_pb2.ForeignMessage.c.__get__(msg)
2977    except TypeError:
2978      pass  # The cpp implementation cannot mix fields from other messages.
2979            # This test exercises a specific check that avoids a crash.
2980    else:
2981      pass  # The python implementation allows fields from other messages.
2982            # This is useless, but works.
2983
2984  def testInitKwargs(self):
2985    proto = unittest_pb2.TestAllTypes(
2986        optional_int32=1,
2987        optional_string='foo',
2988        optional_bool=True,
2989        optional_bytes=b'bar',
2990        optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
2991        optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
2992        optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
2993        optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
2994        repeated_int32=[1, 2, 3])
2995    self.assertTrue(proto.IsInitialized())
2996    self.assertTrue(proto.HasField('optional_int32'))
2997    self.assertTrue(proto.HasField('optional_string'))
2998    self.assertTrue(proto.HasField('optional_bool'))
2999    self.assertTrue(proto.HasField('optional_bytes'))
3000    self.assertTrue(proto.HasField('optional_nested_message'))
3001    self.assertTrue(proto.HasField('optional_foreign_message'))
3002    self.assertTrue(proto.HasField('optional_nested_enum'))
3003    self.assertTrue(proto.HasField('optional_foreign_enum'))
3004    self.assertEqual(1, proto.optional_int32)
3005    self.assertEqual('foo', proto.optional_string)
3006    self.assertEqual(True, proto.optional_bool)
3007    self.assertEqual(b'bar', proto.optional_bytes)
3008    self.assertEqual(1, proto.optional_nested_message.bb)
3009    self.assertEqual(1, proto.optional_foreign_message.c)
3010    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
3011                     proto.optional_nested_enum)
3012    self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
3013    self.assertEqual([1, 2, 3], proto.repeated_int32)
3014
3015  def testInitArgsUnknownFieldName(self):
3016    def InitalizeEmptyMessageWithExtraKeywordArg():
3017      unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
3018    self._CheckRaises(
3019        ValueError,
3020        InitalizeEmptyMessageWithExtraKeywordArg,
3021        'Protocol message TestEmptyMessage has no "unknown" field.')
3022
3023  def testInitRequiredKwargs(self):
3024    proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
3025    self.assertTrue(proto.IsInitialized())
3026    self.assertTrue(proto.HasField('a'))
3027    self.assertTrue(proto.HasField('b'))
3028    self.assertTrue(proto.HasField('c'))
3029    self.assertTrue(not proto.HasField('dummy2'))
3030    self.assertEqual(1, proto.a)
3031    self.assertEqual(1, proto.b)
3032    self.assertEqual(1, proto.c)
3033
3034  def testInitRequiredForeignKwargs(self):
3035    proto = unittest_pb2.TestRequiredForeign(
3036        optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
3037    self.assertTrue(proto.IsInitialized())
3038    self.assertTrue(proto.HasField('optional_message'))
3039    self.assertTrue(proto.optional_message.IsInitialized())
3040    self.assertTrue(proto.optional_message.HasField('a'))
3041    self.assertTrue(proto.optional_message.HasField('b'))
3042    self.assertTrue(proto.optional_message.HasField('c'))
3043    self.assertTrue(not proto.optional_message.HasField('dummy2'))
3044    self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
3045                     proto.optional_message)
3046    self.assertEqual(1, proto.optional_message.a)
3047    self.assertEqual(1, proto.optional_message.b)
3048    self.assertEqual(1, proto.optional_message.c)
3049
3050  def testInitRepeatedKwargs(self):
3051    proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
3052    self.assertTrue(proto.IsInitialized())
3053    self.assertEqual(1, proto.repeated_int32[0])
3054    self.assertEqual(2, proto.repeated_int32[1])
3055    self.assertEqual(3, proto.repeated_int32[2])
3056
3057
3058@testing_refleaks.TestCase
3059class OptionsTest(unittest.TestCase):
3060
3061  def testMessageOptions(self):
3062    proto = message_set_extensions_pb2.TestMessageSet()
3063    self.assertEqual(True,
3064                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
3065    proto = unittest_pb2.TestAllTypes()
3066    self.assertEqual(False,
3067                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
3068
3069  def testPackedOptions(self):
3070    proto = unittest_pb2.TestAllTypes()
3071    proto.optional_int32 = 1
3072    proto.optional_double = 3.0
3073    for field_descriptor, _ in proto.ListFields():
3074      self.assertEqual(False, field_descriptor.GetOptions().packed)
3075
3076    proto = unittest_pb2.TestPackedTypes()
3077    proto.packed_int32.append(1)
3078    proto.packed_double.append(3.0)
3079    for field_descriptor, _ in proto.ListFields():
3080      self.assertEqual(True, field_descriptor.GetOptions().packed)
3081      self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED,
3082                       field_descriptor.label)
3083
3084
3085
3086@testing_refleaks.TestCase
3087class ClassAPITest(unittest.TestCase):
3088
3089  @unittest.skipIf(
3090      api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
3091      'C++ implementation requires a call to MakeDescriptor()')
3092  @testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable')
3093  def testMakeClassWithNestedDescriptor(self):
3094    leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '',
3095                                      containing_type=None, fields=[],
3096                                      nested_types=[], enum_types=[],
3097                                      extensions=[])
3098    child_desc = descriptor.Descriptor('child', 'package.parent.child', '',
3099                                       containing_type=None, fields=[],
3100                                       nested_types=[leaf_desc], enum_types=[],
3101                                       extensions=[])
3102    sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling',
3103                                         '', containing_type=None, fields=[],
3104                                         nested_types=[], enum_types=[],
3105                                         extensions=[])
3106    parent_desc = descriptor.Descriptor('parent', 'package.parent', '',
3107                                        containing_type=None, fields=[],
3108                                        nested_types=[child_desc, sibling_desc],
3109                                        enum_types=[], extensions=[])
3110    reflection.MakeClass(parent_desc)
3111
3112  def _GetSerializedFileDescriptor(self, name):
3113    """Get a serialized representation of a test FileDescriptorProto.
3114
3115    Args:
3116      name: All calls to this must use a unique message name, to avoid
3117          collisions in the cpp descriptor pool.
3118    Returns:
3119      A string containing the serialized form of a test FileDescriptorProto.
3120    """
3121    file_descriptor_str = (
3122        'message_type {'
3123        '  name: "' + name + '"'
3124        '  field {'
3125        '    name: "flat"'
3126        '    number: 1'
3127        '    label: LABEL_REPEATED'
3128        '    type: TYPE_UINT32'
3129        '  }'
3130        '  field {'
3131        '    name: "bar"'
3132        '    number: 2'
3133        '    label: LABEL_OPTIONAL'
3134        '    type: TYPE_MESSAGE'
3135        '    type_name: "Bar"'
3136        '  }'
3137        '  nested_type {'
3138        '    name: "Bar"'
3139        '    field {'
3140        '      name: "baz"'
3141        '      number: 3'
3142        '      label: LABEL_OPTIONAL'
3143        '      type: TYPE_MESSAGE'
3144        '      type_name: "Baz"'
3145        '    }'
3146        '    nested_type {'
3147        '      name: "Baz"'
3148        '      enum_type {'
3149        '        name: "deep_enum"'
3150        '        value {'
3151        '          name: "VALUE_A"'
3152        '          number: 0'
3153        '        }'
3154        '      }'
3155        '      field {'
3156        '        name: "deep"'
3157        '        number: 4'
3158        '        label: LABEL_OPTIONAL'
3159        '        type: TYPE_UINT32'
3160        '      }'
3161        '    }'
3162        '  }'
3163        '}')
3164    file_descriptor = descriptor_pb2.FileDescriptorProto()
3165    text_format.Merge(file_descriptor_str, file_descriptor)
3166    return file_descriptor.SerializeToString()
3167
3168  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
3169  # This test can only run once; the second time, it raises errors about
3170  # conflicting message descriptors.
3171  def testParsingFlatClassWithExplicitClassDeclaration(self):
3172    """Test that the generated class can parse a flat message."""
3173    # TODO(xiaofeng): This test fails with cpp implemetnation in the call
3174    # of six.with_metaclass(). The other two callsites of with_metaclass
3175    # in this file are both excluded from cpp test, so it might be expected
3176    # to fail. Need someone more familiar with the python code to take a
3177    # look at this.
3178    if api_implementation.Type() != 'python':
3179      return
3180    file_descriptor = descriptor_pb2.FileDescriptorProto()
3181    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
3182    msg_descriptor = descriptor.MakeDescriptor(
3183        file_descriptor.message_type[0])
3184
3185    class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
3186      DESCRIPTOR = msg_descriptor
3187    msg = MessageClass()
3188    msg_str = (
3189        'flat: 0 '
3190        'flat: 1 '
3191        'flat: 2 ')
3192    text_format.Merge(msg_str, msg)
3193    self.assertEqual(msg.flat, [0, 1, 2])
3194
3195  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
3196  def testParsingFlatClass(self):
3197    """Test that the generated class can parse a flat message."""
3198    file_descriptor = descriptor_pb2.FileDescriptorProto()
3199    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B'))
3200    msg_descriptor = descriptor.MakeDescriptor(
3201        file_descriptor.message_type[0])
3202    msg_class = reflection.MakeClass(msg_descriptor)
3203    msg = msg_class()
3204    msg_str = (
3205        'flat: 0 '
3206        'flat: 1 '
3207        'flat: 2 ')
3208    text_format.Merge(msg_str, msg)
3209    self.assertEqual(msg.flat, [0, 1, 2])
3210
3211  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
3212  def testParsingNestedClass(self):
3213    """Test that the generated class can parse a nested message."""
3214    file_descriptor = descriptor_pb2.FileDescriptorProto()
3215    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
3216    msg_descriptor = descriptor.MakeDescriptor(
3217        file_descriptor.message_type[0])
3218    msg_class = reflection.MakeClass(msg_descriptor)
3219    msg = msg_class()
3220    msg_str = (
3221        'bar {'
3222        '  baz {'
3223        '    deep: 4'
3224        '  }'
3225        '}')
3226    text_format.Merge(msg_str, msg)
3227    self.assertEqual(msg.bar.baz.deep, 4)
3228
3229if __name__ == '__main__':
3230  unittest.main()
3231