1# Copyright (c) Twisted Matrix Laboratories.
2# See LICENSE for details.
3
4"""
5Tests for error handling in PB.
6"""
7
8from io import StringIO
9
10from twisted.internet import defer, reactor
11from twisted.python import log
12from twisted.python.reflect import qual
13from twisted.spread import flavors, jelly, pb
14from twisted.trial import unittest
15
16# Test exceptions
17
18
19class AsynchronousException(Exception):
20    """
21    Helper used to test remote methods which return Deferreds which fail with
22    exceptions which are not L{pb.Error} subclasses.
23    """
24
25
26class SynchronousException(Exception):
27    """
28    Helper used to test remote methods which raise exceptions which are not
29    L{pb.Error} subclasses.
30    """
31
32
33class AsynchronousError(pb.Error):
34    """
35    Helper used to test remote methods which return Deferreds which fail with
36    exceptions which are L{pb.Error} subclasses.
37    """
38
39
40class SynchronousError(pb.Error):
41    """
42    Helper used to test remote methods which raise exceptions which are
43    L{pb.Error} subclasses.
44    """
45
46
47class JellyError(flavors.Jellyable, pb.Error, pb.RemoteCopy):
48    pass
49
50
51class SecurityError(pb.Error, pb.RemoteCopy):
52    pass
53
54
55pb.setUnjellyableForClass(JellyError, JellyError)
56pb.setUnjellyableForClass(SecurityError, SecurityError)
57pb.globalSecurity.allowInstancesOf(SecurityError)
58
59
60# Server-side
61class SimpleRoot(pb.Root):
62    def remote_asynchronousException(self):
63        """
64        Fail asynchronously with a non-pb.Error exception.
65        """
66        return defer.fail(AsynchronousException("remote asynchronous exception"))
67
68    def remote_synchronousException(self):
69        """
70        Fail synchronously with a non-pb.Error exception.
71        """
72        raise SynchronousException("remote synchronous exception")
73
74    def remote_asynchronousError(self):
75        """
76        Fail asynchronously with a pb.Error exception.
77        """
78        return defer.fail(AsynchronousError("remote asynchronous error"))
79
80    def remote_synchronousError(self):
81        """
82        Fail synchronously with a pb.Error exception.
83        """
84        raise SynchronousError("remote synchronous error")
85
86    def remote_unknownError(self):
87        """
88        Fail with error that is not known to client.
89        """
90
91        class UnknownError(pb.Error):
92            pass
93
94        raise UnknownError("I'm not known to client!")
95
96    def remote_jelly(self):
97        self.raiseJelly()
98
99    def remote_security(self):
100        self.raiseSecurity()
101
102    def remote_deferredJelly(self):
103        d = defer.Deferred()
104        d.addCallback(self.raiseJelly)
105        d.callback(None)
106        return d
107
108    def remote_deferredSecurity(self):
109        d = defer.Deferred()
110        d.addCallback(self.raiseSecurity)
111        d.callback(None)
112        return d
113
114    def raiseJelly(self, results=None):
115        raise JellyError("I'm jellyable!")
116
117    def raiseSecurity(self, results=None):
118        raise SecurityError("I'm secure!")
119
120
121class SaveProtocolServerFactory(pb.PBServerFactory):
122    """
123    A L{pb.PBServerFactory} that saves the latest connected client in
124    C{protocolInstance}.
125    """
126
127    protocolInstance = None
128
129    def clientConnectionMade(self, protocol):
130        """
131        Keep track of the given protocol.
132        """
133        self.protocolInstance = protocol
134
135
136class PBConnTestCase(unittest.TestCase):
137    unsafeTracebacks = 0
138
139    def setUp(self):
140        self._setUpServer()
141        self._setUpClient()
142
143    def _setUpServer(self):
144        self.serverFactory = SaveProtocolServerFactory(SimpleRoot())
145        self.serverFactory.unsafeTracebacks = self.unsafeTracebacks
146        self.serverPort = reactor.listenTCP(
147            0, self.serverFactory, interface="127.0.0.1"
148        )
149
150    def _setUpClient(self):
151        portNo = self.serverPort.getHost().port
152        self.clientFactory = pb.PBClientFactory()
153        self.clientConnector = reactor.connectTCP(
154            "127.0.0.1", portNo, self.clientFactory
155        )
156
157    def tearDown(self):
158        if self.serverFactory.protocolInstance is not None:
159            self.serverFactory.protocolInstance.transport.loseConnection()
160        return defer.gatherResults([self._tearDownServer(), self._tearDownClient()])
161
162    def _tearDownServer(self):
163        return defer.maybeDeferred(self.serverPort.stopListening)
164
165    def _tearDownClient(self):
166        self.clientConnector.disconnect()
167        return defer.succeed(None)
168
169
170class PBFailureTests(PBConnTestCase):
171    compare = unittest.TestCase.assertEqual
172
173    def _exceptionTest(self, method, exceptionType, flush):
174        def eb(err):
175            err.trap(exceptionType)
176            self.compare(err.traceback, "Traceback unavailable\n")
177            if flush:
178                errs = self.flushLoggedErrors(exceptionType)
179                self.assertEqual(len(errs), 1)
180            return (err.type, err.value, err.traceback)
181
182        d = self.clientFactory.getRootObject()
183
184        def gotRootObject(root):
185            d = root.callRemote(method)
186            d.addErrback(eb)
187            return d
188
189        d.addCallback(gotRootObject)
190        return d
191
192    def test_asynchronousException(self):
193        """
194        Test that a Deferred returned by a remote method which already has a
195        Failure correctly has that error passed back to the calling side.
196        """
197        return self._exceptionTest("asynchronousException", AsynchronousException, True)
198
199    def test_synchronousException(self):
200        """
201        Like L{test_asynchronousException}, but for a method which raises an
202        exception synchronously.
203        """
204        return self._exceptionTest("synchronousException", SynchronousException, True)
205
206    def test_asynchronousError(self):
207        """
208        Like L{test_asynchronousException}, but for a method which returns a
209        Deferred failing with an L{pb.Error} subclass.
210        """
211        return self._exceptionTest("asynchronousError", AsynchronousError, False)
212
213    def test_synchronousError(self):
214        """
215        Like L{test_asynchronousError}, but for a method which synchronously
216        raises a L{pb.Error} subclass.
217        """
218        return self._exceptionTest("synchronousError", SynchronousError, False)
219
220    def _success(self, result, expectedResult):
221        self.assertEqual(result, expectedResult)
222        return result
223
224    def _addFailingCallbacks(self, remoteCall, expectedResult, eb):
225        remoteCall.addCallbacks(self._success, eb, callbackArgs=(expectedResult,))
226        return remoteCall
227
228    def _testImpl(self, method, expected, eb, exc=None):
229        """
230        Call the given remote method and attach the given errback to the
231        resulting Deferred.  If C{exc} is not None, also assert that one
232        exception of that type was logged.
233        """
234        rootDeferred = self.clientFactory.getRootObject()
235
236        def gotRootObj(obj):
237            failureDeferred = self._addFailingCallbacks(
238                obj.callRemote(method), expected, eb
239            )
240            if exc is not None:
241
242                def gotFailure(err):
243                    self.assertEqual(len(self.flushLoggedErrors(exc)), 1)
244                    return err
245
246                failureDeferred.addBoth(gotFailure)
247            return failureDeferred
248
249        rootDeferred.addCallback(gotRootObj)
250        return rootDeferred
251
252    def test_jellyFailure(self):
253        """
254        Test that an exception which is a subclass of L{pb.Error} has more
255        information passed across the network to the calling side.
256        """
257
258        def failureJelly(fail):
259            fail.trap(JellyError)
260            self.assertNotIsInstance(fail.type, str)
261            self.assertIsInstance(fail.value, fail.type)
262            return 43
263
264        return self._testImpl("jelly", 43, failureJelly)
265
266    def test_deferredJellyFailure(self):
267        """
268        Test that a Deferred which fails with a L{pb.Error} is treated in
269        the same way as a synchronously raised L{pb.Error}.
270        """
271
272        def failureDeferredJelly(fail):
273            fail.trap(JellyError)
274            self.assertNotIsInstance(fail.type, str)
275            self.assertIsInstance(fail.value, fail.type)
276            return 430
277
278        return self._testImpl("deferredJelly", 430, failureDeferredJelly)
279
280    def test_unjellyableFailure(self):
281        """
282        A non-jellyable L{pb.Error} subclass raised by a remote method is
283        turned into a Failure with a type set to the FQPN of the exception
284        type.
285        """
286
287        def failureUnjellyable(fail):
288            self.assertEqual(
289                fail.type, b"twisted.spread.test.test_pbfailure.SynchronousError"
290            )
291            return 431
292
293        return self._testImpl("synchronousError", 431, failureUnjellyable)
294
295    def test_unknownFailure(self):
296        """
297        Test that an exception which is a subclass of L{pb.Error} but not
298        known on the client side has its type set properly.
299        """
300
301        def failureUnknown(fail):
302            self.assertEqual(
303                fail.type, b"twisted.spread.test.test_pbfailure.UnknownError"
304            )
305            return 4310
306
307        return self._testImpl("unknownError", 4310, failureUnknown)
308
309    def test_securityFailure(self):
310        """
311        Test that even if an exception is not explicitly jellyable (by being
312        a L{pb.Jellyable} subclass), as long as it is an L{pb.Error}
313        subclass it receives the same special treatment.
314        """
315
316        def failureSecurity(fail):
317            fail.trap(SecurityError)
318            self.assertNotIsInstance(fail.type, str)
319            self.assertIsInstance(fail.value, fail.type)
320            return 4300
321
322        return self._testImpl("security", 4300, failureSecurity)
323
324    def test_deferredSecurity(self):
325        """
326        Test that a Deferred which fails with a L{pb.Error} which is not
327        also a L{pb.Jellyable} is treated in the same way as a synchronously
328        raised exception of the same type.
329        """
330
331        def failureDeferredSecurity(fail):
332            fail.trap(SecurityError)
333            self.assertNotIsInstance(fail.type, str)
334            self.assertIsInstance(fail.value, fail.type)
335            return 43000
336
337        return self._testImpl("deferredSecurity", 43000, failureDeferredSecurity)
338
339    def test_noSuchMethodFailure(self):
340        """
341        Test that attempting to call a method which is not defined correctly
342        results in an AttributeError on the calling side.
343        """
344
345        def failureNoSuch(fail):
346            fail.trap(pb.NoSuchMethod)
347            self.compare(fail.traceback, "Traceback unavailable\n")
348            return 42000
349
350        return self._testImpl("nosuch", 42000, failureNoSuch, AttributeError)
351
352    def test_copiedFailureLogging(self):
353        """
354        Test that a copied failure received from a PB call can be logged
355        locally.
356
357        Note: this test needs some serious help: all it really tests is that
358        log.err(copiedFailure) doesn't raise an exception.
359        """
360        d = self.clientFactory.getRootObject()
361
362        def connected(rootObj):
363            return rootObj.callRemote("synchronousException")
364
365        d.addCallback(connected)
366
367        def exception(failure):
368            log.err(failure)
369            errs = self.flushLoggedErrors(SynchronousException)
370            self.assertEqual(len(errs), 2)
371
372        d.addErrback(exception)
373
374        return d
375
376    def test_throwExceptionIntoGenerator(self):
377        """
378        L{pb.CopiedFailure.throwExceptionIntoGenerator} will throw a
379        L{RemoteError} into the given paused generator at the point where it
380        last yielded.
381        """
382        original = pb.CopyableFailure(AttributeError("foo"))
383        copy = jelly.unjelly(jelly.jelly(original, invoker=DummyInvoker()))
384        exception = []
385
386        def generatorFunc():
387            try:
388                yield None
389            except pb.RemoteError as exc:
390                exception.append(exc)
391            else:
392                self.fail("RemoteError not raised")
393
394        gen = generatorFunc()
395        gen.send(None)
396        self.assertRaises(StopIteration, copy.throwExceptionIntoGenerator, gen)
397        self.assertEqual(len(exception), 1)
398        exc = exception[0]
399        self.assertEqual(exc.remoteType, qual(AttributeError).encode("ascii"))
400        self.assertEqual(exc.args, ("foo",))
401        self.assertEqual(exc.remoteTraceback, "Traceback unavailable\n")
402
403
404class PBFailureUnsafeTests(PBFailureTests):
405    compare = unittest.TestCase.failIfEquals
406    unsafeTracebacks = 1
407
408
409class DummyInvoker:
410    """
411    A behaviorless object to be used as the invoker parameter to
412    L{jelly.jelly}.
413    """
414
415    serializingPerspective = None
416
417
418class FailureJellyingTests(unittest.TestCase):
419    """
420    Tests for the interaction of jelly and failures.
421    """
422
423    def test_unjelliedFailureCheck(self):
424        """
425        An unjellied L{CopyableFailure} has a check method which behaves the
426        same way as the original L{CopyableFailure}'s check method.
427        """
428        original = pb.CopyableFailure(ZeroDivisionError())
429        self.assertIs(original.check(ZeroDivisionError), ZeroDivisionError)
430        self.assertIs(original.check(ArithmeticError), ArithmeticError)
431        copied = jelly.unjelly(jelly.jelly(original, invoker=DummyInvoker()))
432        self.assertIs(copied.check(ZeroDivisionError), ZeroDivisionError)
433        self.assertIs(copied.check(ArithmeticError), ArithmeticError)
434
435    def test_twiceUnjelliedFailureCheck(self):
436        """
437        The object which results from jellying a L{CopyableFailure}, unjellying
438        the result, creating a new L{CopyableFailure} from the result of that,
439        jellying it, and finally unjellying the result of that has a check
440        method which behaves the same way as the original L{CopyableFailure}'s
441        check method.
442        """
443        original = pb.CopyableFailure(ZeroDivisionError())
444        self.assertIs(original.check(ZeroDivisionError), ZeroDivisionError)
445        self.assertIs(original.check(ArithmeticError), ArithmeticError)
446        copiedOnce = jelly.unjelly(jelly.jelly(original, invoker=DummyInvoker()))
447        derivative = pb.CopyableFailure(copiedOnce)
448        copiedTwice = jelly.unjelly(jelly.jelly(derivative, invoker=DummyInvoker()))
449        self.assertIs(copiedTwice.check(ZeroDivisionError), ZeroDivisionError)
450        self.assertIs(copiedTwice.check(ArithmeticError), ArithmeticError)
451
452    def test_printTracebackIncludesValue(self):
453        """
454        When L{CopiedFailure.printTraceback} is used to print a copied failure
455        which was unjellied from a L{CopyableFailure} with C{unsafeTracebacks}
456        set to C{False}, the string representation of the exception value is
457        included in the output.
458        """
459        original = pb.CopyableFailure(Exception("some reason"))
460        copied = jelly.unjelly(jelly.jelly(original, invoker=DummyInvoker()))
461        output = StringIO()
462        copied.printTraceback(output)
463        exception = qual(Exception)
464        expectedOutput = "Traceback from remote host -- " "{}: some reason\n".format(
465            exception
466        )
467        self.assertEqual(expectedOutput, output.getvalue())
468