1# coding: utf-8
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import six
16
17from twisted.internet import defer, reactor
18from twisted.trial import unittest
19
20import txredisapi as redis
21
22from tests.mixins import REDIS_HOST, REDIS_PORT
23
24
25class TestSubscriberProtocol(unittest.TestCase):
26    @defer.inlineCallbacks
27    def setUp(self):
28        factory = redis.SubscriberFactory()
29        factory.continueTrying = False
30        reactor.connectTCP(REDIS_HOST, REDIS_PORT, factory)
31        self.db = yield factory.deferred
32
33    @defer.inlineCallbacks
34    def tearDown(self):
35        yield self.db.disconnect()
36
37    @defer.inlineCallbacks
38    def testDisconnectErrors(self):
39        # Slightly dirty, but we want a reference to the actual
40        # protocol instance
41        conn = yield self.db._factory.getConnection(True)
42
43        # This should return a deferred from the replyQueue; then
44        # loseConnection will make it do an errback with a
45        # ConnectionError instance
46        d = self.db.subscribe('foo')
47
48        conn.transport.loseConnection()
49        try:
50            yield d
51            self.fail()
52        except redis.ConnectionError:
53            pass
54
55        # This should immediately errback with a ConnectionError
56        # instance when getConnection finds 0 active instances in the
57        # factory
58        try:
59            yield self.db.subscribe('bar')
60            self.fail()
61        except redis.ConnectionError:
62            pass
63
64        # This should immediately raise a ConnectionError instance
65        # when execute_command() finds that the connection is not
66        # connected
67        try:
68            yield conn.subscribe('baz')
69            self.fail()
70        except redis.ConnectionError:
71            pass
72
73    @defer.inlineCallbacks
74    def testSubscribe(self):
75        reply = yield self.db.subscribe("test_subscribe1")
76        self.assertEqual(reply, ["subscribe", "test_subscribe1", 1])
77
78        reply = yield self.db.subscribe("test_subscribe2")
79        self.assertEqual(reply, ["subscribe",
80                                 "test_subscribe2", 2])
81
82    @defer.inlineCallbacks
83    def testUnsubscribe(self):
84        yield self.db.subscribe("test_unsubscribe1")
85        yield self.db.subscribe("test_unsubscribe2")
86
87        reply = yield self.db.unsubscribe("test_unsubscribe1")
88        self.assertEqual(reply, ["unsubscribe",
89                                 "test_unsubscribe1", 1])
90        reply = yield self.db.unsubscribe("test_unsubscribe2")
91        self.assertEqual(reply, ["unsubscribe",
92                                 "test_unsubscribe2", 0])
93
94    @defer.inlineCallbacks
95    def testPSubscribe(self):
96        reply = yield self.db.psubscribe("test_psubscribe1.*")
97        self.assertEqual(reply, ["psubscribe",
98                                 "test_psubscribe1.*", 1])
99
100        reply = yield self.db.psubscribe("test_psubscribe2.*")
101        self.assertEqual(reply, ["psubscribe",
102                                 "test_psubscribe2.*", 2])
103
104    @defer.inlineCallbacks
105    def testPUnsubscribe(self):
106        yield self.db.psubscribe("test_punsubscribe1.*")
107        yield self.db.psubscribe("test_punsubscribe2.*")
108
109        reply = yield self.db.punsubscribe("test_punsubscribe1.*")
110        self.assertEqual(reply, ["punsubscribe",
111                                 "test_punsubscribe1.*", 1])
112        reply = yield self.db.punsubscribe("test_punsubscribe2.*")
113        self.assertEqual(reply, ["punsubscribe",
114                                 "test_punsubscribe2.*", 0])
115
116
117class TestAuthenticatedSubscriberProtocol(unittest.TestCase):
118    timeout = 5
119
120    @defer.inlineCallbacks
121    def setUp(self):
122        meta = yield redis.Connection(REDIS_HOST, REDIS_PORT)
123        yield meta.execute_command("config", "set", "requirepass", "password")
124        yield meta.disconnect()
125        self.addCleanup(self.removePassword)
126
127        factory = redis.RedisFactory(None, dbid=0, poolsize=1,
128                                     password="password")
129        factory.protocol = redis.SubscriberProtocol
130        factory.continueTrying = False
131        reactor.connectTCP(REDIS_HOST, REDIS_PORT, factory)
132        self.db = yield factory.deferred
133
134    @defer.inlineCallbacks
135    def removePassword(self):
136        meta = yield redis.Connection(REDIS_HOST, REDIS_PORT,
137                                      password="password")
138        yield meta.execute_command("config", "set", "requirepass", "")
139        yield meta.disconnect()
140
141    @defer.inlineCallbacks
142    def tearDown(self):
143        yield self.db.disconnect()
144
145    @defer.inlineCallbacks
146    def testSubscribe(self):
147        reply = yield self.db.subscribe("test_subscribe1")
148        self.assertEqual(reply, [u"subscribe", u"test_subscribe1", 1])
149
150        reply = yield self.db.subscribe("test_subscribe2")
151        self.assertEqual(reply, [u"subscribe", u"test_subscribe2", 2])
152