1#! /usr/bin/env python 2# -*- coding: utf-8 -*- 3# 4# Protocol Buffers - Google's data interchange format 5# Copyright 2008 Google Inc. All rights reserved. 6# https://developers.google.com/protocol-buffers/ 7# 8# Redistribution and use in source and binary forms, with or without 9# modification, are permitted provided that the following conditions are 10# met: 11# 12# * Redistributions of source code must retain the above copyright 13# notice, this list of conditions and the following disclaimer. 14# * Redistributions in binary form must reproduce the above 15# copyright notice, this list of conditions and the following disclaimer 16# in the documentation and/or other materials provided with the 17# distribution. 18# * Neither the name of Google Inc. nor the names of its 19# contributors may be used to endorse or promote products derived from 20# this software without specific prior written permission. 21# 22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 34"""Unittest for reflection.py, which also indirectly tests the output of the 35pure-Python protocol compiler. 36""" 37 38import copy 39import gc 40import operator 41import six 42import struct 43 44try: 45 import unittest2 as unittest #PY26 46except ImportError: 47 import unittest 48 49from google.protobuf import unittest_import_pb2 50from google.protobuf import unittest_mset_pb2 51from google.protobuf import unittest_pb2 52from google.protobuf import descriptor_pb2 53from google.protobuf import descriptor 54from google.protobuf import message 55from google.protobuf import reflection 56from google.protobuf import text_format 57from google.protobuf.internal import api_implementation 58from google.protobuf.internal import more_extensions_pb2 59from google.protobuf.internal import more_messages_pb2 60from google.protobuf.internal import message_set_extensions_pb2 61from google.protobuf.internal import wire_format 62from google.protobuf.internal import test_util 63from google.protobuf.internal import testing_refleaks 64from google.protobuf.internal import decoder 65 66 67if six.PY3: 68 long = int # pylint: disable=redefined-builtin,invalid-name 69 70 71class _MiniDecoder(object): 72 """Decodes a stream of values from a string. 73 74 Once upon a time we actually had a class called decoder.Decoder. Then we 75 got rid of it during a redesign that made decoding much, much faster overall. 76 But a couple tests in this file used it to check that the serialized form of 77 a message was correct. So, this class implements just the methods that were 78 used by said tests, so that we don't have to rewrite the tests. 79 """ 80 81 def __init__(self, bytes): 82 self._bytes = bytes 83 self._pos = 0 84 85 def ReadVarint(self): 86 result, self._pos = decoder._DecodeVarint(self._bytes, self._pos) 87 return result 88 89 ReadInt32 = ReadVarint 90 ReadInt64 = ReadVarint 91 ReadUInt32 = ReadVarint 92 ReadUInt64 = ReadVarint 93 94 def ReadSInt64(self): 95 return wire_format.ZigZagDecode(self.ReadVarint()) 96 97 ReadSInt32 = ReadSInt64 98 99 def ReadFieldNumberAndWireType(self): 100 return wire_format.UnpackTag(self.ReadVarint()) 101 102 def ReadFloat(self): 103 result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0] 104 self._pos += 4 105 return result 106 107 def ReadDouble(self): 108 result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0] 109 self._pos += 8 110 return result 111 112 def EndOfStream(self): 113 return self._pos == len(self._bytes) 114 115 116@testing_refleaks.TestCase 117class ReflectionTest(unittest.TestCase): 118 119 def assertListsEqual(self, values, others): 120 self.assertEqual(len(values), len(others)) 121 for i in range(len(values)): 122 self.assertEqual(values[i], others[i]) 123 124 def testScalarConstructor(self): 125 # Constructor with only scalar types should succeed. 126 proto = unittest_pb2.TestAllTypes( 127 optional_int32=24, 128 optional_double=54.321, 129 optional_string='optional_string', 130 optional_float=None) 131 132 self.assertEqual(24, proto.optional_int32) 133 self.assertEqual(54.321, proto.optional_double) 134 self.assertEqual('optional_string', proto.optional_string) 135 self.assertFalse(proto.HasField("optional_float")) 136 137 def testRepeatedScalarConstructor(self): 138 # Constructor with only repeated scalar types should succeed. 139 proto = unittest_pb2.TestAllTypes( 140 repeated_int32=[1, 2, 3, 4], 141 repeated_double=[1.23, 54.321], 142 repeated_bool=[True, False, False], 143 repeated_string=["optional_string"], 144 repeated_float=None) 145 146 self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32)) 147 self.assertEqual([1.23, 54.321], list(proto.repeated_double)) 148 self.assertEqual([True, False, False], list(proto.repeated_bool)) 149 self.assertEqual(["optional_string"], list(proto.repeated_string)) 150 self.assertEqual([], list(proto.repeated_float)) 151 152 def testRepeatedCompositeConstructor(self): 153 # Constructor with only repeated composite types should succeed. 154 proto = unittest_pb2.TestAllTypes( 155 repeated_nested_message=[ 156 unittest_pb2.TestAllTypes.NestedMessage( 157 bb=unittest_pb2.TestAllTypes.FOO), 158 unittest_pb2.TestAllTypes.NestedMessage( 159 bb=unittest_pb2.TestAllTypes.BAR)], 160 repeated_foreign_message=[ 161 unittest_pb2.ForeignMessage(c=-43), 162 unittest_pb2.ForeignMessage(c=45324), 163 unittest_pb2.ForeignMessage(c=12)], 164 repeatedgroup=[ 165 unittest_pb2.TestAllTypes.RepeatedGroup(), 166 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 167 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) 168 169 self.assertEqual( 170 [unittest_pb2.TestAllTypes.NestedMessage( 171 bb=unittest_pb2.TestAllTypes.FOO), 172 unittest_pb2.TestAllTypes.NestedMessage( 173 bb=unittest_pb2.TestAllTypes.BAR)], 174 list(proto.repeated_nested_message)) 175 self.assertEqual( 176 [unittest_pb2.ForeignMessage(c=-43), 177 unittest_pb2.ForeignMessage(c=45324), 178 unittest_pb2.ForeignMessage(c=12)], 179 list(proto.repeated_foreign_message)) 180 self.assertEqual( 181 [unittest_pb2.TestAllTypes.RepeatedGroup(), 182 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 183 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], 184 list(proto.repeatedgroup)) 185 186 def testMixedConstructor(self): 187 # Constructor with only mixed types should succeed. 188 proto = unittest_pb2.TestAllTypes( 189 optional_int32=24, 190 optional_string='optional_string', 191 repeated_double=[1.23, 54.321], 192 repeated_bool=[True, False, False], 193 repeated_nested_message=[ 194 unittest_pb2.TestAllTypes.NestedMessage( 195 bb=unittest_pb2.TestAllTypes.FOO), 196 unittest_pb2.TestAllTypes.NestedMessage( 197 bb=unittest_pb2.TestAllTypes.BAR)], 198 repeated_foreign_message=[ 199 unittest_pb2.ForeignMessage(c=-43), 200 unittest_pb2.ForeignMessage(c=45324), 201 unittest_pb2.ForeignMessage(c=12)], 202 optional_nested_message=None) 203 204 self.assertEqual(24, proto.optional_int32) 205 self.assertEqual('optional_string', proto.optional_string) 206 self.assertEqual([1.23, 54.321], list(proto.repeated_double)) 207 self.assertEqual([True, False, False], list(proto.repeated_bool)) 208 self.assertEqual( 209 [unittest_pb2.TestAllTypes.NestedMessage( 210 bb=unittest_pb2.TestAllTypes.FOO), 211 unittest_pb2.TestAllTypes.NestedMessage( 212 bb=unittest_pb2.TestAllTypes.BAR)], 213 list(proto.repeated_nested_message)) 214 self.assertEqual( 215 [unittest_pb2.ForeignMessage(c=-43), 216 unittest_pb2.ForeignMessage(c=45324), 217 unittest_pb2.ForeignMessage(c=12)], 218 list(proto.repeated_foreign_message)) 219 self.assertFalse(proto.HasField("optional_nested_message")) 220 221 def testConstructorTypeError(self): 222 self.assertRaises( 223 TypeError, unittest_pb2.TestAllTypes, optional_int32="foo") 224 self.assertRaises( 225 TypeError, unittest_pb2.TestAllTypes, optional_string=1234) 226 self.assertRaises( 227 TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234) 228 self.assertRaises( 229 TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234) 230 self.assertRaises( 231 TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"]) 232 self.assertRaises( 233 TypeError, unittest_pb2.TestAllTypes, repeated_string=1234) 234 self.assertRaises( 235 TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234]) 236 self.assertRaises( 237 TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234) 238 self.assertRaises( 239 TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234]) 240 241 def testConstructorInvalidatesCachedByteSize(self): 242 message = unittest_pb2.TestAllTypes(optional_int32 = 12) 243 self.assertEqual(2, message.ByteSize()) 244 245 message = unittest_pb2.TestAllTypes( 246 optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage()) 247 self.assertEqual(3, message.ByteSize()) 248 249 message = unittest_pb2.TestAllTypes(repeated_int32 = [12]) 250 self.assertEqual(3, message.ByteSize()) 251 252 message = unittest_pb2.TestAllTypes( 253 repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()]) 254 self.assertEqual(3, message.ByteSize()) 255 256 def testSimpleHasBits(self): 257 # Test a scalar. 258 proto = unittest_pb2.TestAllTypes() 259 self.assertTrue(not proto.HasField('optional_int32')) 260 self.assertEqual(0, proto.optional_int32) 261 # HasField() shouldn't be true if all we've done is 262 # read the default value. 263 self.assertTrue(not proto.HasField('optional_int32')) 264 proto.optional_int32 = 1 265 # Setting a value however *should* set the "has" bit. 266 self.assertTrue(proto.HasField('optional_int32')) 267 proto.ClearField('optional_int32') 268 # And clearing that value should unset the "has" bit. 269 self.assertTrue(not proto.HasField('optional_int32')) 270 271 def testHasBitsWithSinglyNestedScalar(self): 272 # Helper used to test foreign messages and groups. 273 # 274 # composite_field_name should be the name of a non-repeated 275 # composite (i.e., foreign or group) field in TestAllTypes, 276 # and scalar_field_name should be the name of an integer-valued 277 # scalar field within that composite. 278 # 279 # I never thought I'd miss C++ macros and templates so much. :( 280 # This helper is semantically just: 281 # 282 # assert proto.composite_field.scalar_field == 0 283 # assert not proto.composite_field.HasField('scalar_field') 284 # assert not proto.HasField('composite_field') 285 # 286 # proto.composite_field.scalar_field = 10 287 # old_composite_field = proto.composite_field 288 # 289 # assert proto.composite_field.scalar_field == 10 290 # assert proto.composite_field.HasField('scalar_field') 291 # assert proto.HasField('composite_field') 292 # 293 # proto.ClearField('composite_field') 294 # 295 # assert not proto.composite_field.HasField('scalar_field') 296 # assert not proto.HasField('composite_field') 297 # assert proto.composite_field.scalar_field == 0 298 # 299 # # Now ensure that ClearField('composite_field') disconnected 300 # # the old field object from the object tree... 301 # assert old_composite_field is not proto.composite_field 302 # old_composite_field.scalar_field = 20 303 # assert not proto.composite_field.HasField('scalar_field') 304 # assert not proto.HasField('composite_field') 305 def TestCompositeHasBits(composite_field_name, scalar_field_name): 306 proto = unittest_pb2.TestAllTypes() 307 # First, check that we can get the scalar value, and see that it's the 308 # default (0), but that proto.HasField('omposite') and 309 # proto.composite.HasField('scalar') will still return False. 310 composite_field = getattr(proto, composite_field_name) 311 original_scalar_value = getattr(composite_field, scalar_field_name) 312 self.assertEqual(0, original_scalar_value) 313 # Assert that the composite object does not "have" the scalar. 314 self.assertTrue(not composite_field.HasField(scalar_field_name)) 315 # Assert that proto does not "have" the composite field. 316 self.assertTrue(not proto.HasField(composite_field_name)) 317 318 # Now set the scalar within the composite field. Ensure that the setting 319 # is reflected, and that proto.HasField('composite') and 320 # proto.composite.HasField('scalar') now both return True. 321 new_val = 20 322 setattr(composite_field, scalar_field_name, new_val) 323 self.assertEqual(new_val, getattr(composite_field, scalar_field_name)) 324 # Hold on to a reference to the current composite_field object. 325 old_composite_field = composite_field 326 # Assert that the has methods now return true. 327 self.assertTrue(composite_field.HasField(scalar_field_name)) 328 self.assertTrue(proto.HasField(composite_field_name)) 329 330 # Now call the clear method... 331 proto.ClearField(composite_field_name) 332 333 # ...and ensure that the "has" bits are all back to False... 334 composite_field = getattr(proto, composite_field_name) 335 self.assertTrue(not composite_field.HasField(scalar_field_name)) 336 self.assertTrue(not proto.HasField(composite_field_name)) 337 # ...and ensure that the scalar field has returned to its default. 338 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 339 340 self.assertTrue(old_composite_field is not composite_field) 341 setattr(old_composite_field, scalar_field_name, new_val) 342 self.assertTrue(not composite_field.HasField(scalar_field_name)) 343 self.assertTrue(not proto.HasField(composite_field_name)) 344 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 345 346 # Test simple, single-level nesting when we set a scalar. 347 TestCompositeHasBits('optionalgroup', 'a') 348 TestCompositeHasBits('optional_nested_message', 'bb') 349 TestCompositeHasBits('optional_foreign_message', 'c') 350 TestCompositeHasBits('optional_import_message', 'd') 351 352 def testReferencesToNestedMessage(self): 353 proto = unittest_pb2.TestAllTypes() 354 nested = proto.optional_nested_message 355 del proto 356 # A previous version had a bug where this would raise an exception when 357 # hitting a now-dead weak reference. 358 nested.bb = 23 359 360 def testDisconnectingNestedMessageBeforeSettingField(self): 361 proto = unittest_pb2.TestAllTypes() 362 nested = proto.optional_nested_message 363 proto.ClearField('optional_nested_message') # Should disconnect from parent 364 self.assertTrue(nested is not proto.optional_nested_message) 365 nested.bb = 23 366 self.assertTrue(not proto.HasField('optional_nested_message')) 367 self.assertEqual(0, proto.optional_nested_message.bb) 368 369 def testGetDefaultMessageAfterDisconnectingDefaultMessage(self): 370 proto = unittest_pb2.TestAllTypes() 371 nested = proto.optional_nested_message 372 proto.ClearField('optional_nested_message') 373 del proto 374 del nested 375 # Force a garbage collect so that the underlying CMessages are freed along 376 # with the Messages they point to. This is to make sure we're not deleting 377 # default message instances. 378 gc.collect() 379 proto = unittest_pb2.TestAllTypes() 380 nested = proto.optional_nested_message 381 382 def testDisconnectingNestedMessageAfterSettingField(self): 383 proto = unittest_pb2.TestAllTypes() 384 nested = proto.optional_nested_message 385 nested.bb = 5 386 self.assertTrue(proto.HasField('optional_nested_message')) 387 proto.ClearField('optional_nested_message') # Should disconnect from parent 388 self.assertEqual(5, nested.bb) 389 self.assertEqual(0, proto.optional_nested_message.bb) 390 self.assertTrue(nested is not proto.optional_nested_message) 391 nested.bb = 23 392 self.assertTrue(not proto.HasField('optional_nested_message')) 393 self.assertEqual(0, proto.optional_nested_message.bb) 394 395 def testDisconnectingNestedMessageBeforeGettingField(self): 396 proto = unittest_pb2.TestAllTypes() 397 self.assertTrue(not proto.HasField('optional_nested_message')) 398 proto.ClearField('optional_nested_message') 399 self.assertTrue(not proto.HasField('optional_nested_message')) 400 401 def testDisconnectingNestedMessageAfterMerge(self): 402 # This test exercises the code path that does not use ReleaseMessage(). 403 # The underlying fear is that if we use ReleaseMessage() incorrectly, 404 # we will have memory leaks. It's hard to check that that doesn't happen, 405 # but at least we can exercise that code path to make sure it works. 406 proto1 = unittest_pb2.TestAllTypes() 407 proto2 = unittest_pb2.TestAllTypes() 408 proto2.optional_nested_message.bb = 5 409 proto1.MergeFrom(proto2) 410 self.assertTrue(proto1.HasField('optional_nested_message')) 411 proto1.ClearField('optional_nested_message') 412 self.assertTrue(not proto1.HasField('optional_nested_message')) 413 414 def testDisconnectingLazyNestedMessage(self): 415 # This test exercises releasing a nested message that is lazy. This test 416 # only exercises real code in the C++ implementation as Python does not 417 # support lazy parsing, but the current C++ implementation results in 418 # memory corruption and a crash. 419 if api_implementation.Type() != 'python': 420 return 421 proto = unittest_pb2.TestAllTypes() 422 proto.optional_lazy_message.bb = 5 423 proto.ClearField('optional_lazy_message') 424 del proto 425 gc.collect() 426 427 def testHasBitsWhenModifyingRepeatedFields(self): 428 # Test nesting when we add an element to a repeated field in a submessage. 429 proto = unittest_pb2.TestNestedMessageHasBits() 430 proto.optional_nested_message.nestedmessage_repeated_int32.append(5) 431 self.assertEqual( 432 [5], proto.optional_nested_message.nestedmessage_repeated_int32) 433 self.assertTrue(proto.HasField('optional_nested_message')) 434 435 # Do the same test, but with a repeated composite field within the 436 # submessage. 437 proto.ClearField('optional_nested_message') 438 self.assertTrue(not proto.HasField('optional_nested_message')) 439 proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add() 440 self.assertTrue(proto.HasField('optional_nested_message')) 441 442 def testHasBitsForManyLevelsOfNesting(self): 443 # Test nesting many levels deep. 444 recursive_proto = unittest_pb2.TestMutualRecursionA() 445 self.assertTrue(not recursive_proto.HasField('bb')) 446 self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32) 447 self.assertTrue(not recursive_proto.HasField('bb')) 448 recursive_proto.bb.a.bb.a.bb.optional_int32 = 5 449 self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32) 450 self.assertTrue(recursive_proto.HasField('bb')) 451 self.assertTrue(recursive_proto.bb.HasField('a')) 452 self.assertTrue(recursive_proto.bb.a.HasField('bb')) 453 self.assertTrue(recursive_proto.bb.a.bb.HasField('a')) 454 self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb')) 455 self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a')) 456 self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32')) 457 458 def testSingularListFields(self): 459 proto = unittest_pb2.TestAllTypes() 460 proto.optional_fixed32 = 1 461 proto.optional_int32 = 5 462 proto.optional_string = 'foo' 463 # Access sub-message but don't set it yet. 464 nested_message = proto.optional_nested_message 465 self.assertEqual( 466 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 467 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 468 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], 469 proto.ListFields()) 470 471 proto.optional_nested_message.bb = 123 472 self.assertEqual( 473 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 474 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 475 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'), 476 (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ], 477 nested_message) ], 478 proto.ListFields()) 479 480 def testRepeatedListFields(self): 481 proto = unittest_pb2.TestAllTypes() 482 proto.repeated_fixed32.append(1) 483 proto.repeated_int32.append(5) 484 proto.repeated_int32.append(11) 485 proto.repeated_string.extend(['foo', 'bar']) 486 proto.repeated_string.extend([]) 487 proto.repeated_string.append('baz') 488 proto.repeated_string.extend(str(x) for x in range(2)) 489 proto.optional_int32 = 21 490 proto.repeated_bool # Access but don't set anything; should not be listed. 491 self.assertEqual( 492 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), 493 (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), 494 (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]), 495 (proto.DESCRIPTOR.fields_by_name['repeated_string' ], 496 ['foo', 'bar', 'baz', '0', '1']) ], 497 proto.ListFields()) 498 499 def testSingularListExtensions(self): 500 proto = unittest_pb2.TestAllExtensions() 501 proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1 502 proto.Extensions[unittest_pb2.optional_int32_extension ] = 5 503 proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo' 504 self.assertEqual( 505 [ (unittest_pb2.optional_int32_extension , 5), 506 (unittest_pb2.optional_fixed32_extension, 1), 507 (unittest_pb2.optional_string_extension , 'foo') ], 508 proto.ListFields()) 509 510 def testRepeatedListExtensions(self): 511 proto = unittest_pb2.TestAllExtensions() 512 proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1) 513 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5) 514 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11) 515 proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo') 516 proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar') 517 proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz') 518 proto.Extensions[unittest_pb2.optional_int32_extension ] = 21 519 self.assertEqual( 520 [ (unittest_pb2.optional_int32_extension , 21), 521 (unittest_pb2.repeated_int32_extension , [5, 11]), 522 (unittest_pb2.repeated_fixed32_extension, [1]), 523 (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ], 524 proto.ListFields()) 525 526 def testListFieldsAndExtensions(self): 527 proto = unittest_pb2.TestFieldOrderings() 528 test_util.SetAllFieldsAndExtensions(proto) 529 unittest_pb2.my_extension_int 530 self.assertEqual( 531 [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1), 532 (unittest_pb2.my_extension_int , 23), 533 (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'), 534 (unittest_pb2.my_extension_string , 'bar'), 535 (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ], 536 proto.ListFields()) 537 538 def testDefaultValues(self): 539 proto = unittest_pb2.TestAllTypes() 540 self.assertEqual(0, proto.optional_int32) 541 self.assertEqual(0, proto.optional_int64) 542 self.assertEqual(0, proto.optional_uint32) 543 self.assertEqual(0, proto.optional_uint64) 544 self.assertEqual(0, proto.optional_sint32) 545 self.assertEqual(0, proto.optional_sint64) 546 self.assertEqual(0, proto.optional_fixed32) 547 self.assertEqual(0, proto.optional_fixed64) 548 self.assertEqual(0, proto.optional_sfixed32) 549 self.assertEqual(0, proto.optional_sfixed64) 550 self.assertEqual(0.0, proto.optional_float) 551 self.assertEqual(0.0, proto.optional_double) 552 self.assertEqual(False, proto.optional_bool) 553 self.assertEqual('', proto.optional_string) 554 self.assertEqual(b'', proto.optional_bytes) 555 556 self.assertEqual(41, proto.default_int32) 557 self.assertEqual(42, proto.default_int64) 558 self.assertEqual(43, proto.default_uint32) 559 self.assertEqual(44, proto.default_uint64) 560 self.assertEqual(-45, proto.default_sint32) 561 self.assertEqual(46, proto.default_sint64) 562 self.assertEqual(47, proto.default_fixed32) 563 self.assertEqual(48, proto.default_fixed64) 564 self.assertEqual(49, proto.default_sfixed32) 565 self.assertEqual(-50, proto.default_sfixed64) 566 self.assertEqual(51.5, proto.default_float) 567 self.assertEqual(52e3, proto.default_double) 568 self.assertEqual(True, proto.default_bool) 569 self.assertEqual('hello', proto.default_string) 570 self.assertEqual(b'world', proto.default_bytes) 571 self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) 572 self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) 573 self.assertEqual(unittest_import_pb2.IMPORT_BAR, 574 proto.default_import_enum) 575 576 proto = unittest_pb2.TestExtremeDefaultValues() 577 self.assertEqual(u'\u1234', proto.utf8_string) 578 579 def testHasFieldWithUnknownFieldName(self): 580 proto = unittest_pb2.TestAllTypes() 581 self.assertRaises(ValueError, proto.HasField, 'nonexistent_field') 582 583 def testClearFieldWithUnknownFieldName(self): 584 proto = unittest_pb2.TestAllTypes() 585 self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') 586 587 def testClearRemovesChildren(self): 588 # Make sure there aren't any implementation bugs that are only partially 589 # clearing the message (which can happen in the more complex C++ 590 # implementation which has parallel message lists). 591 proto = unittest_pb2.TestRequiredForeign() 592 for i in range(10): 593 proto.repeated_message.add() 594 proto2 = unittest_pb2.TestRequiredForeign() 595 proto.CopyFrom(proto2) 596 self.assertRaises(IndexError, lambda: proto.repeated_message[5]) 597 598 def testDisallowedAssignments(self): 599 # It's illegal to assign values directly to repeated fields 600 # or to nonrepeated composite fields. Ensure that this fails. 601 proto = unittest_pb2.TestAllTypes() 602 # Repeated fields. 603 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10) 604 # Lists shouldn't work, either. 605 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10]) 606 # Composite fields. 607 self.assertRaises(AttributeError, setattr, proto, 608 'optional_nested_message', 23) 609 # Assignment to a repeated nested message field without specifying 610 # the index in the array of nested messages. 611 self.assertRaises(AttributeError, setattr, proto.repeated_nested_message, 612 'bb', 34) 613 # Assignment to an attribute of a repeated field. 614 self.assertRaises(AttributeError, setattr, proto.repeated_float, 615 'some_attribute', 34) 616 # proto.nonexistent_field = 23 should fail as well. 617 self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) 618 619 def testSingleScalarTypeSafety(self): 620 proto = unittest_pb2.TestAllTypes() 621 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) 622 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo') 623 self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) 624 self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) 625 self.assertRaises(TypeError, setattr, proto, 'optional_bool', 'foo') 626 self.assertRaises(TypeError, setattr, proto, 'optional_float', 'foo') 627 self.assertRaises(TypeError, setattr, proto, 'optional_double', 'foo') 628 # TODO(jieluo): Fix type checking difference for python and c extension 629 if api_implementation.Type() == 'python': 630 self.assertRaises(TypeError, setattr, proto, 'optional_bool', 1.1) 631 else: 632 proto.optional_bool = 1.1 633 634 def assertIntegerTypes(self, integer_fn): 635 """Verifies setting of scalar integers. 636 637 Args: 638 integer_fn: A function to wrap the integers that will be assigned. 639 """ 640 def TestGetAndDeserialize(field_name, value, expected_type): 641 proto = unittest_pb2.TestAllTypes() 642 value = integer_fn(value) 643 setattr(proto, field_name, value) 644 self.assertIsInstance(getattr(proto, field_name), expected_type) 645 proto2 = unittest_pb2.TestAllTypes() 646 proto2.ParseFromString(proto.SerializeToString()) 647 self.assertIsInstance(getattr(proto2, field_name), expected_type) 648 649 TestGetAndDeserialize('optional_int32', 1, int) 650 TestGetAndDeserialize('optional_int32', 1 << 30, int) 651 TestGetAndDeserialize('optional_uint32', 1 << 30, int) 652 integer_64 = long 653 if struct.calcsize('L') == 4: 654 # Python only has signed ints, so 32-bit python can't fit an uint32 655 # in an int. 656 TestGetAndDeserialize('optional_uint32', 1 << 31, integer_64) 657 else: 658 # 64-bit python can fit uint32 inside an int 659 TestGetAndDeserialize('optional_uint32', 1 << 31, int) 660 TestGetAndDeserialize('optional_int64', 1 << 30, integer_64) 661 TestGetAndDeserialize('optional_int64', 1 << 60, integer_64) 662 TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64) 663 TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64) 664 665 def testIntegerTypes(self): 666 self.assertIntegerTypes(lambda x: x) 667 668 def testNonStandardIntegerTypes(self): 669 self.assertIntegerTypes(test_util.NonStandardInteger) 670 671 def testIllegalValuesForIntegers(self): 672 pb = unittest_pb2.TestAllTypes() 673 674 # Strings are illegal, even when the represent an integer. 675 with self.assertRaises(TypeError): 676 pb.optional_uint64 = '2' 677 678 # The exact error should propagate with a poorly written custom integer. 679 with self.assertRaisesRegexp(RuntimeError, 'my_error'): 680 pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error') 681 682 def assetIntegerBoundsChecking(self, integer_fn): 683 """Verifies bounds checking for scalar integer fields. 684 685 Args: 686 integer_fn: A function to wrap the integers that will be assigned. 687 """ 688 def TestMinAndMaxIntegers(field_name, expected_min, expected_max): 689 pb = unittest_pb2.TestAllTypes() 690 expected_min = integer_fn(expected_min) 691 expected_max = integer_fn(expected_max) 692 setattr(pb, field_name, expected_min) 693 self.assertEqual(expected_min, getattr(pb, field_name)) 694 setattr(pb, field_name, expected_max) 695 self.assertEqual(expected_max, getattr(pb, field_name)) 696 self.assertRaises((ValueError, TypeError), setattr, pb, field_name, 697 expected_min - 1) 698 self.assertRaises((ValueError, TypeError), setattr, pb, field_name, 699 expected_max + 1) 700 701 TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1) 702 TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) 703 TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) 704 TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) 705 # A bit of white-box testing since -1 is an int and not a long in C++ and 706 # so goes down a different path. 707 pb = unittest_pb2.TestAllTypes() 708 with self.assertRaises((ValueError, TypeError)): 709 pb.optional_uint64 = integer_fn(-(1 << 63)) 710 711 pb = unittest_pb2.TestAllTypes() 712 pb.optional_nested_enum = integer_fn(1) 713 self.assertEqual(1, pb.optional_nested_enum) 714 715 def testSingleScalarBoundsChecking(self): 716 self.assetIntegerBoundsChecking(lambda x: x) 717 718 def testNonStandardSingleScalarBoundsChecking(self): 719 self.assetIntegerBoundsChecking(test_util.NonStandardInteger) 720 721 def testRepeatedScalarTypeSafety(self): 722 proto = unittest_pb2.TestAllTypes() 723 self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) 724 self.assertRaises(TypeError, proto.repeated_int32.append, 'foo') 725 self.assertRaises(TypeError, proto.repeated_string, 10) 726 self.assertRaises(TypeError, proto.repeated_bytes, 10) 727 728 proto.repeated_int32.append(10) 729 proto.repeated_int32[0] = 23 730 self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) 731 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') 732 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, []) 733 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 734 'index', 23) 735 736 proto.repeated_string.append('2') 737 self.assertRaises(TypeError, proto.repeated_string.__setitem__, 0, 10) 738 739 # Repeated enums tests. 740 #proto.repeated_nested_enum.append(0) 741 742 def testSingleScalarGettersAndSetters(self): 743 proto = unittest_pb2.TestAllTypes() 744 self.assertEqual(0, proto.optional_int32) 745 proto.optional_int32 = 1 746 self.assertEqual(1, proto.optional_int32) 747 748 proto.optional_uint64 = 0xffffffffffff 749 self.assertEqual(0xffffffffffff, proto.optional_uint64) 750 proto.optional_uint64 = 0xffffffffffffffff 751 self.assertEqual(0xffffffffffffffff, proto.optional_uint64) 752 # TODO(robinson): Test all other scalar field types. 753 754 def testSingleScalarClearField(self): 755 proto = unittest_pb2.TestAllTypes() 756 # Should be allowed to clear something that's not there (a no-op). 757 proto.ClearField('optional_int32') 758 proto.optional_int32 = 1 759 self.assertTrue(proto.HasField('optional_int32')) 760 proto.ClearField('optional_int32') 761 self.assertEqual(0, proto.optional_int32) 762 self.assertTrue(not proto.HasField('optional_int32')) 763 # TODO(robinson): Test all other scalar field types. 764 765 def testEnums(self): 766 proto = unittest_pb2.TestAllTypes() 767 self.assertEqual(1, proto.FOO) 768 self.assertEqual(1, unittest_pb2.TestAllTypes.FOO) 769 self.assertEqual(2, proto.BAR) 770 self.assertEqual(2, unittest_pb2.TestAllTypes.BAR) 771 self.assertEqual(3, proto.BAZ) 772 self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) 773 774 def testEnum_Name(self): 775 self.assertEqual('FOREIGN_FOO', 776 unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO)) 777 self.assertEqual('FOREIGN_BAR', 778 unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR)) 779 self.assertEqual('FOREIGN_BAZ', 780 unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ)) 781 self.assertRaises(ValueError, 782 unittest_pb2.ForeignEnum.Name, 11312) 783 784 proto = unittest_pb2.TestAllTypes() 785 self.assertEqual('FOO', 786 proto.NestedEnum.Name(proto.FOO)) 787 self.assertEqual('FOO', 788 unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO)) 789 self.assertEqual('BAR', 790 proto.NestedEnum.Name(proto.BAR)) 791 self.assertEqual('BAR', 792 unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR)) 793 self.assertEqual('BAZ', 794 proto.NestedEnum.Name(proto.BAZ)) 795 self.assertEqual('BAZ', 796 unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ)) 797 self.assertRaises(ValueError, 798 proto.NestedEnum.Name, 11312) 799 self.assertRaises(ValueError, 800 unittest_pb2.TestAllTypes.NestedEnum.Name, 11312) 801 802 def testEnum_Value(self): 803 self.assertEqual(unittest_pb2.FOREIGN_FOO, 804 unittest_pb2.ForeignEnum.Value('FOREIGN_FOO')) 805 self.assertEqual(unittest_pb2.FOREIGN_FOO, 806 unittest_pb2.ForeignEnum.FOREIGN_FOO) 807 808 self.assertEqual(unittest_pb2.FOREIGN_BAR, 809 unittest_pb2.ForeignEnum.Value('FOREIGN_BAR')) 810 self.assertEqual(unittest_pb2.FOREIGN_BAR, 811 unittest_pb2.ForeignEnum.FOREIGN_BAR) 812 813 self.assertEqual(unittest_pb2.FOREIGN_BAZ, 814 unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ')) 815 self.assertEqual(unittest_pb2.FOREIGN_BAZ, 816 unittest_pb2.ForeignEnum.FOREIGN_BAZ) 817 818 self.assertRaises(ValueError, 819 unittest_pb2.ForeignEnum.Value, 'FO') 820 with self.assertRaises(AttributeError): 821 unittest_pb2.ForeignEnum.FO 822 823 proto = unittest_pb2.TestAllTypes() 824 self.assertEqual(proto.FOO, 825 proto.NestedEnum.Value('FOO')) 826 self.assertEqual(proto.FOO, 827 proto.NestedEnum.FOO) 828 829 self.assertEqual(proto.FOO, 830 unittest_pb2.TestAllTypes.NestedEnum.Value('FOO')) 831 self.assertEqual(proto.FOO, 832 unittest_pb2.TestAllTypes.NestedEnum.FOO) 833 834 self.assertEqual(proto.BAR, 835 proto.NestedEnum.Value('BAR')) 836 self.assertEqual(proto.BAR, 837 proto.NestedEnum.BAR) 838 839 self.assertEqual(proto.BAR, 840 unittest_pb2.TestAllTypes.NestedEnum.Value('BAR')) 841 self.assertEqual(proto.BAR, 842 unittest_pb2.TestAllTypes.NestedEnum.BAR) 843 844 self.assertEqual(proto.BAZ, 845 proto.NestedEnum.Value('BAZ')) 846 self.assertEqual(proto.BAZ, 847 proto.NestedEnum.BAZ) 848 849 self.assertEqual(proto.BAZ, 850 unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ')) 851 self.assertEqual(proto.BAZ, 852 unittest_pb2.TestAllTypes.NestedEnum.BAZ) 853 854 self.assertRaises(ValueError, 855 proto.NestedEnum.Value, 'Foo') 856 with self.assertRaises(AttributeError): 857 proto.NestedEnum.Value.Foo 858 859 self.assertRaises(ValueError, 860 unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo') 861 with self.assertRaises(AttributeError): 862 unittest_pb2.TestAllTypes.NestedEnum.Value.Foo 863 864 def testEnum_KeysAndValues(self): 865 self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'], 866 list(unittest_pb2.ForeignEnum.keys())) 867 self.assertEqual([4, 5, 6], 868 list(unittest_pb2.ForeignEnum.values())) 869 self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), 870 ('FOREIGN_BAZ', 6)], 871 list(unittest_pb2.ForeignEnum.items())) 872 873 proto = unittest_pb2.TestAllTypes() 874 self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], list(proto.NestedEnum.keys())) 875 self.assertEqual([1, 2, 3, -1], list(proto.NestedEnum.values())) 876 self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)], 877 list(proto.NestedEnum.items())) 878 879 def testRepeatedScalars(self): 880 proto = unittest_pb2.TestAllTypes() 881 882 self.assertTrue(not proto.repeated_int32) 883 self.assertEqual(0, len(proto.repeated_int32)) 884 proto.repeated_int32.append(5) 885 proto.repeated_int32.append(10) 886 proto.repeated_int32.append(15) 887 self.assertTrue(proto.repeated_int32) 888 self.assertEqual(3, len(proto.repeated_int32)) 889 890 self.assertEqual([5, 10, 15], proto.repeated_int32) 891 892 # Test single retrieval. 893 self.assertEqual(5, proto.repeated_int32[0]) 894 self.assertEqual(15, proto.repeated_int32[-1]) 895 # Test out-of-bounds indices. 896 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234) 897 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234) 898 # Test incorrect types passed to __getitem__. 899 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo') 900 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None) 901 902 # Test single assignment. 903 proto.repeated_int32[1] = 20 904 self.assertEqual([5, 20, 15], proto.repeated_int32) 905 906 # Test insertion. 907 proto.repeated_int32.insert(1, 25) 908 self.assertEqual([5, 25, 20, 15], proto.repeated_int32) 909 910 # Test slice retrieval. 911 proto.repeated_int32.append(30) 912 self.assertEqual([25, 20, 15], proto.repeated_int32[1:4]) 913 self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:]) 914 915 # Test slice assignment with an iterator 916 proto.repeated_int32[1:4] = (i for i in range(3)) 917 self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32) 918 919 # Test slice assignment. 920 proto.repeated_int32[1:4] = [35, 40, 45] 921 self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32) 922 923 # Test that we can use the field as an iterator. 924 result = [] 925 for i in proto.repeated_int32: 926 result.append(i) 927 self.assertEqual([5, 35, 40, 45, 30], result) 928 929 # Test single deletion. 930 del proto.repeated_int32[2] 931 self.assertEqual([5, 35, 45, 30], proto.repeated_int32) 932 933 # Test slice deletion. 934 del proto.repeated_int32[2:] 935 self.assertEqual([5, 35], proto.repeated_int32) 936 937 # Test extending. 938 proto.repeated_int32.extend([3, 13]) 939 self.assertEqual([5, 35, 3, 13], proto.repeated_int32) 940 941 # Test clearing. 942 proto.ClearField('repeated_int32') 943 self.assertTrue(not proto.repeated_int32) 944 self.assertEqual(0, len(proto.repeated_int32)) 945 946 proto.repeated_int32.append(1) 947 self.assertEqual(1, proto.repeated_int32[-1]) 948 # Test assignment to a negative index. 949 proto.repeated_int32[-1] = 2 950 self.assertEqual(2, proto.repeated_int32[-1]) 951 952 # Test deletion at negative indices. 953 proto.repeated_int32[:] = [0, 1, 2, 3] 954 del proto.repeated_int32[-1] 955 self.assertEqual([0, 1, 2], proto.repeated_int32) 956 957 del proto.repeated_int32[-2] 958 self.assertEqual([0, 2], proto.repeated_int32) 959 960 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3) 961 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300) 962 963 del proto.repeated_int32[-2:-1] 964 self.assertEqual([2], proto.repeated_int32) 965 966 del proto.repeated_int32[100:10000] 967 self.assertEqual([2], proto.repeated_int32) 968 969 def testRepeatedScalarsRemove(self): 970 proto = unittest_pb2.TestAllTypes() 971 972 self.assertTrue(not proto.repeated_int32) 973 self.assertEqual(0, len(proto.repeated_int32)) 974 proto.repeated_int32.append(5) 975 proto.repeated_int32.append(10) 976 proto.repeated_int32.append(5) 977 proto.repeated_int32.append(5) 978 979 self.assertEqual(4, len(proto.repeated_int32)) 980 proto.repeated_int32.remove(5) 981 self.assertEqual(3, len(proto.repeated_int32)) 982 self.assertEqual(10, proto.repeated_int32[0]) 983 self.assertEqual(5, proto.repeated_int32[1]) 984 self.assertEqual(5, proto.repeated_int32[2]) 985 986 proto.repeated_int32.remove(5) 987 self.assertEqual(2, len(proto.repeated_int32)) 988 self.assertEqual(10, proto.repeated_int32[0]) 989 self.assertEqual(5, proto.repeated_int32[1]) 990 991 proto.repeated_int32.remove(10) 992 self.assertEqual(1, len(proto.repeated_int32)) 993 self.assertEqual(5, proto.repeated_int32[0]) 994 995 # Remove a non-existent element. 996 self.assertRaises(ValueError, proto.repeated_int32.remove, 123) 997 998 def testRepeatedComposites(self): 999 proto = unittest_pb2.TestAllTypes() 1000 self.assertTrue(not proto.repeated_nested_message) 1001 self.assertEqual(0, len(proto.repeated_nested_message)) 1002 m0 = proto.repeated_nested_message.add() 1003 m1 = proto.repeated_nested_message.add() 1004 self.assertTrue(proto.repeated_nested_message) 1005 self.assertEqual(2, len(proto.repeated_nested_message)) 1006 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 1007 self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage) 1008 1009 # Test out-of-bounds indices. 1010 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 1011 1234) 1012 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 1013 -1234) 1014 1015 # Test incorrect types passed to __getitem__. 1016 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 1017 'foo') 1018 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 1019 None) 1020 1021 # Test slice retrieval. 1022 m2 = proto.repeated_nested_message.add() 1023 m3 = proto.repeated_nested_message.add() 1024 m4 = proto.repeated_nested_message.add() 1025 self.assertListsEqual( 1026 [m1, m2, m3], proto.repeated_nested_message[1:4]) 1027 self.assertListsEqual( 1028 [m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) 1029 self.assertListsEqual( 1030 [m0, m1], proto.repeated_nested_message[:2]) 1031 self.assertListsEqual( 1032 [m2, m3, m4], proto.repeated_nested_message[2:]) 1033 self.assertEqual( 1034 m0, proto.repeated_nested_message[0]) 1035 self.assertListsEqual( 1036 [m0], proto.repeated_nested_message[:1]) 1037 1038 # Test that we can use the field as an iterator. 1039 result = [] 1040 for i in proto.repeated_nested_message: 1041 result.append(i) 1042 self.assertListsEqual([m0, m1, m2, m3, m4], result) 1043 1044 # Test single deletion. 1045 del proto.repeated_nested_message[2] 1046 self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message) 1047 1048 # Test slice deletion. 1049 del proto.repeated_nested_message[2:] 1050 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 1051 1052 # Test extending. 1053 n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1) 1054 n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2) 1055 proto.repeated_nested_message.extend([n1,n2]) 1056 self.assertEqual(4, len(proto.repeated_nested_message)) 1057 self.assertEqual(n1, proto.repeated_nested_message[2]) 1058 self.assertEqual(n2, proto.repeated_nested_message[3]) 1059 self.assertRaises(TypeError, 1060 proto.repeated_nested_message.extend, n1) 1061 self.assertRaises(TypeError, 1062 proto.repeated_nested_message.extend, [0]) 1063 wrong_message_type = unittest_pb2.TestAllTypes() 1064 self.assertRaises(TypeError, 1065 proto.repeated_nested_message.extend, 1066 [wrong_message_type]) 1067 1068 # Test clearing. 1069 proto.ClearField('repeated_nested_message') 1070 self.assertTrue(not proto.repeated_nested_message) 1071 self.assertEqual(0, len(proto.repeated_nested_message)) 1072 1073 # Test constructing an element while adding it. 1074 proto.repeated_nested_message.add(bb=23) 1075 self.assertEqual(1, len(proto.repeated_nested_message)) 1076 self.assertEqual(23, proto.repeated_nested_message[0].bb) 1077 self.assertRaises(TypeError, proto.repeated_nested_message.add, 23) 1078 with self.assertRaises(Exception): 1079 proto.repeated_nested_message[0] = 23 1080 1081 def testRepeatedCompositeRemove(self): 1082 proto = unittest_pb2.TestAllTypes() 1083 1084 self.assertEqual(0, len(proto.repeated_nested_message)) 1085 m0 = proto.repeated_nested_message.add() 1086 # Need to set some differentiating variable so m0 != m1 != m2: 1087 m0.bb = len(proto.repeated_nested_message) 1088 m1 = proto.repeated_nested_message.add() 1089 m1.bb = len(proto.repeated_nested_message) 1090 self.assertTrue(m0 != m1) 1091 m2 = proto.repeated_nested_message.add() 1092 m2.bb = len(proto.repeated_nested_message) 1093 self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) 1094 1095 self.assertEqual(3, len(proto.repeated_nested_message)) 1096 proto.repeated_nested_message.remove(m0) 1097 self.assertEqual(2, len(proto.repeated_nested_message)) 1098 self.assertEqual(m1, proto.repeated_nested_message[0]) 1099 self.assertEqual(m2, proto.repeated_nested_message[1]) 1100 1101 # Removing m0 again or removing None should raise error 1102 self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0) 1103 self.assertRaises(ValueError, proto.repeated_nested_message.remove, None) 1104 self.assertEqual(2, len(proto.repeated_nested_message)) 1105 1106 proto.repeated_nested_message.remove(m2) 1107 self.assertEqual(1, len(proto.repeated_nested_message)) 1108 self.assertEqual(m1, proto.repeated_nested_message[0]) 1109 1110 def testHandWrittenReflection(self): 1111 # Hand written extensions are only supported by the pure-Python 1112 # implementation of the API. 1113 if api_implementation.Type() != 'python': 1114 return 1115 1116 FieldDescriptor = descriptor.FieldDescriptor 1117 foo_field_descriptor = FieldDescriptor( 1118 name='foo_field', full_name='MyProto.foo_field', 1119 index=0, number=1, type=FieldDescriptor.TYPE_INT64, 1120 cpp_type=FieldDescriptor.CPPTYPE_INT64, 1121 label=FieldDescriptor.LABEL_OPTIONAL, default_value=0, 1122 containing_type=None, message_type=None, enum_type=None, 1123 is_extension=False, extension_scope=None, 1124 options=descriptor_pb2.FieldOptions()) 1125 mydescriptor = descriptor.Descriptor( 1126 name='MyProto', full_name='MyProto', filename='ignored', 1127 containing_type=None, nested_types=[], enum_types=[], 1128 fields=[foo_field_descriptor], extensions=[], 1129 options=descriptor_pb2.MessageOptions()) 1130 class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): 1131 DESCRIPTOR = mydescriptor 1132 myproto_instance = MyProtoClass() 1133 self.assertEqual(0, myproto_instance.foo_field) 1134 self.assertTrue(not myproto_instance.HasField('foo_field')) 1135 myproto_instance.foo_field = 23 1136 self.assertEqual(23, myproto_instance.foo_field) 1137 self.assertTrue(myproto_instance.HasField('foo_field')) 1138 1139 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 1140 def testDescriptorProtoSupport(self): 1141 # Hand written descriptors/reflection are only supported by the pure-Python 1142 # implementation of the API. 1143 if api_implementation.Type() != 'python': 1144 return 1145 1146 def AddDescriptorField(proto, field_name, field_type): 1147 AddDescriptorField.field_index += 1 1148 new_field = proto.field.add() 1149 new_field.name = field_name 1150 new_field.type = field_type 1151 new_field.number = AddDescriptorField.field_index 1152 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL 1153 1154 AddDescriptorField.field_index = 0 1155 1156 desc_proto = descriptor_pb2.DescriptorProto() 1157 desc_proto.name = 'Car' 1158 fdp = descriptor_pb2.FieldDescriptorProto 1159 AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING) 1160 AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64) 1161 AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL) 1162 AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE) 1163 # Add a repeated field 1164 AddDescriptorField.field_index += 1 1165 new_field = desc_proto.field.add() 1166 new_field.name = 'owners' 1167 new_field.type = fdp.TYPE_STRING 1168 new_field.number = AddDescriptorField.field_index 1169 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED 1170 1171 desc = descriptor.MakeDescriptor(desc_proto) 1172 self.assertTrue('name' in desc.fields_by_name) 1173 self.assertTrue('year' in desc.fields_by_name) 1174 self.assertTrue('automatic' in desc.fields_by_name) 1175 self.assertTrue('price' in desc.fields_by_name) 1176 self.assertTrue('owners' in desc.fields_by_name) 1177 1178 class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, 1179 message.Message)): 1180 DESCRIPTOR = desc 1181 1182 prius = CarMessage() 1183 prius.name = 'prius' 1184 prius.year = 2010 1185 prius.automatic = True 1186 prius.price = 25134.75 1187 prius.owners.extend(['bob', 'susan']) 1188 1189 serialized_prius = prius.SerializeToString() 1190 new_prius = reflection.ParseMessage(desc, serialized_prius) 1191 self.assertTrue(new_prius is not prius) 1192 self.assertEqual(prius, new_prius) 1193 1194 # these are unnecessary assuming message equality works as advertised but 1195 # explicitly check to be safe since we're mucking about in metaclass foo 1196 self.assertEqual(prius.name, new_prius.name) 1197 self.assertEqual(prius.year, new_prius.year) 1198 self.assertEqual(prius.automatic, new_prius.automatic) 1199 self.assertEqual(prius.price, new_prius.price) 1200 self.assertEqual(prius.owners, new_prius.owners) 1201 1202 def testExtensionIter(self): 1203 extendee_proto = more_extensions_pb2.ExtendedMessage() 1204 1205 extension_int32 = more_extensions_pb2.optional_int_extension 1206 extendee_proto.Extensions[extension_int32] = 23 1207 1208 extension_repeated = more_extensions_pb2.repeated_int_extension 1209 extendee_proto.Extensions[extension_repeated].append(11) 1210 1211 extension_msg = more_extensions_pb2.optional_message_extension 1212 extendee_proto.Extensions[extension_msg].foreign_message_int = 56 1213 1214 # Set some normal fields. 1215 extendee_proto.optional_int32 = 1 1216 extendee_proto.repeated_string.append('hi') 1217 1218 expected = (extension_int32, extension_msg, extension_repeated) 1219 count = 0 1220 for item in extendee_proto.Extensions: 1221 self.assertEqual(item.name, expected[count].name) 1222 self.assertIn(item, extendee_proto.Extensions) 1223 count += 1 1224 self.assertEqual(count, 3) 1225 1226 def testExtensionContainsError(self): 1227 extendee_proto = more_extensions_pb2.ExtendedMessage() 1228 self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, 0) 1229 1230 field = more_extensions_pb2.ExtendedMessage.DESCRIPTOR.fields_by_name[ 1231 'optional_int32'] 1232 self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, field) 1233 1234 def testTopLevelExtensionsForOptionalScalar(self): 1235 extendee_proto = unittest_pb2.TestAllExtensions() 1236 extension = unittest_pb2.optional_int32_extension 1237 self.assertTrue(not extendee_proto.HasExtension(extension)) 1238 self.assertNotIn(extension, extendee_proto.Extensions) 1239 self.assertEqual(0, extendee_proto.Extensions[extension]) 1240 # As with normal scalar fields, just doing a read doesn't actually set the 1241 # "has" bit. 1242 self.assertTrue(not extendee_proto.HasExtension(extension)) 1243 self.assertNotIn(extension, extendee_proto.Extensions) 1244 # Actually set the thing. 1245 extendee_proto.Extensions[extension] = 23 1246 self.assertEqual(23, extendee_proto.Extensions[extension]) 1247 self.assertTrue(extendee_proto.HasExtension(extension)) 1248 self.assertIn(extension, extendee_proto.Extensions) 1249 # Ensure that clearing works as well. 1250 extendee_proto.ClearExtension(extension) 1251 self.assertEqual(0, extendee_proto.Extensions[extension]) 1252 self.assertTrue(not extendee_proto.HasExtension(extension)) 1253 self.assertNotIn(extension, extendee_proto.Extensions) 1254 1255 def testTopLevelExtensionsForRepeatedScalar(self): 1256 extendee_proto = unittest_pb2.TestAllExtensions() 1257 extension = unittest_pb2.repeated_string_extension 1258 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1259 self.assertNotIn(extension, extendee_proto.Extensions) 1260 extendee_proto.Extensions[extension].append('foo') 1261 self.assertEqual(['foo'], extendee_proto.Extensions[extension]) 1262 self.assertIn(extension, extendee_proto.Extensions) 1263 string_list = extendee_proto.Extensions[extension] 1264 extendee_proto.ClearExtension(extension) 1265 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1266 self.assertNotIn(extension, extendee_proto.Extensions) 1267 self.assertTrue(string_list is not extendee_proto.Extensions[extension]) 1268 # Shouldn't be allowed to do Extensions[extension] = 'a' 1269 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1270 extension, 'a') 1271 1272 def testTopLevelExtensionsForOptionalMessage(self): 1273 extendee_proto = unittest_pb2.TestAllExtensions() 1274 extension = unittest_pb2.optional_foreign_message_extension 1275 self.assertTrue(not extendee_proto.HasExtension(extension)) 1276 self.assertNotIn(extension, extendee_proto.Extensions) 1277 self.assertEqual(0, extendee_proto.Extensions[extension].c) 1278 # As with normal (non-extension) fields, merely reading from the 1279 # thing shouldn't set the "has" bit. 1280 self.assertTrue(not extendee_proto.HasExtension(extension)) 1281 self.assertNotIn(extension, extendee_proto.Extensions) 1282 extendee_proto.Extensions[extension].c = 23 1283 self.assertEqual(23, extendee_proto.Extensions[extension].c) 1284 self.assertTrue(extendee_proto.HasExtension(extension)) 1285 self.assertIn(extension, extendee_proto.Extensions) 1286 # Save a reference here. 1287 foreign_message = extendee_proto.Extensions[extension] 1288 extendee_proto.ClearExtension(extension) 1289 self.assertTrue(foreign_message is not extendee_proto.Extensions[extension]) 1290 # Setting a field on foreign_message now shouldn't set 1291 # any "has" bits on extendee_proto. 1292 foreign_message.c = 42 1293 self.assertEqual(42, foreign_message.c) 1294 self.assertTrue(foreign_message.HasField('c')) 1295 self.assertTrue(not extendee_proto.HasExtension(extension)) 1296 self.assertNotIn(extension, extendee_proto.Extensions) 1297 # Shouldn't be allowed to do Extensions[extension] = 'a' 1298 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1299 extension, 'a') 1300 1301 def testTopLevelExtensionsForRepeatedMessage(self): 1302 extendee_proto = unittest_pb2.TestAllExtensions() 1303 extension = unittest_pb2.repeatedgroup_extension 1304 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1305 group = extendee_proto.Extensions[extension].add() 1306 group.a = 23 1307 self.assertEqual(23, extendee_proto.Extensions[extension][0].a) 1308 group.a = 42 1309 self.assertEqual(42, extendee_proto.Extensions[extension][0].a) 1310 group_list = extendee_proto.Extensions[extension] 1311 extendee_proto.ClearExtension(extension) 1312 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1313 self.assertTrue(group_list is not extendee_proto.Extensions[extension]) 1314 # Shouldn't be allowed to do Extensions[extension] = 'a' 1315 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1316 extension, 'a') 1317 1318 def testNestedExtensions(self): 1319 extendee_proto = unittest_pb2.TestAllExtensions() 1320 extension = unittest_pb2.TestRequired.single 1321 1322 # We just test the non-repeated case. 1323 self.assertTrue(not extendee_proto.HasExtension(extension)) 1324 self.assertNotIn(extension, extendee_proto.Extensions) 1325 required = extendee_proto.Extensions[extension] 1326 self.assertEqual(0, required.a) 1327 self.assertTrue(not extendee_proto.HasExtension(extension)) 1328 self.assertNotIn(extension, extendee_proto.Extensions) 1329 required.a = 23 1330 self.assertEqual(23, extendee_proto.Extensions[extension].a) 1331 self.assertTrue(extendee_proto.HasExtension(extension)) 1332 self.assertIn(extension, extendee_proto.Extensions) 1333 extendee_proto.ClearExtension(extension) 1334 self.assertTrue(required is not extendee_proto.Extensions[extension]) 1335 self.assertTrue(not extendee_proto.HasExtension(extension)) 1336 self.assertNotIn(extension, extendee_proto.Extensions) 1337 1338 def testRegisteredExtensions(self): 1339 pool = unittest_pb2.DESCRIPTOR.pool 1340 self.assertTrue( 1341 pool.FindExtensionByNumber( 1342 unittest_pb2.TestAllExtensions.DESCRIPTOR, 1)) 1343 self.assertIs( 1344 pool.FindExtensionByName( 1345 'protobuf_unittest.optional_int32_extension').containing_type, 1346 unittest_pb2.TestAllExtensions.DESCRIPTOR) 1347 # Make sure extensions haven't been registered into types that shouldn't 1348 # have any. 1349 self.assertEqual(0, len( 1350 pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR))) 1351 1352 # If message A directly contains message B, and 1353 # a.HasField('b') is currently False, then mutating any 1354 # extension in B should change a.HasField('b') to True 1355 # (and so on up the object tree). 1356 def testHasBitsForAncestorsOfExtendedMessage(self): 1357 # Optional scalar extension. 1358 toplevel = more_extensions_pb2.TopLevelMessage() 1359 self.assertTrue(not toplevel.HasField('submessage')) 1360 self.assertEqual(0, toplevel.submessage.Extensions[ 1361 more_extensions_pb2.optional_int_extension]) 1362 self.assertTrue(not toplevel.HasField('submessage')) 1363 toplevel.submessage.Extensions[ 1364 more_extensions_pb2.optional_int_extension] = 23 1365 self.assertEqual(23, toplevel.submessage.Extensions[ 1366 more_extensions_pb2.optional_int_extension]) 1367 self.assertTrue(toplevel.HasField('submessage')) 1368 1369 # Repeated scalar extension. 1370 toplevel = more_extensions_pb2.TopLevelMessage() 1371 self.assertTrue(not toplevel.HasField('submessage')) 1372 self.assertEqual([], toplevel.submessage.Extensions[ 1373 more_extensions_pb2.repeated_int_extension]) 1374 self.assertTrue(not toplevel.HasField('submessage')) 1375 toplevel.submessage.Extensions[ 1376 more_extensions_pb2.repeated_int_extension].append(23) 1377 self.assertEqual([23], toplevel.submessage.Extensions[ 1378 more_extensions_pb2.repeated_int_extension]) 1379 self.assertTrue(toplevel.HasField('submessage')) 1380 1381 # Optional message extension. 1382 toplevel = more_extensions_pb2.TopLevelMessage() 1383 self.assertTrue(not toplevel.HasField('submessage')) 1384 self.assertEqual(0, toplevel.submessage.Extensions[ 1385 more_extensions_pb2.optional_message_extension].foreign_message_int) 1386 self.assertTrue(not toplevel.HasField('submessage')) 1387 toplevel.submessage.Extensions[ 1388 more_extensions_pb2.optional_message_extension].foreign_message_int = 23 1389 self.assertEqual(23, toplevel.submessage.Extensions[ 1390 more_extensions_pb2.optional_message_extension].foreign_message_int) 1391 self.assertTrue(toplevel.HasField('submessage')) 1392 1393 # Repeated message extension. 1394 toplevel = more_extensions_pb2.TopLevelMessage() 1395 self.assertTrue(not toplevel.HasField('submessage')) 1396 self.assertEqual(0, len(toplevel.submessage.Extensions[ 1397 more_extensions_pb2.repeated_message_extension])) 1398 self.assertTrue(not toplevel.HasField('submessage')) 1399 foreign = toplevel.submessage.Extensions[ 1400 more_extensions_pb2.repeated_message_extension].add() 1401 self.assertEqual(foreign, toplevel.submessage.Extensions[ 1402 more_extensions_pb2.repeated_message_extension][0]) 1403 self.assertTrue(toplevel.HasField('submessage')) 1404 1405 def testDisconnectionAfterClearingEmptyMessage(self): 1406 toplevel = more_extensions_pb2.TopLevelMessage() 1407 extendee_proto = toplevel.submessage 1408 extension = more_extensions_pb2.optional_message_extension 1409 extension_proto = extendee_proto.Extensions[extension] 1410 extendee_proto.ClearExtension(extension) 1411 extension_proto.foreign_message_int = 23 1412 1413 self.assertTrue(extension_proto is not extendee_proto.Extensions[extension]) 1414 1415 def testExtensionFailureModes(self): 1416 extendee_proto = unittest_pb2.TestAllExtensions() 1417 1418 # Try non-extension-handle arguments to HasExtension, 1419 # ClearExtension(), and Extensions[]... 1420 self.assertRaises(KeyError, extendee_proto.HasExtension, 1234) 1421 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234) 1422 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234) 1423 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5) 1424 1425 # Try something that *is* an extension handle, just not for 1426 # this message... 1427 for unknown_handle in (more_extensions_pb2.optional_int_extension, 1428 more_extensions_pb2.optional_message_extension, 1429 more_extensions_pb2.repeated_int_extension, 1430 more_extensions_pb2.repeated_message_extension): 1431 self.assertRaises(KeyError, extendee_proto.HasExtension, 1432 unknown_handle) 1433 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1434 unknown_handle) 1435 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1436 unknown_handle) 1437 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1438 unknown_handle, 5) 1439 1440 # Try call HasExtension() with a valid handle, but for a 1441 # *repeated* field. (Just as with non-extension repeated 1442 # fields, Has*() isn't supported for extension repeated fields). 1443 self.assertRaises(KeyError, extendee_proto.HasExtension, 1444 unittest_pb2.repeated_string_extension) 1445 1446 def testStaticParseFrom(self): 1447 proto1 = unittest_pb2.TestAllTypes() 1448 test_util.SetAllFields(proto1) 1449 1450 string1 = proto1.SerializeToString() 1451 proto2 = unittest_pb2.TestAllTypes.FromString(string1) 1452 1453 # Messages should be equal. 1454 self.assertEqual(proto2, proto1) 1455 1456 def testMergeFromSingularField(self): 1457 # Test merge with just a singular field. 1458 proto1 = unittest_pb2.TestAllTypes() 1459 proto1.optional_int32 = 1 1460 1461 proto2 = unittest_pb2.TestAllTypes() 1462 # This shouldn't get overwritten. 1463 proto2.optional_string = 'value' 1464 1465 proto2.MergeFrom(proto1) 1466 self.assertEqual(1, proto2.optional_int32) 1467 self.assertEqual('value', proto2.optional_string) 1468 1469 def testMergeFromRepeatedField(self): 1470 # Test merge with just a repeated field. 1471 proto1 = unittest_pb2.TestAllTypes() 1472 proto1.repeated_int32.append(1) 1473 proto1.repeated_int32.append(2) 1474 1475 proto2 = unittest_pb2.TestAllTypes() 1476 proto2.repeated_int32.append(0) 1477 proto2.MergeFrom(proto1) 1478 1479 self.assertEqual(0, proto2.repeated_int32[0]) 1480 self.assertEqual(1, proto2.repeated_int32[1]) 1481 self.assertEqual(2, proto2.repeated_int32[2]) 1482 1483 def testMergeFromOptionalGroup(self): 1484 # Test merge with an optional group. 1485 proto1 = unittest_pb2.TestAllTypes() 1486 proto1.optionalgroup.a = 12 1487 proto2 = unittest_pb2.TestAllTypes() 1488 proto2.MergeFrom(proto1) 1489 self.assertEqual(12, proto2.optionalgroup.a) 1490 1491 def testMergeFromRepeatedNestedMessage(self): 1492 # Test merge with a repeated nested message. 1493 proto1 = unittest_pb2.TestAllTypes() 1494 m = proto1.repeated_nested_message.add() 1495 m.bb = 123 1496 m = proto1.repeated_nested_message.add() 1497 m.bb = 321 1498 1499 proto2 = unittest_pb2.TestAllTypes() 1500 m = proto2.repeated_nested_message.add() 1501 m.bb = 999 1502 proto2.MergeFrom(proto1) 1503 self.assertEqual(999, proto2.repeated_nested_message[0].bb) 1504 self.assertEqual(123, proto2.repeated_nested_message[1].bb) 1505 self.assertEqual(321, proto2.repeated_nested_message[2].bb) 1506 1507 proto3 = unittest_pb2.TestAllTypes() 1508 proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message) 1509 self.assertEqual(999, proto3.repeated_nested_message[0].bb) 1510 self.assertEqual(123, proto3.repeated_nested_message[1].bb) 1511 self.assertEqual(321, proto3.repeated_nested_message[2].bb) 1512 1513 def testMergeFromAllFields(self): 1514 # With all fields set. 1515 proto1 = unittest_pb2.TestAllTypes() 1516 test_util.SetAllFields(proto1) 1517 proto2 = unittest_pb2.TestAllTypes() 1518 proto2.MergeFrom(proto1) 1519 1520 # Messages should be equal. 1521 self.assertEqual(proto2, proto1) 1522 1523 # Serialized string should be equal too. 1524 string1 = proto1.SerializeToString() 1525 string2 = proto2.SerializeToString() 1526 self.assertEqual(string1, string2) 1527 1528 def testMergeFromExtensionsSingular(self): 1529 proto1 = unittest_pb2.TestAllExtensions() 1530 proto1.Extensions[unittest_pb2.optional_int32_extension] = 1 1531 1532 proto2 = unittest_pb2.TestAllExtensions() 1533 proto2.MergeFrom(proto1) 1534 self.assertEqual( 1535 1, proto2.Extensions[unittest_pb2.optional_int32_extension]) 1536 1537 def testMergeFromExtensionsRepeated(self): 1538 proto1 = unittest_pb2.TestAllExtensions() 1539 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1) 1540 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2) 1541 1542 proto2 = unittest_pb2.TestAllExtensions() 1543 proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0) 1544 proto2.MergeFrom(proto1) 1545 self.assertEqual( 1546 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension])) 1547 self.assertEqual( 1548 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0]) 1549 self.assertEqual( 1550 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1]) 1551 self.assertEqual( 1552 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2]) 1553 1554 def testMergeFromExtensionsNestedMessage(self): 1555 proto1 = unittest_pb2.TestAllExtensions() 1556 ext1 = proto1.Extensions[ 1557 unittest_pb2.repeated_nested_message_extension] 1558 m = ext1.add() 1559 m.bb = 222 1560 m = ext1.add() 1561 m.bb = 333 1562 1563 proto2 = unittest_pb2.TestAllExtensions() 1564 ext2 = proto2.Extensions[ 1565 unittest_pb2.repeated_nested_message_extension] 1566 m = ext2.add() 1567 m.bb = 111 1568 1569 proto2.MergeFrom(proto1) 1570 ext2 = proto2.Extensions[ 1571 unittest_pb2.repeated_nested_message_extension] 1572 self.assertEqual(3, len(ext2)) 1573 self.assertEqual(111, ext2[0].bb) 1574 self.assertEqual(222, ext2[1].bb) 1575 self.assertEqual(333, ext2[2].bb) 1576 1577 def testMergeFromBug(self): 1578 message1 = unittest_pb2.TestAllTypes() 1579 message2 = unittest_pb2.TestAllTypes() 1580 1581 # Cause optional_nested_message to be instantiated within message1, even 1582 # though it is not considered to be "present". 1583 message1.optional_nested_message 1584 self.assertFalse(message1.HasField('optional_nested_message')) 1585 1586 # Merge into message2. This should not instantiate the field is message2. 1587 message2.MergeFrom(message1) 1588 self.assertFalse(message2.HasField('optional_nested_message')) 1589 1590 def testCopyFromSingularField(self): 1591 # Test copy with just a singular field. 1592 proto1 = unittest_pb2.TestAllTypes() 1593 proto1.optional_int32 = 1 1594 proto1.optional_string = 'important-text' 1595 1596 proto2 = unittest_pb2.TestAllTypes() 1597 proto2.optional_string = 'value' 1598 1599 proto2.CopyFrom(proto1) 1600 self.assertEqual(1, proto2.optional_int32) 1601 self.assertEqual('important-text', proto2.optional_string) 1602 1603 def testCopyFromRepeatedField(self): 1604 # Test copy with a repeated field. 1605 proto1 = unittest_pb2.TestAllTypes() 1606 proto1.repeated_int32.append(1) 1607 proto1.repeated_int32.append(2) 1608 1609 proto2 = unittest_pb2.TestAllTypes() 1610 proto2.repeated_int32.append(0) 1611 proto2.CopyFrom(proto1) 1612 1613 self.assertEqual(1, proto2.repeated_int32[0]) 1614 self.assertEqual(2, proto2.repeated_int32[1]) 1615 1616 def testCopyFromAllFields(self): 1617 # With all fields set. 1618 proto1 = unittest_pb2.TestAllTypes() 1619 test_util.SetAllFields(proto1) 1620 proto2 = unittest_pb2.TestAllTypes() 1621 proto2.CopyFrom(proto1) 1622 1623 # Messages should be equal. 1624 self.assertEqual(proto2, proto1) 1625 1626 # Serialized string should be equal too. 1627 string1 = proto1.SerializeToString() 1628 string2 = proto2.SerializeToString() 1629 self.assertEqual(string1, string2) 1630 1631 def testCopyFromSelf(self): 1632 proto1 = unittest_pb2.TestAllTypes() 1633 proto1.repeated_int32.append(1) 1634 proto1.optional_int32 = 2 1635 proto1.optional_string = 'important-text' 1636 1637 proto1.CopyFrom(proto1) 1638 self.assertEqual(1, proto1.repeated_int32[0]) 1639 self.assertEqual(2, proto1.optional_int32) 1640 self.assertEqual('important-text', proto1.optional_string) 1641 1642 def testCopyFromBadType(self): 1643 # The python implementation doesn't raise an exception in this 1644 # case. In theory it should. 1645 if api_implementation.Type() == 'python': 1646 return 1647 proto1 = unittest_pb2.TestAllTypes() 1648 proto2 = unittest_pb2.TestAllExtensions() 1649 self.assertRaises(TypeError, proto1.CopyFrom, proto2) 1650 1651 def testDeepCopy(self): 1652 proto1 = unittest_pb2.TestAllTypes() 1653 proto1.optional_int32 = 1 1654 proto2 = copy.deepcopy(proto1) 1655 self.assertEqual(1, proto2.optional_int32) 1656 1657 proto1.repeated_int32.append(2) 1658 proto1.repeated_int32.append(3) 1659 container = copy.deepcopy(proto1.repeated_int32) 1660 self.assertEqual([2, 3], container) 1661 container.remove(container[0]) 1662 self.assertEqual([3], container) 1663 1664 message1 = proto1.repeated_nested_message.add() 1665 message1.bb = 1 1666 messages = copy.deepcopy(proto1.repeated_nested_message) 1667 self.assertEqual(proto1.repeated_nested_message, messages) 1668 message1.bb = 2 1669 self.assertNotEqual(proto1.repeated_nested_message, messages) 1670 messages.remove(messages[0]) 1671 self.assertEqual(len(messages), 0) 1672 1673 # TODO(anuraag): Implement deepcopy for extension dict 1674 1675 def testClear(self): 1676 proto = unittest_pb2.TestAllTypes() 1677 # C++ implementation does not support lazy fields right now so leave it 1678 # out for now. 1679 if api_implementation.Type() == 'python': 1680 test_util.SetAllFields(proto) 1681 else: 1682 test_util.SetAllNonLazyFields(proto) 1683 # Clear the message. 1684 proto.Clear() 1685 self.assertEqual(proto.ByteSize(), 0) 1686 empty_proto = unittest_pb2.TestAllTypes() 1687 self.assertEqual(proto, empty_proto) 1688 1689 # Test if extensions which were set are cleared. 1690 proto = unittest_pb2.TestAllExtensions() 1691 test_util.SetAllExtensions(proto) 1692 # Clear the message. 1693 proto.Clear() 1694 self.assertEqual(proto.ByteSize(), 0) 1695 empty_proto = unittest_pb2.TestAllExtensions() 1696 self.assertEqual(proto, empty_proto) 1697 1698 def testDisconnectingBeforeClear(self): 1699 proto = unittest_pb2.TestAllTypes() 1700 nested = proto.optional_nested_message 1701 proto.Clear() 1702 self.assertTrue(nested is not proto.optional_nested_message) 1703 nested.bb = 23 1704 self.assertTrue(not proto.HasField('optional_nested_message')) 1705 self.assertEqual(0, proto.optional_nested_message.bb) 1706 1707 proto = unittest_pb2.TestAllTypes() 1708 nested = proto.optional_nested_message 1709 nested.bb = 5 1710 foreign = proto.optional_foreign_message 1711 foreign.c = 6 1712 1713 proto.Clear() 1714 self.assertTrue(nested is not proto.optional_nested_message) 1715 self.assertTrue(foreign is not proto.optional_foreign_message) 1716 self.assertEqual(5, nested.bb) 1717 self.assertEqual(6, foreign.c) 1718 nested.bb = 15 1719 foreign.c = 16 1720 self.assertFalse(proto.HasField('optional_nested_message')) 1721 self.assertEqual(0, proto.optional_nested_message.bb) 1722 self.assertFalse(proto.HasField('optional_foreign_message')) 1723 self.assertEqual(0, proto.optional_foreign_message.c) 1724 1725 def testDisconnectingInOneof(self): 1726 m = unittest_pb2.TestOneof2() # This message has two messages in a oneof. 1727 m.foo_message.qux_int = 5 1728 sub_message = m.foo_message 1729 # Accessing another message's field does not clear the first one 1730 self.assertEqual(m.foo_lazy_message.qux_int, 0) 1731 self.assertEqual(m.foo_message.qux_int, 5) 1732 # But mutating another message in the oneof detaches the first one. 1733 m.foo_lazy_message.qux_int = 6 1734 self.assertEqual(m.foo_message.qux_int, 0) 1735 # The reference we got above was detached and is still valid. 1736 self.assertEqual(sub_message.qux_int, 5) 1737 sub_message.qux_int = 7 1738 1739 def testOneOf(self): 1740 proto = unittest_pb2.TestAllTypes() 1741 proto.oneof_uint32 = 10 1742 proto.oneof_nested_message.bb = 11 1743 self.assertEqual(11, proto.oneof_nested_message.bb) 1744 self.assertFalse(proto.HasField('oneof_uint32')) 1745 nested = proto.oneof_nested_message 1746 proto.oneof_string = 'abc' 1747 self.assertEqual('abc', proto.oneof_string) 1748 self.assertEqual(11, nested.bb) 1749 self.assertFalse(proto.HasField('oneof_nested_message')) 1750 1751 def assertInitialized(self, proto): 1752 self.assertTrue(proto.IsInitialized()) 1753 # Neither method should raise an exception. 1754 proto.SerializeToString() 1755 proto.SerializePartialToString() 1756 1757 def assertNotInitialized(self, proto, error_size=None): 1758 errors = [] 1759 self.assertFalse(proto.IsInitialized()) 1760 self.assertFalse(proto.IsInitialized(errors)) 1761 self.assertEqual(error_size, len(errors)) 1762 self.assertRaises(message.EncodeError, proto.SerializeToString) 1763 # "Partial" serialization doesn't care if message is uninitialized. 1764 proto.SerializePartialToString() 1765 1766 def testIsInitialized(self): 1767 # Trivial cases - all optional fields and extensions. 1768 proto = unittest_pb2.TestAllTypes() 1769 self.assertInitialized(proto) 1770 proto = unittest_pb2.TestAllExtensions() 1771 self.assertInitialized(proto) 1772 1773 # The case of uninitialized required fields. 1774 proto = unittest_pb2.TestRequired() 1775 self.assertNotInitialized(proto, 3) 1776 proto.a = proto.b = proto.c = 2 1777 self.assertInitialized(proto) 1778 1779 # The case of uninitialized submessage. 1780 proto = unittest_pb2.TestRequiredForeign() 1781 self.assertInitialized(proto) 1782 proto.optional_message.a = 1 1783 self.assertNotInitialized(proto, 2) 1784 proto.optional_message.b = 0 1785 proto.optional_message.c = 0 1786 self.assertInitialized(proto) 1787 1788 # Uninitialized repeated submessage. 1789 message1 = proto.repeated_message.add() 1790 self.assertNotInitialized(proto, 3) 1791 message1.a = message1.b = message1.c = 0 1792 self.assertInitialized(proto) 1793 1794 # Uninitialized repeated group in an extension. 1795 proto = unittest_pb2.TestAllExtensions() 1796 extension = unittest_pb2.TestRequired.multi 1797 message1 = proto.Extensions[extension].add() 1798 message2 = proto.Extensions[extension].add() 1799 self.assertNotInitialized(proto, 6) 1800 message1.a = 1 1801 message1.b = 1 1802 message1.c = 1 1803 self.assertNotInitialized(proto, 3) 1804 message2.a = 2 1805 message2.b = 2 1806 message2.c = 2 1807 self.assertInitialized(proto) 1808 1809 # Uninitialized nonrepeated message in an extension. 1810 proto = unittest_pb2.TestAllExtensions() 1811 extension = unittest_pb2.TestRequired.single 1812 proto.Extensions[extension].a = 1 1813 self.assertNotInitialized(proto, 2) 1814 proto.Extensions[extension].b = 2 1815 proto.Extensions[extension].c = 3 1816 self.assertInitialized(proto) 1817 1818 # Try passing an errors list. 1819 errors = [] 1820 proto = unittest_pb2.TestRequired() 1821 self.assertFalse(proto.IsInitialized(errors)) 1822 self.assertEqual(errors, ['a', 'b', 'c']) 1823 1824 @unittest.skipIf( 1825 api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, 1826 'Errors are only available from the most recent C++ implementation.') 1827 def testFileDescriptorErrors(self): 1828 file_name = 'test_file_descriptor_errors.proto' 1829 package_name = 'test_file_descriptor_errors.proto' 1830 file_descriptor_proto = descriptor_pb2.FileDescriptorProto() 1831 file_descriptor_proto.name = file_name 1832 file_descriptor_proto.package = package_name 1833 m1 = file_descriptor_proto.message_type.add() 1834 m1.name = 'msg1' 1835 # Compiles the proto into the C++ descriptor pool 1836 descriptor.FileDescriptor( 1837 file_name, 1838 package_name, 1839 serialized_pb=file_descriptor_proto.SerializeToString()) 1840 # Add a FileDescriptorProto that has duplicate symbols 1841 another_file_name = 'another_test_file_descriptor_errors.proto' 1842 file_descriptor_proto.name = another_file_name 1843 m2 = file_descriptor_proto.message_type.add() 1844 m2.name = 'msg2' 1845 with self.assertRaises(TypeError) as cm: 1846 descriptor.FileDescriptor( 1847 another_file_name, 1848 package_name, 1849 serialized_pb=file_descriptor_proto.SerializeToString()) 1850 self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % 1851 getattr(cm.expected, '__name__', cm.expected)) 1852 self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) 1853 # Error message will say something about this definition being a 1854 # duplicate, though we don't check the message exactly to avoid a 1855 # dependency on the C++ logging code. 1856 self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) 1857 1858 def testStringUTF8Encoding(self): 1859 proto = unittest_pb2.TestAllTypes() 1860 1861 # Assignment of a unicode object to a field of type 'bytes' is not allowed. 1862 self.assertRaises(TypeError, 1863 setattr, proto, 'optional_bytes', u'unicode object') 1864 1865 # Check that the default value is of python's 'unicode' type. 1866 self.assertEqual(type(proto.optional_string), six.text_type) 1867 1868 proto.optional_string = six.text_type('Testing') 1869 self.assertEqual(proto.optional_string, str('Testing')) 1870 1871 # Assign a value of type 'str' which can be encoded in UTF-8. 1872 proto.optional_string = str('Testing') 1873 self.assertEqual(proto.optional_string, six.text_type('Testing')) 1874 1875 # Try to assign a 'bytes' object which contains non-UTF-8. 1876 self.assertRaises(ValueError, 1877 setattr, proto, 'optional_string', b'a\x80a') 1878 # No exception: Assign already encoded UTF-8 bytes to a string field. 1879 utf8_bytes = u'Тест'.encode('utf-8') 1880 proto.optional_string = utf8_bytes 1881 # No exception: Assign the a non-ascii unicode object. 1882 proto.optional_string = u'Тест' 1883 # No exception thrown (normal str assignment containing ASCII). 1884 proto.optional_string = 'abc' 1885 1886 def testStringUTF8Serialization(self): 1887 proto = message_set_extensions_pb2.TestMessageSet() 1888 extension_message = message_set_extensions_pb2.TestMessageSetExtension2 1889 extension = extension_message.message_set_extension 1890 1891 test_utf8 = u'Тест' 1892 test_utf8_bytes = test_utf8.encode('utf-8') 1893 1894 # 'Test' in another language, using UTF-8 charset. 1895 proto.Extensions[extension].str = test_utf8 1896 1897 # Serialize using the MessageSet wire format (this is specified in the 1898 # .proto file). 1899 serialized = proto.SerializeToString() 1900 1901 # Check byte size. 1902 self.assertEqual(proto.ByteSize(), len(serialized)) 1903 1904 raw = unittest_mset_pb2.RawMessageSet() 1905 bytes_read = raw.MergeFromString(serialized) 1906 self.assertEqual(len(serialized), bytes_read) 1907 1908 message2 = message_set_extensions_pb2.TestMessageSetExtension2() 1909 1910 self.assertEqual(1, len(raw.item)) 1911 # Check that the type_id is the same as the tag ID in the .proto file. 1912 self.assertEqual(raw.item[0].type_id, 98418634) 1913 1914 # Check the actual bytes on the wire. 1915 self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes)) 1916 bytes_read = message2.MergeFromString(raw.item[0].message) 1917 self.assertEqual(len(raw.item[0].message), bytes_read) 1918 1919 self.assertEqual(type(message2.str), six.text_type) 1920 self.assertEqual(message2.str, test_utf8) 1921 1922 # The pure Python API throws an exception on MergeFromString(), 1923 # if any of the string fields of the message can't be UTF-8 decoded. 1924 # The C++ implementation of the API has no way to check that on 1925 # MergeFromString and thus has no way to throw the exception. 1926 # 1927 # The pure Python API always returns objects of type 'unicode' (UTF-8 1928 # encoded), or 'bytes' (in 7 bit ASCII). 1929 badbytes = raw.item[0].message.replace( 1930 test_utf8_bytes, len(test_utf8_bytes) * b'\xff') 1931 1932 unicode_decode_failed = False 1933 try: 1934 message2.MergeFromString(badbytes) 1935 except UnicodeDecodeError: 1936 unicode_decode_failed = True 1937 string_field = message2.str 1938 self.assertTrue(unicode_decode_failed or type(string_field) is bytes) 1939 1940 def testBytesInTextFormat(self): 1941 proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') 1942 self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', 1943 six.text_type(proto)) 1944 1945 def testEmptyNestedMessage(self): 1946 proto = unittest_pb2.TestAllTypes() 1947 proto.optional_nested_message.MergeFrom( 1948 unittest_pb2.TestAllTypes.NestedMessage()) 1949 self.assertTrue(proto.HasField('optional_nested_message')) 1950 1951 proto = unittest_pb2.TestAllTypes() 1952 proto.optional_nested_message.CopyFrom( 1953 unittest_pb2.TestAllTypes.NestedMessage()) 1954 self.assertTrue(proto.HasField('optional_nested_message')) 1955 1956 proto = unittest_pb2.TestAllTypes() 1957 bytes_read = proto.optional_nested_message.MergeFromString(b'') 1958 self.assertEqual(0, bytes_read) 1959 self.assertTrue(proto.HasField('optional_nested_message')) 1960 1961 proto = unittest_pb2.TestAllTypes() 1962 proto.optional_nested_message.ParseFromString(b'') 1963 self.assertTrue(proto.HasField('optional_nested_message')) 1964 1965 serialized = proto.SerializeToString() 1966 proto2 = unittest_pb2.TestAllTypes() 1967 self.assertEqual( 1968 len(serialized), 1969 proto2.MergeFromString(serialized)) 1970 self.assertTrue(proto2.HasField('optional_nested_message')) 1971 1972 def testSetInParent(self): 1973 proto = unittest_pb2.TestAllTypes() 1974 self.assertFalse(proto.HasField('optionalgroup')) 1975 proto.optionalgroup.SetInParent() 1976 self.assertTrue(proto.HasField('optionalgroup')) 1977 1978 def testPackageInitializationImport(self): 1979 """Test that we can import nested messages from their __init__.py. 1980 1981 Such setup is not trivial since at the time of processing of __init__.py one 1982 can't refer to its submodules by name in code, so expressions like 1983 google.protobuf.internal.import_test_package.inner_pb2 1984 don't work. They do work in imports, so we have assign an alias at import 1985 and then use that alias in generated code. 1986 """ 1987 # We import here since it's the import that used to fail, and we want 1988 # the failure to have the right context. 1989 # pylint: disable=g-import-not-at-top 1990 from google.protobuf.internal import import_test_package 1991 # pylint: enable=g-import-not-at-top 1992 msg = import_test_package.myproto.Outer() 1993 # Just check the default value. 1994 self.assertEqual(57, msg.inner.value) 1995 1996# Since we had so many tests for protocol buffer equality, we broke these out 1997# into separate TestCase classes. 1998 1999 2000@testing_refleaks.TestCase 2001class TestAllTypesEqualityTest(unittest.TestCase): 2002 2003 def setUp(self): 2004 self.first_proto = unittest_pb2.TestAllTypes() 2005 self.second_proto = unittest_pb2.TestAllTypes() 2006 2007 def testNotHashable(self): 2008 self.assertRaises(TypeError, hash, self.first_proto) 2009 2010 def testSelfEquality(self): 2011 self.assertEqual(self.first_proto, self.first_proto) 2012 2013 def testEmptyProtosEqual(self): 2014 self.assertEqual(self.first_proto, self.second_proto) 2015 2016 2017@testing_refleaks.TestCase 2018class FullProtosEqualityTest(unittest.TestCase): 2019 2020 """Equality tests using completely-full protos as a starting point.""" 2021 2022 def setUp(self): 2023 self.first_proto = unittest_pb2.TestAllTypes() 2024 self.second_proto = unittest_pb2.TestAllTypes() 2025 test_util.SetAllFields(self.first_proto) 2026 test_util.SetAllFields(self.second_proto) 2027 2028 def testNotHashable(self): 2029 self.assertRaises(TypeError, hash, self.first_proto) 2030 2031 def testNoneNotEqual(self): 2032 self.assertNotEqual(self.first_proto, None) 2033 self.assertNotEqual(None, self.second_proto) 2034 2035 def testNotEqualToOtherMessage(self): 2036 third_proto = unittest_pb2.TestRequired() 2037 self.assertNotEqual(self.first_proto, third_proto) 2038 self.assertNotEqual(third_proto, self.second_proto) 2039 2040 def testAllFieldsFilledEquality(self): 2041 self.assertEqual(self.first_proto, self.second_proto) 2042 2043 def testNonRepeatedScalar(self): 2044 # Nonrepeated scalar field change should cause inequality. 2045 self.first_proto.optional_int32 += 1 2046 self.assertNotEqual(self.first_proto, self.second_proto) 2047 # ...as should clearing a field. 2048 self.first_proto.ClearField('optional_int32') 2049 self.assertNotEqual(self.first_proto, self.second_proto) 2050 2051 def testNonRepeatedComposite(self): 2052 # Change a nonrepeated composite field. 2053 self.first_proto.optional_nested_message.bb += 1 2054 self.assertNotEqual(self.first_proto, self.second_proto) 2055 self.first_proto.optional_nested_message.bb -= 1 2056 self.assertEqual(self.first_proto, self.second_proto) 2057 # Clear a field in the nested message. 2058 self.first_proto.optional_nested_message.ClearField('bb') 2059 self.assertNotEqual(self.first_proto, self.second_proto) 2060 self.first_proto.optional_nested_message.bb = ( 2061 self.second_proto.optional_nested_message.bb) 2062 self.assertEqual(self.first_proto, self.second_proto) 2063 # Remove the nested message entirely. 2064 self.first_proto.ClearField('optional_nested_message') 2065 self.assertNotEqual(self.first_proto, self.second_proto) 2066 2067 def testRepeatedScalar(self): 2068 # Change a repeated scalar field. 2069 self.first_proto.repeated_int32.append(5) 2070 self.assertNotEqual(self.first_proto, self.second_proto) 2071 self.first_proto.ClearField('repeated_int32') 2072 self.assertNotEqual(self.first_proto, self.second_proto) 2073 2074 def testRepeatedComposite(self): 2075 # Change value within a repeated composite field. 2076 self.first_proto.repeated_nested_message[0].bb += 1 2077 self.assertNotEqual(self.first_proto, self.second_proto) 2078 self.first_proto.repeated_nested_message[0].bb -= 1 2079 self.assertEqual(self.first_proto, self.second_proto) 2080 # Add a value to a repeated composite field. 2081 self.first_proto.repeated_nested_message.add() 2082 self.assertNotEqual(self.first_proto, self.second_proto) 2083 self.second_proto.repeated_nested_message.add() 2084 self.assertEqual(self.first_proto, self.second_proto) 2085 2086 def testNonRepeatedScalarHasBits(self): 2087 # Ensure that we test "has" bits as well as value for 2088 # nonrepeated scalar field. 2089 self.first_proto.ClearField('optional_int32') 2090 self.second_proto.optional_int32 = 0 2091 self.assertNotEqual(self.first_proto, self.second_proto) 2092 2093 def testNonRepeatedCompositeHasBits(self): 2094 # Ensure that we test "has" bits as well as value for 2095 # nonrepeated composite field. 2096 self.first_proto.ClearField('optional_nested_message') 2097 self.second_proto.optional_nested_message.ClearField('bb') 2098 self.assertNotEqual(self.first_proto, self.second_proto) 2099 self.first_proto.optional_nested_message.bb = 0 2100 self.first_proto.optional_nested_message.ClearField('bb') 2101 self.assertEqual(self.first_proto, self.second_proto) 2102 2103 2104@testing_refleaks.TestCase 2105class ExtensionEqualityTest(unittest.TestCase): 2106 2107 def testExtensionEquality(self): 2108 first_proto = unittest_pb2.TestAllExtensions() 2109 second_proto = unittest_pb2.TestAllExtensions() 2110 self.assertEqual(first_proto, second_proto) 2111 test_util.SetAllExtensions(first_proto) 2112 self.assertNotEqual(first_proto, second_proto) 2113 test_util.SetAllExtensions(second_proto) 2114 self.assertEqual(first_proto, second_proto) 2115 2116 # Ensure that we check value equality. 2117 first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1 2118 self.assertNotEqual(first_proto, second_proto) 2119 first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1 2120 self.assertEqual(first_proto, second_proto) 2121 2122 # Ensure that we also look at "has" bits. 2123 first_proto.ClearExtension(unittest_pb2.optional_int32_extension) 2124 second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 2125 self.assertNotEqual(first_proto, second_proto) 2126 first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 2127 self.assertEqual(first_proto, second_proto) 2128 2129 # Ensure that differences in cached values 2130 # don't matter if "has" bits are both false. 2131 first_proto = unittest_pb2.TestAllExtensions() 2132 second_proto = unittest_pb2.TestAllExtensions() 2133 self.assertEqual( 2134 0, first_proto.Extensions[unittest_pb2.optional_int32_extension]) 2135 self.assertEqual(first_proto, second_proto) 2136 2137 2138@testing_refleaks.TestCase 2139class MutualRecursionEqualityTest(unittest.TestCase): 2140 2141 def testEqualityWithMutualRecursion(self): 2142 first_proto = unittest_pb2.TestMutualRecursionA() 2143 second_proto = unittest_pb2.TestMutualRecursionA() 2144 self.assertEqual(first_proto, second_proto) 2145 first_proto.bb.a.bb.optional_int32 = 23 2146 self.assertNotEqual(first_proto, second_proto) 2147 second_proto.bb.a.bb.optional_int32 = 23 2148 self.assertEqual(first_proto, second_proto) 2149 2150 2151@testing_refleaks.TestCase 2152class ByteSizeTest(unittest.TestCase): 2153 2154 def setUp(self): 2155 self.proto = unittest_pb2.TestAllTypes() 2156 self.extended_proto = more_extensions_pb2.ExtendedMessage() 2157 self.packed_proto = unittest_pb2.TestPackedTypes() 2158 self.packed_extended_proto = unittest_pb2.TestPackedExtensions() 2159 2160 def Size(self): 2161 return self.proto.ByteSize() 2162 2163 def testEmptyMessage(self): 2164 self.assertEqual(0, self.proto.ByteSize()) 2165 2166 def testSizedOnKwargs(self): 2167 # Use a separate message to ensure testing right after creation. 2168 proto = unittest_pb2.TestAllTypes() 2169 self.assertEqual(0, proto.ByteSize()) 2170 proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1) 2171 # One byte for the tag, one to encode varint 1. 2172 self.assertEqual(2, proto_kwargs.ByteSize()) 2173 2174 def testVarints(self): 2175 def Test(i, expected_varint_size): 2176 self.proto.Clear() 2177 self.proto.optional_int64 = i 2178 # Add one to the varint size for the tag info 2179 # for tag 1. 2180 self.assertEqual(expected_varint_size + 1, self.Size()) 2181 Test(0, 1) 2182 Test(1, 1) 2183 for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)): 2184 Test((1 << i) - 1, num_bytes) 2185 Test(-1, 10) 2186 Test(-2, 10) 2187 Test(-(1 << 63), 10) 2188 2189 def testStrings(self): 2190 self.proto.optional_string = '' 2191 # Need one byte for tag info (tag #14), and one byte for length. 2192 self.assertEqual(2, self.Size()) 2193 2194 self.proto.optional_string = 'abc' 2195 # Need one byte for tag info (tag #14), and one byte for length. 2196 self.assertEqual(2 + len(self.proto.optional_string), self.Size()) 2197 2198 self.proto.optional_string = 'x' * 128 2199 # Need one byte for tag info (tag #14), and TWO bytes for length. 2200 self.assertEqual(3 + len(self.proto.optional_string), self.Size()) 2201 2202 def testOtherNumerics(self): 2203 self.proto.optional_fixed32 = 1234 2204 # One byte for tag and 4 bytes for fixed32. 2205 self.assertEqual(5, self.Size()) 2206 self.proto = unittest_pb2.TestAllTypes() 2207 2208 self.proto.optional_fixed64 = 1234 2209 # One byte for tag and 8 bytes for fixed64. 2210 self.assertEqual(9, self.Size()) 2211 self.proto = unittest_pb2.TestAllTypes() 2212 2213 self.proto.optional_float = 1.234 2214 # One byte for tag and 4 bytes for float. 2215 self.assertEqual(5, self.Size()) 2216 self.proto = unittest_pb2.TestAllTypes() 2217 2218 self.proto.optional_double = 1.234 2219 # One byte for tag and 8 bytes for float. 2220 self.assertEqual(9, self.Size()) 2221 self.proto = unittest_pb2.TestAllTypes() 2222 2223 self.proto.optional_sint32 = 64 2224 # One byte for tag and 2 bytes for zig-zag-encoded 64. 2225 self.assertEqual(3, self.Size()) 2226 self.proto = unittest_pb2.TestAllTypes() 2227 2228 def testComposites(self): 2229 # 3 bytes. 2230 self.proto.optional_nested_message.bb = (1 << 14) 2231 # Plus one byte for bb tag. 2232 # Plus 1 byte for optional_nested_message serialized size. 2233 # Plus two bytes for optional_nested_message tag. 2234 self.assertEqual(3 + 1 + 1 + 2, self.Size()) 2235 2236 def testGroups(self): 2237 # 4 bytes. 2238 self.proto.optionalgroup.a = (1 << 21) 2239 # Plus two bytes for |a| tag. 2240 # Plus 2 * two bytes for START_GROUP and END_GROUP tags. 2241 self.assertEqual(4 + 2 + 2*2, self.Size()) 2242 2243 def testRepeatedScalars(self): 2244 self.proto.repeated_int32.append(10) # 1 byte. 2245 self.proto.repeated_int32.append(128) # 2 bytes. 2246 # Also need 2 bytes for each entry for tag. 2247 self.assertEqual(1 + 2 + 2*2, self.Size()) 2248 2249 def testRepeatedScalarsExtend(self): 2250 self.proto.repeated_int32.extend([10, 128]) # 3 bytes. 2251 # Also need 2 bytes for each entry for tag. 2252 self.assertEqual(1 + 2 + 2*2, self.Size()) 2253 2254 def testRepeatedScalarsRemove(self): 2255 self.proto.repeated_int32.append(10) # 1 byte. 2256 self.proto.repeated_int32.append(128) # 2 bytes. 2257 # Also need 2 bytes for each entry for tag. 2258 self.assertEqual(1 + 2 + 2*2, self.Size()) 2259 self.proto.repeated_int32.remove(128) 2260 self.assertEqual(1 + 2, self.Size()) 2261 2262 def testRepeatedComposites(self): 2263 # Empty message. 2 bytes tag plus 1 byte length. 2264 foreign_message_0 = self.proto.repeated_nested_message.add() 2265 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2266 foreign_message_1 = self.proto.repeated_nested_message.add() 2267 foreign_message_1.bb = 7 2268 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2269 2270 def testRepeatedCompositesDelete(self): 2271 # Empty message. 2 bytes tag plus 1 byte length. 2272 foreign_message_0 = self.proto.repeated_nested_message.add() 2273 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2274 foreign_message_1 = self.proto.repeated_nested_message.add() 2275 foreign_message_1.bb = 9 2276 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2277 repeated_nested_message = copy.deepcopy( 2278 self.proto.repeated_nested_message) 2279 2280 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2281 del self.proto.repeated_nested_message[0] 2282 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2283 2284 # Now add a new message. 2285 foreign_message_2 = self.proto.repeated_nested_message.add() 2286 foreign_message_2.bb = 12 2287 2288 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2289 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2290 self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size()) 2291 2292 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2293 del self.proto.repeated_nested_message[1] 2294 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2295 2296 del self.proto.repeated_nested_message[0] 2297 self.assertEqual(0, self.Size()) 2298 2299 self.assertEqual(2, len(repeated_nested_message)) 2300 del repeated_nested_message[0:1] 2301 # TODO(jieluo): Fix cpp extension bug when delete repeated message. 2302 if api_implementation.Type() == 'python': 2303 self.assertEqual(1, len(repeated_nested_message)) 2304 del repeated_nested_message[-1] 2305 # TODO(jieluo): Fix cpp extension bug when delete repeated message. 2306 if api_implementation.Type() == 'python': 2307 self.assertEqual(0, len(repeated_nested_message)) 2308 2309 def testRepeatedGroups(self): 2310 # 2-byte START_GROUP plus 2-byte END_GROUP. 2311 group_0 = self.proto.repeatedgroup.add() 2312 # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a| 2313 # plus 2-byte END_GROUP. 2314 group_1 = self.proto.repeatedgroup.add() 2315 group_1.a = 7 2316 self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size()) 2317 2318 def testExtensions(self): 2319 proto = unittest_pb2.TestAllExtensions() 2320 self.assertEqual(0, proto.ByteSize()) 2321 extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte. 2322 proto.Extensions[extension] = 23 2323 # 1 byte for tag, 1 byte for value. 2324 self.assertEqual(2, proto.ByteSize()) 2325 field = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[ 2326 'optional_int32'] 2327 with self.assertRaises(KeyError): 2328 proto.Extensions[field] = 23 2329 2330 def testCacheInvalidationForNonrepeatedScalar(self): 2331 # Test non-extension. 2332 self.proto.optional_int32 = 1 2333 self.assertEqual(2, self.proto.ByteSize()) 2334 self.proto.optional_int32 = 128 2335 self.assertEqual(3, self.proto.ByteSize()) 2336 self.proto.ClearField('optional_int32') 2337 self.assertEqual(0, self.proto.ByteSize()) 2338 2339 # Test within extension. 2340 extension = more_extensions_pb2.optional_int_extension 2341 self.extended_proto.Extensions[extension] = 1 2342 self.assertEqual(2, self.extended_proto.ByteSize()) 2343 self.extended_proto.Extensions[extension] = 128 2344 self.assertEqual(3, self.extended_proto.ByteSize()) 2345 self.extended_proto.ClearExtension(extension) 2346 self.assertEqual(0, self.extended_proto.ByteSize()) 2347 2348 def testCacheInvalidationForRepeatedScalar(self): 2349 # Test non-extension. 2350 self.proto.repeated_int32.append(1) 2351 self.assertEqual(3, self.proto.ByteSize()) 2352 self.proto.repeated_int32.append(1) 2353 self.assertEqual(6, self.proto.ByteSize()) 2354 self.proto.repeated_int32[1] = 128 2355 self.assertEqual(7, self.proto.ByteSize()) 2356 self.proto.ClearField('repeated_int32') 2357 self.assertEqual(0, self.proto.ByteSize()) 2358 2359 # Test within extension. 2360 extension = more_extensions_pb2.repeated_int_extension 2361 repeated = self.extended_proto.Extensions[extension] 2362 repeated.append(1) 2363 self.assertEqual(2, self.extended_proto.ByteSize()) 2364 repeated.append(1) 2365 self.assertEqual(4, self.extended_proto.ByteSize()) 2366 repeated[1] = 128 2367 self.assertEqual(5, self.extended_proto.ByteSize()) 2368 self.extended_proto.ClearExtension(extension) 2369 self.assertEqual(0, self.extended_proto.ByteSize()) 2370 2371 def testCacheInvalidationForNonrepeatedMessage(self): 2372 # Test non-extension. 2373 self.proto.optional_foreign_message.c = 1 2374 self.assertEqual(5, self.proto.ByteSize()) 2375 self.proto.optional_foreign_message.c = 128 2376 self.assertEqual(6, self.proto.ByteSize()) 2377 self.proto.optional_foreign_message.ClearField('c') 2378 self.assertEqual(3, self.proto.ByteSize()) 2379 self.proto.ClearField('optional_foreign_message') 2380 self.assertEqual(0, self.proto.ByteSize()) 2381 2382 if api_implementation.Type() == 'python': 2383 # This is only possible in pure-Python implementation of the API. 2384 child = self.proto.optional_foreign_message 2385 self.proto.ClearField('optional_foreign_message') 2386 child.c = 128 2387 self.assertEqual(0, self.proto.ByteSize()) 2388 2389 # Test within extension. 2390 extension = more_extensions_pb2.optional_message_extension 2391 child = self.extended_proto.Extensions[extension] 2392 self.assertEqual(0, self.extended_proto.ByteSize()) 2393 child.foreign_message_int = 1 2394 self.assertEqual(4, self.extended_proto.ByteSize()) 2395 child.foreign_message_int = 128 2396 self.assertEqual(5, self.extended_proto.ByteSize()) 2397 self.extended_proto.ClearExtension(extension) 2398 self.assertEqual(0, self.extended_proto.ByteSize()) 2399 2400 def testCacheInvalidationForRepeatedMessage(self): 2401 # Test non-extension. 2402 child0 = self.proto.repeated_foreign_message.add() 2403 self.assertEqual(3, self.proto.ByteSize()) 2404 self.proto.repeated_foreign_message.add() 2405 self.assertEqual(6, self.proto.ByteSize()) 2406 child0.c = 1 2407 self.assertEqual(8, self.proto.ByteSize()) 2408 self.proto.ClearField('repeated_foreign_message') 2409 self.assertEqual(0, self.proto.ByteSize()) 2410 2411 # Test within extension. 2412 extension = more_extensions_pb2.repeated_message_extension 2413 child_list = self.extended_proto.Extensions[extension] 2414 child0 = child_list.add() 2415 self.assertEqual(2, self.extended_proto.ByteSize()) 2416 child_list.add() 2417 self.assertEqual(4, self.extended_proto.ByteSize()) 2418 child0.foreign_message_int = 1 2419 self.assertEqual(6, self.extended_proto.ByteSize()) 2420 child0.ClearField('foreign_message_int') 2421 self.assertEqual(4, self.extended_proto.ByteSize()) 2422 self.extended_proto.ClearExtension(extension) 2423 self.assertEqual(0, self.extended_proto.ByteSize()) 2424 2425 def testPackedRepeatedScalars(self): 2426 self.assertEqual(0, self.packed_proto.ByteSize()) 2427 2428 self.packed_proto.packed_int32.append(10) # 1 byte. 2429 self.packed_proto.packed_int32.append(128) # 2 bytes. 2430 # The tag is 2 bytes (the field number is 90), and the varint 2431 # storing the length is 1 byte. 2432 int_size = 1 + 2 + 3 2433 self.assertEqual(int_size, self.packed_proto.ByteSize()) 2434 2435 self.packed_proto.packed_double.append(4.2) # 8 bytes 2436 self.packed_proto.packed_double.append(3.25) # 8 bytes 2437 # 2 more tag bytes, 1 more length byte. 2438 double_size = 8 + 8 + 3 2439 self.assertEqual(int_size+double_size, self.packed_proto.ByteSize()) 2440 2441 self.packed_proto.ClearField('packed_int32') 2442 self.assertEqual(double_size, self.packed_proto.ByteSize()) 2443 2444 def testPackedExtensions(self): 2445 self.assertEqual(0, self.packed_extended_proto.ByteSize()) 2446 extension = self.packed_extended_proto.Extensions[ 2447 unittest_pb2.packed_fixed32_extension] 2448 extension.extend([1, 2, 3, 4]) # 16 bytes 2449 # Tag is 3 bytes. 2450 self.assertEqual(19, self.packed_extended_proto.ByteSize()) 2451 2452 2453# Issues to be sure to cover include: 2454# * Handling of unrecognized tags ("uninterpreted_bytes"). 2455# * Handling of MessageSets. 2456# * Consistent ordering of tags in the wire format, 2457# including ordering between extensions and non-extension 2458# fields. 2459# * Consistent serialization of negative numbers, especially 2460# negative int32s. 2461# * Handling of empty submessages (with and without "has" 2462# bits set). 2463 2464@testing_refleaks.TestCase 2465class SerializationTest(unittest.TestCase): 2466 2467 def testSerializeEmtpyMessage(self): 2468 first_proto = unittest_pb2.TestAllTypes() 2469 second_proto = unittest_pb2.TestAllTypes() 2470 serialized = first_proto.SerializeToString() 2471 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2472 self.assertEqual( 2473 len(serialized), 2474 second_proto.MergeFromString(serialized)) 2475 self.assertEqual(first_proto, second_proto) 2476 2477 def testSerializeAllFields(self): 2478 first_proto = unittest_pb2.TestAllTypes() 2479 second_proto = unittest_pb2.TestAllTypes() 2480 test_util.SetAllFields(first_proto) 2481 serialized = first_proto.SerializeToString() 2482 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2483 self.assertEqual( 2484 len(serialized), 2485 second_proto.MergeFromString(serialized)) 2486 self.assertEqual(first_proto, second_proto) 2487 2488 def testSerializeAllExtensions(self): 2489 first_proto = unittest_pb2.TestAllExtensions() 2490 second_proto = unittest_pb2.TestAllExtensions() 2491 test_util.SetAllExtensions(first_proto) 2492 serialized = first_proto.SerializeToString() 2493 self.assertEqual( 2494 len(serialized), 2495 second_proto.MergeFromString(serialized)) 2496 self.assertEqual(first_proto, second_proto) 2497 2498 def testSerializeWithOptionalGroup(self): 2499 first_proto = unittest_pb2.TestAllTypes() 2500 second_proto = unittest_pb2.TestAllTypes() 2501 first_proto.optionalgroup.a = 242 2502 serialized = first_proto.SerializeToString() 2503 self.assertEqual( 2504 len(serialized), 2505 second_proto.MergeFromString(serialized)) 2506 self.assertEqual(first_proto, second_proto) 2507 2508 def testSerializeNegativeValues(self): 2509 first_proto = unittest_pb2.TestAllTypes() 2510 2511 first_proto.optional_int32 = -1 2512 first_proto.optional_int64 = -(2 << 40) 2513 first_proto.optional_sint32 = -3 2514 first_proto.optional_sint64 = -(4 << 40) 2515 first_proto.optional_sfixed32 = -5 2516 first_proto.optional_sfixed64 = -(6 << 40) 2517 2518 second_proto = unittest_pb2.TestAllTypes.FromString( 2519 first_proto.SerializeToString()) 2520 2521 self.assertEqual(first_proto, second_proto) 2522 2523 def testParseTruncated(self): 2524 # This test is only applicable for the Python implementation of the API. 2525 if api_implementation.Type() != 'python': 2526 return 2527 2528 first_proto = unittest_pb2.TestAllTypes() 2529 test_util.SetAllFields(first_proto) 2530 serialized = memoryview(first_proto.SerializeToString()) 2531 2532 for truncation_point in range(len(serialized) + 1): 2533 try: 2534 second_proto = unittest_pb2.TestAllTypes() 2535 unknown_fields = unittest_pb2.TestEmptyMessage() 2536 pos = second_proto._InternalParse(serialized, 0, truncation_point) 2537 # If we didn't raise an error then we read exactly the amount expected. 2538 self.assertEqual(truncation_point, pos) 2539 2540 # Parsing to unknown fields should not throw if parsing to known fields 2541 # did not. 2542 try: 2543 pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point) 2544 self.assertEqual(truncation_point, pos2) 2545 except message.DecodeError: 2546 self.fail('Parsing unknown fields failed when parsing known fields ' 2547 'did not.') 2548 except message.DecodeError: 2549 # Parsing unknown fields should also fail. 2550 self.assertRaises(message.DecodeError, unknown_fields._InternalParse, 2551 serialized, 0, truncation_point) 2552 2553 def testCanonicalSerializationOrder(self): 2554 proto = more_messages_pb2.OutOfOrderFields() 2555 # These are also their tag numbers. Even though we're setting these in 2556 # reverse-tag order AND they're listed in reverse tag-order in the .proto 2557 # file, they should nonetheless be serialized in tag order. 2558 proto.optional_sint32 = 5 2559 proto.Extensions[more_messages_pb2.optional_uint64] = 4 2560 proto.optional_uint32 = 3 2561 proto.Extensions[more_messages_pb2.optional_int64] = 2 2562 proto.optional_int32 = 1 2563 serialized = proto.SerializeToString() 2564 self.assertEqual(proto.ByteSize(), len(serialized)) 2565 d = _MiniDecoder(serialized) 2566 ReadTag = d.ReadFieldNumberAndWireType 2567 self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) 2568 self.assertEqual(1, d.ReadInt32()) 2569 self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag()) 2570 self.assertEqual(2, d.ReadInt64()) 2571 self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag()) 2572 self.assertEqual(3, d.ReadUInt32()) 2573 self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag()) 2574 self.assertEqual(4, d.ReadUInt64()) 2575 self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag()) 2576 self.assertEqual(5, d.ReadSInt32()) 2577 2578 def testCanonicalSerializationOrderSameAsCpp(self): 2579 # Copy of the same test we use for C++. 2580 proto = unittest_pb2.TestFieldOrderings() 2581 test_util.SetAllFieldsAndExtensions(proto) 2582 serialized = proto.SerializeToString() 2583 test_util.ExpectAllFieldsAndExtensionsInOrder(serialized) 2584 2585 def testMergeFromStringWhenFieldsAlreadySet(self): 2586 first_proto = unittest_pb2.TestAllTypes() 2587 first_proto.repeated_string.append('foobar') 2588 first_proto.optional_int32 = 23 2589 first_proto.optional_nested_message.bb = 42 2590 serialized = first_proto.SerializeToString() 2591 2592 second_proto = unittest_pb2.TestAllTypes() 2593 second_proto.repeated_string.append('baz') 2594 second_proto.optional_int32 = 100 2595 second_proto.optional_nested_message.bb = 999 2596 2597 bytes_parsed = second_proto.MergeFromString(serialized) 2598 self.assertEqual(len(serialized), bytes_parsed) 2599 2600 # Ensure that we append to repeated fields. 2601 self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) 2602 # Ensure that we overwrite nonrepeatd scalars. 2603 self.assertEqual(23, second_proto.optional_int32) 2604 # Ensure that we recursively call MergeFromString() on 2605 # submessages. 2606 self.assertEqual(42, second_proto.optional_nested_message.bb) 2607 2608 def testMessageSetWireFormat(self): 2609 proto = message_set_extensions_pb2.TestMessageSet() 2610 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2611 extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2 2612 extension1 = extension_message1.message_set_extension 2613 extension2 = extension_message2.message_set_extension 2614 extension3 = message_set_extensions_pb2.message_set_extension3 2615 proto.Extensions[extension1].i = 123 2616 proto.Extensions[extension2].str = 'foo' 2617 proto.Extensions[extension3].text = 'bar' 2618 2619 # Serialize using the MessageSet wire format (this is specified in the 2620 # .proto file). 2621 serialized = proto.SerializeToString() 2622 2623 raw = unittest_mset_pb2.RawMessageSet() 2624 self.assertEqual(False, 2625 raw.DESCRIPTOR.GetOptions().message_set_wire_format) 2626 self.assertEqual( 2627 len(serialized), 2628 raw.MergeFromString(serialized)) 2629 self.assertEqual(3, len(raw.item)) 2630 2631 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2632 self.assertEqual( 2633 len(raw.item[0].message), 2634 message1.MergeFromString(raw.item[0].message)) 2635 self.assertEqual(123, message1.i) 2636 2637 message2 = message_set_extensions_pb2.TestMessageSetExtension2() 2638 self.assertEqual( 2639 len(raw.item[1].message), 2640 message2.MergeFromString(raw.item[1].message)) 2641 self.assertEqual('foo', message2.str) 2642 2643 message3 = message_set_extensions_pb2.TestMessageSetExtension3() 2644 self.assertEqual( 2645 len(raw.item[2].message), 2646 message3.MergeFromString(raw.item[2].message)) 2647 self.assertEqual('bar', message3.text) 2648 2649 # Deserialize using the MessageSet wire format. 2650 proto2 = message_set_extensions_pb2.TestMessageSet() 2651 self.assertEqual( 2652 len(serialized), 2653 proto2.MergeFromString(serialized)) 2654 self.assertEqual(123, proto2.Extensions[extension1].i) 2655 self.assertEqual('foo', proto2.Extensions[extension2].str) 2656 self.assertEqual('bar', proto2.Extensions[extension3].text) 2657 2658 # Check byte size. 2659 self.assertEqual(proto2.ByteSize(), len(serialized)) 2660 self.assertEqual(proto.ByteSize(), len(serialized)) 2661 2662 def testMessageSetWireFormatUnknownExtension(self): 2663 # Create a message using the message set wire format with an unknown 2664 # message. 2665 raw = unittest_mset_pb2.RawMessageSet() 2666 2667 # Add an item. 2668 item = raw.item.add() 2669 item.type_id = 98418603 2670 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2671 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2672 message1.i = 12345 2673 item.message = message1.SerializeToString() 2674 2675 # Add a second, unknown extension. 2676 item = raw.item.add() 2677 item.type_id = 98418604 2678 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2679 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2680 message1.i = 12346 2681 item.message = message1.SerializeToString() 2682 2683 # Add another unknown extension. 2684 item = raw.item.add() 2685 item.type_id = 98418605 2686 message1 = message_set_extensions_pb2.TestMessageSetExtension2() 2687 message1.str = 'foo' 2688 item.message = message1.SerializeToString() 2689 2690 serialized = raw.SerializeToString() 2691 2692 # Parse message using the message set wire format. 2693 proto = message_set_extensions_pb2.TestMessageSet() 2694 self.assertEqual( 2695 len(serialized), 2696 proto.MergeFromString(serialized)) 2697 2698 # Check that the message parsed well. 2699 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2700 extension1 = extension_message1.message_set_extension 2701 self.assertEqual(12345, proto.Extensions[extension1].i) 2702 2703 def testUnknownFields(self): 2704 proto = unittest_pb2.TestAllTypes() 2705 test_util.SetAllFields(proto) 2706 2707 serialized = proto.SerializeToString() 2708 2709 # The empty message should be parsable with all of the fields 2710 # unknown. 2711 proto2 = unittest_pb2.TestEmptyMessage() 2712 2713 # Parsing this message should succeed. 2714 self.assertEqual( 2715 len(serialized), 2716 proto2.MergeFromString(serialized)) 2717 2718 # Now test with a int64 field set. 2719 proto = unittest_pb2.TestAllTypes() 2720 proto.optional_int64 = 0x0fffffffffffffff 2721 serialized = proto.SerializeToString() 2722 # The empty message should be parsable with all of the fields 2723 # unknown. 2724 proto2 = unittest_pb2.TestEmptyMessage() 2725 # Parsing this message should succeed. 2726 self.assertEqual( 2727 len(serialized), 2728 proto2.MergeFromString(serialized)) 2729 2730 def _CheckRaises(self, exc_class, callable_obj, exception): 2731 """This method checks if the excpetion type and message are as expected.""" 2732 try: 2733 callable_obj() 2734 except exc_class as ex: 2735 # Check if the exception message is the right one. 2736 self.assertEqual(exception, str(ex)) 2737 return 2738 else: 2739 raise self.failureException('%s not raised' % str(exc_class)) 2740 2741 def testSerializeUninitialized(self): 2742 proto = unittest_pb2.TestRequired() 2743 self._CheckRaises( 2744 message.EncodeError, 2745 proto.SerializeToString, 2746 'Message protobuf_unittest.TestRequired is missing required fields: ' 2747 'a,b,c') 2748 # Shouldn't raise exceptions. 2749 partial = proto.SerializePartialToString() 2750 2751 proto2 = unittest_pb2.TestRequired() 2752 self.assertFalse(proto2.HasField('a')) 2753 # proto2 ParseFromString does not check that required fields are set. 2754 proto2.ParseFromString(partial) 2755 self.assertFalse(proto2.HasField('a')) 2756 2757 proto.a = 1 2758 self._CheckRaises( 2759 message.EncodeError, 2760 proto.SerializeToString, 2761 'Message protobuf_unittest.TestRequired is missing required fields: b,c') 2762 # Shouldn't raise exceptions. 2763 partial = proto.SerializePartialToString() 2764 2765 proto.b = 2 2766 self._CheckRaises( 2767 message.EncodeError, 2768 proto.SerializeToString, 2769 'Message protobuf_unittest.TestRequired is missing required fields: c') 2770 # Shouldn't raise exceptions. 2771 partial = proto.SerializePartialToString() 2772 2773 proto.c = 3 2774 serialized = proto.SerializeToString() 2775 # Shouldn't raise exceptions. 2776 partial = proto.SerializePartialToString() 2777 2778 proto2 = unittest_pb2.TestRequired() 2779 self.assertEqual( 2780 len(serialized), 2781 proto2.MergeFromString(serialized)) 2782 self.assertEqual(1, proto2.a) 2783 self.assertEqual(2, proto2.b) 2784 self.assertEqual(3, proto2.c) 2785 self.assertEqual( 2786 len(partial), 2787 proto2.MergeFromString(partial)) 2788 self.assertEqual(1, proto2.a) 2789 self.assertEqual(2, proto2.b) 2790 self.assertEqual(3, proto2.c) 2791 2792 def testSerializeUninitializedSubMessage(self): 2793 proto = unittest_pb2.TestRequiredForeign() 2794 2795 # Sub-message doesn't exist yet, so this succeeds. 2796 proto.SerializeToString() 2797 2798 proto.optional_message.a = 1 2799 self._CheckRaises( 2800 message.EncodeError, 2801 proto.SerializeToString, 2802 'Message protobuf_unittest.TestRequiredForeign ' 2803 'is missing required fields: ' 2804 'optional_message.b,optional_message.c') 2805 2806 proto.optional_message.b = 2 2807 proto.optional_message.c = 3 2808 proto.SerializeToString() 2809 2810 proto.repeated_message.add().a = 1 2811 proto.repeated_message.add().b = 2 2812 self._CheckRaises( 2813 message.EncodeError, 2814 proto.SerializeToString, 2815 'Message protobuf_unittest.TestRequiredForeign is missing required fields: ' 2816 'repeated_message[0].b,repeated_message[0].c,' 2817 'repeated_message[1].a,repeated_message[1].c') 2818 2819 proto.repeated_message[0].b = 2 2820 proto.repeated_message[0].c = 3 2821 proto.repeated_message[1].a = 1 2822 proto.repeated_message[1].c = 3 2823 proto.SerializeToString() 2824 2825 def testSerializeAllPackedFields(self): 2826 first_proto = unittest_pb2.TestPackedTypes() 2827 second_proto = unittest_pb2.TestPackedTypes() 2828 test_util.SetAllPackedFields(first_proto) 2829 serialized = first_proto.SerializeToString() 2830 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2831 bytes_read = second_proto.MergeFromString(serialized) 2832 self.assertEqual(second_proto.ByteSize(), bytes_read) 2833 self.assertEqual(first_proto, second_proto) 2834 2835 def testSerializeAllPackedExtensions(self): 2836 first_proto = unittest_pb2.TestPackedExtensions() 2837 second_proto = unittest_pb2.TestPackedExtensions() 2838 test_util.SetAllPackedExtensions(first_proto) 2839 serialized = first_proto.SerializeToString() 2840 bytes_read = second_proto.MergeFromString(serialized) 2841 self.assertEqual(second_proto.ByteSize(), bytes_read) 2842 self.assertEqual(first_proto, second_proto) 2843 2844 def testMergePackedFromStringWhenSomeFieldsAlreadySet(self): 2845 first_proto = unittest_pb2.TestPackedTypes() 2846 first_proto.packed_int32.extend([1, 2]) 2847 first_proto.packed_double.append(3.0) 2848 serialized = first_proto.SerializeToString() 2849 2850 second_proto = unittest_pb2.TestPackedTypes() 2851 second_proto.packed_int32.append(3) 2852 second_proto.packed_double.extend([1.0, 2.0]) 2853 second_proto.packed_sint32.append(4) 2854 2855 self.assertEqual( 2856 len(serialized), 2857 second_proto.MergeFromString(serialized)) 2858 self.assertEqual([3, 1, 2], second_proto.packed_int32) 2859 self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) 2860 self.assertEqual([4], second_proto.packed_sint32) 2861 2862 def testPackedFieldsWireFormat(self): 2863 proto = unittest_pb2.TestPackedTypes() 2864 proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes 2865 proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes 2866 proto.packed_float.append(2.0) # 4 bytes, will be before double 2867 serialized = proto.SerializeToString() 2868 self.assertEqual(proto.ByteSize(), len(serialized)) 2869 d = _MiniDecoder(serialized) 2870 ReadTag = d.ReadFieldNumberAndWireType 2871 self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2872 self.assertEqual(1+1+1+2, d.ReadInt32()) 2873 self.assertEqual(1, d.ReadInt32()) 2874 self.assertEqual(2, d.ReadInt32()) 2875 self.assertEqual(150, d.ReadInt32()) 2876 self.assertEqual(3, d.ReadInt32()) 2877 self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2878 self.assertEqual(4, d.ReadInt32()) 2879 self.assertEqual(2.0, d.ReadFloat()) 2880 self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2881 self.assertEqual(8+8, d.ReadInt32()) 2882 self.assertEqual(1.0, d.ReadDouble()) 2883 self.assertEqual(1000.0, d.ReadDouble()) 2884 self.assertTrue(d.EndOfStream()) 2885 2886 def testParsePackedFromUnpacked(self): 2887 unpacked = unittest_pb2.TestUnpackedTypes() 2888 test_util.SetAllUnpackedFields(unpacked) 2889 packed = unittest_pb2.TestPackedTypes() 2890 serialized = unpacked.SerializeToString() 2891 self.assertEqual( 2892 len(serialized), 2893 packed.MergeFromString(serialized)) 2894 expected = unittest_pb2.TestPackedTypes() 2895 test_util.SetAllPackedFields(expected) 2896 self.assertEqual(expected, packed) 2897 2898 def testParseUnpackedFromPacked(self): 2899 packed = unittest_pb2.TestPackedTypes() 2900 test_util.SetAllPackedFields(packed) 2901 unpacked = unittest_pb2.TestUnpackedTypes() 2902 serialized = packed.SerializeToString() 2903 self.assertEqual( 2904 len(serialized), 2905 unpacked.MergeFromString(serialized)) 2906 expected = unittest_pb2.TestUnpackedTypes() 2907 test_util.SetAllUnpackedFields(expected) 2908 self.assertEqual(expected, unpacked) 2909 2910 def testFieldNumbers(self): 2911 proto = unittest_pb2.TestAllTypes() 2912 self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1) 2913 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1) 2914 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16) 2915 self.assertEqual( 2916 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18) 2917 self.assertEqual( 2918 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21) 2919 self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31) 2920 self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46) 2921 self.assertEqual( 2922 unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48) 2923 self.assertEqual( 2924 unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51) 2925 2926 def testExtensionFieldNumbers(self): 2927 self.assertEqual(unittest_pb2.TestRequired.single.number, 1000) 2928 self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000) 2929 self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001) 2930 self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001) 2931 self.assertEqual(unittest_pb2.optional_int32_extension.number, 1) 2932 self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1) 2933 self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16) 2934 self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16) 2935 self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18) 2936 self.assertEqual( 2937 unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18) 2938 self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21) 2939 self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 2940 21) 2941 self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31) 2942 self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31) 2943 self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46) 2944 self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46) 2945 self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48) 2946 self.assertEqual( 2947 unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48) 2948 self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51) 2949 self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 2950 51) 2951 2952 def testFieldProperties(self): 2953 cls = unittest_pb2.TestAllTypes 2954 self.assertIs(cls.optional_int32.DESCRIPTOR, 2955 cls.DESCRIPTOR.fields_by_name['optional_int32']) 2956 self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER, 2957 cls.optional_int32.DESCRIPTOR.number) 2958 self.assertIs(cls.optional_nested_message.DESCRIPTOR, 2959 cls.DESCRIPTOR.fields_by_name['optional_nested_message']) 2960 self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 2961 cls.optional_nested_message.DESCRIPTOR.number) 2962 self.assertIs(cls.repeated_int32.DESCRIPTOR, 2963 cls.DESCRIPTOR.fields_by_name['repeated_int32']) 2964 self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER, 2965 cls.repeated_int32.DESCRIPTOR.number) 2966 2967 def testFieldDataDescriptor(self): 2968 msg = unittest_pb2.TestAllTypes() 2969 msg.optional_int32 = 42 2970 self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42) 2971 unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25) 2972 self.assertEqual(msg.optional_int32, 25) 2973 with self.assertRaises(AttributeError): 2974 del msg.optional_int32 2975 try: 2976 unittest_pb2.ForeignMessage.c.__get__(msg) 2977 except TypeError: 2978 pass # The cpp implementation cannot mix fields from other messages. 2979 # This test exercises a specific check that avoids a crash. 2980 else: 2981 pass # The python implementation allows fields from other messages. 2982 # This is useless, but works. 2983 2984 def testInitKwargs(self): 2985 proto = unittest_pb2.TestAllTypes( 2986 optional_int32=1, 2987 optional_string='foo', 2988 optional_bool=True, 2989 optional_bytes=b'bar', 2990 optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), 2991 optional_foreign_message=unittest_pb2.ForeignMessage(c=1), 2992 optional_nested_enum=unittest_pb2.TestAllTypes.FOO, 2993 optional_foreign_enum=unittest_pb2.FOREIGN_FOO, 2994 repeated_int32=[1, 2, 3]) 2995 self.assertTrue(proto.IsInitialized()) 2996 self.assertTrue(proto.HasField('optional_int32')) 2997 self.assertTrue(proto.HasField('optional_string')) 2998 self.assertTrue(proto.HasField('optional_bool')) 2999 self.assertTrue(proto.HasField('optional_bytes')) 3000 self.assertTrue(proto.HasField('optional_nested_message')) 3001 self.assertTrue(proto.HasField('optional_foreign_message')) 3002 self.assertTrue(proto.HasField('optional_nested_enum')) 3003 self.assertTrue(proto.HasField('optional_foreign_enum')) 3004 self.assertEqual(1, proto.optional_int32) 3005 self.assertEqual('foo', proto.optional_string) 3006 self.assertEqual(True, proto.optional_bool) 3007 self.assertEqual(b'bar', proto.optional_bytes) 3008 self.assertEqual(1, proto.optional_nested_message.bb) 3009 self.assertEqual(1, proto.optional_foreign_message.c) 3010 self.assertEqual(unittest_pb2.TestAllTypes.FOO, 3011 proto.optional_nested_enum) 3012 self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum) 3013 self.assertEqual([1, 2, 3], proto.repeated_int32) 3014 3015 def testInitArgsUnknownFieldName(self): 3016 def InitalizeEmptyMessageWithExtraKeywordArg(): 3017 unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown') 3018 self._CheckRaises( 3019 ValueError, 3020 InitalizeEmptyMessageWithExtraKeywordArg, 3021 'Protocol message TestEmptyMessage has no "unknown" field.') 3022 3023 def testInitRequiredKwargs(self): 3024 proto = unittest_pb2.TestRequired(a=1, b=1, c=1) 3025 self.assertTrue(proto.IsInitialized()) 3026 self.assertTrue(proto.HasField('a')) 3027 self.assertTrue(proto.HasField('b')) 3028 self.assertTrue(proto.HasField('c')) 3029 self.assertTrue(not proto.HasField('dummy2')) 3030 self.assertEqual(1, proto.a) 3031 self.assertEqual(1, proto.b) 3032 self.assertEqual(1, proto.c) 3033 3034 def testInitRequiredForeignKwargs(self): 3035 proto = unittest_pb2.TestRequiredForeign( 3036 optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1)) 3037 self.assertTrue(proto.IsInitialized()) 3038 self.assertTrue(proto.HasField('optional_message')) 3039 self.assertTrue(proto.optional_message.IsInitialized()) 3040 self.assertTrue(proto.optional_message.HasField('a')) 3041 self.assertTrue(proto.optional_message.HasField('b')) 3042 self.assertTrue(proto.optional_message.HasField('c')) 3043 self.assertTrue(not proto.optional_message.HasField('dummy2')) 3044 self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1), 3045 proto.optional_message) 3046 self.assertEqual(1, proto.optional_message.a) 3047 self.assertEqual(1, proto.optional_message.b) 3048 self.assertEqual(1, proto.optional_message.c) 3049 3050 def testInitRepeatedKwargs(self): 3051 proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3]) 3052 self.assertTrue(proto.IsInitialized()) 3053 self.assertEqual(1, proto.repeated_int32[0]) 3054 self.assertEqual(2, proto.repeated_int32[1]) 3055 self.assertEqual(3, proto.repeated_int32[2]) 3056 3057 3058@testing_refleaks.TestCase 3059class OptionsTest(unittest.TestCase): 3060 3061 def testMessageOptions(self): 3062 proto = message_set_extensions_pb2.TestMessageSet() 3063 self.assertEqual(True, 3064 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 3065 proto = unittest_pb2.TestAllTypes() 3066 self.assertEqual(False, 3067 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 3068 3069 def testPackedOptions(self): 3070 proto = unittest_pb2.TestAllTypes() 3071 proto.optional_int32 = 1 3072 proto.optional_double = 3.0 3073 for field_descriptor, _ in proto.ListFields(): 3074 self.assertEqual(False, field_descriptor.GetOptions().packed) 3075 3076 proto = unittest_pb2.TestPackedTypes() 3077 proto.packed_int32.append(1) 3078 proto.packed_double.append(3.0) 3079 for field_descriptor, _ in proto.ListFields(): 3080 self.assertEqual(True, field_descriptor.GetOptions().packed) 3081 self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED, 3082 field_descriptor.label) 3083 3084 3085 3086@testing_refleaks.TestCase 3087class ClassAPITest(unittest.TestCase): 3088 3089 @unittest.skipIf( 3090 api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, 3091 'C++ implementation requires a call to MakeDescriptor()') 3092 @testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable') 3093 def testMakeClassWithNestedDescriptor(self): 3094 leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '', 3095 containing_type=None, fields=[], 3096 nested_types=[], enum_types=[], 3097 extensions=[]) 3098 child_desc = descriptor.Descriptor('child', 'package.parent.child', '', 3099 containing_type=None, fields=[], 3100 nested_types=[leaf_desc], enum_types=[], 3101 extensions=[]) 3102 sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling', 3103 '', containing_type=None, fields=[], 3104 nested_types=[], enum_types=[], 3105 extensions=[]) 3106 parent_desc = descriptor.Descriptor('parent', 'package.parent', '', 3107 containing_type=None, fields=[], 3108 nested_types=[child_desc, sibling_desc], 3109 enum_types=[], extensions=[]) 3110 reflection.MakeClass(parent_desc) 3111 3112 def _GetSerializedFileDescriptor(self, name): 3113 """Get a serialized representation of a test FileDescriptorProto. 3114 3115 Args: 3116 name: All calls to this must use a unique message name, to avoid 3117 collisions in the cpp descriptor pool. 3118 Returns: 3119 A string containing the serialized form of a test FileDescriptorProto. 3120 """ 3121 file_descriptor_str = ( 3122 'message_type {' 3123 ' name: "' + name + '"' 3124 ' field {' 3125 ' name: "flat"' 3126 ' number: 1' 3127 ' label: LABEL_REPEATED' 3128 ' type: TYPE_UINT32' 3129 ' }' 3130 ' field {' 3131 ' name: "bar"' 3132 ' number: 2' 3133 ' label: LABEL_OPTIONAL' 3134 ' type: TYPE_MESSAGE' 3135 ' type_name: "Bar"' 3136 ' }' 3137 ' nested_type {' 3138 ' name: "Bar"' 3139 ' field {' 3140 ' name: "baz"' 3141 ' number: 3' 3142 ' label: LABEL_OPTIONAL' 3143 ' type: TYPE_MESSAGE' 3144 ' type_name: "Baz"' 3145 ' }' 3146 ' nested_type {' 3147 ' name: "Baz"' 3148 ' enum_type {' 3149 ' name: "deep_enum"' 3150 ' value {' 3151 ' name: "VALUE_A"' 3152 ' number: 0' 3153 ' }' 3154 ' }' 3155 ' field {' 3156 ' name: "deep"' 3157 ' number: 4' 3158 ' label: LABEL_OPTIONAL' 3159 ' type: TYPE_UINT32' 3160 ' }' 3161 ' }' 3162 ' }' 3163 '}') 3164 file_descriptor = descriptor_pb2.FileDescriptorProto() 3165 text_format.Merge(file_descriptor_str, file_descriptor) 3166 return file_descriptor.SerializeToString() 3167 3168 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 3169 # This test can only run once; the second time, it raises errors about 3170 # conflicting message descriptors. 3171 def testParsingFlatClassWithExplicitClassDeclaration(self): 3172 """Test that the generated class can parse a flat message.""" 3173 # TODO(xiaofeng): This test fails with cpp implemetnation in the call 3174 # of six.with_metaclass(). The other two callsites of with_metaclass 3175 # in this file are both excluded from cpp test, so it might be expected 3176 # to fail. Need someone more familiar with the python code to take a 3177 # look at this. 3178 if api_implementation.Type() != 'python': 3179 return 3180 file_descriptor = descriptor_pb2.FileDescriptorProto() 3181 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) 3182 msg_descriptor = descriptor.MakeDescriptor( 3183 file_descriptor.message_type[0]) 3184 3185 class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): 3186 DESCRIPTOR = msg_descriptor 3187 msg = MessageClass() 3188 msg_str = ( 3189 'flat: 0 ' 3190 'flat: 1 ' 3191 'flat: 2 ') 3192 text_format.Merge(msg_str, msg) 3193 self.assertEqual(msg.flat, [0, 1, 2]) 3194 3195 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 3196 def testParsingFlatClass(self): 3197 """Test that the generated class can parse a flat message.""" 3198 file_descriptor = descriptor_pb2.FileDescriptorProto() 3199 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) 3200 msg_descriptor = descriptor.MakeDescriptor( 3201 file_descriptor.message_type[0]) 3202 msg_class = reflection.MakeClass(msg_descriptor) 3203 msg = msg_class() 3204 msg_str = ( 3205 'flat: 0 ' 3206 'flat: 1 ' 3207 'flat: 2 ') 3208 text_format.Merge(msg_str, msg) 3209 self.assertEqual(msg.flat, [0, 1, 2]) 3210 3211 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 3212 def testParsingNestedClass(self): 3213 """Test that the generated class can parse a nested message.""" 3214 file_descriptor = descriptor_pb2.FileDescriptorProto() 3215 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) 3216 msg_descriptor = descriptor.MakeDescriptor( 3217 file_descriptor.message_type[0]) 3218 msg_class = reflection.MakeClass(msg_descriptor) 3219 msg = msg_class() 3220 msg_str = ( 3221 'bar {' 3222 ' baz {' 3223 ' deep: 4' 3224 ' }' 3225 '}') 3226 text_format.Merge(msg_str, msg) 3227 self.assertEqual(msg.bar.baz.deep, 4) 3228 3229if __name__ == '__main__': 3230 unittest.main() 3231