1############################################################################### 2# 3# The MIT License (MIT) 4# 5# Copyright (c) Crossbar.io Technologies GmbH 6# 7# Permission is hereby granted, free of charge, to any person obtaining a copy 8# of this software and associated documentation files (the "Software"), to deal 9# in the Software without restriction, including without limitation the rights 10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11# copies of the Software, and to permit persons to whom the Software is 12# furnished to do so, subject to the following conditions: 13# 14# The above copyright notice and this permission notice shall be included in 15# all copies or substantial portions of the Software. 16# 17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23# THE SOFTWARE. 24# 25############################################################################### 26 27from __future__ import absolute_import, print_function 28 29import os 30import unittest 31from hashlib import sha1 32from base64 import b64encode 33 34from autobahn.websocket.protocol import WebSocketServerProtocol 35from autobahn.websocket.protocol import WebSocketServerFactory 36from autobahn.websocket.protocol import WebSocketClientProtocol 37from autobahn.websocket.protocol import WebSocketClientFactory 38from autobahn.websocket.protocol import WebSocketProtocol 39from autobahn.websocket.types import ConnectingRequest 40from autobahn.test import FakeTransport 41 42import txaio 43 44from mock import Mock 45 46 47class WebSocketClientProtocolTests(unittest.TestCase): 48 49 def setUp(self): 50 t = FakeTransport() 51 f = WebSocketClientFactory() 52 p = WebSocketClientProtocol() 53 p.factory = f 54 p.transport = t 55 p._create_transport_details = Mock() 56 57 p._connectionMade() 58 p.state = p.STATE_OPEN 59 p.websocket_version = 18 60 61 self.protocol = p 62 self.transport = t 63 64 def tearDown(self): 65 for call in [ 66 self.protocol.autoPingPendingCall, 67 self.protocol.autoPingTimeoutCall, 68 self.protocol.openHandshakeTimeoutCall, 69 self.protocol.closeHandshakeTimeoutCall, 70 ]: 71 if call is not None: 72 call.cancel() 73 74 def test_auto_ping(self): 75 self.protocol.autoPingInterval = 1 76 self.protocol.websocket_protocols = [Mock()] 77 self.protocol.websocket_extensions = [] 78 self.protocol._onOpen = lambda: None 79 self.protocol._wskey = '0' * 24 80 self.protocol.peer = Mock() 81 82 # usually provided by the Twisted or asyncio specific 83 # subclass, but we're testing the parent here... 84 self.protocol._onConnect = Mock() 85 self.protocol._closeConnection = Mock() 86 87 # set up a connection 88 self.protocol._actuallyStartHandshake( 89 ConnectingRequest( 90 host="example.com", 91 port=80, 92 resource="/ws", 93 ) 94 ) 95 96 key = self.protocol.websocket_key + WebSocketProtocol._WS_MAGIC 97 self.protocol.data = ( 98 b"HTTP/1.1 101 Switching Protocols\x0d\x0a" 99 b"Upgrade: websocket\x0d\x0a" 100 b"Connection: upgrade\x0d\x0a" 101 b"Sec-Websocket-Accept: " + b64encode(sha1(key).digest()) + b"\x0d\x0a\x0d\x0a" 102 ) 103 self.protocol.processHandshake() 104 105 self.assertTrue(self.protocol.autoPingPendingCall is not None) 106 107 108class WebSocketServerProtocolTests(unittest.TestCase): 109 """ 110 Tests for autobahn.websocket.protocol.WebSocketProtocol. 111 """ 112 def setUp(self): 113 t = FakeTransport() 114 f = WebSocketServerFactory() 115 p = WebSocketServerProtocol() 116 p.factory = f 117 p.transport = t 118 119 p._connectionMade() 120 p.state = p.STATE_OPEN 121 p.websocket_version = 18 122 123 self.protocol = p 124 self.transport = t 125 126 def tearDown(self): 127 for call in [ 128 self.protocol.autoPingPendingCall, 129 self.protocol.autoPingTimeoutCall, 130 self.protocol.openHandshakeTimeoutCall, 131 self.protocol.closeHandshakeTimeoutCall, 132 ]: 133 if call is not None: 134 call.cancel() 135 136 def test_auto_ping(self): 137 proto = Mock() 138 proto._get_seconds = Mock(return_value=1) 139 self.protocol.autoPingInterval = 1 140 self.protocol.websocket_protocols = [proto] 141 self.protocol.websocket_extensions = [] 142 self.protocol._onOpen = lambda: None 143 self.protocol._wskey = '0' * 24 144 self.protocol.succeedHandshake(proto) 145 146 self.assertTrue(self.protocol.autoPingPendingCall is not None) 147 148 def test_sendClose_none(self): 149 """ 150 sendClose with no code or reason works. 151 """ 152 self.protocol.sendClose() 153 154 # We closed properly 155 self.assertEqual(self.transport._written, b"\x88\x00") 156 self.assertEqual(self.protocol.state, self.protocol.STATE_CLOSING) 157 158 def test_sendClose_str_reason(self): 159 """ 160 sendClose with a str reason works. 161 """ 162 self.protocol.sendClose(code=1000, reason=u"oh no") 163 164 # We closed properly 165 self.assertEqual(self.transport._written[2:], b"\x03\xe8oh no") 166 self.assertEqual(self.protocol.state, self.protocol.STATE_CLOSING) 167 168 def test_sendClose_unicode_reason(self): 169 """ 170 sendClose with a unicode reason works. 171 """ 172 self.protocol.sendClose(code=1000, reason=u"oh no") 173 174 # We closed properly 175 self.assertEqual(self.transport._written[2:], b"\x03\xe8oh no") 176 self.assertEqual(self.protocol.state, self.protocol.STATE_CLOSING) 177 178 def test_sendClose_toolong(self): 179 """ 180 sendClose with a too-long reason will truncate it. 181 """ 182 self.protocol.sendClose(code=1000, reason=u"abc" * 1000) 183 184 # We closed properly 185 self.assertEqual(self.transport._written[2:], 186 b"\x03\xe8" + (b"abc" * 41)) 187 self.assertEqual(self.protocol.state, self.protocol.STATE_CLOSING) 188 189 def test_sendClose_reason_with_no_code(self): 190 """ 191 Trying to sendClose with a reason but no code will raise an Exception. 192 """ 193 with self.assertRaises(Exception) as e: 194 self.protocol.sendClose(reason=u"abc") 195 196 self.assertIn("close reason without close code", str(e.exception)) 197 198 # We shouldn't have closed 199 self.assertEqual(self.transport._written, b"") 200 self.assertEqual(self.protocol.state, self.protocol.STATE_OPEN) 201 202 def test_sendClose_invalid_code_type(self): 203 """ 204 Trying to sendClose with a non-int code will raise an Exception. 205 """ 206 with self.assertRaises(Exception) as e: 207 self.protocol.sendClose(code="134") 208 209 self.assertIn("invalid type", str(e.exception)) 210 211 # We shouldn't have closed 212 self.assertEqual(self.transport._written, b"") 213 self.assertEqual(self.protocol.state, self.protocol.STATE_OPEN) 214 215 def test_sendClose_invalid_code_value(self): 216 """ 217 Trying to sendClose with a non-valid int code will raise an Exception. 218 """ 219 with self.assertRaises(Exception) as e: 220 self.protocol.sendClose(code=10) 221 222 self.assertIn("invalid close code 10", str(e.exception)) 223 224 # We shouldn't have closed 225 self.assertEqual(self.transport._written, b"") 226 self.assertEqual(self.protocol.state, self.protocol.STATE_OPEN) 227 228 229if os.environ.get('USE_TWISTED', False): 230 class TwistedProtocolTests(unittest.TestCase): 231 """ 232 Tests which require a specific framework's protocol class to work 233 (in this case, using Twisted) 234 """ 235 def setUp(self): 236 from autobahn.twisted.websocket import WebSocketServerProtocol 237 from autobahn.twisted.websocket import WebSocketServerFactory 238 t = FakeTransport() 239 f = WebSocketServerFactory() 240 p = WebSocketServerProtocol() 241 p.factory = f 242 p.transport = t 243 244 p._connectionMade() 245 p.state = p.STATE_OPEN 246 p.websocket_version = 18 247 248 self.protocol = p 249 self.transport = t 250 251 def tearDown(self): 252 for call in [ 253 self.protocol.autoPingPendingCall, 254 self.protocol.autoPingTimeoutCall, 255 self.protocol.openHandshakeTimeoutCall, 256 self.protocol.closeHandshakeTimeoutCall, 257 ]: 258 if call is not None: 259 call.cancel() 260 261 def test_loseConnection(self): 262 """ 263 If we lose our connection before openHandshakeTimeout fires, it is 264 cleaned up 265 """ 266 # so, I guess a little cheezy, but we depend on the asyncio or 267 # twisted class to call _connectionLost at some point; faking 268 # that here 269 self.protocol._connectionLost(txaio.create_failure(RuntimeError("testing"))) 270 self.assertTrue(self.protocol.openHandshakeTimeoutCall is None) 271