1# Copyright 2005 Divmod, Inc. See LICENSE file for details 2 3import six 4 5from epsilon import juice 6from epsilon.test import iosim 7from twisted.trial import unittest 8from twisted.internet import protocol, defer 9 10class TestProto(protocol.Protocol): 11 def __init__(self, onConnLost, dataToSend): 12 self.onConnLost = onConnLost 13 self.dataToSend = dataToSend 14 15 def connectionMade(self): 16 self.data = [] 17 self.transport.write(self.dataToSend) 18 19 def dataReceived(self, bytes): 20 self.data.append(bytes) 21 self.transport.loseConnection() 22 23 def connectionLost(self, reason): 24 self.onConnLost.callback(self.data) 25 26class SimpleSymmetricProtocol(juice.Juice): 27 28 def sendHello(self, text): 29 return self.sendCommand("hello", 30 hello=text) 31 32 def sendGoodbye(self): 33 return self.sendCommand("goodbye") 34 35 def juice_HELLO(self, box): 36 return juice.Box(hello=box['hello']) 37 38 def juice_GOODBYE(self, box): 39 return juice.QuitBox(goodbye='world') 40 41class UnfriendlyGreeting(Exception): 42 """Greeting was insufficiently kind. 43 """ 44 45class UnknownProtocol(Exception): 46 """Asked to switch to the wrong protocol. 47 """ 48 49class Hello(juice.Command): 50 commandName = 'hello' 51 arguments = [('hello', juice.String())] 52 response = [('hello', juice.String())] 53 54 errors = {UnfriendlyGreeting: 'UNFRIENDLY'} 55 56class Goodbye(juice.Command): 57 commandName = 'goodbye' 58 responseType = juice.QuitBox 59 60class GetList(juice.Command): 61 commandName = 'getlist' 62 arguments = [('length', juice.Integer())] 63 response = [('body', juice.JuiceList([('x', juice.Integer())]))] 64 65class TestSwitchProto(juice.ProtocolSwitchCommand): 66 commandName = 'Switch-Proto' 67 68 arguments = [ 69 ('name', juice.String()), 70 ] 71 errors = {UnknownProtocol: 'UNKNOWN'} 72 73class SingleUseFactory(protocol.ClientFactory): 74 def __init__(self, proto): 75 self.proto = proto 76 77 def buildProtocol(self, addr): 78 p, self.proto = self.proto, None 79 return p 80 81class SimpleSymmetricCommandProtocol(juice.Juice): 82 maybeLater = None 83 def __init__(self, issueGreeting, onConnLost=None): 84 juice.Juice.__init__(self, issueGreeting) 85 self.onConnLost = onConnLost 86 87 def sendHello(self, text): 88 return Hello(hello=text).do(self) 89 def sendGoodbye(self): 90 return Goodbye().do(self) 91 def command_HELLO(self, hello): 92 if hello.startswith('fuck'): 93 raise UnfriendlyGreeting("Don't be a dick.") 94 return dict(hello=hello) 95 def command_GETLIST(self, length): 96 return {'body': [dict(x=1)] * length} 97 def command_GOODBYE(self): 98 return dict(goodbye='world') 99 command_HELLO.command = Hello 100 command_GOODBYE.command = Goodbye 101 command_GETLIST.command = GetList 102 103 def switchToTestProtocol(self): 104 p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA) 105 return TestSwitchProto(SingleUseFactory(p), name='test-proto').do(self).addCallback(lambda ign: p) 106 107 def command_SWITCH_PROTO(self, name): 108 if name == 'test-proto': 109 return TestProto(self.onConnLost, SWITCH_SERVER_DATA) 110 raise UnknownProtocol(name) 111 112 command_SWITCH_PROTO.command = TestSwitchProto 113 114class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol): 115 def command_SWITCH_PROTO(self, name): 116 if name == 'test-proto': 117 self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA) 118 self.maybeLater = defer.Deferred() 119 return self.maybeLater 120 raise UnknownProtocol(name) 121 122 command_SWITCH_PROTO.command = TestSwitchProto 123 124 125class SSPF: protocol = SimpleSymmetricProtocol 126class SSSF(SSPF, protocol.ServerFactory): pass 127class SSCF(SSPF, protocol.ClientFactory): pass 128 129def connectedServerAndClient(ServerClass=lambda: SimpleSymmetricProtocol(True), 130 ClientClass=lambda: SimpleSymmetricProtocol(False), 131 *a, **kw): 132 """Returns a 3-tuple: (client, server, pump) 133 """ 134 return iosim.connectedServerAndClient( 135 ServerClass, ClientClass, 136 *a, **kw) 137 138class TotallyDumbProtocol(protocol.Protocol): 139 buf = b'' 140 def dataReceived(self, data): 141 self.buf += data 142 143class LiteralJuice(juice.Juice): 144 def __init__(self, issueGreeting): 145 juice.Juice.__init__(self, issueGreeting) 146 self.boxes = [] 147 148 def juiceBoxReceived(self, box): 149 self.boxes.append(box) 150 return 151 152class LiteralParsingTest(unittest.TestCase): 153 def testBasicRequestResponse(self): 154 c, s, p = connectedServerAndClient(ClientClass=TotallyDumbProtocol) 155 HELLO = b'abcdefg' 156 ASKTOK = b'hand-crafted-ask' 157 c.transport.write((b"""-Command: HeLlO 158-Ask: %s 159Hello: %s 160World: this header is ignored 161 162""" % (ASKTOK, HELLO,)).replace(b'\n',b'\r\n')) 163 p.flush() 164 asserts = {'hello': HELLO, 165 '-answer': ASKTOK} 166 hdrs = [j.split(b': ') for j in c.buf.split(b'\r\n')[:-2]] 167 self.assertEquals(len(asserts), len(hdrs)) 168 for hdr in hdrs: 169 k, v = hdr 170 self.assertEquals(v, asserts[six.ensure_str(k).lower()]) 171 172 def testParsingRoundTrip(self): 173 c, s, p = connectedServerAndClient(ClientClass=lambda: LiteralJuice(False), 174 ServerClass=lambda: LiteralJuice(True)) 175 176 SIMPLE = ('simple', 'test') 177 CE = ('ceq', ': ') 178 CR = ('crtest', 'test\r') 179 LF = ('lftest', 'hello\n') 180 NEWLINE = ('newline', 'test\r\none\r\ntwo') 181 NEWLINE2 = ('newline2', 'test\r\none\r\n two') 182 BLANKLINE = ('newline3', 'test\r\n\r\nblank\r\n\r\nline') 183 BODYTEST = (juice.BODY, 'blah\r\n\r\ntesttest') 184 185 def onetest(test): 186 jb = juice.Box() 187 jb.update(dict(test)) 188 jb.sendTo(c) 189 p.flush() 190 self.assertEquals(s.boxes[-1], jb) 191 192 onetest([SIMPLE]) 193 onetest([SIMPLE, BODYTEST]) 194 onetest([SIMPLE, CE]) 195 onetest([SIMPLE, CR]) 196 onetest([SIMPLE, CE, CR, LF]) 197 onetest([CE, CR, LF]) 198 onetest([SIMPLE, NEWLINE, CE, NEWLINE2]) 199 onetest([BODYTEST, SIMPLE, NEWLINE]) 200 201 202SWITCH_CLIENT_DATA = b'Success!' 203SWITCH_SERVER_DATA = b'No, really. Success.' 204 205class AppLevelTest(unittest.TestCase): 206 def testHelloWorld(self): 207 c, s, p = connectedServerAndClient() 208 L = [] 209 HELLO = 'world' 210 c.sendHello(HELLO).addCallback(L.append) 211 p.flush() 212 self.assertEquals(L[0]['hello'], HELLO) 213 214 def testHelloWorldCommand(self): 215 c, s, p = connectedServerAndClient( 216 ServerClass=lambda: SimpleSymmetricCommandProtocol(True), 217 ClientClass=lambda: SimpleSymmetricCommandProtocol(False)) 218 L = [] 219 HELLO = 'world' 220 c.sendHello(HELLO).addCallback(L.append) 221 p.flush() 222 self.assertEquals(L[0]['hello'], HELLO) 223 224 def testHelloErrorHandling(self): 225 L=[] 226 c, s, p = connectedServerAndClient(ServerClass=lambda: SimpleSymmetricCommandProtocol(True), 227 ClientClass=lambda: SimpleSymmetricCommandProtocol(False)) 228 HELLO = 'fuck you' 229 c.sendHello(HELLO).addErrback(L.append) 230 p.flush() 231 L[0].trap(UnfriendlyGreeting) 232 self.assertEquals(str(L[0].value), "Don't be a dick.") 233 234 def testJuiceListCommand(self): 235 c, s, p = connectedServerAndClient(ServerClass=lambda: SimpleSymmetricCommandProtocol(True), 236 ClientClass=lambda: SimpleSymmetricCommandProtocol(False)) 237 L = [] 238 GetList(length=10).do(c).addCallback(L.append) 239 p.flush() 240 values = L.pop().get('body') 241 self.assertEquals(values, [{'x': 1}] * 10) 242 243 def testFailEarlyOnArgSending(self): 244 okayCommand = Hello(Hello="What?") 245 self.assertRaises(RuntimeError, Hello) 246 247 def testSupportsVersion1(self): 248 c, s, p = connectedServerAndClient(ServerClass=lambda: juice.Juice(True), 249 ClientClass=lambda: juice.Juice(False)) 250 negotiatedVersion = [] 251 s.renegotiateVersion(1).addCallback(negotiatedVersion.append) 252 p.flush() 253 self.assertEquals(negotiatedVersion[0], 1) 254 self.assertEquals(c.protocolVersion, 1) 255 self.assertEquals(s.protocolVersion, 1) 256 257 def testProtocolSwitch(self, switcher=SimpleSymmetricCommandProtocol): 258 self.testSucceeded = False 259 260 serverDeferred = defer.Deferred() 261 serverProto = switcher(True, serverDeferred) 262 clientDeferred = defer.Deferred() 263 clientProto = switcher(False, clientDeferred) 264 c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto, 265 ClientClass=lambda: clientProto) 266 267 switchDeferred = c.switchToTestProtocol() 268 269 def cbConnsLost(results): 270 (serverSuccess, serverData), (clientSuccess, clientData) = results 271 self.failUnless(serverSuccess) 272 self.failUnless(clientSuccess) 273 self.assertEquals(b''.join(serverData), SWITCH_CLIENT_DATA) 274 self.assertEquals(b''.join(clientData), SWITCH_SERVER_DATA) 275 self.testSucceeded = True 276 277 def cbSwitch(proto): 278 return defer.DeferredList([serverDeferred, clientDeferred]).addCallback(cbConnsLost) 279 280 switchDeferred.addCallback(cbSwitch) 281 p.flush() 282 if serverProto.maybeLater is not None: 283 serverProto.maybeLater.callback(serverProto.maybeLaterProto) 284 p.flush() 285 self.failUnless(self.testSucceeded) 286 287 def testProtocolSwitchDeferred(self): 288 return self.testProtocolSwitch(switcher=DeferredSymmetricCommandProtocol) 289