1#!/usr/bin/env python
2
3#
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements. See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership. The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License. You may obtain a copy of the License at
11#
12#   http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing,
15# software distributed under the License is distributed on an
16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17# KIND, either express or implied. See the License for the
18# specific language governing permissions and limitations
19# under the License.
20#
21
22from DebugProtoTest.ttypes import CompactProtoTestStruct, Empty, Wrapper
23from thrift.Thrift import TFrozenDict
24from thrift.transport import TTransport
25from thrift.protocol import TBinaryProtocol, TCompactProtocol
26import collections
27import unittest
28
29
30class TestFrozenBase(unittest.TestCase):
31    def _roundtrip(self, src, dst):
32        otrans = TTransport.TMemoryBuffer()
33        optoro = self.protocol(otrans)
34        src.write(optoro)
35        itrans = TTransport.TMemoryBuffer(otrans.getvalue())
36        iproto = self.protocol(itrans)
37        return dst.read(iproto) or dst
38
39    def test_dict_is_hashable_only_after_frozen(self):
40        d0 = {}
41        self.assertFalse(isinstance(d0, collections.Hashable))
42        d1 = TFrozenDict(d0)
43        self.assertTrue(isinstance(d1, collections.Hashable))
44
45    def test_struct_with_collection_fields(self):
46        pass
47
48    def test_set(self):
49        """Test that annotated set field can be serialized and deserialized"""
50        x = CompactProtoTestStruct(set_byte_map={
51            frozenset([42, 100, -100]): 99,
52            frozenset([0]): 100,
53            frozenset([]): 0,
54        })
55        x2 = self._roundtrip(x, CompactProtoTestStruct())
56        self.assertEqual(x2.set_byte_map[frozenset([42, 100, -100])], 99)
57        self.assertEqual(x2.set_byte_map[frozenset([0])], 100)
58        self.assertEqual(x2.set_byte_map[frozenset([])], 0)
59
60    def test_map(self):
61        """Test that annotated map field can be serialized and deserialized"""
62        x = CompactProtoTestStruct(map_byte_map={
63            TFrozenDict({42: 42, 100: -100}): 99,
64            TFrozenDict({0: 0}): 100,
65            TFrozenDict({}): 0,
66        })
67        x2 = self._roundtrip(x, CompactProtoTestStruct())
68        self.assertEqual(x2.map_byte_map[TFrozenDict({42: 42, 100: -100})], 99)
69        self.assertEqual(x2.map_byte_map[TFrozenDict({0: 0})], 100)
70        self.assertEqual(x2.map_byte_map[TFrozenDict({})], 0)
71
72    def test_list(self):
73        """Test that annotated list field can be serialized and deserialized"""
74        x = CompactProtoTestStruct(list_byte_map={
75            (42, 100, -100): 99,
76            (0,): 100,
77            (): 0,
78        })
79        x2 = self._roundtrip(x, CompactProtoTestStruct())
80        self.assertEqual(x2.list_byte_map[(42, 100, -100)], 99)
81        self.assertEqual(x2.list_byte_map[(0,)], 100)
82        self.assertEqual(x2.list_byte_map[()], 0)
83
84    def test_empty_struct(self):
85        """Test that annotated empty struct can be serialized and deserialized"""
86        x = CompactProtoTestStruct(empty_struct_field=Empty())
87        x2 = self._roundtrip(x, CompactProtoTestStruct())
88        self.assertEqual(x2.empty_struct_field, Empty())
89
90    def test_struct(self):
91        """Test that annotated struct can be serialized and deserialized"""
92        x = Wrapper(foo=Empty())
93        self.assertEqual(x.foo, Empty())
94        x2 = self._roundtrip(x, Wrapper)
95        self.assertEqual(x2.foo, Empty())
96
97
98class TestFrozen(TestFrozenBase):
99    def protocol(self, trans):
100        return TBinaryProtocol.TBinaryProtocolFactory().getProtocol(trans)
101
102
103class TestFrozenAcceleratedBinary(TestFrozenBase):
104    def protocol(self, trans):
105        return TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(trans)
106
107
108class TestFrozenAcceleratedCompact(TestFrozenBase):
109    def protocol(self, trans):
110        return TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False).getProtocol(trans)
111
112
113def suite():
114    suite = unittest.TestSuite()
115    loader = unittest.TestLoader()
116    suite.addTest(loader.loadTestsFromTestCase(TestFrozen))
117    suite.addTest(loader.loadTestsFromTestCase(TestFrozenAcceleratedBinary))
118    suite.addTest(loader.loadTestsFromTestCase(TestFrozenAcceleratedCompact))
119    return suite
120
121
122if __name__ == "__main__":
123    unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
124