1#!/usr/bin/env python
2#
3# Copyright 2016 Confluent Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18
19#
20# derived from https://github.com/verisign/python-confluent-schemaregistry.git
21#
22
23import unittest
24
25from tests.avro import mock_registry
26from tests.avro import data_gen
27from confluent_kafka.avro.cached_schema_registry_client import CachedSchemaRegistryClient
28from confluent_kafka import avro
29
30
31class TestCacheSchemaRegistryClient(unittest.TestCase):
32    def setUp(self):
33        self.server = mock_registry.ServerThread(0)
34        self.server.start()
35        self.client = CachedSchemaRegistryClient('http://127.0.0.1:' + str(self.server.server.server_port))
36
37    def tearDown(self):
38        self.server.shutdown()
39        self.server.join()
40
41    def test_register(self):
42        parsed = avro.loads(data_gen.BASIC_SCHEMA)
43        client = self.client
44        schema_id = client.register('test', parsed)
45        self.assertTrue(schema_id > 0)
46        self.assertEqual(len(client.id_to_schema), 1)
47
48    def test_multi_subject_register(self):
49        parsed = avro.loads(data_gen.BASIC_SCHEMA)
50        client = self.client
51        schema_id = client.register('test', parsed)
52        self.assertTrue(schema_id > 0)
53
54        # register again under different subject
55        dupe_id = client.register('other', parsed)
56        self.assertEqual(schema_id, dupe_id)
57        self.assertEqual(len(client.id_to_schema), 1)
58
59    def test_dupe_register(self):
60        parsed = avro.loads(data_gen.BASIC_SCHEMA)
61        subject = 'test'
62        client = self.client
63        schema_id = client.register(subject, parsed)
64        self.assertTrue(schema_id > 0)
65        latest = client.get_latest_schema(subject)
66
67        # register again under same subject
68        dupe_id = client.register(subject, parsed)
69        self.assertEqual(schema_id, dupe_id)
70        dupe_latest = client.get_latest_schema(subject)
71        self.assertEqual(latest, dupe_latest)
72
73    def assertLatest(self, meta_tuple, sid, schema, version):
74        self.assertNotEqual(sid, -1)
75        self.assertNotEqual(version, -1)
76        self.assertEqual(meta_tuple[0], sid)
77        self.assertEqual(meta_tuple[1], schema)
78        self.assertEqual(meta_tuple[2], version)
79
80    def test_getters(self):
81        parsed = avro.loads(data_gen.BASIC_SCHEMA)
82        client = self.client
83        subject = 'test'
84        version = client.get_version(subject, parsed)
85        self.assertEqual(version, None)
86        schema = client.get_by_id(1)
87        self.assertEqual(schema, None)
88        latest = client.get_latest_schema(subject)
89        self.assertEqual(latest, (None, None, None))
90
91        # register
92        schema_id = client.register(subject, parsed)
93        latest = client.get_latest_schema(subject)
94        version = client.get_version(subject, parsed)
95        self.assertLatest(latest, schema_id, parsed, version)
96
97        fetched = client.get_by_id(schema_id)
98        self.assertEqual(fetched, parsed)
99
100    def test_multi_register(self):
101        basic = avro.loads(data_gen.BASIC_SCHEMA)
102        adv = avro.loads(data_gen.ADVANCED_SCHEMA)
103        subject = 'test'
104        client = self.client
105
106        id1 = client.register(subject, basic)
107        latest1 = client.get_latest_schema(subject)
108        v1 = client.get_version(subject, basic)
109        self.assertLatest(latest1, id1, basic, v1)
110
111        id2 = client.register(subject, adv)
112        latest2 = client.get_latest_schema(subject)
113        v2 = client.get_version(subject, adv)
114        self.assertLatest(latest2, id2, adv, v2)
115
116        self.assertNotEqual(id1, id2)
117        self.assertNotEqual(latest1, latest2)
118        # ensure version is higher
119        self.assertTrue(latest1[2] < latest2[2])
120
121        client.register(subject, basic)
122        latest3 = client.get_latest_schema(subject)
123        # latest should not change with a re-reg
124        self.assertEqual(latest2, latest3)
125
126    def hash_func(self):
127        return hash(str(self))
128
129    def test_cert_no_key(self):
130        with self.assertRaises(ValueError):
131            self.client = CachedSchemaRegistryClient(url='https://127.0.0.1:65534',
132                                                     cert_location='/path/to/cert')
133
134    def test_cert_with_key(self):
135        self.client = CachedSchemaRegistryClient(url='https://127.0.0.1:65534',
136                                                 cert_location='/path/to/cert',
137                                                 key_location='/path/to/key')
138        self.assertTupleEqual(('/path/to/cert', '/path/to/key'), self.client._session.cert)
139
140    def test_cert_path(self):
141        self.client = CachedSchemaRegistryClient(url='https://127.0.0.1:65534',
142                                                 ca_location='/path/to/ca')
143        self.assertEqual('/path/to/ca', self.client._session.verify)
144
145    def test_context(self):
146        with self.client as c:
147            parsed = avro.loads(data_gen.BASIC_SCHEMA)
148            schema_id = c.register('test', parsed)
149            self.assertTrue(schema_id > 0)
150            self.assertEqual(len(c.id_to_schema), 1)
151
152    def test_init_with_dict(self):
153        self.client = CachedSchemaRegistryClient({
154            'url': 'https://127.0.0.1:65534',
155            'ssl.certificate.location': '/path/to/cert',
156            'ssl.key.location': '/path/to/key'
157        })
158        self.assertEqual('https://127.0.0.1:65534', self.client.url)
159
160    def test_empty_url(self):
161        with self.assertRaises(ValueError):
162            self.client = CachedSchemaRegistryClient({
163                'url': ''
164            })
165
166    def test_invalid_type_url(self):
167        with self.assertRaises(TypeError):
168            self.client = CachedSchemaRegistryClient(
169                url=1)
170
171    def test_invalid_type_url_dict(self):
172        with self.assertRaises(TypeError):
173            self.client = CachedSchemaRegistryClient({
174                "url": 1
175                })
176
177    def test_invalid_url(self):
178        with self.assertRaises(ValueError):
179            self.client = CachedSchemaRegistryClient({
180                'url': 'example.com:65534'
181            })
182
183    def test_basic_auth_url(self):
184        self.client = CachedSchemaRegistryClient({
185            'url': 'https://user_url:secret_url@127.0.0.1:65534',
186        })
187        self.assertTupleEqual(('user_url', 'secret_url'), self.client._session.auth)
188
189    def test_basic_auth_userinfo(self):
190        self.client = CachedSchemaRegistryClient({
191            'url': 'https://user_url:secret_url@127.0.0.1:65534',
192            'basic.auth.credentials.source': 'user_info',
193            'basic.auth.user.info': 'user_userinfo:secret_userinfo'
194        })
195        self.assertTupleEqual(('user_userinfo', 'secret_userinfo'), self.client._session.auth)
196
197    def test_basic_auth_sasl_inherit(self):
198        self.client = CachedSchemaRegistryClient({
199            'url': 'https://user_url:secret_url@127.0.0.1:65534',
200            'basic.auth.credentials.source': 'SASL_INHERIT',
201            'sasl.mechanism': 'PLAIN',
202            'sasl.username': 'user_sasl',
203            'sasl.password': 'secret_sasl'
204        })
205        self.assertTupleEqual(('user_sasl', 'secret_sasl'), self.client._session.auth)
206
207    def test_basic_auth_invalid(self):
208        with self.assertRaises(ValueError):
209            self.client = CachedSchemaRegistryClient({
210                'url': 'https://user_url:secret_url@127.0.0.1:65534',
211                'basic.auth.credentials.source': 'VAULT',
212            })
213
214    def test_invalid_conf(self):
215        with self.assertRaises(ValueError):
216            self.client = CachedSchemaRegistryClient({
217                'url': 'https://user_url:secret_url@127.0.0.1:65534',
218                'basic.auth.credentials.source': 'SASL_INHERIT',
219                'sasl.username': 'user_sasl',
220                'sasl.password': 'secret_sasl',
221                'invalid.conf': 1,
222                'invalid.conf2': 2
223            })
224