1from pymemcache.client.hash import HashClient
2from pymemcache.client.base import Client, PooledClient
3from pymemcache.exceptions import MemcacheError, MemcacheUnknownError
4from pymemcache import pool
5
6from .test_client import ClientTestMixin, MockSocket
7import unittest
8import os
9import pytest
10import mock
11import socket
12
13
14class TestHashClient(ClientTestMixin, unittest.TestCase):
15
16    def make_client_pool(self, hostname, mock_socket_values,
17                         serializer=None, **kwargs):
18        mock_client = Client(hostname, serializer=serializer, **kwargs)
19        mock_client.sock = MockSocket(mock_socket_values)
20        client = PooledClient(hostname, serializer=serializer)
21        client.client_pool = pool.ObjectPool(lambda: mock_client)
22        return mock_client
23
24    def make_client(self, *mock_socket_values, **kwargs):
25        current_port = 11012
26        client = HashClient([], **kwargs)
27        ip = '127.0.0.1'
28
29        for vals in mock_socket_values:
30            s = '%s:%s' % (ip, current_port)
31            c = self.make_client_pool(
32                (ip, current_port),
33                vals,
34                **kwargs
35            )
36            client.clients[s] = c
37            client.hasher.add_node(s)
38            current_port += 1
39
40        return client
41
42    def make_unix_client(self, sockets, *mock_socket_values, **kwargs):
43        client = HashClient([], **kwargs)
44
45        for socket_, vals in zip(sockets, mock_socket_values):
46            c = self.make_client_pool(
47                socket_,
48                vals,
49                **kwargs
50            )
51            client.clients[socket_] = c
52            client.hasher.add_node(socket_)
53
54        return client
55
56    def test_setup_client_without_pooling(self):
57        client_class = 'pymemcache.client.hash.HashClient.client_class'
58        with mock.patch(client_class) as internal_client:
59            client = HashClient([], timeout=999, key_prefix='foo_bar_baz')
60            client.add_server(('127.0.0.1', '11211'))
61
62        assert internal_client.call_args[0][0] == ('127.0.0.1', '11211')
63        kwargs = internal_client.call_args[1]
64        assert kwargs['timeout'] == 999
65        assert kwargs['key_prefix'] == 'foo_bar_baz'
66
67    def test_get_many_unix(self):
68        pid = os.getpid()
69        sockets = [
70            '/tmp/pymemcache.1.%d' % pid,
71            '/tmp/pymemcache.2.%d' % pid,
72        ]
73        client = self.make_unix_client(sockets, *[
74            [b'STORED\r\n', b'VALUE key3 0 6\r\nvalue2\r\nEND\r\n', ],
75            [b'STORED\r\n', b'VALUE key1 0 6\r\nvalue1\r\nEND\r\n', ],
76        ])
77
78        def get_clients(key):
79            if key == b'key3':
80                return client.clients['/tmp/pymemcache.1.%d' % pid]
81            else:
82                return client.clients['/tmp/pymemcache.2.%d' % pid]
83
84        client._get_client = get_clients
85
86        result = client.set(b'key1', b'value1', noreply=False)
87        result = client.set(b'key3', b'value2', noreply=False)
88        result = client.get_many([b'key1', b'key3'])
89        assert result == {b'key1': b'value1', b'key3': b'value2'}
90
91    def test_get_many_all_found(self):
92        client = self.make_client(*[
93            [b'STORED\r\n', b'VALUE key3 0 6\r\nvalue2\r\nEND\r\n', ],
94            [b'STORED\r\n', b'VALUE key1 0 6\r\nvalue1\r\nEND\r\n', ],
95        ])
96
97        def get_clients(key):
98            if key == b'key3':
99                return client.clients['127.0.0.1:11012']
100            else:
101                return client.clients['127.0.0.1:11013']
102
103        client._get_client = get_clients
104
105        result = client.set(b'key1', b'value1', noreply=False)
106        result = client.set(b'key3', b'value2', noreply=False)
107        result = client.get_many([b'key1', b'key3'])
108        assert result == {b'key1': b'value1', b'key3': b'value2'}
109
110    def test_get_many_some_found(self):
111        client = self.make_client(*[
112            [b'END\r\n', ],
113            [b'STORED\r\n', b'VALUE key1 0 6\r\nvalue1\r\nEND\r\n', ],
114        ])
115
116        def get_clients(key):
117            if key == b'key3':
118                return client.clients['127.0.0.1:11012']
119            else:
120                return client.clients['127.0.0.1:11013']
121
122        client._get_client = get_clients
123        result = client.set(b'key1', b'value1', noreply=False)
124        result = client.get_many([b'key1', b'key3'])
125
126        assert result == {b'key1': b'value1'}
127
128    def test_get_many_bad_server_data(self):
129        client = self.make_client(*[
130            [b'STORED\r\n', b'VAXLUE key3 0 6\r\nvalue2\r\nEND\r\n', ],
131            [b'STORED\r\n', b'VAXLUE key1 0 6\r\nvalue1\r\nEND\r\n', ],
132        ])
133
134        def get_clients(key):
135            if key == b'key3':
136                return client.clients['127.0.0.1:11012']
137            else:
138                return client.clients['127.0.0.1:11013']
139
140        client._get_client = get_clients
141
142        with pytest.raises(MemcacheUnknownError):
143            client.set(b'key1', b'value1', noreply=False)
144            client.set(b'key3', b'value2', noreply=False)
145            client.get_many([b'key1', b'key3'])
146
147    def test_get_many_bad_server_data_ignore(self):
148        client = self.make_client(*[
149            [b'STORED\r\n', b'VAXLUE key3 0 6\r\nvalue2\r\nEND\r\n', ],
150            [b'STORED\r\n', b'VAXLUE key1 0 6\r\nvalue1\r\nEND\r\n', ],
151        ], ignore_exc=True)
152
153        def get_clients(key):
154            if key == b'key3':
155                return client.clients['127.0.0.1:11012']
156            else:
157                return client.clients['127.0.0.1:11013']
158
159        client._get_client = get_clients
160
161        client.set(b'key1', b'value1', noreply=False)
162        client.set(b'key3', b'value2', noreply=False)
163        result = client.get_many([b'key1', b'key3'])
164        assert result == {}
165
166    def test_gets_many(self):
167        client = self.make_client(*[
168            [b'STORED\r\n', b'VALUE key3 0 6 1\r\nvalue2\r\nEND\r\n', ],
169            [b'STORED\r\n', b'VALUE key1 0 6 1\r\nvalue1\r\nEND\r\n', ],
170        ])
171
172        def get_clients(key):
173            if key == b'key3':
174                return client.clients['127.0.0.1:11012']
175            else:
176                return client.clients['127.0.0.1:11013']
177
178        client._get_client = get_clients
179
180        assert client.set(b'key1', b'value1', noreply=False) is True
181        assert client.set(b'key3', b'value2', noreply=False) is True
182        result = client.gets_many([b'key1', b'key3'])
183        assert (result ==
184                {b'key1': (b'value1', b'1'), b'key3': (b'value2', b'1')})
185
186    def test_touch_not_found(self):
187        client = self.make_client([b'NOT_FOUND\r\n'])
188        result = client.touch(b'key', noreply=False)
189        assert result is False
190
191    def test_touch_no_expiry_found(self):
192        client = self.make_client([b'TOUCHED\r\n'])
193        result = client.touch(b'key', noreply=False)
194        assert result is True
195
196    def test_touch_with_expiry_found(self):
197        client = self.make_client([b'TOUCHED\r\n'])
198        result = client.touch(b'key', 1, noreply=False)
199        assert result is True
200
201    def test_close(self):
202        client = self.make_client([])
203        assert all(c.sock is not None for c in client.clients.values())
204        result = client.close()
205        assert result is None
206        assert all(c.sock is None for c in client.clients.values())
207
208    def test_quit(self):
209        client = self.make_client([])
210        assert all(c.sock is not None for c in client.clients.values())
211        result = client.quit()
212        assert result is None
213        assert all(c.sock is None for c in client.clients.values())
214
215    def test_no_servers_left(self):
216        from pymemcache.client.hash import HashClient
217        client = HashClient(
218            [], use_pooling=True,
219            ignore_exc=True,
220            timeout=1, connect_timeout=1
221        )
222
223        hashed_client = client._get_client('foo')
224        assert hashed_client is None
225
226    def test_no_servers_left_raise_exception(self):
227        from pymemcache.client.hash import HashClient
228        client = HashClient(
229            [], use_pooling=True,
230            ignore_exc=False,
231            timeout=1, connect_timeout=1
232        )
233
234        with pytest.raises(MemcacheError) as e:
235            client._get_client('foo')
236
237        assert str(e.value) == 'All servers seem to be down right now'
238
239    def test_unavailable_servers_zero_retry_raise_exception(self):
240        from pymemcache.client.hash import HashClient
241        client = HashClient(
242            [('example.com', 11211)], use_pooling=True,
243            ignore_exc=False,
244            retry_attempts=0, timeout=1, connect_timeout=1
245        )
246
247        with pytest.raises(socket.error):
248            client.get('foo')
249
250    def test_no_servers_left_with_commands_return_default_value(self):
251        from pymemcache.client.hash import HashClient
252        client = HashClient(
253            [], use_pooling=True,
254            ignore_exc=True,
255            timeout=1, connect_timeout=1
256        )
257
258        result = client.get('foo')
259        assert result is None
260        result = client.set('foo', 'bar')
261        assert result is False
262
263    def test_no_servers_left_with_set_many(self):
264        from pymemcache.client.hash import HashClient
265        client = HashClient(
266            [], use_pooling=True,
267            ignore_exc=True,
268            timeout=1, connect_timeout=1
269        )
270
271        result = client.set_many({'foo': 'bar'})
272        assert result == ['foo']
273
274    def test_no_servers_left_with_get_many(self):
275        from pymemcache.client.hash import HashClient
276        client = HashClient(
277            [], use_pooling=True,
278            ignore_exc=True,
279            timeout=1, connect_timeout=1
280        )
281
282        result = client.get_many(['foo', 'bar'])
283        assert result == {}
284
285    def test_ignore_exec_set_many(self):
286        values = {
287            'key1': 'value1',
288            'key2': 'value2',
289            'key3': 'value3'
290        }
291
292        with pytest.raises(MemcacheUnknownError):
293            client = self.make_client(*[
294                [b'STORED\r\n', b'UNKNOWN\r\n', b'STORED\r\n'],
295                [b'STORED\r\n', b'UNKNOWN\r\n', b'STORED\r\n'],
296            ])
297            client.set_many(values, noreply=False)
298
299        client = self.make_client(*[
300            [b'STORED\r\n', b'UNKNOWN\r\n', b'STORED\r\n'],
301        ], ignore_exc=True)
302        result = client.set_many(values, noreply=False)
303
304        assert len(result) == 0
305
306    def test_noreply_set_many(self):
307        values = {
308            'key1': 'value1',
309            'key2': 'value2',
310            'key3': 'value3'
311        }
312
313        client = self.make_client(*[
314            [b'STORED\r\n', b'NOT_STORED\r\n', b'STORED\r\n'],
315        ])
316        result = client.set_many(values, noreply=False)
317        assert len(result) == 1
318
319        client = self.make_client(*[
320            [b'STORED\r\n', b'NOT_STORED\r\n', b'STORED\r\n'],
321        ])
322        result = client.set_many(values, noreply=True)
323        assert result == []
324
325    def test_set_many_unix(self):
326        values = {
327            'key1': 'value1',
328            'key2': 'value2',
329            'key3': 'value3'
330        }
331
332        pid = os.getpid()
333        sockets = ['/tmp/pymemcache.%d' % pid]
334        client = self.make_unix_client(sockets, *[
335            [b'STORED\r\n', b'NOT_STORED\r\n', b'STORED\r\n'],
336        ])
337
338        result = client.set_many(values, noreply=False)
339        assert len(result) == 1
340
341    def test_server_encoding_pooled(self):
342        """
343        test passed encoding from hash client to pooled clients
344        """
345        encoding = 'utf8'
346        from pymemcache.client.hash import HashClient
347        hash_client = HashClient(
348            [('example.com', 11211)], use_pooling=True,
349            encoding=encoding
350        )
351
352        for client in hash_client.clients.values():
353            assert client.encoding == encoding
354
355    def test_server_encoding_client(self):
356        """
357        test passed encoding from hash client to clients
358        """
359        encoding = 'utf8'
360        from pymemcache.client.hash import HashClient
361        hash_client = HashClient(
362            [('example.com', 11211)], encoding=encoding
363        )
364
365        for client in hash_client.clients.values():
366            assert client.encoding == encoding
367
368    @mock.patch("pymemcache.client.hash.HashClient.client_class")
369    def test_dead_server_comes_back(self, client_patch):
370        client = HashClient([], dead_timeout=0, retry_attempts=0)
371        client.add_server(("127.0.0.1", 11211))
372
373        test_client = client_patch.return_value
374        test_client.server = ("127.0.0.1", 11211)
375
376        test_client.get.side_effect = socket.timeout()
377        with pytest.raises(socket.timeout):
378            client.get(b"key", noreply=False)
379        # Client gets removed because of socket timeout
380        assert ("127.0.0.1", 11211) in client._dead_clients
381
382        test_client.get.side_effect = lambda *_: "Some value"
383        # Client should be retried and brought back
384        assert client.get(b"key") == "Some value"
385        assert ("127.0.0.1", 11211) not in client._dead_clients
386
387    @mock.patch("pymemcache.client.hash.HashClient.client_class")
388    def test_failed_is_retried(self, client_patch):
389        client = HashClient([], retry_attempts=1, retry_timeout=0)
390        client.add_server(("127.0.0.1", 11211))
391
392        assert client_patch.call_count == 1
393
394        test_client = client_patch.return_value
395        test_client.server = ("127.0.0.1", 11211)
396
397        test_client.get.side_effect = socket.timeout()
398        with pytest.raises(socket.timeout):
399            client.get(b"key", noreply=False)
400
401        test_client.get.side_effect = lambda *_: "Some value"
402        assert client.get(b"key") == "Some value"
403
404        assert client_patch.call_count == 1
405
406    def test_custom_client(self):
407        class MyClient(Client):
408            pass
409
410        client = HashClient([])
411        client.client_class = MyClient
412        client.add_server(('host', 11211))
413        assert isinstance(client.clients['host:11211'], MyClient)
414
415    def test_custom_client_with_pooling(self):
416        class MyClient(Client):
417            pass
418
419        client = HashClient([], use_pooling=True)
420        client.client_class = MyClient
421        client.add_server(('host', 11211))
422        assert isinstance(client.clients['host:11211'], PooledClient)
423
424        pool = client.clients['host:11211'].client_pool
425        with pool.get_and_release(destroy_on_fail=True) as c:
426            assert isinstance(c, MyClient)
427
428    def test_mixed_inet_and_unix_sockets(self):
429        expected = {
430            '/tmp/pymemcache.{pid}'.format(pid=os.getpid()),
431            ('127.0.0.1', 11211),
432            ('::1', 11211),
433        }
434        client = HashClient([
435            '/tmp/pymemcache.{pid}'.format(pid=os.getpid()),
436            '127.0.0.1',
437            '127.0.0.1:11211',
438            '[::1]',
439            '[::1]:11211',
440            ('127.0.0.1', 11211),
441            ('::1', 11211),
442        ])
443        assert expected == {c.server for c in client.clients.values()}
444
445    def test_legacy_add_remove_server_signature(self):
446        server = ('127.0.0.1', 11211)
447        client = HashClient([])
448        assert client.clients == {}
449        client.add_server(*server)  # Unpack (host, port) tuple.
450        assert ('%s:%s' % server) in client.clients
451        client._mark_failed_server(server)
452        assert server in client._failed_clients
453        client.remove_server(*server)  # Unpack (host, port) tuple.
454        assert server in client._dead_clients
455        assert server not in client._failed_clients
456
457        # Ensure that server is a string if passing port argument:
458        with pytest.raises(TypeError):
459            client.add_server(server, server[-1])
460        with pytest.raises(TypeError):
461            client.remove_server(server, server[-1])
462
463    # TODO: Test failover logic
464