1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import itertools
16import pytest
17
18import proto
19
20
21def test_message_constructor_instance():
22    class Foo(proto.Message):
23        bar = proto.Field(proto.INT64, number=1)
24
25    foo_original = Foo(bar=42)
26    foo_copy = Foo(foo_original)
27    assert foo_original.bar == foo_copy.bar == 42
28    assert foo_original == foo_copy
29    assert foo_original is not foo_copy
30    assert isinstance(foo_original, Foo)
31    assert isinstance(foo_copy, Foo)
32    assert isinstance(Foo.pb(foo_copy), Foo.pb())
33
34
35def test_message_constructor_underlying_pb2():
36    class Foo(proto.Message):
37        bar = proto.Field(proto.INT64, number=1)
38
39    foo_pb2 = Foo.pb()(bar=42)
40    foo = Foo(foo_pb2)
41    assert foo.bar == Foo.pb(foo).bar == foo_pb2.bar == 42
42    assert foo == foo_pb2  # Not communitive. Nothing we can do about that.
43    assert foo_pb2 == Foo.pb(foo)
44    assert foo_pb2 is not Foo.pb(foo)
45    assert isinstance(foo, Foo)
46    assert isinstance(Foo.pb(foo), Foo.pb())
47    assert isinstance(foo_pb2, Foo.pb())
48
49
50def test_message_constructor_underlying_pb2_and_kwargs():
51    class Foo(proto.Message):
52        bar = proto.Field(proto.INT64, number=1)
53
54    foo_pb2 = Foo.pb()(bar=42)
55    foo = Foo(foo_pb2, bar=99)
56    assert foo.bar == Foo.pb(foo).bar == 99
57    assert foo_pb2.bar == 42
58    assert isinstance(foo, Foo)
59    assert isinstance(Foo.pb(foo), Foo.pb())
60    assert isinstance(foo_pb2, Foo.pb())
61
62
63def test_message_constructor_dict():
64    class Foo(proto.Message):
65        bar = proto.Field(proto.INT64, number=1)
66
67    foo = Foo({"bar": 42})
68    assert foo.bar == Foo.pb(foo).bar == 42
69    assert foo != {"bar": 42}
70    assert isinstance(foo, Foo)
71    assert isinstance(Foo.pb(foo), Foo.pb())
72
73
74def test_message_constructor_kwargs():
75    class Foo(proto.Message):
76        bar = proto.Field(proto.INT64, number=1)
77
78    foo = Foo(bar=42)
79    assert foo.bar == Foo.pb(foo).bar == 42
80    assert isinstance(foo, Foo)
81    assert isinstance(Foo.pb(foo), Foo.pb())
82
83
84def test_message_constructor_invalid():
85    class Foo(proto.Message):
86        bar = proto.Field(proto.INT64, number=1)
87
88    with pytest.raises(TypeError):
89        Foo(object())
90
91
92def test_message_constructor_explicit_qualname():
93    class Foo(proto.Message):
94        __qualname__ = "Foo"
95        bar = proto.Field(proto.INT64, number=1)
96
97    foo_original = Foo(bar=42)
98    foo_copy = Foo(foo_original)
99    assert foo_original.bar == foo_copy.bar == 42
100    assert foo_original == foo_copy
101    assert foo_original is not foo_copy
102    assert isinstance(foo_original, Foo)
103    assert isinstance(foo_copy, Foo)
104    assert isinstance(Foo.pb(foo_copy), Foo.pb())
105
106
107def test_message_contains_primitive():
108    class Foo(proto.Message):
109        bar = proto.Field(proto.INT64, number=1)
110
111    assert "bar" in Foo(bar=42)
112    assert "bar" not in Foo(bar=0)
113    assert "bar" not in Foo()
114
115
116def test_message_contains_composite():
117    class Foo(proto.Message):
118        bar = proto.Field(proto.INT64, number=1)
119
120    class Baz(proto.Message):
121        foo = proto.Field(proto.MESSAGE, number=1, message=Foo)
122
123    assert "foo" in Baz(foo=Foo(bar=42))
124    assert "foo" in Baz(foo=Foo())
125    assert "foo" not in Baz()
126
127
128def test_message_contains_repeated_primitive():
129    class Foo(proto.Message):
130        bar = proto.RepeatedField(proto.INT64, number=1)
131
132    assert "bar" in Foo(bar=[1, 1, 2, 3, 5])
133    assert "bar" in Foo(bar=[0])
134    assert "bar" not in Foo(bar=[])
135    assert "bar" not in Foo()
136
137
138def test_message_contains_repeated_composite():
139    class Foo(proto.Message):
140        bar = proto.Field(proto.INT64, number=1)
141
142    class Baz(proto.Message):
143        foo = proto.RepeatedField(proto.MESSAGE, number=1, message=Foo)
144
145    assert "foo" in Baz(foo=[Foo(bar=42)])
146    assert "foo" in Baz(foo=[Foo()])
147    assert "foo" not in Baz(foo=[])
148    assert "foo" not in Baz()
149
150
151def test_message_eq_primitives():
152    class Foo(proto.Message):
153        bar = proto.Field(proto.INT32, number=1)
154        baz = proto.Field(proto.STRING, number=2)
155        bacon = proto.Field(proto.BOOL, number=3)
156
157    assert Foo() == Foo()
158    assert Foo(bar=42, baz="42") == Foo(bar=42, baz="42")
159    assert Foo(bar=42, baz="42") != Foo(baz="42")
160    assert Foo(bar=42, bacon=True) == Foo(bar=42, bacon=True)
161    assert Foo(bar=42, bacon=True) != Foo(bar=42)
162    assert Foo(bar=42, baz="42", bacon=True) != Foo(bar=42, bacon=True)
163    assert Foo(bacon=False) == Foo()
164    assert Foo(bacon=True) != Foo(bacon=False)
165    assert Foo(bar=21 * 2) == Foo(bar=42)
166    assert Foo() == Foo(bar=0)
167    assert Foo() == Foo(bar=0, baz="", bacon=False)
168    assert Foo() != Foo(bar=0, baz="0", bacon=False)
169
170
171def test_message_serialize():
172    class Foo(proto.Message):
173        bar = proto.Field(proto.INT32, number=1)
174        baz = proto.Field(proto.STRING, number=2)
175        bacon = proto.Field(proto.BOOL, number=3)
176
177    foo = Foo(bar=42, bacon=True)
178    assert Foo.serialize(foo) == Foo.pb(foo).SerializeToString()
179
180
181def test_message_dict_serialize():
182    class Foo(proto.Message):
183        bar = proto.Field(proto.INT32, number=1)
184        baz = proto.Field(proto.STRING, number=2)
185        bacon = proto.Field(proto.BOOL, number=3)
186
187    foo = {"bar": 42, "bacon": True}
188    assert Foo.serialize(foo) == Foo.pb(foo, coerce=True).SerializeToString()
189
190
191def test_message_deserialize():
192    class OldFoo(proto.Message):
193        bar = proto.Field(proto.INT32, number=1)
194
195    class NewFoo(proto.Message):
196        bar = proto.Field(proto.INT64, number=1)
197
198    serialized = OldFoo.serialize(OldFoo(bar=42))
199    new_foo = NewFoo.deserialize(serialized)
200    assert isinstance(new_foo, NewFoo)
201    assert new_foo.bar == 42
202
203
204def test_message_pb():
205    class Foo(proto.Message):
206        bar = proto.Field(proto.INT32, number=1)
207
208    assert isinstance(Foo.pb(Foo()), Foo.pb())
209    with pytest.raises(TypeError):
210        Foo.pb(object())
211
212
213def test_invalid_field_access():
214    class Squid(proto.Message):
215        mass_kg = proto.Field(proto.INT32, number=1)
216
217    s = Squid()
218    with pytest.raises(AttributeError):
219        getattr(s, "shell")
220
221
222def test_setattr():
223    class Squid(proto.Message):
224        mass_kg = proto.Field(proto.INT32, number=1)
225
226    s1 = Squid()
227    s2 = Squid(mass_kg=20)
228
229    s1._pb = s2._pb
230
231    assert s1.mass_kg == 20
232
233
234def test_serialize_to_dict():
235    class Squid(proto.Message):
236        # Test primitives, enums, and repeated fields.
237        class Chromatophore(proto.Message):
238            class Color(proto.Enum):
239                UNKNOWN = 0
240                RED = 1
241                BROWN = 2
242                WHITE = 3
243                BLUE = 4
244
245            color = proto.Field(Color, number=1)
246
247        mass_kg = proto.Field(proto.INT32, number=1)
248        chromatophores = proto.RepeatedField(Chromatophore, number=2)
249
250    s = Squid(mass_kg=20)
251    colors = ["RED", "BROWN", "WHITE", "BLUE"]
252    s.chromatophores = [
253        {"color": c} for c in itertools.islice(itertools.cycle(colors), 10)
254    ]
255
256    s_dict = Squid.to_dict(s)
257    assert s_dict["chromatophores"][0]["color"] == 1
258
259    new_s = Squid(s_dict)
260    assert new_s == s
261
262    s_dict = Squid.to_dict(s, use_integers_for_enums=False)
263    assert s_dict["chromatophores"][0]["color"] == "RED"
264
265    s_new_2 = Squid(mass_kg=20)
266    s_dict_2 = Squid.to_dict(s_new_2, including_default_value_fields=False)
267    expected_dict = {"mass_kg": 20}
268    assert s_dict_2 == expected_dict
269
270    new_s = Squid(s_dict)
271    assert new_s == s
272
273
274def test_unknown_field_deserialize():
275    # This is a somewhat common setup: a client uses an older proto definition,
276    # while the server sends the newer definition. The client still needs to be
277    # able to interact with the protos it receives from the server.
278
279    class Octopus_Old(proto.Message):
280        mass_kg = proto.Field(proto.INT32, number=1)
281
282    class Octopus_New(proto.Message):
283        mass_kg = proto.Field(proto.INT32, number=1)
284        length_cm = proto.Field(proto.INT32, number=2)
285
286    o_new = Octopus_New(mass_kg=20, length_cm=100)
287    o_ser = Octopus_New.serialize(o_new)
288
289    o_old = Octopus_Old.deserialize(o_ser)
290    assert not hasattr(o_old, "length_cm")
291
292
293def test_unknown_field_deserialize_keep_fields():
294    # This is a somewhat common setup: a client uses an older proto definition,
295    # while the server sends the newer definition. The client still needs to be
296    # able to interact with the protos it receives from the server.
297
298    class Octopus_Old(proto.Message):
299        mass_kg = proto.Field(proto.INT32, number=1)
300
301    class Octopus_New(proto.Message):
302        mass_kg = proto.Field(proto.INT32, number=1)
303        length_cm = proto.Field(proto.INT32, number=2)
304
305    o_new = Octopus_New(mass_kg=20, length_cm=100)
306    o_ser = Octopus_New.serialize(o_new)
307
308    o_old = Octopus_Old.deserialize(o_ser)
309    assert not hasattr(o_old, "length_cm")
310
311    o_new = Octopus_New.deserialize(Octopus_Old.serialize(o_old))
312    assert o_new.length_cm == 100
313
314
315def test_unknown_field_from_dict():
316    class Squid(proto.Message):
317        mass_kg = proto.Field(proto.INT32, number=1)
318
319    # By default we don't permit unknown fields
320    with pytest.raises(ValueError):
321        s = Squid({"mass_kg": 20, "length_cm": 100})
322
323    s = Squid({"mass_kg": 20, "length_cm": 100}, ignore_unknown_fields=True)
324    assert not hasattr(s, "length_cm")
325
326
327def test_copy_from():
328    class Mollusc(proto.Message):
329        class Squid(proto.Message):
330            mass_kg = proto.Field(proto.INT32, number=1)
331
332        squid = proto.Field(Squid, number=1)
333
334    m = Mollusc()
335    s = Mollusc.Squid(mass_kg=20)
336    Mollusc.Squid.copy_from(m.squid, s)
337    assert m.squid is not s
338    assert m.squid == s
339
340    s.mass_kg = 30
341    Mollusc.Squid.copy_from(m.squid, Mollusc.Squid.pb(s))
342    assert m.squid == s
343
344    Mollusc.Squid.copy_from(m.squid, {"mass_kg": 10})
345    assert m.squid.mass_kg == 10
346
347    with pytest.raises(TypeError):
348        Mollusc.Squid.copy_from(m.squid, (("mass_kg", 20)))
349