1###############################################################################
2#
3# The MIT License (MIT)
4#
5# Copyright (c) Crossbar.io Technologies GmbH
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23# THE SOFTWARE.
24#
25###############################################################################
26
27from __future__ import absolute_import, print_function
28
29import os
30import unittest
31from hashlib import sha1
32from base64 import b64encode
33
34from autobahn.websocket.protocol import WebSocketServerProtocol
35from autobahn.websocket.protocol import WebSocketServerFactory
36from autobahn.websocket.protocol import WebSocketClientProtocol
37from autobahn.websocket.protocol import WebSocketClientFactory
38from autobahn.websocket.protocol import WebSocketProtocol
39from autobahn.websocket.types import ConnectingRequest
40from autobahn.test import FakeTransport
41
42import txaio
43
44from mock import Mock
45
46
47class WebSocketClientProtocolTests(unittest.TestCase):
48
49    def setUp(self):
50        t = FakeTransport()
51        f = WebSocketClientFactory()
52        p = WebSocketClientProtocol()
53        p.factory = f
54        p.transport = t
55        p._create_transport_details = Mock()
56
57        p._connectionMade()
58        p.state = p.STATE_OPEN
59        p.websocket_version = 18
60
61        self.protocol = p
62        self.transport = t
63
64    def tearDown(self):
65        for call in [
66                self.protocol.autoPingPendingCall,
67                self.protocol.autoPingTimeoutCall,
68                self.protocol.openHandshakeTimeoutCall,
69                self.protocol.closeHandshakeTimeoutCall,
70        ]:
71            if call is not None:
72                call.cancel()
73
74    def test_auto_ping(self):
75        self.protocol.autoPingInterval = 1
76        self.protocol.websocket_protocols = [Mock()]
77        self.protocol.websocket_extensions = []
78        self.protocol._onOpen = lambda: None
79        self.protocol._wskey = '0' * 24
80        self.protocol.peer = Mock()
81
82        # usually provided by the Twisted or asyncio specific
83        # subclass, but we're testing the parent here...
84        self.protocol._onConnect = Mock()
85        self.protocol._closeConnection = Mock()
86
87        # set up a connection
88        self.protocol._actuallyStartHandshake(
89            ConnectingRequest(
90                host="example.com",
91                port=80,
92                resource="/ws",
93            )
94        )
95
96        key = self.protocol.websocket_key + WebSocketProtocol._WS_MAGIC
97        self.protocol.data = (
98            b"HTTP/1.1 101 Switching Protocols\x0d\x0a"
99            b"Upgrade: websocket\x0d\x0a"
100            b"Connection: upgrade\x0d\x0a"
101            b"Sec-Websocket-Accept: " + b64encode(sha1(key).digest()) + b"\x0d\x0a\x0d\x0a"
102        )
103        self.protocol.processHandshake()
104
105        self.assertTrue(self.protocol.autoPingPendingCall is not None)
106
107
108class WebSocketServerProtocolTests(unittest.TestCase):
109    """
110    Tests for autobahn.websocket.protocol.WebSocketProtocol.
111    """
112    def setUp(self):
113        t = FakeTransport()
114        f = WebSocketServerFactory()
115        p = WebSocketServerProtocol()
116        p.factory = f
117        p.transport = t
118
119        p._connectionMade()
120        p.state = p.STATE_OPEN
121        p.websocket_version = 18
122
123        self.protocol = p
124        self.transport = t
125
126    def tearDown(self):
127        for call in [
128                self.protocol.autoPingPendingCall,
129                self.protocol.autoPingTimeoutCall,
130                self.protocol.openHandshakeTimeoutCall,
131                self.protocol.closeHandshakeTimeoutCall,
132        ]:
133            if call is not None:
134                call.cancel()
135
136    def test_auto_ping(self):
137        proto = Mock()
138        proto._get_seconds = Mock(return_value=1)
139        self.protocol.autoPingInterval = 1
140        self.protocol.websocket_protocols = [proto]
141        self.protocol.websocket_extensions = []
142        self.protocol._onOpen = lambda: None
143        self.protocol._wskey = '0' * 24
144        self.protocol.succeedHandshake(proto)
145
146        self.assertTrue(self.protocol.autoPingPendingCall is not None)
147
148    def test_sendClose_none(self):
149        """
150        sendClose with no code or reason works.
151        """
152        self.protocol.sendClose()
153
154        # We closed properly
155        self.assertEqual(self.transport._written, b"\x88\x00")
156        self.assertEqual(self.protocol.state, self.protocol.STATE_CLOSING)
157
158    def test_sendClose_str_reason(self):
159        """
160        sendClose with a str reason works.
161        """
162        self.protocol.sendClose(code=1000, reason=u"oh no")
163
164        # We closed properly
165        self.assertEqual(self.transport._written[2:], b"\x03\xe8oh no")
166        self.assertEqual(self.protocol.state, self.protocol.STATE_CLOSING)
167
168    def test_sendClose_unicode_reason(self):
169        """
170        sendClose with a unicode reason works.
171        """
172        self.protocol.sendClose(code=1000, reason=u"oh no")
173
174        # We closed properly
175        self.assertEqual(self.transport._written[2:], b"\x03\xe8oh no")
176        self.assertEqual(self.protocol.state, self.protocol.STATE_CLOSING)
177
178    def test_sendClose_toolong(self):
179        """
180        sendClose with a too-long reason will truncate it.
181        """
182        self.protocol.sendClose(code=1000, reason=u"abc" * 1000)
183
184        # We closed properly
185        self.assertEqual(self.transport._written[2:],
186                         b"\x03\xe8" + (b"abc" * 41))
187        self.assertEqual(self.protocol.state, self.protocol.STATE_CLOSING)
188
189    def test_sendClose_reason_with_no_code(self):
190        """
191        Trying to sendClose with a reason but no code will raise an Exception.
192        """
193        with self.assertRaises(Exception) as e:
194            self.protocol.sendClose(reason=u"abc")
195
196        self.assertIn("close reason without close code", str(e.exception))
197
198        # We shouldn't have closed
199        self.assertEqual(self.transport._written, b"")
200        self.assertEqual(self.protocol.state, self.protocol.STATE_OPEN)
201
202    def test_sendClose_invalid_code_type(self):
203        """
204        Trying to sendClose with a non-int code will raise an Exception.
205        """
206        with self.assertRaises(Exception) as e:
207            self.protocol.sendClose(code="134")
208
209        self.assertIn("invalid type", str(e.exception))
210
211        # We shouldn't have closed
212        self.assertEqual(self.transport._written, b"")
213        self.assertEqual(self.protocol.state, self.protocol.STATE_OPEN)
214
215    def test_sendClose_invalid_code_value(self):
216        """
217        Trying to sendClose with a non-valid int code will raise an Exception.
218        """
219        with self.assertRaises(Exception) as e:
220            self.protocol.sendClose(code=10)
221
222        self.assertIn("invalid close code 10", str(e.exception))
223
224        # We shouldn't have closed
225        self.assertEqual(self.transport._written, b"")
226        self.assertEqual(self.protocol.state, self.protocol.STATE_OPEN)
227
228
229if os.environ.get('USE_TWISTED', False):
230    class TwistedProtocolTests(unittest.TestCase):
231        """
232        Tests which require a specific framework's protocol class to work
233        (in this case, using Twisted)
234        """
235        def setUp(self):
236            from autobahn.twisted.websocket import WebSocketServerProtocol
237            from autobahn.twisted.websocket import WebSocketServerFactory
238            t = FakeTransport()
239            f = WebSocketServerFactory()
240            p = WebSocketServerProtocol()
241            p.factory = f
242            p.transport = t
243
244            p._connectionMade()
245            p.state = p.STATE_OPEN
246            p.websocket_version = 18
247
248            self.protocol = p
249            self.transport = t
250
251        def tearDown(self):
252            for call in [
253                    self.protocol.autoPingPendingCall,
254                    self.protocol.autoPingTimeoutCall,
255                    self.protocol.openHandshakeTimeoutCall,
256                    self.protocol.closeHandshakeTimeoutCall,
257            ]:
258                if call is not None:
259                    call.cancel()
260
261        def test_loseConnection(self):
262            """
263            If we lose our connection before openHandshakeTimeout fires, it is
264            cleaned up
265            """
266            # so, I guess a little cheezy, but we depend on the asyncio or
267            # twisted class to call _connectionLost at some point; faking
268            # that here
269            self.protocol._connectionLost(txaio.create_failure(RuntimeError("testing")))
270            self.assertTrue(self.protocol.openHandshakeTimeoutCall is None)
271