1#! /usr/bin/env python 2# -*- coding: utf-8 -*- 3# 4# Protocol Buffers - Google's data interchange format 5# Copyright 2008 Google Inc. All rights reserved. 6# https://developers.google.com/protocol-buffers/ 7# 8# Redistribution and use in source and binary forms, with or without 9# modification, are permitted provided that the following conditions are 10# met: 11# 12# * Redistributions of source code must retain the above copyright 13# notice, this list of conditions and the following disclaimer. 14# * Redistributions in binary form must reproduce the above 15# copyright notice, this list of conditions and the following disclaimer 16# in the documentation and/or other materials provided with the 17# distribution. 18# * Neither the name of Google Inc. nor the names of its 19# contributors may be used to endorse or promote products derived from 20# this software without specific prior written permission. 21# 22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 34"""Unittest for reflection.py, which also indirectly tests the output of the 35pure-Python protocol compiler. 36""" 37 38import copy 39import gc 40import operator 41import six 42import struct 43import warnings 44 45try: 46 import unittest2 as unittest #PY26 47except ImportError: 48 import unittest 49 50from google.protobuf import unittest_import_pb2 51from google.protobuf import unittest_mset_pb2 52from google.protobuf import unittest_pb2 53from google.protobuf import unittest_proto3_arena_pb2 54from google.protobuf import descriptor_pb2 55from google.protobuf import descriptor 56from google.protobuf import message 57from google.protobuf import reflection 58from google.protobuf import text_format 59from google.protobuf.internal import api_implementation 60from google.protobuf.internal import more_extensions_pb2 61from google.protobuf.internal import more_messages_pb2 62from google.protobuf.internal import message_set_extensions_pb2 63from google.protobuf.internal import wire_format 64from google.protobuf.internal import test_util 65from google.protobuf.internal import testing_refleaks 66from google.protobuf.internal import decoder 67from google.protobuf.internal import _parameterized 68 69 70if six.PY3: 71 long = int # pylint: disable=redefined-builtin,invalid-name 72 73 74warnings.simplefilter('error', DeprecationWarning) 75 76 77class _MiniDecoder(object): 78 """Decodes a stream of values from a string. 79 80 Once upon a time we actually had a class called decoder.Decoder. Then we 81 got rid of it during a redesign that made decoding much, much faster overall. 82 But a couple tests in this file used it to check that the serialized form of 83 a message was correct. So, this class implements just the methods that were 84 used by said tests, so that we don't have to rewrite the tests. 85 """ 86 87 def __init__(self, bytes): 88 self._bytes = bytes 89 self._pos = 0 90 91 def ReadVarint(self): 92 result, self._pos = decoder._DecodeVarint(self._bytes, self._pos) 93 return result 94 95 ReadInt32 = ReadVarint 96 ReadInt64 = ReadVarint 97 ReadUInt32 = ReadVarint 98 ReadUInt64 = ReadVarint 99 100 def ReadSInt64(self): 101 return wire_format.ZigZagDecode(self.ReadVarint()) 102 103 ReadSInt32 = ReadSInt64 104 105 def ReadFieldNumberAndWireType(self): 106 return wire_format.UnpackTag(self.ReadVarint()) 107 108 def ReadFloat(self): 109 result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0] 110 self._pos += 4 111 return result 112 113 def ReadDouble(self): 114 result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0] 115 self._pos += 8 116 return result 117 118 def EndOfStream(self): 119 return self._pos == len(self._bytes) 120 121 122@_parameterized.named_parameters( 123 ('_proto2', unittest_pb2), 124 ('_proto3', unittest_proto3_arena_pb2)) 125@testing_refleaks.TestCase 126class ReflectionTest(unittest.TestCase): 127 128 def assertListsEqual(self, values, others): 129 self.assertEqual(len(values), len(others)) 130 for i in range(len(values)): 131 self.assertEqual(values[i], others[i]) 132 133 def testScalarConstructor(self, message_module): 134 # Constructor with only scalar types should succeed. 135 proto = message_module.TestAllTypes( 136 optional_int32=24, 137 optional_double=54.321, 138 optional_string='optional_string', 139 optional_float=None) 140 141 self.assertEqual(24, proto.optional_int32) 142 self.assertEqual(54.321, proto.optional_double) 143 self.assertEqual('optional_string', proto.optional_string) 144 if message_module is unittest_pb2: 145 self.assertFalse(proto.HasField("optional_float")) 146 147 def testRepeatedScalarConstructor(self, message_module): 148 # Constructor with only repeated scalar types should succeed. 149 proto = message_module.TestAllTypes( 150 repeated_int32=[1, 2, 3, 4], 151 repeated_double=[1.23, 54.321], 152 repeated_bool=[True, False, False], 153 repeated_string=["optional_string"], 154 repeated_float=None) 155 156 self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32)) 157 self.assertEqual([1.23, 54.321], list(proto.repeated_double)) 158 self.assertEqual([True, False, False], list(proto.repeated_bool)) 159 self.assertEqual(["optional_string"], list(proto.repeated_string)) 160 self.assertEqual([], list(proto.repeated_float)) 161 162 def testMixedConstructor(self, message_module): 163 # Constructor with only mixed types should succeed. 164 proto = message_module.TestAllTypes( 165 optional_int32=24, 166 optional_string='optional_string', 167 repeated_double=[1.23, 54.321], 168 repeated_bool=[True, False, False], 169 repeated_nested_message=[ 170 message_module.TestAllTypes.NestedMessage( 171 bb=message_module.TestAllTypes.FOO), 172 message_module.TestAllTypes.NestedMessage( 173 bb=message_module.TestAllTypes.BAR)], 174 repeated_foreign_message=[ 175 message_module.ForeignMessage(c=-43), 176 message_module.ForeignMessage(c=45324), 177 message_module.ForeignMessage(c=12)], 178 optional_nested_message=None) 179 180 self.assertEqual(24, proto.optional_int32) 181 self.assertEqual('optional_string', proto.optional_string) 182 self.assertEqual([1.23, 54.321], list(proto.repeated_double)) 183 self.assertEqual([True, False, False], list(proto.repeated_bool)) 184 self.assertEqual( 185 [message_module.TestAllTypes.NestedMessage( 186 bb=message_module.TestAllTypes.FOO), 187 message_module.TestAllTypes.NestedMessage( 188 bb=message_module.TestAllTypes.BAR)], 189 list(proto.repeated_nested_message)) 190 self.assertEqual( 191 [message_module.ForeignMessage(c=-43), 192 message_module.ForeignMessage(c=45324), 193 message_module.ForeignMessage(c=12)], 194 list(proto.repeated_foreign_message)) 195 self.assertFalse(proto.HasField("optional_nested_message")) 196 197 def testConstructorTypeError(self, message_module): 198 self.assertRaises( 199 TypeError, message_module.TestAllTypes, optional_int32='foo') 200 self.assertRaises( 201 TypeError, message_module.TestAllTypes, optional_string=1234) 202 self.assertRaises( 203 TypeError, message_module.TestAllTypes, optional_nested_message=1234) 204 self.assertRaises( 205 TypeError, message_module.TestAllTypes, repeated_int32=1234) 206 self.assertRaises( 207 TypeError, message_module.TestAllTypes, repeated_int32=['foo']) 208 self.assertRaises( 209 TypeError, message_module.TestAllTypes, repeated_string=1234) 210 self.assertRaises( 211 TypeError, message_module.TestAllTypes, repeated_string=[1234]) 212 self.assertRaises( 213 TypeError, message_module.TestAllTypes, repeated_nested_message=1234) 214 self.assertRaises( 215 TypeError, message_module.TestAllTypes, repeated_nested_message=[1234]) 216 217 def testConstructorInvalidatesCachedByteSize(self, message_module): 218 message = message_module.TestAllTypes(optional_int32=12) 219 self.assertEqual(2, message.ByteSize()) 220 221 message = message_module.TestAllTypes( 222 optional_nested_message=message_module.TestAllTypes.NestedMessage()) 223 self.assertEqual(3, message.ByteSize()) 224 225 message = message_module.TestAllTypes(repeated_int32=[12]) 226 # TODO(user): Add this test back for proto3 227 if message_module is unittest_pb2: 228 self.assertEqual(3, message.ByteSize()) 229 230 message = message_module.TestAllTypes( 231 repeated_nested_message=[message_module.TestAllTypes.NestedMessage()]) 232 self.assertEqual(3, message.ByteSize()) 233 234 def testReferencesToNestedMessage(self, message_module): 235 proto = message_module.TestAllTypes() 236 nested = proto.optional_nested_message 237 del proto 238 # A previous version had a bug where this would raise an exception when 239 # hitting a now-dead weak reference. 240 nested.bb = 23 241 242 def testOneOf(self, message_module): 243 proto = message_module.TestAllTypes() 244 proto.oneof_uint32 = 10 245 proto.oneof_nested_message.bb = 11 246 self.assertEqual(11, proto.oneof_nested_message.bb) 247 self.assertFalse(proto.HasField('oneof_uint32')) 248 nested = proto.oneof_nested_message 249 proto.oneof_string = 'abc' 250 self.assertEqual('abc', proto.oneof_string) 251 self.assertEqual(11, nested.bb) 252 self.assertFalse(proto.HasField('oneof_nested_message')) 253 254 def testGetDefaultMessageAfterDisconnectingDefaultMessage( 255 self, message_module): 256 proto = message_module.TestAllTypes() 257 nested = proto.optional_nested_message 258 proto.ClearField('optional_nested_message') 259 del proto 260 del nested 261 # Force a garbage collect so that the underlying CMessages are freed along 262 # with the Messages they point to. This is to make sure we're not deleting 263 # default message instances. 264 gc.collect() 265 proto = message_module.TestAllTypes() 266 nested = proto.optional_nested_message 267 268 def testDisconnectingNestedMessageAfterSettingField(self, message_module): 269 proto = message_module.TestAllTypes() 270 nested = proto.optional_nested_message 271 nested.bb = 5 272 self.assertTrue(proto.HasField('optional_nested_message')) 273 proto.ClearField('optional_nested_message') # Should disconnect from parent 274 self.assertEqual(5, nested.bb) 275 self.assertEqual(0, proto.optional_nested_message.bb) 276 self.assertIsNot(nested, proto.optional_nested_message) 277 nested.bb = 23 278 self.assertFalse(proto.HasField('optional_nested_message')) 279 self.assertEqual(0, proto.optional_nested_message.bb) 280 281 def testDisconnectingNestedMessageBeforeGettingField(self, message_module): 282 proto = message_module.TestAllTypes() 283 self.assertFalse(proto.HasField('optional_nested_message')) 284 proto.ClearField('optional_nested_message') 285 self.assertFalse(proto.HasField('optional_nested_message')) 286 287 def testDisconnectingNestedMessageAfterMerge(self, message_module): 288 # This test exercises the code path that does not use ReleaseMessage(). 289 # The underlying fear is that if we use ReleaseMessage() incorrectly, 290 # we will have memory leaks. It's hard to check that that doesn't happen, 291 # but at least we can exercise that code path to make sure it works. 292 proto1 = message_module.TestAllTypes() 293 proto2 = message_module.TestAllTypes() 294 proto2.optional_nested_message.bb = 5 295 proto1.MergeFrom(proto2) 296 self.assertTrue(proto1.HasField('optional_nested_message')) 297 proto1.ClearField('optional_nested_message') 298 self.assertFalse(proto1.HasField('optional_nested_message')) 299 300 def testDisconnectingLazyNestedMessage(self, message_module): 301 # This test exercises releasing a nested message that is lazy. This test 302 # only exercises real code in the C++ implementation as Python does not 303 # support lazy parsing, but the current C++ implementation results in 304 # memory corruption and a crash. 305 if api_implementation.Type() != 'python': 306 return 307 proto = message_module.TestAllTypes() 308 proto.optional_lazy_message.bb = 5 309 proto.ClearField('optional_lazy_message') 310 del proto 311 gc.collect() 312 313 def testSingularListFields(self, message_module): 314 proto = message_module.TestAllTypes() 315 proto.optional_fixed32 = 1 316 proto.optional_int32 = 5 317 proto.optional_string = 'foo' 318 # Access sub-message but don't set it yet. 319 nested_message = proto.optional_nested_message 320 self.assertEqual( 321 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 322 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 323 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], 324 proto.ListFields()) 325 326 proto.optional_nested_message.bb = 123 327 self.assertEqual( 328 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 329 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 330 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'), 331 (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ], 332 nested_message) ], 333 proto.ListFields()) 334 335 def testRepeatedListFields(self, message_module): 336 proto = message_module.TestAllTypes() 337 proto.repeated_fixed32.append(1) 338 proto.repeated_int32.append(5) 339 proto.repeated_int32.append(11) 340 proto.repeated_string.extend(['foo', 'bar']) 341 proto.repeated_string.extend([]) 342 proto.repeated_string.append('baz') 343 proto.repeated_string.extend(str(x) for x in range(2)) 344 proto.optional_int32 = 21 345 proto.repeated_bool # Access but don't set anything; should not be listed. 346 self.assertEqual( 347 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), 348 (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), 349 (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]), 350 (proto.DESCRIPTOR.fields_by_name['repeated_string' ], 351 ['foo', 'bar', 'baz', '0', '1']) ], 352 proto.ListFields()) 353 354 def testClearFieldWithUnknownFieldName(self, message_module): 355 proto = message_module.TestAllTypes() 356 self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') 357 self.assertRaises(ValueError, proto.ClearField, b'nonexistent_field') 358 359 def testDisallowedAssignments(self, message_module): 360 # It's illegal to assign values directly to repeated fields 361 # or to nonrepeated composite fields. Ensure that this fails. 362 proto = message_module.TestAllTypes() 363 # Repeated fields. 364 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10) 365 # Lists shouldn't work, either. 366 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10]) 367 # Composite fields. 368 self.assertRaises(AttributeError, setattr, proto, 369 'optional_nested_message', 23) 370 # Assignment to a repeated nested message field without specifying 371 # the index in the array of nested messages. 372 self.assertRaises(AttributeError, setattr, proto.repeated_nested_message, 373 'bb', 34) 374 # Assignment to an attribute of a repeated field. 375 self.assertRaises(AttributeError, setattr, proto.repeated_float, 376 'some_attribute', 34) 377 # proto.nonexistent_field = 23 should fail as well. 378 self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) 379 380 def testSingleScalarTypeSafety(self, message_module): 381 proto = message_module.TestAllTypes() 382 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) 383 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo') 384 self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) 385 self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) 386 self.assertRaises(TypeError, setattr, proto, 'optional_bool', 'foo') 387 self.assertRaises(TypeError, setattr, proto, 'optional_float', 'foo') 388 self.assertRaises(TypeError, setattr, proto, 'optional_double', 'foo') 389 # TODO(user): Fix type checking difference for python and c extension 390 if api_implementation.Type() == 'python': 391 self.assertRaises(TypeError, setattr, proto, 'optional_bool', 1.1) 392 else: 393 proto.optional_bool = 1.1 394 395 def assertIntegerTypes(self, integer_fn, message_module): 396 """Verifies setting of scalar integers. 397 398 Args: 399 integer_fn: A function to wrap the integers that will be assigned. 400 message_module: unittest_pb2 or unittest_proto3_arena_pb2 401 """ 402 def TestGetAndDeserialize(field_name, value, expected_type): 403 proto = message_module.TestAllTypes() 404 value = integer_fn(value) 405 setattr(proto, field_name, value) 406 self.assertIsInstance(getattr(proto, field_name), expected_type) 407 proto2 = message_module.TestAllTypes() 408 proto2.ParseFromString(proto.SerializeToString()) 409 self.assertIsInstance(getattr(proto2, field_name), expected_type) 410 411 TestGetAndDeserialize('optional_int32', 1, int) 412 TestGetAndDeserialize('optional_int32', 1 << 30, int) 413 TestGetAndDeserialize('optional_uint32', 1 << 30, int) 414 integer_64 = long 415 if struct.calcsize('L') == 4: 416 # Python only has signed ints, so 32-bit python can't fit an uint32 417 # in an int. 418 TestGetAndDeserialize('optional_uint32', 1 << 31, integer_64) 419 else: 420 # 64-bit python can fit uint32 inside an int 421 TestGetAndDeserialize('optional_uint32', 1 << 31, int) 422 TestGetAndDeserialize('optional_int64', 1 << 30, integer_64) 423 TestGetAndDeserialize('optional_int64', 1 << 60, integer_64) 424 TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64) 425 TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64) 426 427 def testIntegerTypes(self, message_module): 428 self.assertIntegerTypes(lambda x: x, message_module) 429 430 def testNonStandardIntegerTypes(self, message_module): 431 self.assertIntegerTypes(test_util.NonStandardInteger, message_module) 432 433 def testIllegalValuesForIntegers(self, message_module): 434 pb = message_module.TestAllTypes() 435 436 # Strings are illegal, even when the represent an integer. 437 with self.assertRaises(TypeError): 438 pb.optional_uint64 = '2' 439 440 # The exact error should propagate with a poorly written custom integer. 441 with self.assertRaisesRegexp(RuntimeError, 'my_error'): 442 pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error') 443 444 def assetIntegerBoundsChecking(self, integer_fn, message_module): 445 """Verifies bounds checking for scalar integer fields. 446 447 Args: 448 integer_fn: A function to wrap the integers that will be assigned. 449 message_module: unittest_pb2 or unittest_proto3_arena_pb2 450 """ 451 def TestMinAndMaxIntegers(field_name, expected_min, expected_max): 452 pb = message_module.TestAllTypes() 453 expected_min = integer_fn(expected_min) 454 expected_max = integer_fn(expected_max) 455 setattr(pb, field_name, expected_min) 456 self.assertEqual(expected_min, getattr(pb, field_name)) 457 setattr(pb, field_name, expected_max) 458 self.assertEqual(expected_max, getattr(pb, field_name)) 459 self.assertRaises((ValueError, TypeError), setattr, pb, field_name, 460 expected_min - 1) 461 self.assertRaises((ValueError, TypeError), setattr, pb, field_name, 462 expected_max + 1) 463 464 TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1) 465 TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) 466 TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) 467 TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) 468 # A bit of white-box testing since -1 is an int and not a long in C++ and 469 # so goes down a different path. 470 pb = message_module.TestAllTypes() 471 with self.assertRaises((ValueError, TypeError)): 472 pb.optional_uint64 = integer_fn(-(1 << 63)) 473 474 pb = message_module.TestAllTypes() 475 pb.optional_nested_enum = integer_fn(1) 476 self.assertEqual(1, pb.optional_nested_enum) 477 478 def testSingleScalarBoundsChecking(self, message_module): 479 self.assetIntegerBoundsChecking(lambda x: x, message_module) 480 481 def testNonStandardSingleScalarBoundsChecking(self, message_module): 482 self.assetIntegerBoundsChecking( 483 test_util.NonStandardInteger, message_module) 484 485 def testRepeatedScalarTypeSafety(self, message_module): 486 proto = message_module.TestAllTypes() 487 self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) 488 self.assertRaises(TypeError, proto.repeated_int32.append, 'foo') 489 self.assertRaises(TypeError, proto.repeated_string, 10) 490 self.assertRaises(TypeError, proto.repeated_bytes, 10) 491 492 proto.repeated_int32.append(10) 493 proto.repeated_int32[0] = 23 494 self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) 495 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') 496 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, []) 497 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 498 'index', 23) 499 500 proto.repeated_string.append('2') 501 self.assertRaises(TypeError, proto.repeated_string.__setitem__, 0, 10) 502 503 # Repeated enums tests. 504 #proto.repeated_nested_enum.append(0) 505 506 def testSingleScalarGettersAndSetters(self, message_module): 507 proto = message_module.TestAllTypes() 508 self.assertEqual(0, proto.optional_int32) 509 proto.optional_int32 = 1 510 self.assertEqual(1, proto.optional_int32) 511 512 proto.optional_uint64 = 0xffffffffffff 513 self.assertEqual(0xffffffffffff, proto.optional_uint64) 514 proto.optional_uint64 = 0xffffffffffffffff 515 self.assertEqual(0xffffffffffffffff, proto.optional_uint64) 516 # TODO(user): Test all other scalar field types. 517 518 def testEnums(self, message_module): 519 proto = message_module.TestAllTypes() 520 self.assertEqual(1, proto.FOO) 521 self.assertEqual(1, message_module.TestAllTypes.FOO) 522 self.assertEqual(2, proto.BAR) 523 self.assertEqual(2, message_module.TestAllTypes.BAR) 524 self.assertEqual(3, proto.BAZ) 525 self.assertEqual(3, message_module.TestAllTypes.BAZ) 526 527 def testEnum_Name(self, message_module): 528 self.assertEqual( 529 'FOREIGN_FOO', 530 message_module.ForeignEnum.Name(message_module.FOREIGN_FOO)) 531 self.assertEqual( 532 'FOREIGN_BAR', 533 message_module.ForeignEnum.Name(message_module.FOREIGN_BAR)) 534 self.assertEqual( 535 'FOREIGN_BAZ', 536 message_module.ForeignEnum.Name(message_module.FOREIGN_BAZ)) 537 self.assertRaises(ValueError, 538 message_module.ForeignEnum.Name, 11312) 539 540 proto = message_module.TestAllTypes() 541 self.assertEqual('FOO', 542 proto.NestedEnum.Name(proto.FOO)) 543 self.assertEqual('FOO', 544 message_module.TestAllTypes.NestedEnum.Name(proto.FOO)) 545 self.assertEqual('BAR', 546 proto.NestedEnum.Name(proto.BAR)) 547 self.assertEqual('BAR', 548 message_module.TestAllTypes.NestedEnum.Name(proto.BAR)) 549 self.assertEqual('BAZ', 550 proto.NestedEnum.Name(proto.BAZ)) 551 self.assertEqual('BAZ', 552 message_module.TestAllTypes.NestedEnum.Name(proto.BAZ)) 553 self.assertRaises(ValueError, 554 proto.NestedEnum.Name, 11312) 555 self.assertRaises(ValueError, 556 message_module.TestAllTypes.NestedEnum.Name, 11312) 557 558 # Check some coercion cases. 559 self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name, 560 11312.0) 561 self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name, 562 None) 563 self.assertEqual('FOO', message_module.TestAllTypes.NestedEnum.Name(True)) 564 565 def testEnum_Value(self, message_module): 566 self.assertEqual(message_module.FOREIGN_FOO, 567 message_module.ForeignEnum.Value('FOREIGN_FOO')) 568 self.assertEqual(message_module.FOREIGN_FOO, 569 message_module.ForeignEnum.FOREIGN_FOO) 570 571 self.assertEqual(message_module.FOREIGN_BAR, 572 message_module.ForeignEnum.Value('FOREIGN_BAR')) 573 self.assertEqual(message_module.FOREIGN_BAR, 574 message_module.ForeignEnum.FOREIGN_BAR) 575 576 self.assertEqual(message_module.FOREIGN_BAZ, 577 message_module.ForeignEnum.Value('FOREIGN_BAZ')) 578 self.assertEqual(message_module.FOREIGN_BAZ, 579 message_module.ForeignEnum.FOREIGN_BAZ) 580 581 self.assertRaises(ValueError, 582 message_module.ForeignEnum.Value, 'FO') 583 with self.assertRaises(AttributeError): 584 message_module.ForeignEnum.FO 585 586 proto = message_module.TestAllTypes() 587 self.assertEqual(proto.FOO, 588 proto.NestedEnum.Value('FOO')) 589 self.assertEqual(proto.FOO, 590 proto.NestedEnum.FOO) 591 592 self.assertEqual(proto.FOO, 593 message_module.TestAllTypes.NestedEnum.Value('FOO')) 594 self.assertEqual(proto.FOO, 595 message_module.TestAllTypes.NestedEnum.FOO) 596 597 self.assertEqual(proto.BAR, 598 proto.NestedEnum.Value('BAR')) 599 self.assertEqual(proto.BAR, 600 proto.NestedEnum.BAR) 601 602 self.assertEqual(proto.BAR, 603 message_module.TestAllTypes.NestedEnum.Value('BAR')) 604 self.assertEqual(proto.BAR, 605 message_module.TestAllTypes.NestedEnum.BAR) 606 607 self.assertEqual(proto.BAZ, 608 proto.NestedEnum.Value('BAZ')) 609 self.assertEqual(proto.BAZ, 610 proto.NestedEnum.BAZ) 611 612 self.assertEqual(proto.BAZ, 613 message_module.TestAllTypes.NestedEnum.Value('BAZ')) 614 self.assertEqual(proto.BAZ, 615 message_module.TestAllTypes.NestedEnum.BAZ) 616 617 self.assertRaises(ValueError, 618 proto.NestedEnum.Value, 'Foo') 619 with self.assertRaises(AttributeError): 620 proto.NestedEnum.Value.Foo 621 622 self.assertRaises(ValueError, 623 message_module.TestAllTypes.NestedEnum.Value, 'Foo') 624 with self.assertRaises(AttributeError): 625 message_module.TestAllTypes.NestedEnum.Value.Foo 626 627 def testEnum_KeysAndValues(self, message_module): 628 if message_module == unittest_pb2: 629 keys = ['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'] 630 values = [4, 5, 6] 631 items = [('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), ('FOREIGN_BAZ', 6)] 632 else: 633 keys = ['FOREIGN_ZERO', 'FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'] 634 values = [0, 4, 5, 6] 635 items = [('FOREIGN_ZERO', 0), ('FOREIGN_FOO', 4), 636 ('FOREIGN_BAR', 5), ('FOREIGN_BAZ', 6)] 637 self.assertEqual(keys, 638 list(message_module.ForeignEnum.keys())) 639 self.assertEqual(values, 640 list(message_module.ForeignEnum.values())) 641 self.assertEqual(items, 642 list(message_module.ForeignEnum.items())) 643 644 proto = message_module.TestAllTypes() 645 if message_module == unittest_pb2: 646 keys = ['FOO', 'BAR', 'BAZ', 'NEG'] 647 values = [1, 2, 3, -1] 648 items = [('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)] 649 else: 650 keys = ['ZERO', 'FOO', 'BAR', 'BAZ', 'NEG'] 651 values = [0, 1, 2, 3, -1] 652 items = [('ZERO', 0), ('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)] 653 self.assertEqual(keys, list(proto.NestedEnum.keys())) 654 self.assertEqual(values, list(proto.NestedEnum.values())) 655 self.assertEqual(items, 656 list(proto.NestedEnum.items())) 657 658 def testStaticParseFrom(self, message_module): 659 proto1 = message_module.TestAllTypes() 660 test_util.SetAllFields(proto1) 661 662 string1 = proto1.SerializeToString() 663 proto2 = message_module.TestAllTypes.FromString(string1) 664 665 # Messages should be equal. 666 self.assertEqual(proto2, proto1) 667 668 def testMergeFromSingularField(self, message_module): 669 # Test merge with just a singular field. 670 proto1 = message_module.TestAllTypes() 671 proto1.optional_int32 = 1 672 673 proto2 = message_module.TestAllTypes() 674 # This shouldn't get overwritten. 675 proto2.optional_string = 'value' 676 677 proto2.MergeFrom(proto1) 678 self.assertEqual(1, proto2.optional_int32) 679 self.assertEqual('value', proto2.optional_string) 680 681 def testMergeFromRepeatedField(self, message_module): 682 # Test merge with just a repeated field. 683 proto1 = message_module.TestAllTypes() 684 proto1.repeated_int32.append(1) 685 proto1.repeated_int32.append(2) 686 687 proto2 = message_module.TestAllTypes() 688 proto2.repeated_int32.append(0) 689 proto2.MergeFrom(proto1) 690 691 self.assertEqual(0, proto2.repeated_int32[0]) 692 self.assertEqual(1, proto2.repeated_int32[1]) 693 self.assertEqual(2, proto2.repeated_int32[2]) 694 695 def testMergeFromRepeatedNestedMessage(self, message_module): 696 # Test merge with a repeated nested message. 697 proto1 = message_module.TestAllTypes() 698 m = proto1.repeated_nested_message.add() 699 m.bb = 123 700 m = proto1.repeated_nested_message.add() 701 m.bb = 321 702 703 proto2 = message_module.TestAllTypes() 704 m = proto2.repeated_nested_message.add() 705 m.bb = 999 706 proto2.MergeFrom(proto1) 707 self.assertEqual(999, proto2.repeated_nested_message[0].bb) 708 self.assertEqual(123, proto2.repeated_nested_message[1].bb) 709 self.assertEqual(321, proto2.repeated_nested_message[2].bb) 710 711 proto3 = message_module.TestAllTypes() 712 proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message) 713 self.assertEqual(999, proto3.repeated_nested_message[0].bb) 714 self.assertEqual(123, proto3.repeated_nested_message[1].bb) 715 self.assertEqual(321, proto3.repeated_nested_message[2].bb) 716 717 def testMergeFromAllFields(self, message_module): 718 # With all fields set. 719 proto1 = message_module.TestAllTypes() 720 test_util.SetAllFields(proto1) 721 proto2 = message_module.TestAllTypes() 722 proto2.MergeFrom(proto1) 723 724 # Messages should be equal. 725 self.assertEqual(proto2, proto1) 726 727 # Serialized string should be equal too. 728 string1 = proto1.SerializeToString() 729 string2 = proto2.SerializeToString() 730 self.assertEqual(string1, string2) 731 732 def testMergeFromBug(self, message_module): 733 message1 = message_module.TestAllTypes() 734 message2 = message_module.TestAllTypes() 735 736 # Cause optional_nested_message to be instantiated within message1, even 737 # though it is not considered to be "present". 738 message1.optional_nested_message 739 self.assertFalse(message1.HasField('optional_nested_message')) 740 741 # Merge into message2. This should not instantiate the field is message2. 742 message2.MergeFrom(message1) 743 self.assertFalse(message2.HasField('optional_nested_message')) 744 745 def testCopyFromSingularField(self, message_module): 746 # Test copy with just a singular field. 747 proto1 = message_module.TestAllTypes() 748 proto1.optional_int32 = 1 749 proto1.optional_string = 'important-text' 750 751 proto2 = message_module.TestAllTypes() 752 proto2.optional_string = 'value' 753 754 proto2.CopyFrom(proto1) 755 self.assertEqual(1, proto2.optional_int32) 756 self.assertEqual('important-text', proto2.optional_string) 757 758 def testCopyFromRepeatedField(self, message_module): 759 # Test copy with a repeated field. 760 proto1 = message_module.TestAllTypes() 761 proto1.repeated_int32.append(1) 762 proto1.repeated_int32.append(2) 763 764 proto2 = message_module.TestAllTypes() 765 proto2.repeated_int32.append(0) 766 proto2.CopyFrom(proto1) 767 768 self.assertEqual(1, proto2.repeated_int32[0]) 769 self.assertEqual(2, proto2.repeated_int32[1]) 770 771 def testCopyFromAllFields(self, message_module): 772 # With all fields set. 773 proto1 = message_module.TestAllTypes() 774 test_util.SetAllFields(proto1) 775 proto2 = message_module.TestAllTypes() 776 proto2.CopyFrom(proto1) 777 778 # Messages should be equal. 779 self.assertEqual(proto2, proto1) 780 781 # Serialized string should be equal too. 782 string1 = proto1.SerializeToString() 783 string2 = proto2.SerializeToString() 784 self.assertEqual(string1, string2) 785 786 def testCopyFromSelf(self, message_module): 787 proto1 = message_module.TestAllTypes() 788 proto1.repeated_int32.append(1) 789 proto1.optional_int32 = 2 790 proto1.optional_string = 'important-text' 791 792 proto1.CopyFrom(proto1) 793 self.assertEqual(1, proto1.repeated_int32[0]) 794 self.assertEqual(2, proto1.optional_int32) 795 self.assertEqual('important-text', proto1.optional_string) 796 797 def testDeepCopy(self, message_module): 798 proto1 = message_module.TestAllTypes() 799 proto1.optional_int32 = 1 800 proto2 = copy.deepcopy(proto1) 801 self.assertEqual(1, proto2.optional_int32) 802 803 proto1.repeated_int32.append(2) 804 proto1.repeated_int32.append(3) 805 container = copy.deepcopy(proto1.repeated_int32) 806 self.assertEqual([2, 3], container) 807 container.remove(container[0]) 808 self.assertEqual([3], container) 809 810 message1 = proto1.repeated_nested_message.add() 811 message1.bb = 1 812 messages = copy.deepcopy(proto1.repeated_nested_message) 813 self.assertEqual(proto1.repeated_nested_message, messages) 814 message1.bb = 2 815 self.assertNotEqual(proto1.repeated_nested_message, messages) 816 messages.remove(messages[0]) 817 self.assertEqual(len(messages), 0) 818 819 # TODO(user): Implement deepcopy for extension dict 820 821 def testDisconnectingBeforeClear(self, message_module): 822 proto = message_module.TestAllTypes() 823 nested = proto.optional_nested_message 824 proto.Clear() 825 self.assertIsNot(nested, proto.optional_nested_message) 826 nested.bb = 23 827 self.assertFalse(proto.HasField('optional_nested_message')) 828 self.assertEqual(0, proto.optional_nested_message.bb) 829 830 proto = message_module.TestAllTypes() 831 nested = proto.optional_nested_message 832 nested.bb = 5 833 foreign = proto.optional_foreign_message 834 foreign.c = 6 835 proto.Clear() 836 self.assertIsNot(nested, proto.optional_nested_message) 837 self.assertIsNot(foreign, proto.optional_foreign_message) 838 self.assertEqual(5, nested.bb) 839 self.assertEqual(6, foreign.c) 840 nested.bb = 15 841 foreign.c = 16 842 self.assertFalse(proto.HasField('optional_nested_message')) 843 self.assertEqual(0, proto.optional_nested_message.bb) 844 self.assertFalse(proto.HasField('optional_foreign_message')) 845 self.assertEqual(0, proto.optional_foreign_message.c) 846 847 def testStringUTF8Encoding(self, message_module): 848 proto = message_module.TestAllTypes() 849 850 # Assignment of a unicode object to a field of type 'bytes' is not allowed. 851 self.assertRaises(TypeError, 852 setattr, proto, 'optional_bytes', u'unicode object') 853 854 # Check that the default value is of python's 'unicode' type. 855 self.assertEqual(type(proto.optional_string), six.text_type) 856 857 proto.optional_string = six.text_type('Testing') 858 self.assertEqual(proto.optional_string, str('Testing')) 859 860 # Assign a value of type 'str' which can be encoded in UTF-8. 861 proto.optional_string = str('Testing') 862 self.assertEqual(proto.optional_string, six.text_type('Testing')) 863 864 # Try to assign a 'bytes' object which contains non-UTF-8. 865 self.assertRaises(ValueError, 866 setattr, proto, 'optional_string', b'a\x80a') 867 # No exception: Assign already encoded UTF-8 bytes to a string field. 868 utf8_bytes = u'Тест'.encode('utf-8') 869 proto.optional_string = utf8_bytes 870 # No exception: Assign the a non-ascii unicode object. 871 proto.optional_string = u'Тест' 872 # No exception thrown (normal str assignment containing ASCII). 873 proto.optional_string = 'abc' 874 875 def testBytesInTextFormat(self, message_module): 876 proto = message_module.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') 877 self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', 878 six.text_type(proto)) 879 880 def testEmptyNestedMessage(self, message_module): 881 proto = message_module.TestAllTypes() 882 proto.optional_nested_message.MergeFrom( 883 message_module.TestAllTypes.NestedMessage()) 884 self.assertTrue(proto.HasField('optional_nested_message')) 885 886 proto = message_module.TestAllTypes() 887 proto.optional_nested_message.CopyFrom( 888 message_module.TestAllTypes.NestedMessage()) 889 self.assertTrue(proto.HasField('optional_nested_message')) 890 891 proto = message_module.TestAllTypes() 892 bytes_read = proto.optional_nested_message.MergeFromString(b'') 893 self.assertEqual(0, bytes_read) 894 self.assertTrue(proto.HasField('optional_nested_message')) 895 896 proto = message_module.TestAllTypes() 897 proto.optional_nested_message.ParseFromString(b'') 898 self.assertTrue(proto.HasField('optional_nested_message')) 899 900 serialized = proto.SerializeToString() 901 proto2 = message_module.TestAllTypes() 902 self.assertEqual( 903 len(serialized), 904 proto2.MergeFromString(serialized)) 905 self.assertTrue(proto2.HasField('optional_nested_message')) 906 907 908# Class to test proto2-only features (required, extensions, etc.) 909@testing_refleaks.TestCase 910class Proto2ReflectionTest(unittest.TestCase): 911 912 def testRepeatedCompositeConstructor(self): 913 # Constructor with only repeated composite types should succeed. 914 proto = unittest_pb2.TestAllTypes( 915 repeated_nested_message=[ 916 unittest_pb2.TestAllTypes.NestedMessage( 917 bb=unittest_pb2.TestAllTypes.FOO), 918 unittest_pb2.TestAllTypes.NestedMessage( 919 bb=unittest_pb2.TestAllTypes.BAR)], 920 repeated_foreign_message=[ 921 unittest_pb2.ForeignMessage(c=-43), 922 unittest_pb2.ForeignMessage(c=45324), 923 unittest_pb2.ForeignMessage(c=12)], 924 repeatedgroup=[ 925 unittest_pb2.TestAllTypes.RepeatedGroup(), 926 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 927 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) 928 929 self.assertEqual( 930 [unittest_pb2.TestAllTypes.NestedMessage( 931 bb=unittest_pb2.TestAllTypes.FOO), 932 unittest_pb2.TestAllTypes.NestedMessage( 933 bb=unittest_pb2.TestAllTypes.BAR)], 934 list(proto.repeated_nested_message)) 935 self.assertEqual( 936 [unittest_pb2.ForeignMessage(c=-43), 937 unittest_pb2.ForeignMessage(c=45324), 938 unittest_pb2.ForeignMessage(c=12)], 939 list(proto.repeated_foreign_message)) 940 self.assertEqual( 941 [unittest_pb2.TestAllTypes.RepeatedGroup(), 942 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 943 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], 944 list(proto.repeatedgroup)) 945 946 def assertListsEqual(self, values, others): 947 self.assertEqual(len(values), len(others)) 948 for i in range(len(values)): 949 self.assertEqual(values[i], others[i]) 950 951 def testSimpleHasBits(self): 952 # Test a scalar. 953 proto = unittest_pb2.TestAllTypes() 954 self.assertFalse(proto.HasField('optional_int32')) 955 self.assertEqual(0, proto.optional_int32) 956 # HasField() shouldn't be true if all we've done is 957 # read the default value. 958 self.assertFalse(proto.HasField('optional_int32')) 959 proto.optional_int32 = 1 960 # Setting a value however *should* set the "has" bit. 961 self.assertTrue(proto.HasField('optional_int32')) 962 proto.ClearField('optional_int32') 963 # And clearing that value should unset the "has" bit. 964 self.assertFalse(proto.HasField('optional_int32')) 965 966 def testHasBitsWithSinglyNestedScalar(self): 967 # Helper used to test foreign messages and groups. 968 # 969 # composite_field_name should be the name of a non-repeated 970 # composite (i.e., foreign or group) field in TestAllTypes, 971 # and scalar_field_name should be the name of an integer-valued 972 # scalar field within that composite. 973 # 974 # I never thought I'd miss C++ macros and templates so much. :( 975 # This helper is semantically just: 976 # 977 # assert proto.composite_field.scalar_field == 0 978 # assert not proto.composite_field.HasField('scalar_field') 979 # assert not proto.HasField('composite_field') 980 # 981 # proto.composite_field.scalar_field = 10 982 # old_composite_field = proto.composite_field 983 # 984 # assert proto.composite_field.scalar_field == 10 985 # assert proto.composite_field.HasField('scalar_field') 986 # assert proto.HasField('composite_field') 987 # 988 # proto.ClearField('composite_field') 989 # 990 # assert not proto.composite_field.HasField('scalar_field') 991 # assert not proto.HasField('composite_field') 992 # assert proto.composite_field.scalar_field == 0 993 # 994 # # Now ensure that ClearField('composite_field') disconnected 995 # # the old field object from the object tree... 996 # assert old_composite_field is not proto.composite_field 997 # old_composite_field.scalar_field = 20 998 # assert not proto.composite_field.HasField('scalar_field') 999 # assert not proto.HasField('composite_field') 1000 def TestCompositeHasBits(composite_field_name, scalar_field_name): 1001 proto = unittest_pb2.TestAllTypes() 1002 # First, check that we can get the scalar value, and see that it's the 1003 # default (0), but that proto.HasField('omposite') and 1004 # proto.composite.HasField('scalar') will still return False. 1005 composite_field = getattr(proto, composite_field_name) 1006 original_scalar_value = getattr(composite_field, scalar_field_name) 1007 self.assertEqual(0, original_scalar_value) 1008 # Assert that the composite object does not "have" the scalar. 1009 self.assertFalse(composite_field.HasField(scalar_field_name)) 1010 # Assert that proto does not "have" the composite field. 1011 self.assertFalse(proto.HasField(composite_field_name)) 1012 1013 # Now set the scalar within the composite field. Ensure that the setting 1014 # is reflected, and that proto.HasField('composite') and 1015 # proto.composite.HasField('scalar') now both return True. 1016 new_val = 20 1017 setattr(composite_field, scalar_field_name, new_val) 1018 self.assertEqual(new_val, getattr(composite_field, scalar_field_name)) 1019 # Hold on to a reference to the current composite_field object. 1020 old_composite_field = composite_field 1021 # Assert that the has methods now return true. 1022 self.assertTrue(composite_field.HasField(scalar_field_name)) 1023 self.assertTrue(proto.HasField(composite_field_name)) 1024 1025 # Now call the clear method... 1026 proto.ClearField(composite_field_name) 1027 1028 # ...and ensure that the "has" bits are all back to False... 1029 composite_field = getattr(proto, composite_field_name) 1030 self.assertFalse(composite_field.HasField(scalar_field_name)) 1031 self.assertFalse(proto.HasField(composite_field_name)) 1032 # ...and ensure that the scalar field has returned to its default. 1033 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 1034 1035 self.assertIsNot(old_composite_field, composite_field) 1036 setattr(old_composite_field, scalar_field_name, new_val) 1037 self.assertFalse(composite_field.HasField(scalar_field_name)) 1038 self.assertFalse(proto.HasField(composite_field_name)) 1039 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 1040 1041 # Test simple, single-level nesting when we set a scalar. 1042 TestCompositeHasBits('optionalgroup', 'a') 1043 TestCompositeHasBits('optional_nested_message', 'bb') 1044 TestCompositeHasBits('optional_foreign_message', 'c') 1045 TestCompositeHasBits('optional_import_message', 'd') 1046 1047 def testHasBitsWhenModifyingRepeatedFields(self): 1048 # Test nesting when we add an element to a repeated field in a submessage. 1049 proto = unittest_pb2.TestNestedMessageHasBits() 1050 proto.optional_nested_message.nestedmessage_repeated_int32.append(5) 1051 self.assertEqual( 1052 [5], proto.optional_nested_message.nestedmessage_repeated_int32) 1053 self.assertTrue(proto.HasField('optional_nested_message')) 1054 1055 # Do the same test, but with a repeated composite field within the 1056 # submessage. 1057 proto.ClearField('optional_nested_message') 1058 self.assertFalse(proto.HasField('optional_nested_message')) 1059 proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add() 1060 self.assertTrue(proto.HasField('optional_nested_message')) 1061 1062 def testHasBitsForManyLevelsOfNesting(self): 1063 # Test nesting many levels deep. 1064 recursive_proto = unittest_pb2.TestMutualRecursionA() 1065 self.assertFalse(recursive_proto.HasField('bb')) 1066 self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32) 1067 self.assertFalse(recursive_proto.HasField('bb')) 1068 recursive_proto.bb.a.bb.a.bb.optional_int32 = 5 1069 self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32) 1070 self.assertTrue(recursive_proto.HasField('bb')) 1071 self.assertTrue(recursive_proto.bb.HasField('a')) 1072 self.assertTrue(recursive_proto.bb.a.HasField('bb')) 1073 self.assertTrue(recursive_proto.bb.a.bb.HasField('a')) 1074 self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb')) 1075 self.assertFalse(recursive_proto.bb.a.bb.a.bb.HasField('a')) 1076 self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32')) 1077 1078 def testSingularListExtensions(self): 1079 proto = unittest_pb2.TestAllExtensions() 1080 proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1 1081 proto.Extensions[unittest_pb2.optional_int32_extension ] = 5 1082 proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo' 1083 self.assertEqual( 1084 [ (unittest_pb2.optional_int32_extension , 5), 1085 (unittest_pb2.optional_fixed32_extension, 1), 1086 (unittest_pb2.optional_string_extension , 'foo') ], 1087 proto.ListFields()) 1088 del proto.Extensions[unittest_pb2.optional_fixed32_extension] 1089 self.assertEqual( 1090 [(unittest_pb2.optional_int32_extension, 5), 1091 (unittest_pb2.optional_string_extension, 'foo')], 1092 proto.ListFields()) 1093 1094 def testRepeatedListExtensions(self): 1095 proto = unittest_pb2.TestAllExtensions() 1096 proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1) 1097 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5) 1098 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11) 1099 proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo') 1100 proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar') 1101 proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz') 1102 proto.Extensions[unittest_pb2.optional_int32_extension ] = 21 1103 self.assertEqual( 1104 [ (unittest_pb2.optional_int32_extension , 21), 1105 (unittest_pb2.repeated_int32_extension , [5, 11]), 1106 (unittest_pb2.repeated_fixed32_extension, [1]), 1107 (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ], 1108 proto.ListFields()) 1109 del proto.Extensions[unittest_pb2.repeated_int32_extension] 1110 del proto.Extensions[unittest_pb2.repeated_string_extension] 1111 self.assertEqual( 1112 [(unittest_pb2.optional_int32_extension, 21), 1113 (unittest_pb2.repeated_fixed32_extension, [1])], 1114 proto.ListFields()) 1115 1116 def testListFieldsAndExtensions(self): 1117 proto = unittest_pb2.TestFieldOrderings() 1118 test_util.SetAllFieldsAndExtensions(proto) 1119 unittest_pb2.my_extension_int 1120 self.assertEqual( 1121 [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1), 1122 (unittest_pb2.my_extension_int , 23), 1123 (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'), 1124 (unittest_pb2.my_extension_string , 'bar'), 1125 (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ], 1126 proto.ListFields()) 1127 1128 def testDefaultValues(self): 1129 proto = unittest_pb2.TestAllTypes() 1130 self.assertEqual(0, proto.optional_int32) 1131 self.assertEqual(0, proto.optional_int64) 1132 self.assertEqual(0, proto.optional_uint32) 1133 self.assertEqual(0, proto.optional_uint64) 1134 self.assertEqual(0, proto.optional_sint32) 1135 self.assertEqual(0, proto.optional_sint64) 1136 self.assertEqual(0, proto.optional_fixed32) 1137 self.assertEqual(0, proto.optional_fixed64) 1138 self.assertEqual(0, proto.optional_sfixed32) 1139 self.assertEqual(0, proto.optional_sfixed64) 1140 self.assertEqual(0.0, proto.optional_float) 1141 self.assertEqual(0.0, proto.optional_double) 1142 self.assertEqual(False, proto.optional_bool) 1143 self.assertEqual('', proto.optional_string) 1144 self.assertEqual(b'', proto.optional_bytes) 1145 1146 self.assertEqual(41, proto.default_int32) 1147 self.assertEqual(42, proto.default_int64) 1148 self.assertEqual(43, proto.default_uint32) 1149 self.assertEqual(44, proto.default_uint64) 1150 self.assertEqual(-45, proto.default_sint32) 1151 self.assertEqual(46, proto.default_sint64) 1152 self.assertEqual(47, proto.default_fixed32) 1153 self.assertEqual(48, proto.default_fixed64) 1154 self.assertEqual(49, proto.default_sfixed32) 1155 self.assertEqual(-50, proto.default_sfixed64) 1156 self.assertEqual(51.5, proto.default_float) 1157 self.assertEqual(52e3, proto.default_double) 1158 self.assertEqual(True, proto.default_bool) 1159 self.assertEqual('hello', proto.default_string) 1160 self.assertEqual(b'world', proto.default_bytes) 1161 self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) 1162 self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) 1163 self.assertEqual(unittest_import_pb2.IMPORT_BAR, 1164 proto.default_import_enum) 1165 1166 proto = unittest_pb2.TestExtremeDefaultValues() 1167 self.assertEqual(u'\u1234', proto.utf8_string) 1168 1169 def testHasFieldWithUnknownFieldName(self): 1170 proto = unittest_pb2.TestAllTypes() 1171 self.assertRaises(ValueError, proto.HasField, 'nonexistent_field') 1172 1173 def testClearRemovesChildren(self): 1174 # Make sure there aren't any implementation bugs that are only partially 1175 # clearing the message (which can happen in the more complex C++ 1176 # implementation which has parallel message lists). 1177 proto = unittest_pb2.TestRequiredForeign() 1178 for i in range(10): 1179 proto.repeated_message.add() 1180 proto2 = unittest_pb2.TestRequiredForeign() 1181 proto.CopyFrom(proto2) 1182 self.assertRaises(IndexError, lambda: proto.repeated_message[5]) 1183 1184 def testSingleScalarClearField(self): 1185 proto = unittest_pb2.TestAllTypes() 1186 # Should be allowed to clear something that's not there (a no-op). 1187 proto.ClearField('optional_int32') 1188 proto.optional_int32 = 1 1189 self.assertTrue(proto.HasField('optional_int32')) 1190 proto.ClearField('optional_int32') 1191 self.assertEqual(0, proto.optional_int32) 1192 self.assertFalse(proto.HasField('optional_int32')) 1193 # TODO(user): Test all other scalar field types. 1194 1195 def testRepeatedScalars(self): 1196 proto = unittest_pb2.TestAllTypes() 1197 1198 self.assertFalse(proto.repeated_int32) 1199 self.assertEqual(0, len(proto.repeated_int32)) 1200 proto.repeated_int32.append(5) 1201 proto.repeated_int32.append(10) 1202 proto.repeated_int32.append(15) 1203 self.assertTrue(proto.repeated_int32) 1204 self.assertEqual(3, len(proto.repeated_int32)) 1205 1206 self.assertEqual([5, 10, 15], proto.repeated_int32) 1207 1208 # Test single retrieval. 1209 self.assertEqual(5, proto.repeated_int32[0]) 1210 self.assertEqual(15, proto.repeated_int32[-1]) 1211 # Test out-of-bounds indices. 1212 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234) 1213 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234) 1214 # Test incorrect types passed to __getitem__. 1215 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo') 1216 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None) 1217 1218 # Test single assignment. 1219 proto.repeated_int32[1] = 20 1220 self.assertEqual([5, 20, 15], proto.repeated_int32) 1221 1222 # Test insertion. 1223 proto.repeated_int32.insert(1, 25) 1224 self.assertEqual([5, 25, 20, 15], proto.repeated_int32) 1225 1226 # Test slice retrieval. 1227 proto.repeated_int32.append(30) 1228 self.assertEqual([25, 20, 15], proto.repeated_int32[1:4]) 1229 self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:]) 1230 1231 # Test slice assignment with an iterator 1232 proto.repeated_int32[1:4] = (i for i in range(3)) 1233 self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32) 1234 1235 # Test slice assignment. 1236 proto.repeated_int32[1:4] = [35, 40, 45] 1237 self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32) 1238 1239 # Test that we can use the field as an iterator. 1240 result = [] 1241 for i in proto.repeated_int32: 1242 result.append(i) 1243 self.assertEqual([5, 35, 40, 45, 30], result) 1244 1245 # Test single deletion. 1246 del proto.repeated_int32[2] 1247 self.assertEqual([5, 35, 45, 30], proto.repeated_int32) 1248 1249 # Test slice deletion. 1250 del proto.repeated_int32[2:] 1251 self.assertEqual([5, 35], proto.repeated_int32) 1252 1253 # Test extending. 1254 proto.repeated_int32.extend([3, 13]) 1255 self.assertEqual([5, 35, 3, 13], proto.repeated_int32) 1256 1257 # Test clearing. 1258 proto.ClearField('repeated_int32') 1259 self.assertFalse(proto.repeated_int32) 1260 self.assertEqual(0, len(proto.repeated_int32)) 1261 1262 proto.repeated_int32.append(1) 1263 self.assertEqual(1, proto.repeated_int32[-1]) 1264 # Test assignment to a negative index. 1265 proto.repeated_int32[-1] = 2 1266 self.assertEqual(2, proto.repeated_int32[-1]) 1267 1268 # Test deletion at negative indices. 1269 proto.repeated_int32[:] = [0, 1, 2, 3] 1270 del proto.repeated_int32[-1] 1271 self.assertEqual([0, 1, 2], proto.repeated_int32) 1272 1273 del proto.repeated_int32[-2] 1274 self.assertEqual([0, 2], proto.repeated_int32) 1275 1276 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3) 1277 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300) 1278 1279 del proto.repeated_int32[-2:-1] 1280 self.assertEqual([2], proto.repeated_int32) 1281 1282 del proto.repeated_int32[100:10000] 1283 self.assertEqual([2], proto.repeated_int32) 1284 1285 def testRepeatedScalarsRemove(self): 1286 proto = unittest_pb2.TestAllTypes() 1287 1288 self.assertFalse(proto.repeated_int32) 1289 self.assertEqual(0, len(proto.repeated_int32)) 1290 proto.repeated_int32.append(5) 1291 proto.repeated_int32.append(10) 1292 proto.repeated_int32.append(5) 1293 proto.repeated_int32.append(5) 1294 1295 self.assertEqual(4, len(proto.repeated_int32)) 1296 proto.repeated_int32.remove(5) 1297 self.assertEqual(3, len(proto.repeated_int32)) 1298 self.assertEqual(10, proto.repeated_int32[0]) 1299 self.assertEqual(5, proto.repeated_int32[1]) 1300 self.assertEqual(5, proto.repeated_int32[2]) 1301 1302 proto.repeated_int32.remove(5) 1303 self.assertEqual(2, len(proto.repeated_int32)) 1304 self.assertEqual(10, proto.repeated_int32[0]) 1305 self.assertEqual(5, proto.repeated_int32[1]) 1306 1307 proto.repeated_int32.remove(10) 1308 self.assertEqual(1, len(proto.repeated_int32)) 1309 self.assertEqual(5, proto.repeated_int32[0]) 1310 1311 # Remove a non-existent element. 1312 self.assertRaises(ValueError, proto.repeated_int32.remove, 123) 1313 1314 def testRepeatedComposites(self): 1315 proto = unittest_pb2.TestAllTypes() 1316 self.assertFalse(proto.repeated_nested_message) 1317 self.assertEqual(0, len(proto.repeated_nested_message)) 1318 m0 = proto.repeated_nested_message.add() 1319 m1 = proto.repeated_nested_message.add() 1320 self.assertTrue(proto.repeated_nested_message) 1321 self.assertEqual(2, len(proto.repeated_nested_message)) 1322 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 1323 self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage) 1324 1325 # Test out-of-bounds indices. 1326 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 1327 1234) 1328 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 1329 -1234) 1330 1331 # Test incorrect types passed to __getitem__. 1332 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 1333 'foo') 1334 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 1335 None) 1336 1337 # Test slice retrieval. 1338 m2 = proto.repeated_nested_message.add() 1339 m3 = proto.repeated_nested_message.add() 1340 m4 = proto.repeated_nested_message.add() 1341 self.assertListsEqual( 1342 [m1, m2, m3], proto.repeated_nested_message[1:4]) 1343 self.assertListsEqual( 1344 [m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) 1345 self.assertListsEqual( 1346 [m0, m1], proto.repeated_nested_message[:2]) 1347 self.assertListsEqual( 1348 [m2, m3, m4], proto.repeated_nested_message[2:]) 1349 self.assertEqual( 1350 m0, proto.repeated_nested_message[0]) 1351 self.assertListsEqual( 1352 [m0], proto.repeated_nested_message[:1]) 1353 1354 # Test that we can use the field as an iterator. 1355 result = [] 1356 for i in proto.repeated_nested_message: 1357 result.append(i) 1358 self.assertListsEqual([m0, m1, m2, m3, m4], result) 1359 1360 # Test single deletion. 1361 del proto.repeated_nested_message[2] 1362 self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message) 1363 1364 # Test slice deletion. 1365 del proto.repeated_nested_message[2:] 1366 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 1367 1368 # Test extending. 1369 n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1) 1370 n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2) 1371 proto.repeated_nested_message.extend([n1,n2]) 1372 self.assertEqual(4, len(proto.repeated_nested_message)) 1373 self.assertEqual(n1, proto.repeated_nested_message[2]) 1374 self.assertEqual(n2, proto.repeated_nested_message[3]) 1375 self.assertRaises(TypeError, 1376 proto.repeated_nested_message.extend, n1) 1377 self.assertRaises(TypeError, 1378 proto.repeated_nested_message.extend, [0]) 1379 wrong_message_type = unittest_pb2.TestAllTypes() 1380 self.assertRaises(TypeError, 1381 proto.repeated_nested_message.extend, 1382 [wrong_message_type]) 1383 1384 # Test clearing. 1385 proto.ClearField('repeated_nested_message') 1386 self.assertFalse(proto.repeated_nested_message) 1387 self.assertEqual(0, len(proto.repeated_nested_message)) 1388 1389 # Test constructing an element while adding it. 1390 proto.repeated_nested_message.add(bb=23) 1391 self.assertEqual(1, len(proto.repeated_nested_message)) 1392 self.assertEqual(23, proto.repeated_nested_message[0].bb) 1393 self.assertRaises(TypeError, proto.repeated_nested_message.add, 23) 1394 with self.assertRaises(Exception): 1395 proto.repeated_nested_message[0] = 23 1396 1397 def testRepeatedCompositeRemove(self): 1398 proto = unittest_pb2.TestAllTypes() 1399 1400 self.assertEqual(0, len(proto.repeated_nested_message)) 1401 m0 = proto.repeated_nested_message.add() 1402 # Need to set some differentiating variable so m0 != m1 != m2: 1403 m0.bb = len(proto.repeated_nested_message) 1404 m1 = proto.repeated_nested_message.add() 1405 m1.bb = len(proto.repeated_nested_message) 1406 self.assertTrue(m0 != m1) 1407 m2 = proto.repeated_nested_message.add() 1408 m2.bb = len(proto.repeated_nested_message) 1409 self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) 1410 1411 self.assertEqual(3, len(proto.repeated_nested_message)) 1412 proto.repeated_nested_message.remove(m0) 1413 self.assertEqual(2, len(proto.repeated_nested_message)) 1414 self.assertEqual(m1, proto.repeated_nested_message[0]) 1415 self.assertEqual(m2, proto.repeated_nested_message[1]) 1416 1417 # Removing m0 again or removing None should raise error 1418 self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0) 1419 self.assertRaises(ValueError, proto.repeated_nested_message.remove, None) 1420 self.assertEqual(2, len(proto.repeated_nested_message)) 1421 1422 proto.repeated_nested_message.remove(m2) 1423 self.assertEqual(1, len(proto.repeated_nested_message)) 1424 self.assertEqual(m1, proto.repeated_nested_message[0]) 1425 1426 def testHandWrittenReflection(self): 1427 # Hand written extensions are only supported by the pure-Python 1428 # implementation of the API. 1429 if api_implementation.Type() != 'python': 1430 return 1431 1432 FieldDescriptor = descriptor.FieldDescriptor 1433 foo_field_descriptor = FieldDescriptor( 1434 name='foo_field', full_name='MyProto.foo_field', 1435 index=0, number=1, type=FieldDescriptor.TYPE_INT64, 1436 cpp_type=FieldDescriptor.CPPTYPE_INT64, 1437 label=FieldDescriptor.LABEL_OPTIONAL, default_value=0, 1438 containing_type=None, message_type=None, enum_type=None, 1439 is_extension=False, extension_scope=None, 1440 options=descriptor_pb2.FieldOptions(), 1441 # pylint: disable=protected-access 1442 create_key=descriptor._internal_create_key) 1443 mydescriptor = descriptor.Descriptor( 1444 name='MyProto', full_name='MyProto', filename='ignored', 1445 containing_type=None, nested_types=[], enum_types=[], 1446 fields=[foo_field_descriptor], extensions=[], 1447 options=descriptor_pb2.MessageOptions(), 1448 # pylint: disable=protected-access 1449 create_key=descriptor._internal_create_key) 1450 class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): 1451 DESCRIPTOR = mydescriptor 1452 myproto_instance = MyProtoClass() 1453 self.assertEqual(0, myproto_instance.foo_field) 1454 self.assertFalse(myproto_instance.HasField('foo_field')) 1455 myproto_instance.foo_field = 23 1456 self.assertEqual(23, myproto_instance.foo_field) 1457 self.assertTrue(myproto_instance.HasField('foo_field')) 1458 1459 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 1460 def testDescriptorProtoSupport(self): 1461 # Hand written descriptors/reflection are only supported by the pure-Python 1462 # implementation of the API. 1463 if api_implementation.Type() != 'python': 1464 return 1465 1466 def AddDescriptorField(proto, field_name, field_type): 1467 AddDescriptorField.field_index += 1 1468 new_field = proto.field.add() 1469 new_field.name = field_name 1470 new_field.type = field_type 1471 new_field.number = AddDescriptorField.field_index 1472 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL 1473 1474 AddDescriptorField.field_index = 0 1475 1476 desc_proto = descriptor_pb2.DescriptorProto() 1477 desc_proto.name = 'Car' 1478 fdp = descriptor_pb2.FieldDescriptorProto 1479 AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING) 1480 AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64) 1481 AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL) 1482 AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE) 1483 # Add a repeated field 1484 AddDescriptorField.field_index += 1 1485 new_field = desc_proto.field.add() 1486 new_field.name = 'owners' 1487 new_field.type = fdp.TYPE_STRING 1488 new_field.number = AddDescriptorField.field_index 1489 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED 1490 1491 desc = descriptor.MakeDescriptor(desc_proto) 1492 self.assertTrue('name' in desc.fields_by_name) 1493 self.assertTrue('year' in desc.fields_by_name) 1494 self.assertTrue('automatic' in desc.fields_by_name) 1495 self.assertTrue('price' in desc.fields_by_name) 1496 self.assertTrue('owners' in desc.fields_by_name) 1497 1498 class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, 1499 message.Message)): 1500 DESCRIPTOR = desc 1501 1502 prius = CarMessage() 1503 prius.name = 'prius' 1504 prius.year = 2010 1505 prius.automatic = True 1506 prius.price = 25134.75 1507 prius.owners.extend(['bob', 'susan']) 1508 1509 serialized_prius = prius.SerializeToString() 1510 new_prius = reflection.ParseMessage(desc, serialized_prius) 1511 self.assertIsNot(new_prius, prius) 1512 self.assertEqual(prius, new_prius) 1513 1514 # these are unnecessary assuming message equality works as advertised but 1515 # explicitly check to be safe since we're mucking about in metaclass foo 1516 self.assertEqual(prius.name, new_prius.name) 1517 self.assertEqual(prius.year, new_prius.year) 1518 self.assertEqual(prius.automatic, new_prius.automatic) 1519 self.assertEqual(prius.price, new_prius.price) 1520 self.assertEqual(prius.owners, new_prius.owners) 1521 1522 def testExtensionDelete(self): 1523 extendee_proto = more_extensions_pb2.ExtendedMessage() 1524 1525 extension_int32 = more_extensions_pb2.optional_int_extension 1526 extendee_proto.Extensions[extension_int32] = 23 1527 1528 extension_repeated = more_extensions_pb2.repeated_int_extension 1529 extendee_proto.Extensions[extension_repeated].append(11) 1530 1531 extension_msg = more_extensions_pb2.optional_message_extension 1532 extendee_proto.Extensions[extension_msg].foreign_message_int = 56 1533 1534 self.assertEqual(len(extendee_proto.Extensions), 3) 1535 del extendee_proto.Extensions[extension_msg] 1536 self.assertEqual(len(extendee_proto.Extensions), 2) 1537 del extendee_proto.Extensions[extension_repeated] 1538 self.assertEqual(len(extendee_proto.Extensions), 1) 1539 # Delete a none exist extension. It is OK to "del m.Extensions[ext]" 1540 # even if the extension is not present in the message; we don't 1541 # raise KeyError. This is consistent with "m.Extensions[ext]" 1542 # returning a default value even if we did not set anything. 1543 del extendee_proto.Extensions[extension_repeated] 1544 self.assertEqual(len(extendee_proto.Extensions), 1) 1545 del extendee_proto.Extensions[extension_int32] 1546 self.assertEqual(len(extendee_proto.Extensions), 0) 1547 1548 def testExtensionIter(self): 1549 extendee_proto = more_extensions_pb2.ExtendedMessage() 1550 1551 extension_int32 = more_extensions_pb2.optional_int_extension 1552 extendee_proto.Extensions[extension_int32] = 23 1553 1554 extension_repeated = more_extensions_pb2.repeated_int_extension 1555 extendee_proto.Extensions[extension_repeated].append(11) 1556 1557 extension_msg = more_extensions_pb2.optional_message_extension 1558 extendee_proto.Extensions[extension_msg].foreign_message_int = 56 1559 1560 # Set some normal fields. 1561 extendee_proto.optional_int32 = 1 1562 extendee_proto.repeated_string.append('hi') 1563 1564 expected = (extension_int32, extension_msg, extension_repeated) 1565 count = 0 1566 for item in extendee_proto.Extensions: 1567 self.assertEqual(item.name, expected[count].name) 1568 self.assertIn(item, extendee_proto.Extensions) 1569 count += 1 1570 self.assertEqual(count, 3) 1571 1572 def testExtensionContainsError(self): 1573 extendee_proto = more_extensions_pb2.ExtendedMessage() 1574 self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, 0) 1575 1576 field = more_extensions_pb2.ExtendedMessage.DESCRIPTOR.fields_by_name[ 1577 'optional_int32'] 1578 self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, field) 1579 1580 def testTopLevelExtensionsForOptionalScalar(self): 1581 extendee_proto = unittest_pb2.TestAllExtensions() 1582 extension = unittest_pb2.optional_int32_extension 1583 self.assertFalse(extendee_proto.HasExtension(extension)) 1584 self.assertNotIn(extension, extendee_proto.Extensions) 1585 self.assertEqual(0, extendee_proto.Extensions[extension]) 1586 # As with normal scalar fields, just doing a read doesn't actually set the 1587 # "has" bit. 1588 self.assertFalse(extendee_proto.HasExtension(extension)) 1589 self.assertNotIn(extension, extendee_proto.Extensions) 1590 # Actually set the thing. 1591 extendee_proto.Extensions[extension] = 23 1592 self.assertEqual(23, extendee_proto.Extensions[extension]) 1593 self.assertTrue(extendee_proto.HasExtension(extension)) 1594 self.assertIn(extension, extendee_proto.Extensions) 1595 # Ensure that clearing works as well. 1596 extendee_proto.ClearExtension(extension) 1597 self.assertEqual(0, extendee_proto.Extensions[extension]) 1598 self.assertFalse(extendee_proto.HasExtension(extension)) 1599 self.assertNotIn(extension, extendee_proto.Extensions) 1600 1601 def testTopLevelExtensionsForRepeatedScalar(self): 1602 extendee_proto = unittest_pb2.TestAllExtensions() 1603 extension = unittest_pb2.repeated_string_extension 1604 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1605 self.assertNotIn(extension, extendee_proto.Extensions) 1606 extendee_proto.Extensions[extension].append('foo') 1607 self.assertEqual(['foo'], extendee_proto.Extensions[extension]) 1608 self.assertIn(extension, extendee_proto.Extensions) 1609 string_list = extendee_proto.Extensions[extension] 1610 extendee_proto.ClearExtension(extension) 1611 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1612 self.assertNotIn(extension, extendee_proto.Extensions) 1613 self.assertIsNot(string_list, extendee_proto.Extensions[extension]) 1614 # Shouldn't be allowed to do Extensions[extension] = 'a' 1615 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1616 extension, 'a') 1617 1618 def testTopLevelExtensionsForOptionalMessage(self): 1619 extendee_proto = unittest_pb2.TestAllExtensions() 1620 extension = unittest_pb2.optional_foreign_message_extension 1621 self.assertFalse(extendee_proto.HasExtension(extension)) 1622 self.assertNotIn(extension, extendee_proto.Extensions) 1623 self.assertEqual(0, extendee_proto.Extensions[extension].c) 1624 # As with normal (non-extension) fields, merely reading from the 1625 # thing shouldn't set the "has" bit. 1626 self.assertFalse(extendee_proto.HasExtension(extension)) 1627 self.assertNotIn(extension, extendee_proto.Extensions) 1628 extendee_proto.Extensions[extension].c = 23 1629 self.assertEqual(23, extendee_proto.Extensions[extension].c) 1630 self.assertTrue(extendee_proto.HasExtension(extension)) 1631 self.assertIn(extension, extendee_proto.Extensions) 1632 # Save a reference here. 1633 foreign_message = extendee_proto.Extensions[extension] 1634 extendee_proto.ClearExtension(extension) 1635 self.assertIsNot(foreign_message, extendee_proto.Extensions[extension]) 1636 # Setting a field on foreign_message now shouldn't set 1637 # any "has" bits on extendee_proto. 1638 foreign_message.c = 42 1639 self.assertEqual(42, foreign_message.c) 1640 self.assertTrue(foreign_message.HasField('c')) 1641 self.assertFalse(extendee_proto.HasExtension(extension)) 1642 self.assertNotIn(extension, extendee_proto.Extensions) 1643 # Shouldn't be allowed to do Extensions[extension] = 'a' 1644 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1645 extension, 'a') 1646 1647 def testTopLevelExtensionsForRepeatedMessage(self): 1648 extendee_proto = unittest_pb2.TestAllExtensions() 1649 extension = unittest_pb2.repeatedgroup_extension 1650 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1651 group = extendee_proto.Extensions[extension].add() 1652 group.a = 23 1653 self.assertEqual(23, extendee_proto.Extensions[extension][0].a) 1654 group.a = 42 1655 self.assertEqual(42, extendee_proto.Extensions[extension][0].a) 1656 group_list = extendee_proto.Extensions[extension] 1657 extendee_proto.ClearExtension(extension) 1658 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1659 self.assertIsNot(group_list, extendee_proto.Extensions[extension]) 1660 # Shouldn't be allowed to do Extensions[extension] = 'a' 1661 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1662 extension, 'a') 1663 1664 def testNestedExtensions(self): 1665 extendee_proto = unittest_pb2.TestAllExtensions() 1666 extension = unittest_pb2.TestRequired.single 1667 1668 # We just test the non-repeated case. 1669 self.assertFalse(extendee_proto.HasExtension(extension)) 1670 self.assertNotIn(extension, extendee_proto.Extensions) 1671 required = extendee_proto.Extensions[extension] 1672 self.assertEqual(0, required.a) 1673 self.assertFalse(extendee_proto.HasExtension(extension)) 1674 self.assertNotIn(extension, extendee_proto.Extensions) 1675 required.a = 23 1676 self.assertEqual(23, extendee_proto.Extensions[extension].a) 1677 self.assertTrue(extendee_proto.HasExtension(extension)) 1678 self.assertIn(extension, extendee_proto.Extensions) 1679 extendee_proto.ClearExtension(extension) 1680 self.assertIsNot(required, extendee_proto.Extensions[extension]) 1681 self.assertFalse(extendee_proto.HasExtension(extension)) 1682 self.assertNotIn(extension, extendee_proto.Extensions) 1683 1684 def testRegisteredExtensions(self): 1685 pool = unittest_pb2.DESCRIPTOR.pool 1686 self.assertTrue( 1687 pool.FindExtensionByNumber( 1688 unittest_pb2.TestAllExtensions.DESCRIPTOR, 1)) 1689 self.assertIs( 1690 pool.FindExtensionByName( 1691 'protobuf_unittest.optional_int32_extension').containing_type, 1692 unittest_pb2.TestAllExtensions.DESCRIPTOR) 1693 # Make sure extensions haven't been registered into types that shouldn't 1694 # have any. 1695 self.assertEqual(0, len( 1696 pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR))) 1697 1698 # If message A directly contains message B, and 1699 # a.HasField('b') is currently False, then mutating any 1700 # extension in B should change a.HasField('b') to True 1701 # (and so on up the object tree). 1702 def testHasBitsForAncestorsOfExtendedMessage(self): 1703 # Optional scalar extension. 1704 toplevel = more_extensions_pb2.TopLevelMessage() 1705 self.assertFalse(toplevel.HasField('submessage')) 1706 self.assertEqual(0, toplevel.submessage.Extensions[ 1707 more_extensions_pb2.optional_int_extension]) 1708 self.assertFalse(toplevel.HasField('submessage')) 1709 toplevel.submessage.Extensions[ 1710 more_extensions_pb2.optional_int_extension] = 23 1711 self.assertEqual(23, toplevel.submessage.Extensions[ 1712 more_extensions_pb2.optional_int_extension]) 1713 self.assertTrue(toplevel.HasField('submessage')) 1714 1715 # Repeated scalar extension. 1716 toplevel = more_extensions_pb2.TopLevelMessage() 1717 self.assertFalse(toplevel.HasField('submessage')) 1718 self.assertEqual([], toplevel.submessage.Extensions[ 1719 more_extensions_pb2.repeated_int_extension]) 1720 self.assertFalse(toplevel.HasField('submessage')) 1721 toplevel.submessage.Extensions[ 1722 more_extensions_pb2.repeated_int_extension].append(23) 1723 self.assertEqual([23], toplevel.submessage.Extensions[ 1724 more_extensions_pb2.repeated_int_extension]) 1725 self.assertTrue(toplevel.HasField('submessage')) 1726 1727 # Optional message extension. 1728 toplevel = more_extensions_pb2.TopLevelMessage() 1729 self.assertFalse(toplevel.HasField('submessage')) 1730 self.assertEqual(0, toplevel.submessage.Extensions[ 1731 more_extensions_pb2.optional_message_extension].foreign_message_int) 1732 self.assertFalse(toplevel.HasField('submessage')) 1733 toplevel.submessage.Extensions[ 1734 more_extensions_pb2.optional_message_extension].foreign_message_int = 23 1735 self.assertEqual(23, toplevel.submessage.Extensions[ 1736 more_extensions_pb2.optional_message_extension].foreign_message_int) 1737 self.assertTrue(toplevel.HasField('submessage')) 1738 1739 # Repeated message extension. 1740 toplevel = more_extensions_pb2.TopLevelMessage() 1741 self.assertFalse(toplevel.HasField('submessage')) 1742 self.assertEqual(0, len(toplevel.submessage.Extensions[ 1743 more_extensions_pb2.repeated_message_extension])) 1744 self.assertFalse(toplevel.HasField('submessage')) 1745 foreign = toplevel.submessage.Extensions[ 1746 more_extensions_pb2.repeated_message_extension].add() 1747 self.assertEqual(foreign, toplevel.submessage.Extensions[ 1748 more_extensions_pb2.repeated_message_extension][0]) 1749 self.assertTrue(toplevel.HasField('submessage')) 1750 1751 def testDisconnectionAfterClearingEmptyMessage(self): 1752 toplevel = more_extensions_pb2.TopLevelMessage() 1753 extendee_proto = toplevel.submessage 1754 extension = more_extensions_pb2.optional_message_extension 1755 extension_proto = extendee_proto.Extensions[extension] 1756 extendee_proto.ClearExtension(extension) 1757 extension_proto.foreign_message_int = 23 1758 1759 self.assertIsNot(extension_proto, extendee_proto.Extensions[extension]) 1760 1761 def testExtensionFailureModes(self): 1762 extendee_proto = unittest_pb2.TestAllExtensions() 1763 1764 # Try non-extension-handle arguments to HasExtension, 1765 # ClearExtension(), and Extensions[]... 1766 self.assertRaises(KeyError, extendee_proto.HasExtension, 1234) 1767 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234) 1768 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234) 1769 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5) 1770 1771 # Try something that *is* an extension handle, just not for 1772 # this message... 1773 for unknown_handle in (more_extensions_pb2.optional_int_extension, 1774 more_extensions_pb2.optional_message_extension, 1775 more_extensions_pb2.repeated_int_extension, 1776 more_extensions_pb2.repeated_message_extension): 1777 self.assertRaises(KeyError, extendee_proto.HasExtension, 1778 unknown_handle) 1779 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1780 unknown_handle) 1781 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1782 unknown_handle) 1783 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1784 unknown_handle, 5) 1785 1786 # Try call HasExtension() with a valid handle, but for a 1787 # *repeated* field. (Just as with non-extension repeated 1788 # fields, Has*() isn't supported for extension repeated fields). 1789 self.assertRaises(KeyError, extendee_proto.HasExtension, 1790 unittest_pb2.repeated_string_extension) 1791 1792 def testMergeFromOptionalGroup(self): 1793 # Test merge with an optional group. 1794 proto1 = unittest_pb2.TestAllTypes() 1795 proto1.optionalgroup.a = 12 1796 proto2 = unittest_pb2.TestAllTypes() 1797 proto2.MergeFrom(proto1) 1798 self.assertEqual(12, proto2.optionalgroup.a) 1799 1800 def testMergeFromExtensionsSingular(self): 1801 proto1 = unittest_pb2.TestAllExtensions() 1802 proto1.Extensions[unittest_pb2.optional_int32_extension] = 1 1803 1804 proto2 = unittest_pb2.TestAllExtensions() 1805 proto2.MergeFrom(proto1) 1806 self.assertEqual( 1807 1, proto2.Extensions[unittest_pb2.optional_int32_extension]) 1808 1809 def testMergeFromExtensionsRepeated(self): 1810 proto1 = unittest_pb2.TestAllExtensions() 1811 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1) 1812 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2) 1813 1814 proto2 = unittest_pb2.TestAllExtensions() 1815 proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0) 1816 proto2.MergeFrom(proto1) 1817 self.assertEqual( 1818 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension])) 1819 self.assertEqual( 1820 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0]) 1821 self.assertEqual( 1822 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1]) 1823 self.assertEqual( 1824 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2]) 1825 1826 def testMergeFromExtensionsNestedMessage(self): 1827 proto1 = unittest_pb2.TestAllExtensions() 1828 ext1 = proto1.Extensions[ 1829 unittest_pb2.repeated_nested_message_extension] 1830 m = ext1.add() 1831 m.bb = 222 1832 m = ext1.add() 1833 m.bb = 333 1834 1835 proto2 = unittest_pb2.TestAllExtensions() 1836 ext2 = proto2.Extensions[ 1837 unittest_pb2.repeated_nested_message_extension] 1838 m = ext2.add() 1839 m.bb = 111 1840 1841 proto2.MergeFrom(proto1) 1842 ext2 = proto2.Extensions[ 1843 unittest_pb2.repeated_nested_message_extension] 1844 self.assertEqual(3, len(ext2)) 1845 self.assertEqual(111, ext2[0].bb) 1846 self.assertEqual(222, ext2[1].bb) 1847 self.assertEqual(333, ext2[2].bb) 1848 1849 def testCopyFromBadType(self): 1850 # The python implementation doesn't raise an exception in this 1851 # case. In theory it should. 1852 if api_implementation.Type() == 'python': 1853 return 1854 proto1 = unittest_pb2.TestAllTypes() 1855 proto2 = unittest_pb2.TestAllExtensions() 1856 self.assertRaises(TypeError, proto1.CopyFrom, proto2) 1857 1858 def testClear(self): 1859 proto = unittest_pb2.TestAllTypes() 1860 # C++ implementation does not support lazy fields right now so leave it 1861 # out for now. 1862 if api_implementation.Type() == 'python': 1863 test_util.SetAllFields(proto) 1864 else: 1865 test_util.SetAllNonLazyFields(proto) 1866 # Clear the message. 1867 proto.Clear() 1868 self.assertEqual(proto.ByteSize(), 0) 1869 empty_proto = unittest_pb2.TestAllTypes() 1870 self.assertEqual(proto, empty_proto) 1871 1872 # Test if extensions which were set are cleared. 1873 proto = unittest_pb2.TestAllExtensions() 1874 test_util.SetAllExtensions(proto) 1875 # Clear the message. 1876 proto.Clear() 1877 self.assertEqual(proto.ByteSize(), 0) 1878 empty_proto = unittest_pb2.TestAllExtensions() 1879 self.assertEqual(proto, empty_proto) 1880 1881 def testDisconnectingInOneof(self): 1882 m = unittest_pb2.TestOneof2() # This message has two messages in a oneof. 1883 m.foo_message.qux_int = 5 1884 sub_message = m.foo_message 1885 # Accessing another message's field does not clear the first one 1886 self.assertEqual(m.foo_lazy_message.qux_int, 0) 1887 self.assertEqual(m.foo_message.qux_int, 5) 1888 # But mutating another message in the oneof detaches the first one. 1889 m.foo_lazy_message.qux_int = 6 1890 self.assertEqual(m.foo_message.qux_int, 0) 1891 # The reference we got above was detached and is still valid. 1892 self.assertEqual(sub_message.qux_int, 5) 1893 sub_message.qux_int = 7 1894 1895 def assertInitialized(self, proto): 1896 self.assertTrue(proto.IsInitialized()) 1897 # Neither method should raise an exception. 1898 proto.SerializeToString() 1899 proto.SerializePartialToString() 1900 1901 def assertNotInitialized(self, proto, error_size=None): 1902 errors = [] 1903 self.assertFalse(proto.IsInitialized()) 1904 self.assertFalse(proto.IsInitialized(errors)) 1905 self.assertEqual(error_size, len(errors)) 1906 self.assertRaises(message.EncodeError, proto.SerializeToString) 1907 # "Partial" serialization doesn't care if message is uninitialized. 1908 proto.SerializePartialToString() 1909 1910 def testIsInitialized(self): 1911 # Trivial cases - all optional fields and extensions. 1912 proto = unittest_pb2.TestAllTypes() 1913 self.assertInitialized(proto) 1914 proto = unittest_pb2.TestAllExtensions() 1915 self.assertInitialized(proto) 1916 1917 # The case of uninitialized required fields. 1918 proto = unittest_pb2.TestRequired() 1919 self.assertNotInitialized(proto, 3) 1920 proto.a = proto.b = proto.c = 2 1921 self.assertInitialized(proto) 1922 1923 # The case of uninitialized submessage. 1924 proto = unittest_pb2.TestRequiredForeign() 1925 self.assertInitialized(proto) 1926 proto.optional_message.a = 1 1927 self.assertNotInitialized(proto, 2) 1928 proto.optional_message.b = 0 1929 proto.optional_message.c = 0 1930 self.assertInitialized(proto) 1931 1932 # Uninitialized repeated submessage. 1933 message1 = proto.repeated_message.add() 1934 self.assertNotInitialized(proto, 3) 1935 message1.a = message1.b = message1.c = 0 1936 self.assertInitialized(proto) 1937 1938 # Uninitialized repeated group in an extension. 1939 proto = unittest_pb2.TestAllExtensions() 1940 extension = unittest_pb2.TestRequired.multi 1941 message1 = proto.Extensions[extension].add() 1942 message2 = proto.Extensions[extension].add() 1943 self.assertNotInitialized(proto, 6) 1944 message1.a = 1 1945 message1.b = 1 1946 message1.c = 1 1947 self.assertNotInitialized(proto, 3) 1948 message2.a = 2 1949 message2.b = 2 1950 message2.c = 2 1951 self.assertInitialized(proto) 1952 1953 # Uninitialized nonrepeated message in an extension. 1954 proto = unittest_pb2.TestAllExtensions() 1955 extension = unittest_pb2.TestRequired.single 1956 proto.Extensions[extension].a = 1 1957 self.assertNotInitialized(proto, 2) 1958 proto.Extensions[extension].b = 2 1959 proto.Extensions[extension].c = 3 1960 self.assertInitialized(proto) 1961 1962 # Try passing an errors list. 1963 errors = [] 1964 proto = unittest_pb2.TestRequired() 1965 self.assertFalse(proto.IsInitialized(errors)) 1966 self.assertEqual(errors, ['a', 'b', 'c']) 1967 self.assertRaises(TypeError, proto.IsInitialized, 1, 2, 3) 1968 1969 @unittest.skipIf( 1970 api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, 1971 'Errors are only available from the most recent C++ implementation.') 1972 def testFileDescriptorErrors(self): 1973 file_name = 'test_file_descriptor_errors.proto' 1974 package_name = 'test_file_descriptor_errors.proto' 1975 file_descriptor_proto = descriptor_pb2.FileDescriptorProto() 1976 file_descriptor_proto.name = file_name 1977 file_descriptor_proto.package = package_name 1978 m1 = file_descriptor_proto.message_type.add() 1979 m1.name = 'msg1' 1980 # Compiles the proto into the C++ descriptor pool 1981 descriptor.FileDescriptor( 1982 file_name, 1983 package_name, 1984 serialized_pb=file_descriptor_proto.SerializeToString()) 1985 # Add a FileDescriptorProto that has duplicate symbols 1986 another_file_name = 'another_test_file_descriptor_errors.proto' 1987 file_descriptor_proto.name = another_file_name 1988 m2 = file_descriptor_proto.message_type.add() 1989 m2.name = 'msg2' 1990 with self.assertRaises(TypeError) as cm: 1991 descriptor.FileDescriptor( 1992 another_file_name, 1993 package_name, 1994 serialized_pb=file_descriptor_proto.SerializeToString()) 1995 self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % 1996 getattr(cm.expected, '__name__', cm.expected)) 1997 self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) 1998 # Error message will say something about this definition being a 1999 # duplicate, though we don't check the message exactly to avoid a 2000 # dependency on the C++ logging code. 2001 self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) 2002 2003 def testStringUTF8Serialization(self): 2004 proto = message_set_extensions_pb2.TestMessageSet() 2005 extension_message = message_set_extensions_pb2.TestMessageSetExtension2 2006 extension = extension_message.message_set_extension 2007 2008 test_utf8 = u'Тест' 2009 test_utf8_bytes = test_utf8.encode('utf-8') 2010 2011 # 'Test' in another language, using UTF-8 charset. 2012 proto.Extensions[extension].str = test_utf8 2013 2014 # Serialize using the MessageSet wire format (this is specified in the 2015 # .proto file). 2016 serialized = proto.SerializeToString() 2017 2018 # Check byte size. 2019 self.assertEqual(proto.ByteSize(), len(serialized)) 2020 2021 raw = unittest_mset_pb2.RawMessageSet() 2022 bytes_read = raw.MergeFromString(serialized) 2023 self.assertEqual(len(serialized), bytes_read) 2024 2025 message2 = message_set_extensions_pb2.TestMessageSetExtension2() 2026 2027 self.assertEqual(1, len(raw.item)) 2028 # Check that the type_id is the same as the tag ID in the .proto file. 2029 self.assertEqual(raw.item[0].type_id, 98418634) 2030 2031 # Check the actual bytes on the wire. 2032 self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes)) 2033 bytes_read = message2.MergeFromString(raw.item[0].message) 2034 self.assertEqual(len(raw.item[0].message), bytes_read) 2035 2036 self.assertEqual(type(message2.str), six.text_type) 2037 self.assertEqual(message2.str, test_utf8) 2038 2039 # The pure Python API throws an exception on MergeFromString(), 2040 # if any of the string fields of the message can't be UTF-8 decoded. 2041 # The C++ implementation of the API has no way to check that on 2042 # MergeFromString and thus has no way to throw the exception. 2043 # 2044 # The pure Python API always returns objects of type 'unicode' (UTF-8 2045 # encoded), or 'bytes' (in 7 bit ASCII). 2046 badbytes = raw.item[0].message.replace( 2047 test_utf8_bytes, len(test_utf8_bytes) * b'\xff') 2048 2049 unicode_decode_failed = False 2050 try: 2051 message2.MergeFromString(badbytes) 2052 except UnicodeDecodeError: 2053 unicode_decode_failed = True 2054 string_field = message2.str 2055 self.assertTrue(unicode_decode_failed or type(string_field) is bytes) 2056 2057 def testSetInParent(self): 2058 proto = unittest_pb2.TestAllTypes() 2059 self.assertFalse(proto.HasField('optionalgroup')) 2060 proto.optionalgroup.SetInParent() 2061 self.assertTrue(proto.HasField('optionalgroup')) 2062 2063 def testPackageInitializationImport(self): 2064 """Test that we can import nested messages from their __init__.py. 2065 2066 Such setup is not trivial since at the time of processing of __init__.py one 2067 can't refer to its submodules by name in code, so expressions like 2068 google.protobuf.internal.import_test_package.inner_pb2 2069 don't work. They do work in imports, so we have assign an alias at import 2070 and then use that alias in generated code. 2071 """ 2072 # We import here since it's the import that used to fail, and we want 2073 # the failure to have the right context. 2074 # pylint: disable=g-import-not-at-top 2075 from google.protobuf.internal import import_test_package 2076 # pylint: enable=g-import-not-at-top 2077 msg = import_test_package.myproto.Outer() 2078 # Just check the default value. 2079 self.assertEqual(57, msg.inner.value) 2080 2081# Since we had so many tests for protocol buffer equality, we broke these out 2082# into separate TestCase classes. 2083 2084 2085@testing_refleaks.TestCase 2086class TestAllTypesEqualityTest(unittest.TestCase): 2087 2088 def setUp(self): 2089 self.first_proto = unittest_pb2.TestAllTypes() 2090 self.second_proto = unittest_pb2.TestAllTypes() 2091 2092 def testNotHashable(self): 2093 self.assertRaises(TypeError, hash, self.first_proto) 2094 2095 def testSelfEquality(self): 2096 self.assertEqual(self.first_proto, self.first_proto) 2097 2098 def testEmptyProtosEqual(self): 2099 self.assertEqual(self.first_proto, self.second_proto) 2100 2101 2102@testing_refleaks.TestCase 2103class FullProtosEqualityTest(unittest.TestCase): 2104 2105 """Equality tests using completely-full protos as a starting point.""" 2106 2107 def setUp(self): 2108 self.first_proto = unittest_pb2.TestAllTypes() 2109 self.second_proto = unittest_pb2.TestAllTypes() 2110 test_util.SetAllFields(self.first_proto) 2111 test_util.SetAllFields(self.second_proto) 2112 2113 def testNotHashable(self): 2114 self.assertRaises(TypeError, hash, self.first_proto) 2115 2116 def testNoneNotEqual(self): 2117 self.assertNotEqual(self.first_proto, None) 2118 self.assertNotEqual(None, self.second_proto) 2119 2120 def testNotEqualToOtherMessage(self): 2121 third_proto = unittest_pb2.TestRequired() 2122 self.assertNotEqual(self.first_proto, third_proto) 2123 self.assertNotEqual(third_proto, self.second_proto) 2124 2125 def testAllFieldsFilledEquality(self): 2126 self.assertEqual(self.first_proto, self.second_proto) 2127 2128 def testNonRepeatedScalar(self): 2129 # Nonrepeated scalar field change should cause inequality. 2130 self.first_proto.optional_int32 += 1 2131 self.assertNotEqual(self.first_proto, self.second_proto) 2132 # ...as should clearing a field. 2133 self.first_proto.ClearField('optional_int32') 2134 self.assertNotEqual(self.first_proto, self.second_proto) 2135 2136 def testNonRepeatedComposite(self): 2137 # Change a nonrepeated composite field. 2138 self.first_proto.optional_nested_message.bb += 1 2139 self.assertNotEqual(self.first_proto, self.second_proto) 2140 self.first_proto.optional_nested_message.bb -= 1 2141 self.assertEqual(self.first_proto, self.second_proto) 2142 # Clear a field in the nested message. 2143 self.first_proto.optional_nested_message.ClearField('bb') 2144 self.assertNotEqual(self.first_proto, self.second_proto) 2145 self.first_proto.optional_nested_message.bb = ( 2146 self.second_proto.optional_nested_message.bb) 2147 self.assertEqual(self.first_proto, self.second_proto) 2148 # Remove the nested message entirely. 2149 self.first_proto.ClearField('optional_nested_message') 2150 self.assertNotEqual(self.first_proto, self.second_proto) 2151 2152 def testRepeatedScalar(self): 2153 # Change a repeated scalar field. 2154 self.first_proto.repeated_int32.append(5) 2155 self.assertNotEqual(self.first_proto, self.second_proto) 2156 self.first_proto.ClearField('repeated_int32') 2157 self.assertNotEqual(self.first_proto, self.second_proto) 2158 2159 def testRepeatedComposite(self): 2160 # Change value within a repeated composite field. 2161 self.first_proto.repeated_nested_message[0].bb += 1 2162 self.assertNotEqual(self.first_proto, self.second_proto) 2163 self.first_proto.repeated_nested_message[0].bb -= 1 2164 self.assertEqual(self.first_proto, self.second_proto) 2165 # Add a value to a repeated composite field. 2166 self.first_proto.repeated_nested_message.add() 2167 self.assertNotEqual(self.first_proto, self.second_proto) 2168 self.second_proto.repeated_nested_message.add() 2169 self.assertEqual(self.first_proto, self.second_proto) 2170 2171 def testNonRepeatedScalarHasBits(self): 2172 # Ensure that we test "has" bits as well as value for 2173 # nonrepeated scalar field. 2174 self.first_proto.ClearField('optional_int32') 2175 self.second_proto.optional_int32 = 0 2176 self.assertNotEqual(self.first_proto, self.second_proto) 2177 2178 def testNonRepeatedCompositeHasBits(self): 2179 # Ensure that we test "has" bits as well as value for 2180 # nonrepeated composite field. 2181 self.first_proto.ClearField('optional_nested_message') 2182 self.second_proto.optional_nested_message.ClearField('bb') 2183 self.assertNotEqual(self.first_proto, self.second_proto) 2184 self.first_proto.optional_nested_message.bb = 0 2185 self.first_proto.optional_nested_message.ClearField('bb') 2186 self.assertEqual(self.first_proto, self.second_proto) 2187 2188 2189@testing_refleaks.TestCase 2190class ExtensionEqualityTest(unittest.TestCase): 2191 2192 def testExtensionEquality(self): 2193 first_proto = unittest_pb2.TestAllExtensions() 2194 second_proto = unittest_pb2.TestAllExtensions() 2195 self.assertEqual(first_proto, second_proto) 2196 test_util.SetAllExtensions(first_proto) 2197 self.assertNotEqual(first_proto, second_proto) 2198 test_util.SetAllExtensions(second_proto) 2199 self.assertEqual(first_proto, second_proto) 2200 2201 # Ensure that we check value equality. 2202 first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1 2203 self.assertNotEqual(first_proto, second_proto) 2204 first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1 2205 self.assertEqual(first_proto, second_proto) 2206 2207 # Ensure that we also look at "has" bits. 2208 first_proto.ClearExtension(unittest_pb2.optional_int32_extension) 2209 second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 2210 self.assertNotEqual(first_proto, second_proto) 2211 first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 2212 self.assertEqual(first_proto, second_proto) 2213 2214 # Ensure that differences in cached values 2215 # don't matter if "has" bits are both false. 2216 first_proto = unittest_pb2.TestAllExtensions() 2217 second_proto = unittest_pb2.TestAllExtensions() 2218 self.assertEqual( 2219 0, first_proto.Extensions[unittest_pb2.optional_int32_extension]) 2220 self.assertEqual(first_proto, second_proto) 2221 2222 2223@testing_refleaks.TestCase 2224class MutualRecursionEqualityTest(unittest.TestCase): 2225 2226 def testEqualityWithMutualRecursion(self): 2227 first_proto = unittest_pb2.TestMutualRecursionA() 2228 second_proto = unittest_pb2.TestMutualRecursionA() 2229 self.assertEqual(first_proto, second_proto) 2230 first_proto.bb.a.bb.optional_int32 = 23 2231 self.assertNotEqual(first_proto, second_proto) 2232 second_proto.bb.a.bb.optional_int32 = 23 2233 self.assertEqual(first_proto, second_proto) 2234 2235 2236@testing_refleaks.TestCase 2237class ByteSizeTest(unittest.TestCase): 2238 2239 def setUp(self): 2240 self.proto = unittest_pb2.TestAllTypes() 2241 self.extended_proto = more_extensions_pb2.ExtendedMessage() 2242 self.packed_proto = unittest_pb2.TestPackedTypes() 2243 self.packed_extended_proto = unittest_pb2.TestPackedExtensions() 2244 2245 def Size(self): 2246 return self.proto.ByteSize() 2247 2248 def testEmptyMessage(self): 2249 self.assertEqual(0, self.proto.ByteSize()) 2250 2251 def testSizedOnKwargs(self): 2252 # Use a separate message to ensure testing right after creation. 2253 proto = unittest_pb2.TestAllTypes() 2254 self.assertEqual(0, proto.ByteSize()) 2255 proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1) 2256 # One byte for the tag, one to encode varint 1. 2257 self.assertEqual(2, proto_kwargs.ByteSize()) 2258 2259 def testVarints(self): 2260 def Test(i, expected_varint_size): 2261 self.proto.Clear() 2262 self.proto.optional_int64 = i 2263 # Add one to the varint size for the tag info 2264 # for tag 1. 2265 self.assertEqual(expected_varint_size + 1, self.Size()) 2266 Test(0, 1) 2267 Test(1, 1) 2268 for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)): 2269 Test((1 << i) - 1, num_bytes) 2270 Test(-1, 10) 2271 Test(-2, 10) 2272 Test(-(1 << 63), 10) 2273 2274 def testStrings(self): 2275 self.proto.optional_string = '' 2276 # Need one byte for tag info (tag #14), and one byte for length. 2277 self.assertEqual(2, self.Size()) 2278 2279 self.proto.optional_string = 'abc' 2280 # Need one byte for tag info (tag #14), and one byte for length. 2281 self.assertEqual(2 + len(self.proto.optional_string), self.Size()) 2282 2283 self.proto.optional_string = 'x' * 128 2284 # Need one byte for tag info (tag #14), and TWO bytes for length. 2285 self.assertEqual(3 + len(self.proto.optional_string), self.Size()) 2286 2287 def testOtherNumerics(self): 2288 self.proto.optional_fixed32 = 1234 2289 # One byte for tag and 4 bytes for fixed32. 2290 self.assertEqual(5, self.Size()) 2291 self.proto = unittest_pb2.TestAllTypes() 2292 2293 self.proto.optional_fixed64 = 1234 2294 # One byte for tag and 8 bytes for fixed64. 2295 self.assertEqual(9, self.Size()) 2296 self.proto = unittest_pb2.TestAllTypes() 2297 2298 self.proto.optional_float = 1.234 2299 # One byte for tag and 4 bytes for float. 2300 self.assertEqual(5, self.Size()) 2301 self.proto = unittest_pb2.TestAllTypes() 2302 2303 self.proto.optional_double = 1.234 2304 # One byte for tag and 8 bytes for float. 2305 self.assertEqual(9, self.Size()) 2306 self.proto = unittest_pb2.TestAllTypes() 2307 2308 self.proto.optional_sint32 = 64 2309 # One byte for tag and 2 bytes for zig-zag-encoded 64. 2310 self.assertEqual(3, self.Size()) 2311 self.proto = unittest_pb2.TestAllTypes() 2312 2313 def testComposites(self): 2314 # 3 bytes. 2315 self.proto.optional_nested_message.bb = (1 << 14) 2316 # Plus one byte for bb tag. 2317 # Plus 1 byte for optional_nested_message serialized size. 2318 # Plus two bytes for optional_nested_message tag. 2319 self.assertEqual(3 + 1 + 1 + 2, self.Size()) 2320 2321 def testGroups(self): 2322 # 4 bytes. 2323 self.proto.optionalgroup.a = (1 << 21) 2324 # Plus two bytes for |a| tag. 2325 # Plus 2 * two bytes for START_GROUP and END_GROUP tags. 2326 self.assertEqual(4 + 2 + 2*2, self.Size()) 2327 2328 def testRepeatedScalars(self): 2329 self.proto.repeated_int32.append(10) # 1 byte. 2330 self.proto.repeated_int32.append(128) # 2 bytes. 2331 # Also need 2 bytes for each entry for tag. 2332 self.assertEqual(1 + 2 + 2*2, self.Size()) 2333 2334 def testRepeatedScalarsExtend(self): 2335 self.proto.repeated_int32.extend([10, 128]) # 3 bytes. 2336 # Also need 2 bytes for each entry for tag. 2337 self.assertEqual(1 + 2 + 2*2, self.Size()) 2338 2339 def testRepeatedScalarsRemove(self): 2340 self.proto.repeated_int32.append(10) # 1 byte. 2341 self.proto.repeated_int32.append(128) # 2 bytes. 2342 # Also need 2 bytes for each entry for tag. 2343 self.assertEqual(1 + 2 + 2*2, self.Size()) 2344 self.proto.repeated_int32.remove(128) 2345 self.assertEqual(1 + 2, self.Size()) 2346 2347 def testRepeatedComposites(self): 2348 # Empty message. 2 bytes tag plus 1 byte length. 2349 foreign_message_0 = self.proto.repeated_nested_message.add() 2350 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2351 foreign_message_1 = self.proto.repeated_nested_message.add() 2352 foreign_message_1.bb = 7 2353 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2354 2355 def testRepeatedCompositesDelete(self): 2356 # Empty message. 2 bytes tag plus 1 byte length. 2357 foreign_message_0 = self.proto.repeated_nested_message.add() 2358 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2359 foreign_message_1 = self.proto.repeated_nested_message.add() 2360 foreign_message_1.bb = 9 2361 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2362 repeated_nested_message = copy.deepcopy( 2363 self.proto.repeated_nested_message) 2364 2365 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2366 del self.proto.repeated_nested_message[0] 2367 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2368 2369 # Now add a new message. 2370 foreign_message_2 = self.proto.repeated_nested_message.add() 2371 foreign_message_2.bb = 12 2372 2373 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2374 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2375 self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size()) 2376 2377 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2378 del self.proto.repeated_nested_message[1] 2379 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2380 2381 del self.proto.repeated_nested_message[0] 2382 self.assertEqual(0, self.Size()) 2383 2384 self.assertEqual(2, len(repeated_nested_message)) 2385 del repeated_nested_message[0:1] 2386 # TODO(user): Fix cpp extension bug when delete repeated message. 2387 if api_implementation.Type() == 'python': 2388 self.assertEqual(1, len(repeated_nested_message)) 2389 del repeated_nested_message[-1] 2390 # TODO(user): Fix cpp extension bug when delete repeated message. 2391 if api_implementation.Type() == 'python': 2392 self.assertEqual(0, len(repeated_nested_message)) 2393 2394 def testRepeatedGroups(self): 2395 # 2-byte START_GROUP plus 2-byte END_GROUP. 2396 group_0 = self.proto.repeatedgroup.add() 2397 # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a| 2398 # plus 2-byte END_GROUP. 2399 group_1 = self.proto.repeatedgroup.add() 2400 group_1.a = 7 2401 self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size()) 2402 2403 def testExtensions(self): 2404 proto = unittest_pb2.TestAllExtensions() 2405 self.assertEqual(0, proto.ByteSize()) 2406 extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte. 2407 proto.Extensions[extension] = 23 2408 # 1 byte for tag, 1 byte for value. 2409 self.assertEqual(2, proto.ByteSize()) 2410 field = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[ 2411 'optional_int32'] 2412 with self.assertRaises(KeyError): 2413 proto.Extensions[field] = 23 2414 2415 def testCacheInvalidationForNonrepeatedScalar(self): 2416 # Test non-extension. 2417 self.proto.optional_int32 = 1 2418 self.assertEqual(2, self.proto.ByteSize()) 2419 self.proto.optional_int32 = 128 2420 self.assertEqual(3, self.proto.ByteSize()) 2421 self.proto.ClearField('optional_int32') 2422 self.assertEqual(0, self.proto.ByteSize()) 2423 2424 # Test within extension. 2425 extension = more_extensions_pb2.optional_int_extension 2426 self.extended_proto.Extensions[extension] = 1 2427 self.assertEqual(2, self.extended_proto.ByteSize()) 2428 self.extended_proto.Extensions[extension] = 128 2429 self.assertEqual(3, self.extended_proto.ByteSize()) 2430 self.extended_proto.ClearExtension(extension) 2431 self.assertEqual(0, self.extended_proto.ByteSize()) 2432 2433 def testCacheInvalidationForRepeatedScalar(self): 2434 # Test non-extension. 2435 self.proto.repeated_int32.append(1) 2436 self.assertEqual(3, self.proto.ByteSize()) 2437 self.proto.repeated_int32.append(1) 2438 self.assertEqual(6, self.proto.ByteSize()) 2439 self.proto.repeated_int32[1] = 128 2440 self.assertEqual(7, self.proto.ByteSize()) 2441 self.proto.ClearField('repeated_int32') 2442 self.assertEqual(0, self.proto.ByteSize()) 2443 2444 # Test within extension. 2445 extension = more_extensions_pb2.repeated_int_extension 2446 repeated = self.extended_proto.Extensions[extension] 2447 repeated.append(1) 2448 self.assertEqual(2, self.extended_proto.ByteSize()) 2449 repeated.append(1) 2450 self.assertEqual(4, self.extended_proto.ByteSize()) 2451 repeated[1] = 128 2452 self.assertEqual(5, self.extended_proto.ByteSize()) 2453 self.extended_proto.ClearExtension(extension) 2454 self.assertEqual(0, self.extended_proto.ByteSize()) 2455 2456 def testCacheInvalidationForNonrepeatedMessage(self): 2457 # Test non-extension. 2458 self.proto.optional_foreign_message.c = 1 2459 self.assertEqual(5, self.proto.ByteSize()) 2460 self.proto.optional_foreign_message.c = 128 2461 self.assertEqual(6, self.proto.ByteSize()) 2462 self.proto.optional_foreign_message.ClearField('c') 2463 self.assertEqual(3, self.proto.ByteSize()) 2464 self.proto.ClearField('optional_foreign_message') 2465 self.assertEqual(0, self.proto.ByteSize()) 2466 2467 if api_implementation.Type() == 'python': 2468 # This is only possible in pure-Python implementation of the API. 2469 child = self.proto.optional_foreign_message 2470 self.proto.ClearField('optional_foreign_message') 2471 child.c = 128 2472 self.assertEqual(0, self.proto.ByteSize()) 2473 2474 # Test within extension. 2475 extension = more_extensions_pb2.optional_message_extension 2476 child = self.extended_proto.Extensions[extension] 2477 self.assertEqual(0, self.extended_proto.ByteSize()) 2478 child.foreign_message_int = 1 2479 self.assertEqual(4, self.extended_proto.ByteSize()) 2480 child.foreign_message_int = 128 2481 self.assertEqual(5, self.extended_proto.ByteSize()) 2482 self.extended_proto.ClearExtension(extension) 2483 self.assertEqual(0, self.extended_proto.ByteSize()) 2484 2485 def testCacheInvalidationForRepeatedMessage(self): 2486 # Test non-extension. 2487 child0 = self.proto.repeated_foreign_message.add() 2488 self.assertEqual(3, self.proto.ByteSize()) 2489 self.proto.repeated_foreign_message.add() 2490 self.assertEqual(6, self.proto.ByteSize()) 2491 child0.c = 1 2492 self.assertEqual(8, self.proto.ByteSize()) 2493 self.proto.ClearField('repeated_foreign_message') 2494 self.assertEqual(0, self.proto.ByteSize()) 2495 2496 # Test within extension. 2497 extension = more_extensions_pb2.repeated_message_extension 2498 child_list = self.extended_proto.Extensions[extension] 2499 child0 = child_list.add() 2500 self.assertEqual(2, self.extended_proto.ByteSize()) 2501 child_list.add() 2502 self.assertEqual(4, self.extended_proto.ByteSize()) 2503 child0.foreign_message_int = 1 2504 self.assertEqual(6, self.extended_proto.ByteSize()) 2505 child0.ClearField('foreign_message_int') 2506 self.assertEqual(4, self.extended_proto.ByteSize()) 2507 self.extended_proto.ClearExtension(extension) 2508 self.assertEqual(0, self.extended_proto.ByteSize()) 2509 2510 def testPackedRepeatedScalars(self): 2511 self.assertEqual(0, self.packed_proto.ByteSize()) 2512 2513 self.packed_proto.packed_int32.append(10) # 1 byte. 2514 self.packed_proto.packed_int32.append(128) # 2 bytes. 2515 # The tag is 2 bytes (the field number is 90), and the varint 2516 # storing the length is 1 byte. 2517 int_size = 1 + 2 + 3 2518 self.assertEqual(int_size, self.packed_proto.ByteSize()) 2519 2520 self.packed_proto.packed_double.append(4.2) # 8 bytes 2521 self.packed_proto.packed_double.append(3.25) # 8 bytes 2522 # 2 more tag bytes, 1 more length byte. 2523 double_size = 8 + 8 + 3 2524 self.assertEqual(int_size+double_size, self.packed_proto.ByteSize()) 2525 2526 self.packed_proto.ClearField('packed_int32') 2527 self.assertEqual(double_size, self.packed_proto.ByteSize()) 2528 2529 def testPackedExtensions(self): 2530 self.assertEqual(0, self.packed_extended_proto.ByteSize()) 2531 extension = self.packed_extended_proto.Extensions[ 2532 unittest_pb2.packed_fixed32_extension] 2533 extension.extend([1, 2, 3, 4]) # 16 bytes 2534 # Tag is 3 bytes. 2535 self.assertEqual(19, self.packed_extended_proto.ByteSize()) 2536 2537 2538# Issues to be sure to cover include: 2539# * Handling of unrecognized tags ("uninterpreted_bytes"). 2540# * Handling of MessageSets. 2541# * Consistent ordering of tags in the wire format, 2542# including ordering between extensions and non-extension 2543# fields. 2544# * Consistent serialization of negative numbers, especially 2545# negative int32s. 2546# * Handling of empty submessages (with and without "has" 2547# bits set). 2548 2549@testing_refleaks.TestCase 2550class SerializationTest(unittest.TestCase): 2551 2552 def testSerializeEmtpyMessage(self): 2553 first_proto = unittest_pb2.TestAllTypes() 2554 second_proto = unittest_pb2.TestAllTypes() 2555 serialized = first_proto.SerializeToString() 2556 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2557 self.assertEqual( 2558 len(serialized), 2559 second_proto.MergeFromString(serialized)) 2560 self.assertEqual(first_proto, second_proto) 2561 2562 def testSerializeAllFields(self): 2563 first_proto = unittest_pb2.TestAllTypes() 2564 second_proto = unittest_pb2.TestAllTypes() 2565 test_util.SetAllFields(first_proto) 2566 serialized = first_proto.SerializeToString() 2567 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2568 self.assertEqual( 2569 len(serialized), 2570 second_proto.MergeFromString(serialized)) 2571 self.assertEqual(first_proto, second_proto) 2572 2573 def testSerializeAllExtensions(self): 2574 first_proto = unittest_pb2.TestAllExtensions() 2575 second_proto = unittest_pb2.TestAllExtensions() 2576 test_util.SetAllExtensions(first_proto) 2577 serialized = first_proto.SerializeToString() 2578 self.assertEqual( 2579 len(serialized), 2580 second_proto.MergeFromString(serialized)) 2581 self.assertEqual(first_proto, second_proto) 2582 2583 def testSerializeWithOptionalGroup(self): 2584 first_proto = unittest_pb2.TestAllTypes() 2585 second_proto = unittest_pb2.TestAllTypes() 2586 first_proto.optionalgroup.a = 242 2587 serialized = first_proto.SerializeToString() 2588 self.assertEqual( 2589 len(serialized), 2590 second_proto.MergeFromString(serialized)) 2591 self.assertEqual(first_proto, second_proto) 2592 2593 def testSerializeNegativeValues(self): 2594 first_proto = unittest_pb2.TestAllTypes() 2595 2596 first_proto.optional_int32 = -1 2597 first_proto.optional_int64 = -(2 << 40) 2598 first_proto.optional_sint32 = -3 2599 first_proto.optional_sint64 = -(4 << 40) 2600 first_proto.optional_sfixed32 = -5 2601 first_proto.optional_sfixed64 = -(6 << 40) 2602 2603 second_proto = unittest_pb2.TestAllTypes.FromString( 2604 first_proto.SerializeToString()) 2605 2606 self.assertEqual(first_proto, second_proto) 2607 2608 def testParseTruncated(self): 2609 # This test is only applicable for the Python implementation of the API. 2610 if api_implementation.Type() != 'python': 2611 return 2612 2613 first_proto = unittest_pb2.TestAllTypes() 2614 test_util.SetAllFields(first_proto) 2615 serialized = memoryview(first_proto.SerializeToString()) 2616 2617 for truncation_point in range(len(serialized) + 1): 2618 try: 2619 second_proto = unittest_pb2.TestAllTypes() 2620 unknown_fields = unittest_pb2.TestEmptyMessage() 2621 pos = second_proto._InternalParse(serialized, 0, truncation_point) 2622 # If we didn't raise an error then we read exactly the amount expected. 2623 self.assertEqual(truncation_point, pos) 2624 2625 # Parsing to unknown fields should not throw if parsing to known fields 2626 # did not. 2627 try: 2628 pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point) 2629 self.assertEqual(truncation_point, pos2) 2630 except message.DecodeError: 2631 self.fail('Parsing unknown fields failed when parsing known fields ' 2632 'did not.') 2633 except message.DecodeError: 2634 # Parsing unknown fields should also fail. 2635 self.assertRaises(message.DecodeError, unknown_fields._InternalParse, 2636 serialized, 0, truncation_point) 2637 2638 def testCanonicalSerializationOrder(self): 2639 proto = more_messages_pb2.OutOfOrderFields() 2640 # These are also their tag numbers. Even though we're setting these in 2641 # reverse-tag order AND they're listed in reverse tag-order in the .proto 2642 # file, they should nonetheless be serialized in tag order. 2643 proto.optional_sint32 = 5 2644 proto.Extensions[more_messages_pb2.optional_uint64] = 4 2645 proto.optional_uint32 = 3 2646 proto.Extensions[more_messages_pb2.optional_int64] = 2 2647 proto.optional_int32 = 1 2648 serialized = proto.SerializeToString() 2649 self.assertEqual(proto.ByteSize(), len(serialized)) 2650 d = _MiniDecoder(serialized) 2651 ReadTag = d.ReadFieldNumberAndWireType 2652 self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) 2653 self.assertEqual(1, d.ReadInt32()) 2654 self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag()) 2655 self.assertEqual(2, d.ReadInt64()) 2656 self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag()) 2657 self.assertEqual(3, d.ReadUInt32()) 2658 self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag()) 2659 self.assertEqual(4, d.ReadUInt64()) 2660 self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag()) 2661 self.assertEqual(5, d.ReadSInt32()) 2662 2663 def testCanonicalSerializationOrderSameAsCpp(self): 2664 # Copy of the same test we use for C++. 2665 proto = unittest_pb2.TestFieldOrderings() 2666 test_util.SetAllFieldsAndExtensions(proto) 2667 serialized = proto.SerializeToString() 2668 test_util.ExpectAllFieldsAndExtensionsInOrder(serialized) 2669 2670 def testMergeFromStringWhenFieldsAlreadySet(self): 2671 first_proto = unittest_pb2.TestAllTypes() 2672 first_proto.repeated_string.append('foobar') 2673 first_proto.optional_int32 = 23 2674 first_proto.optional_nested_message.bb = 42 2675 serialized = first_proto.SerializeToString() 2676 2677 second_proto = unittest_pb2.TestAllTypes() 2678 second_proto.repeated_string.append('baz') 2679 second_proto.optional_int32 = 100 2680 second_proto.optional_nested_message.bb = 999 2681 2682 bytes_parsed = second_proto.MergeFromString(serialized) 2683 self.assertEqual(len(serialized), bytes_parsed) 2684 2685 # Ensure that we append to repeated fields. 2686 self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) 2687 # Ensure that we overwrite nonrepeatd scalars. 2688 self.assertEqual(23, second_proto.optional_int32) 2689 # Ensure that we recursively call MergeFromString() on 2690 # submessages. 2691 self.assertEqual(42, second_proto.optional_nested_message.bb) 2692 2693 def testMessageSetWireFormat(self): 2694 proto = message_set_extensions_pb2.TestMessageSet() 2695 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2696 extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2 2697 extension1 = extension_message1.message_set_extension 2698 extension2 = extension_message2.message_set_extension 2699 extension3 = message_set_extensions_pb2.message_set_extension3 2700 proto.Extensions[extension1].i = 123 2701 proto.Extensions[extension2].str = 'foo' 2702 proto.Extensions[extension3].text = 'bar' 2703 2704 # Serialize using the MessageSet wire format (this is specified in the 2705 # .proto file). 2706 serialized = proto.SerializeToString() 2707 2708 raw = unittest_mset_pb2.RawMessageSet() 2709 self.assertEqual(False, 2710 raw.DESCRIPTOR.GetOptions().message_set_wire_format) 2711 self.assertEqual( 2712 len(serialized), 2713 raw.MergeFromString(serialized)) 2714 self.assertEqual(3, len(raw.item)) 2715 2716 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2717 self.assertEqual( 2718 len(raw.item[0].message), 2719 message1.MergeFromString(raw.item[0].message)) 2720 self.assertEqual(123, message1.i) 2721 2722 message2 = message_set_extensions_pb2.TestMessageSetExtension2() 2723 self.assertEqual( 2724 len(raw.item[1].message), 2725 message2.MergeFromString(raw.item[1].message)) 2726 self.assertEqual('foo', message2.str) 2727 2728 message3 = message_set_extensions_pb2.TestMessageSetExtension3() 2729 self.assertEqual( 2730 len(raw.item[2].message), 2731 message3.MergeFromString(raw.item[2].message)) 2732 self.assertEqual('bar', message3.text) 2733 2734 # Deserialize using the MessageSet wire format. 2735 proto2 = message_set_extensions_pb2.TestMessageSet() 2736 self.assertEqual( 2737 len(serialized), 2738 proto2.MergeFromString(serialized)) 2739 self.assertEqual(123, proto2.Extensions[extension1].i) 2740 self.assertEqual('foo', proto2.Extensions[extension2].str) 2741 self.assertEqual('bar', proto2.Extensions[extension3].text) 2742 2743 # Check byte size. 2744 self.assertEqual(proto2.ByteSize(), len(serialized)) 2745 self.assertEqual(proto.ByteSize(), len(serialized)) 2746 2747 def testMessageSetWireFormatUnknownExtension(self): 2748 # Create a message using the message set wire format with an unknown 2749 # message. 2750 raw = unittest_mset_pb2.RawMessageSet() 2751 2752 # Add an item. 2753 item = raw.item.add() 2754 item.type_id = 98418603 2755 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2756 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2757 message1.i = 12345 2758 item.message = message1.SerializeToString() 2759 2760 # Add a second, unknown extension. 2761 item = raw.item.add() 2762 item.type_id = 98418604 2763 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2764 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2765 message1.i = 12346 2766 item.message = message1.SerializeToString() 2767 2768 # Add another unknown extension. 2769 item = raw.item.add() 2770 item.type_id = 98418605 2771 message1 = message_set_extensions_pb2.TestMessageSetExtension2() 2772 message1.str = 'foo' 2773 item.message = message1.SerializeToString() 2774 2775 serialized = raw.SerializeToString() 2776 2777 # Parse message using the message set wire format. 2778 proto = message_set_extensions_pb2.TestMessageSet() 2779 self.assertEqual( 2780 len(serialized), 2781 proto.MergeFromString(serialized)) 2782 2783 # Check that the message parsed well. 2784 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2785 extension1 = extension_message1.message_set_extension 2786 self.assertEqual(12345, proto.Extensions[extension1].i) 2787 2788 def testUnknownFields(self): 2789 proto = unittest_pb2.TestAllTypes() 2790 test_util.SetAllFields(proto) 2791 2792 serialized = proto.SerializeToString() 2793 2794 # The empty message should be parsable with all of the fields 2795 # unknown. 2796 proto2 = unittest_pb2.TestEmptyMessage() 2797 2798 # Parsing this message should succeed. 2799 self.assertEqual( 2800 len(serialized), 2801 proto2.MergeFromString(serialized)) 2802 2803 # Now test with a int64 field set. 2804 proto = unittest_pb2.TestAllTypes() 2805 proto.optional_int64 = 0x0fffffffffffffff 2806 serialized = proto.SerializeToString() 2807 # The empty message should be parsable with all of the fields 2808 # unknown. 2809 proto2 = unittest_pb2.TestEmptyMessage() 2810 # Parsing this message should succeed. 2811 self.assertEqual( 2812 len(serialized), 2813 proto2.MergeFromString(serialized)) 2814 2815 def _CheckRaises(self, exc_class, callable_obj, exception): 2816 """This method checks if the excpetion type and message are as expected.""" 2817 try: 2818 callable_obj() 2819 except exc_class as ex: 2820 # Check if the exception message is the right one. 2821 self.assertEqual(exception, str(ex)) 2822 return 2823 else: 2824 raise self.failureException('%s not raised' % str(exc_class)) 2825 2826 def testSerializeUninitialized(self): 2827 proto = unittest_pb2.TestRequired() 2828 self._CheckRaises( 2829 message.EncodeError, 2830 proto.SerializeToString, 2831 'Message protobuf_unittest.TestRequired is missing required fields: ' 2832 'a,b,c') 2833 # Shouldn't raise exceptions. 2834 partial = proto.SerializePartialToString() 2835 2836 proto2 = unittest_pb2.TestRequired() 2837 self.assertFalse(proto2.HasField('a')) 2838 # proto2 ParseFromString does not check that required fields are set. 2839 proto2.ParseFromString(partial) 2840 self.assertFalse(proto2.HasField('a')) 2841 2842 proto.a = 1 2843 self._CheckRaises( 2844 message.EncodeError, 2845 proto.SerializeToString, 2846 'Message protobuf_unittest.TestRequired is missing required fields: b,c') 2847 # Shouldn't raise exceptions. 2848 partial = proto.SerializePartialToString() 2849 2850 proto.b = 2 2851 self._CheckRaises( 2852 message.EncodeError, 2853 proto.SerializeToString, 2854 'Message protobuf_unittest.TestRequired is missing required fields: c') 2855 # Shouldn't raise exceptions. 2856 partial = proto.SerializePartialToString() 2857 2858 proto.c = 3 2859 serialized = proto.SerializeToString() 2860 # Shouldn't raise exceptions. 2861 partial = proto.SerializePartialToString() 2862 2863 proto2 = unittest_pb2.TestRequired() 2864 self.assertEqual( 2865 len(serialized), 2866 proto2.MergeFromString(serialized)) 2867 self.assertEqual(1, proto2.a) 2868 self.assertEqual(2, proto2.b) 2869 self.assertEqual(3, proto2.c) 2870 self.assertEqual( 2871 len(partial), 2872 proto2.MergeFromString(partial)) 2873 self.assertEqual(1, proto2.a) 2874 self.assertEqual(2, proto2.b) 2875 self.assertEqual(3, proto2.c) 2876 2877 def testSerializeUninitializedSubMessage(self): 2878 proto = unittest_pb2.TestRequiredForeign() 2879 2880 # Sub-message doesn't exist yet, so this succeeds. 2881 proto.SerializeToString() 2882 2883 proto.optional_message.a = 1 2884 self._CheckRaises( 2885 message.EncodeError, 2886 proto.SerializeToString, 2887 'Message protobuf_unittest.TestRequiredForeign ' 2888 'is missing required fields: ' 2889 'optional_message.b,optional_message.c') 2890 2891 proto.optional_message.b = 2 2892 proto.optional_message.c = 3 2893 proto.SerializeToString() 2894 2895 proto.repeated_message.add().a = 1 2896 proto.repeated_message.add().b = 2 2897 self._CheckRaises( 2898 message.EncodeError, 2899 proto.SerializeToString, 2900 'Message protobuf_unittest.TestRequiredForeign is missing required fields: ' 2901 'repeated_message[0].b,repeated_message[0].c,' 2902 'repeated_message[1].a,repeated_message[1].c') 2903 2904 proto.repeated_message[0].b = 2 2905 proto.repeated_message[0].c = 3 2906 proto.repeated_message[1].a = 1 2907 proto.repeated_message[1].c = 3 2908 proto.SerializeToString() 2909 2910 def testSerializeAllPackedFields(self): 2911 first_proto = unittest_pb2.TestPackedTypes() 2912 second_proto = unittest_pb2.TestPackedTypes() 2913 test_util.SetAllPackedFields(first_proto) 2914 serialized = first_proto.SerializeToString() 2915 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2916 bytes_read = second_proto.MergeFromString(serialized) 2917 self.assertEqual(second_proto.ByteSize(), bytes_read) 2918 self.assertEqual(first_proto, second_proto) 2919 2920 def testSerializeAllPackedExtensions(self): 2921 first_proto = unittest_pb2.TestPackedExtensions() 2922 second_proto = unittest_pb2.TestPackedExtensions() 2923 test_util.SetAllPackedExtensions(first_proto) 2924 serialized = first_proto.SerializeToString() 2925 bytes_read = second_proto.MergeFromString(serialized) 2926 self.assertEqual(second_proto.ByteSize(), bytes_read) 2927 self.assertEqual(first_proto, second_proto) 2928 2929 def testMergePackedFromStringWhenSomeFieldsAlreadySet(self): 2930 first_proto = unittest_pb2.TestPackedTypes() 2931 first_proto.packed_int32.extend([1, 2]) 2932 first_proto.packed_double.append(3.0) 2933 serialized = first_proto.SerializeToString() 2934 2935 second_proto = unittest_pb2.TestPackedTypes() 2936 second_proto.packed_int32.append(3) 2937 second_proto.packed_double.extend([1.0, 2.0]) 2938 second_proto.packed_sint32.append(4) 2939 2940 self.assertEqual( 2941 len(serialized), 2942 second_proto.MergeFromString(serialized)) 2943 self.assertEqual([3, 1, 2], second_proto.packed_int32) 2944 self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) 2945 self.assertEqual([4], second_proto.packed_sint32) 2946 2947 def testPackedFieldsWireFormat(self): 2948 proto = unittest_pb2.TestPackedTypes() 2949 proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes 2950 proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes 2951 proto.packed_float.append(2.0) # 4 bytes, will be before double 2952 serialized = proto.SerializeToString() 2953 self.assertEqual(proto.ByteSize(), len(serialized)) 2954 d = _MiniDecoder(serialized) 2955 ReadTag = d.ReadFieldNumberAndWireType 2956 self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2957 self.assertEqual(1+1+1+2, d.ReadInt32()) 2958 self.assertEqual(1, d.ReadInt32()) 2959 self.assertEqual(2, d.ReadInt32()) 2960 self.assertEqual(150, d.ReadInt32()) 2961 self.assertEqual(3, d.ReadInt32()) 2962 self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2963 self.assertEqual(4, d.ReadInt32()) 2964 self.assertEqual(2.0, d.ReadFloat()) 2965 self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2966 self.assertEqual(8+8, d.ReadInt32()) 2967 self.assertEqual(1.0, d.ReadDouble()) 2968 self.assertEqual(1000.0, d.ReadDouble()) 2969 self.assertTrue(d.EndOfStream()) 2970 2971 def testParsePackedFromUnpacked(self): 2972 unpacked = unittest_pb2.TestUnpackedTypes() 2973 test_util.SetAllUnpackedFields(unpacked) 2974 packed = unittest_pb2.TestPackedTypes() 2975 serialized = unpacked.SerializeToString() 2976 self.assertEqual( 2977 len(serialized), 2978 packed.MergeFromString(serialized)) 2979 expected = unittest_pb2.TestPackedTypes() 2980 test_util.SetAllPackedFields(expected) 2981 self.assertEqual(expected, packed) 2982 2983 def testParseUnpackedFromPacked(self): 2984 packed = unittest_pb2.TestPackedTypes() 2985 test_util.SetAllPackedFields(packed) 2986 unpacked = unittest_pb2.TestUnpackedTypes() 2987 serialized = packed.SerializeToString() 2988 self.assertEqual( 2989 len(serialized), 2990 unpacked.MergeFromString(serialized)) 2991 expected = unittest_pb2.TestUnpackedTypes() 2992 test_util.SetAllUnpackedFields(expected) 2993 self.assertEqual(expected, unpacked) 2994 2995 def testFieldNumbers(self): 2996 proto = unittest_pb2.TestAllTypes() 2997 self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1) 2998 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1) 2999 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16) 3000 self.assertEqual( 3001 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18) 3002 self.assertEqual( 3003 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21) 3004 self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31) 3005 self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46) 3006 self.assertEqual( 3007 unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48) 3008 self.assertEqual( 3009 unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51) 3010 3011 def testExtensionFieldNumbers(self): 3012 self.assertEqual(unittest_pb2.TestRequired.single.number, 1000) 3013 self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000) 3014 self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001) 3015 self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001) 3016 self.assertEqual(unittest_pb2.optional_int32_extension.number, 1) 3017 self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1) 3018 self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16) 3019 self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16) 3020 self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18) 3021 self.assertEqual( 3022 unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18) 3023 self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21) 3024 self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 3025 21) 3026 self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31) 3027 self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31) 3028 self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46) 3029 self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46) 3030 self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48) 3031 self.assertEqual( 3032 unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48) 3033 self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51) 3034 self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 3035 51) 3036 3037 def testFieldProperties(self): 3038 cls = unittest_pb2.TestAllTypes 3039 self.assertIs(cls.optional_int32.DESCRIPTOR, 3040 cls.DESCRIPTOR.fields_by_name['optional_int32']) 3041 self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER, 3042 cls.optional_int32.DESCRIPTOR.number) 3043 self.assertIs(cls.optional_nested_message.DESCRIPTOR, 3044 cls.DESCRIPTOR.fields_by_name['optional_nested_message']) 3045 self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 3046 cls.optional_nested_message.DESCRIPTOR.number) 3047 self.assertIs(cls.repeated_int32.DESCRIPTOR, 3048 cls.DESCRIPTOR.fields_by_name['repeated_int32']) 3049 self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER, 3050 cls.repeated_int32.DESCRIPTOR.number) 3051 3052 def testFieldDataDescriptor(self): 3053 msg = unittest_pb2.TestAllTypes() 3054 msg.optional_int32 = 42 3055 self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42) 3056 unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25) 3057 self.assertEqual(msg.optional_int32, 25) 3058 with self.assertRaises(AttributeError): 3059 del msg.optional_int32 3060 try: 3061 unittest_pb2.ForeignMessage.c.__get__(msg) 3062 except TypeError: 3063 pass # The cpp implementation cannot mix fields from other messages. 3064 # This test exercises a specific check that avoids a crash. 3065 else: 3066 pass # The python implementation allows fields from other messages. 3067 # This is useless, but works. 3068 3069 def testInitKwargs(self): 3070 proto = unittest_pb2.TestAllTypes( 3071 optional_int32=1, 3072 optional_string='foo', 3073 optional_bool=True, 3074 optional_bytes=b'bar', 3075 optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), 3076 optional_foreign_message=unittest_pb2.ForeignMessage(c=1), 3077 optional_nested_enum=unittest_pb2.TestAllTypes.FOO, 3078 optional_foreign_enum=unittest_pb2.FOREIGN_FOO, 3079 repeated_int32=[1, 2, 3]) 3080 self.assertTrue(proto.IsInitialized()) 3081 self.assertTrue(proto.HasField('optional_int32')) 3082 self.assertTrue(proto.HasField('optional_string')) 3083 self.assertTrue(proto.HasField('optional_bool')) 3084 self.assertTrue(proto.HasField('optional_bytes')) 3085 self.assertTrue(proto.HasField('optional_nested_message')) 3086 self.assertTrue(proto.HasField('optional_foreign_message')) 3087 self.assertTrue(proto.HasField('optional_nested_enum')) 3088 self.assertTrue(proto.HasField('optional_foreign_enum')) 3089 self.assertEqual(1, proto.optional_int32) 3090 self.assertEqual('foo', proto.optional_string) 3091 self.assertEqual(True, proto.optional_bool) 3092 self.assertEqual(b'bar', proto.optional_bytes) 3093 self.assertEqual(1, proto.optional_nested_message.bb) 3094 self.assertEqual(1, proto.optional_foreign_message.c) 3095 self.assertEqual(unittest_pb2.TestAllTypes.FOO, 3096 proto.optional_nested_enum) 3097 self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum) 3098 self.assertEqual([1, 2, 3], proto.repeated_int32) 3099 3100 def testInitArgsUnknownFieldName(self): 3101 def InitalizeEmptyMessageWithExtraKeywordArg(): 3102 unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown') 3103 self._CheckRaises( 3104 ValueError, 3105 InitalizeEmptyMessageWithExtraKeywordArg, 3106 'Protocol message TestEmptyMessage has no "unknown" field.') 3107 3108 def testInitRequiredKwargs(self): 3109 proto = unittest_pb2.TestRequired(a=1, b=1, c=1) 3110 self.assertTrue(proto.IsInitialized()) 3111 self.assertTrue(proto.HasField('a')) 3112 self.assertTrue(proto.HasField('b')) 3113 self.assertTrue(proto.HasField('c')) 3114 self.assertFalse(proto.HasField('dummy2')) 3115 self.assertEqual(1, proto.a) 3116 self.assertEqual(1, proto.b) 3117 self.assertEqual(1, proto.c) 3118 3119 def testInitRequiredForeignKwargs(self): 3120 proto = unittest_pb2.TestRequiredForeign( 3121 optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1)) 3122 self.assertTrue(proto.IsInitialized()) 3123 self.assertTrue(proto.HasField('optional_message')) 3124 self.assertTrue(proto.optional_message.IsInitialized()) 3125 self.assertTrue(proto.optional_message.HasField('a')) 3126 self.assertTrue(proto.optional_message.HasField('b')) 3127 self.assertTrue(proto.optional_message.HasField('c')) 3128 self.assertFalse(proto.optional_message.HasField('dummy2')) 3129 self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1), 3130 proto.optional_message) 3131 self.assertEqual(1, proto.optional_message.a) 3132 self.assertEqual(1, proto.optional_message.b) 3133 self.assertEqual(1, proto.optional_message.c) 3134 3135 def testInitRepeatedKwargs(self): 3136 proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3]) 3137 self.assertTrue(proto.IsInitialized()) 3138 self.assertEqual(1, proto.repeated_int32[0]) 3139 self.assertEqual(2, proto.repeated_int32[1]) 3140 self.assertEqual(3, proto.repeated_int32[2]) 3141 3142 3143@testing_refleaks.TestCase 3144class OptionsTest(unittest.TestCase): 3145 3146 def testMessageOptions(self): 3147 proto = message_set_extensions_pb2.TestMessageSet() 3148 self.assertEqual(True, 3149 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 3150 proto = unittest_pb2.TestAllTypes() 3151 self.assertEqual(False, 3152 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 3153 3154 def testPackedOptions(self): 3155 proto = unittest_pb2.TestAllTypes() 3156 proto.optional_int32 = 1 3157 proto.optional_double = 3.0 3158 for field_descriptor, _ in proto.ListFields(): 3159 self.assertEqual(False, field_descriptor.GetOptions().packed) 3160 3161 proto = unittest_pb2.TestPackedTypes() 3162 proto.packed_int32.append(1) 3163 proto.packed_double.append(3.0) 3164 for field_descriptor, _ in proto.ListFields(): 3165 self.assertEqual(True, field_descriptor.GetOptions().packed) 3166 self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED, 3167 field_descriptor.label) 3168 3169 3170 3171@testing_refleaks.TestCase 3172class ClassAPITest(unittest.TestCase): 3173 3174 @unittest.skipIf( 3175 api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, 3176 'C++ implementation requires a call to MakeDescriptor()') 3177 @testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable') 3178 def testMakeClassWithNestedDescriptor(self): 3179 leaf_desc = descriptor.Descriptor( 3180 'leaf', 'package.parent.child.leaf', '', 3181 containing_type=None, fields=[], 3182 nested_types=[], enum_types=[], 3183 extensions=[], 3184 # pylint: disable=protected-access 3185 create_key=descriptor._internal_create_key) 3186 child_desc = descriptor.Descriptor( 3187 'child', 'package.parent.child', '', 3188 containing_type=None, fields=[], 3189 nested_types=[leaf_desc], enum_types=[], 3190 extensions=[], 3191 # pylint: disable=protected-access 3192 create_key=descriptor._internal_create_key) 3193 sibling_desc = descriptor.Descriptor( 3194 'sibling', 'package.parent.sibling', 3195 '', containing_type=None, fields=[], 3196 nested_types=[], enum_types=[], 3197 extensions=[], 3198 # pylint: disable=protected-access 3199 create_key=descriptor._internal_create_key) 3200 parent_desc = descriptor.Descriptor( 3201 'parent', 'package.parent', '', 3202 containing_type=None, fields=[], 3203 nested_types=[child_desc, sibling_desc], 3204 enum_types=[], extensions=[], 3205 # pylint: disable=protected-access 3206 create_key=descriptor._internal_create_key) 3207 reflection.MakeClass(parent_desc) 3208 3209 def _GetSerializedFileDescriptor(self, name): 3210 """Get a serialized representation of a test FileDescriptorProto. 3211 3212 Args: 3213 name: All calls to this must use a unique message name, to avoid 3214 collisions in the cpp descriptor pool. 3215 Returns: 3216 A string containing the serialized form of a test FileDescriptorProto. 3217 """ 3218 file_descriptor_str = ( 3219 'message_type {' 3220 ' name: "' + name + '"' 3221 ' field {' 3222 ' name: "flat"' 3223 ' number: 1' 3224 ' label: LABEL_REPEATED' 3225 ' type: TYPE_UINT32' 3226 ' }' 3227 ' field {' 3228 ' name: "bar"' 3229 ' number: 2' 3230 ' label: LABEL_OPTIONAL' 3231 ' type: TYPE_MESSAGE' 3232 ' type_name: "Bar"' 3233 ' }' 3234 ' nested_type {' 3235 ' name: "Bar"' 3236 ' field {' 3237 ' name: "baz"' 3238 ' number: 3' 3239 ' label: LABEL_OPTIONAL' 3240 ' type: TYPE_MESSAGE' 3241 ' type_name: "Baz"' 3242 ' }' 3243 ' nested_type {' 3244 ' name: "Baz"' 3245 ' enum_type {' 3246 ' name: "deep_enum"' 3247 ' value {' 3248 ' name: "VALUE_A"' 3249 ' number: 0' 3250 ' }' 3251 ' }' 3252 ' field {' 3253 ' name: "deep"' 3254 ' number: 4' 3255 ' label: LABEL_OPTIONAL' 3256 ' type: TYPE_UINT32' 3257 ' }' 3258 ' }' 3259 ' }' 3260 '}') 3261 file_descriptor = descriptor_pb2.FileDescriptorProto() 3262 text_format.Merge(file_descriptor_str, file_descriptor) 3263 return file_descriptor.SerializeToString() 3264 3265 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 3266 # This test can only run once; the second time, it raises errors about 3267 # conflicting message descriptors. 3268 def testParsingFlatClassWithExplicitClassDeclaration(self): 3269 """Test that the generated class can parse a flat message.""" 3270 # TODO(user): This test fails with cpp implemetnation in the call 3271 # of six.with_metaclass(). The other two callsites of with_metaclass 3272 # in this file are both excluded from cpp test, so it might be expected 3273 # to fail. Need someone more familiar with the python code to take a 3274 # look at this. 3275 if api_implementation.Type() != 'python': 3276 return 3277 file_descriptor = descriptor_pb2.FileDescriptorProto() 3278 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) 3279 msg_descriptor = descriptor.MakeDescriptor( 3280 file_descriptor.message_type[0]) 3281 3282 class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): 3283 DESCRIPTOR = msg_descriptor 3284 msg = MessageClass() 3285 msg_str = ( 3286 'flat: 0 ' 3287 'flat: 1 ' 3288 'flat: 2 ') 3289 text_format.Merge(msg_str, msg) 3290 self.assertEqual(msg.flat, [0, 1, 2]) 3291 3292 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 3293 def testParsingFlatClass(self): 3294 """Test that the generated class can parse a flat message.""" 3295 file_descriptor = descriptor_pb2.FileDescriptorProto() 3296 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) 3297 msg_descriptor = descriptor.MakeDescriptor( 3298 file_descriptor.message_type[0]) 3299 msg_class = reflection.MakeClass(msg_descriptor) 3300 msg = msg_class() 3301 msg_str = ( 3302 'flat: 0 ' 3303 'flat: 1 ' 3304 'flat: 2 ') 3305 text_format.Merge(msg_str, msg) 3306 self.assertEqual(msg.flat, [0, 1, 2]) 3307 3308 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 3309 def testParsingNestedClass(self): 3310 """Test that the generated class can parse a nested message.""" 3311 file_descriptor = descriptor_pb2.FileDescriptorProto() 3312 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) 3313 msg_descriptor = descriptor.MakeDescriptor( 3314 file_descriptor.message_type[0]) 3315 msg_class = reflection.MakeClass(msg_descriptor) 3316 msg = msg_class() 3317 msg_str = ( 3318 'bar {' 3319 ' baz {' 3320 ' deep: 4' 3321 ' }' 3322 '}') 3323 text_format.Merge(msg_str, msg) 3324 self.assertEqual(msg.bar.baz.deep, 4) 3325 3326if __name__ == '__main__': 3327 unittest.main() 3328