1# Copyright (c) Twisted Matrix Laboratories.
2# See LICENSE for details.
3
4"""
5Tests for Perspective Broker module.
6
7TODO: update protocol level tests to use new connection API, leaving
8only specific tests for old API.
9"""
10
11# issue1195 TODOs: replace pump.pump() with something involving Deferreds.
12# Clean up warning suppression.
13
14
15import gc
16import os
17import sys
18import time
19import weakref
20from collections import deque
21from io import BytesIO as StringIO
22from typing import Dict
23
24from zope.interface import Interface, implementer
25
26from twisted.cred import checkers, credentials, portal
27from twisted.cred.error import UnauthorizedLogin, UnhandledCredentials
28from twisted.internet import address, main, protocol, reactor
29from twisted.internet.defer import Deferred, gatherResults, succeed
30from twisted.internet.error import ConnectionRefusedError
31from twisted.protocols.policies import WrappingFactory
32from twisted.python import failure, log
33from twisted.python.compat import iterbytes
34from twisted.spread import jelly, pb, publish, util
35from twisted.test.proto_helpers import _FakeConnector
36from twisted.trial import unittest
37
38
39class Dummy(pb.Viewable):
40    def view_doNothing(self, user):
41        if isinstance(user, DummyPerspective):
42            return "hello world!"
43        else:
44            return "goodbye, cruel world!"
45
46
47class DummyPerspective(pb.Avatar):
48    """
49    An L{IPerspective} avatar which will be used in some tests.
50    """
51
52    def perspective_getDummyViewPoint(self):
53        return Dummy()
54
55
56@implementer(portal.IRealm)
57class DummyRealm:
58    def requestAvatar(self, avatarId, mind, *interfaces):
59        for iface in interfaces:
60            if iface is pb.IPerspective:
61                return iface, DummyPerspective(avatarId), lambda: None
62
63
64class IOPump:
65    """
66    Utility to pump data between clients and servers for protocol testing.
67
68    Perhaps this is a utility worthy of being in protocol.py?
69    """
70
71    def __init__(self, client, server, clientIO, serverIO):
72        self.client = client
73        self.server = server
74        self.clientIO = clientIO
75        self.serverIO = serverIO
76
77    def flush(self):
78        """
79        Pump until there is no more input or output or until L{stop} is called.
80        This does not run any timers, so don't use it with any code that calls
81        reactor.callLater.
82        """
83        # failsafe timeout
84        self._stop = False
85        timeout = time.time() + 5
86        while not self._stop and self.pump():
87            if time.time() > timeout:
88                return
89
90    def stop(self):
91        """
92        Stop a running L{flush} operation, even if data remains to be
93        transferred.
94        """
95        self._stop = True
96
97    def pump(self):
98        """
99        Move data back and forth.
100
101        Returns whether any data was moved.
102        """
103        self.clientIO.seek(0)
104        self.serverIO.seek(0)
105        cData = self.clientIO.read()
106        sData = self.serverIO.read()
107        self.clientIO.seek(0)
108        self.serverIO.seek(0)
109        self.clientIO.truncate()
110        self.serverIO.truncate()
111        self.client.transport._checkProducer()
112        self.server.transport._checkProducer()
113        for byte in iterbytes(cData):
114            self.server.dataReceived(byte)
115        for byte in iterbytes(sData):
116            self.client.dataReceived(byte)
117        if cData or sData:
118            return 1
119        else:
120            return 0
121
122
123def connectServerAndClient(test, clientFactory, serverFactory):
124    """
125    Create a server and a client and connect the two with an
126    L{IOPump}.
127
128    @param test: the test case where the client and server will be
129        used.
130    @type test: L{twisted.trial.unittest.TestCase}
131
132    @param clientFactory: The factory that creates the client object.
133    @type clientFactory: L{twisted.spread.pb.PBClientFactory}
134
135    @param serverFactory: The factory that creates the server object.
136    @type serverFactory: L{twisted.spread.pb.PBServerFactory}
137
138    @return: a 3-tuple of (client, server, pump)
139    @rtype: (L{twisted.spread.pb.Broker}, L{twisted.spread.pb.Broker},
140        L{IOPump})
141    """
142    addr = ("127.0.0.1",)
143    clientBroker = clientFactory.buildProtocol(addr)
144    serverBroker = serverFactory.buildProtocol(addr)
145
146    clientTransport = StringIO()
147    serverTransport = StringIO()
148
149    clientBroker.makeConnection(protocol.FileWrapper(clientTransport))
150    serverBroker.makeConnection(protocol.FileWrapper(serverTransport))
151    pump = IOPump(clientBroker, serverBroker, clientTransport, serverTransport)
152
153    def maybeDisconnect(broker):
154        if not broker.disconnected:
155            broker.connectionLost(failure.Failure(main.CONNECTION_DONE))
156
157    def disconnectClientFactory():
158        # There's no connector, just a FileWrapper mediated by the
159        # IOPump.  Fortunately PBClientFactory.clientConnectionLost
160        # doesn't do anything with the connector so we can get away
161        # with passing None here.
162        clientFactory.clientConnectionLost(
163            connector=None, reason=failure.Failure(main.CONNECTION_DONE)
164        )
165
166    test.addCleanup(maybeDisconnect, clientBroker)
167    test.addCleanup(maybeDisconnect, serverBroker)
168    test.addCleanup(disconnectClientFactory)
169    # Establish the connection
170    pump.pump()
171    return clientBroker, serverBroker, pump
172
173
174class _ReconnectingFakeConnectorState:
175    """
176    Manages connection notifications for a
177    L{_ReconnectingFakeConnector} instance.
178
179    @ivar notifications: pending L{Deferreds} that will fire when the
180        L{_ReconnectingFakeConnector}'s connect method is called
181    """
182
183    def __init__(self):
184        self.notifications = deque()
185
186    def notifyOnConnect(self):
187        """
188        Connection notification.
189
190        @return: A L{Deferred} that fires when this instance's
191            L{twisted.internet.interfaces.IConnector.connect} method
192            is called.
193        @rtype: L{Deferred}
194        """
195        notifier = Deferred()
196        self.notifications.appendleft(notifier)
197        return notifier
198
199    def notifyAll(self):
200        """
201        Fire all pending notifications.
202        """
203        while self.notifications:
204            self.notifications.pop().callback(self)
205
206
207class _ReconnectingFakeConnector(_FakeConnector):
208    """
209    A fake L{IConnector} that can fire L{Deferred}s when its
210    C{connect} method is called.
211    """
212
213    def __init__(self, address, state):
214        """
215        @param address: An L{IAddress} provider that represents this
216            connector's destination.
217        @type address: An L{IAddress} provider.
218
219        @param state: The state instance
220        @type state: L{_ReconnectingFakeConnectorState}
221        """
222        super().__init__(address)
223        self._state = state
224
225    def connect(self):
226        """
227        A C{connect} implementation that calls C{reconnectCallback}
228        """
229        super().connect()
230        self._state.notifyAll()
231
232
233def connectedServerAndClient(test, realm=None):
234    """
235    Connect a client and server L{Broker} together with an L{IOPump}
236
237    @param realm: realm to use, defaulting to a L{DummyRealm}
238
239    @returns: a 3-tuple (client, server, pump).
240    """
241    realm = realm or DummyRealm()
242    checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(guest=b"guest")
243    serverFactory = pb.PBServerFactory(portal.Portal(realm, [checker]))
244    clientFactory = pb.PBClientFactory()
245    return connectServerAndClient(test, clientFactory, serverFactory)
246
247
248class SimpleRemote(pb.Referenceable):
249    def remote_thunk(self, arg):
250        self.arg = arg
251        return arg + 1
252
253    def remote_knuth(self, arg):
254        raise Exception()
255
256
257class NestedRemote(pb.Referenceable):
258    def remote_getSimple(self):
259        return SimpleRemote()
260
261
262class SimpleCopy(pb.Copyable):
263    def __init__(self):
264        self.x = 1
265        self.y = {"Hello": "World"}
266        self.z = ["test"]
267
268
269class SimpleLocalCopy(pb.RemoteCopy):
270    pass
271
272
273pb.setUnjellyableForClass(SimpleCopy, SimpleLocalCopy)
274
275
276class SimpleFactoryCopy(pb.Copyable):
277    """
278    @cvar allIDs: hold every created instances of this class.
279    @type allIDs: C{dict}
280    """
281
282    allIDs: Dict[int, "SimpleFactoryCopy"] = {}
283
284    def __init__(self, id):
285        self.id = id
286        SimpleFactoryCopy.allIDs[id] = self
287
288
289def createFactoryCopy(state):
290    """
291    Factory of L{SimpleFactoryCopy}, getting a created instance given the
292    C{id} found in C{state}.
293    """
294    stateId = state.get("id", None)
295    if stateId is None:
296        raise RuntimeError(f"factory copy state has no 'id' member {repr(state)}")
297    if stateId not in SimpleFactoryCopy.allIDs:
298        raise RuntimeError(f"factory class has no ID: {SimpleFactoryCopy.allIDs}")
299    inst = SimpleFactoryCopy.allIDs[stateId]
300    if not inst:
301        raise RuntimeError("factory method found no object with id")
302    return inst
303
304
305pb.setUnjellyableFactoryForClass(SimpleFactoryCopy, createFactoryCopy)
306
307
308class NestedCopy(pb.Referenceable):
309    def remote_getCopy(self):
310        return SimpleCopy()
311
312    def remote_getFactory(self, value):
313        return SimpleFactoryCopy(value)
314
315
316class SimpleCache(pb.Cacheable):
317    def __init___(self):
318        self.x = 1
319        self.y = {"Hello": "World"}
320        self.z = ["test"]
321
322
323class NestedComplicatedCache(pb.Referenceable):
324    def __init__(self):
325        self.c = VeryVeryComplicatedCacheable()
326
327    def remote_getCache(self):
328        return self.c
329
330
331class VeryVeryComplicatedCacheable(pb.Cacheable):
332    def __init__(self):
333        self.x = 1
334        self.y = 2
335        self.foo = 3
336
337    def setFoo4(self):
338        self.foo = 4
339        self.observer.callRemote("foo", 4)
340
341    def getStateToCacheAndObserveFor(self, perspective, observer):
342        self.observer = observer
343        return {"x": self.x, "y": self.y, "foo": self.foo}
344
345    def stoppedObserving(self, perspective, observer):
346        log.msg("stopped observing")
347        observer.callRemote("end")
348        if observer == self.observer:
349            self.observer = None
350
351
352class RatherBaroqueCache(pb.RemoteCache):
353    def observe_foo(self, newFoo):
354        self.foo = newFoo
355
356    def observe_end(self):
357        log.msg("the end of things")
358
359
360pb.setUnjellyableForClass(VeryVeryComplicatedCacheable, RatherBaroqueCache)
361
362
363class SimpleLocalCache(pb.RemoteCache):
364    def setCopyableState(self, state):
365        self.__dict__.update(state)
366
367    def checkMethod(self):
368        return self.check
369
370    def checkSelf(self):
371        return self
372
373    def check(self):
374        return 1
375
376
377pb.setUnjellyableForClass(SimpleCache, SimpleLocalCache)
378
379
380class NestedCache(pb.Referenceable):
381    def __init__(self):
382        self.x = SimpleCache()
383
384    def remote_getCache(self):
385        return [self.x, self.x]
386
387    def remote_putCache(self, cache):
388        return self.x is cache
389
390
391class Observable(pb.Referenceable):
392    def __init__(self):
393        self.observers = []
394
395    def remote_observe(self, obs):
396        self.observers.append(obs)
397
398    def remote_unobserve(self, obs):
399        self.observers.remove(obs)
400
401    def notify(self, obj):
402        for observer in self.observers:
403            observer.callRemote("notify", self, obj)
404
405
406class DeferredRemote(pb.Referenceable):
407    def __init__(self):
408        self.run = 0
409
410    def runMe(self, arg):
411        self.run = arg
412        return arg + 1
413
414    def dontRunMe(self, arg):
415        assert 0, "shouldn't have been run!"
416
417    def remote_doItLater(self):
418        """
419        Return a L{Deferred} to be fired on client side. When fired,
420        C{self.runMe} is called.
421        """
422        d = Deferred()
423        d.addCallbacks(self.runMe, self.dontRunMe)
424        self.d = d
425        return d
426
427
428class Observer(pb.Referenceable):
429    notified = 0
430    obj = None
431
432    def remote_notify(self, other, obj):
433        self.obj = obj
434        self.notified = self.notified + 1
435        other.callRemote("unobserve", self)
436
437
438class NewStyleCopy(pb.Copyable, pb.RemoteCopy):
439    def __init__(self, s):
440        self.s = s
441
442
443pb.setUnjellyableForClass(NewStyleCopy, NewStyleCopy)
444
445
446class NewStyleCopy2(pb.Copyable, pb.RemoteCopy):
447    allocated = 0
448    initialized = 0
449    value = 1
450
451    def __new__(self):
452        NewStyleCopy2.allocated += 1
453        inst = object.__new__(self)
454        inst.value = 2
455        return inst
456
457    def __init__(self):
458        NewStyleCopy2.initialized += 1
459
460
461pb.setUnjellyableForClass(NewStyleCopy2, NewStyleCopy2)
462
463
464class NewStyleCacheCopy(pb.Cacheable, pb.RemoteCache):
465    def getStateToCacheAndObserveFor(self, perspective, observer):
466        return self.__dict__
467
468
469pb.setUnjellyableForClass(NewStyleCacheCopy, NewStyleCacheCopy)
470
471
472class Echoer(pb.Root):
473    def remote_echo(self, st):
474        return st
475
476    def remote_echoWithKeywords(self, st, **kw):
477        return (st, kw)
478
479
480class CachedReturner(pb.Root):
481    def __init__(self, cache):
482        self.cache = cache
483
484    def remote_giveMeCache(self, st):
485        return self.cache
486
487
488class NewStyleTests(unittest.SynchronousTestCase):
489    def setUp(self):
490        """
491        Create a pb server using L{Echoer} protocol and connect a client to it.
492        """
493        self.serverFactory = pb.PBServerFactory(Echoer())
494        clientFactory = pb.PBClientFactory()
495        client, self.server, self.pump = connectServerAndClient(
496            test=self, clientFactory=clientFactory, serverFactory=self.serverFactory
497        )
498        self.ref = self.successResultOf(clientFactory.getRootObject())
499
500    def tearDown(self):
501        """
502        Close client and server connections, reset values of L{NewStyleCopy2}
503        class variables.
504        """
505        NewStyleCopy2.allocated = 0
506        NewStyleCopy2.initialized = 0
507        NewStyleCopy2.value = 1
508
509    def test_newStyle(self):
510        """
511        Create a new style object, send it over the wire, and check the result.
512        """
513        orig = NewStyleCopy("value")
514        d = self.ref.callRemote("echo", orig)
515        self.pump.flush()
516
517        def cb(res):
518            self.assertIsInstance(res, NewStyleCopy)
519            self.assertEqual(res.s, "value")
520            self.assertFalse(res is orig)  # no cheating :)
521
522        d.addCallback(cb)
523        return d
524
525    def test_alloc(self):
526        """
527        Send a new style object and check the number of allocations.
528        """
529        orig = NewStyleCopy2()
530        self.assertEqual(NewStyleCopy2.allocated, 1)
531        self.assertEqual(NewStyleCopy2.initialized, 1)
532        d = self.ref.callRemote("echo", orig)
533        self.pump.flush()
534
535        def cb(res):
536            # Receiving the response creates a third one on the way back
537            self.assertIsInstance(res, NewStyleCopy2)
538            self.assertEqual(res.value, 2)
539            self.assertEqual(NewStyleCopy2.allocated, 3)
540            self.assertEqual(NewStyleCopy2.initialized, 1)
541            self.assertIsNot(res, orig)  # No cheating :)
542
543        # Sending the object creates a second one on the far side
544        d.addCallback(cb)
545        return d
546
547    def test_newStyleWithKeywords(self):
548        """
549        Create a new style object with keywords,
550        send it over the wire, and check the result.
551        """
552        orig = NewStyleCopy("value1")
553        d = self.ref.callRemote(
554            "echoWithKeywords", orig, keyword1="one", keyword2="two"
555        )
556        self.pump.flush()
557
558        def cb(res):
559            self.assertIsInstance(res, tuple)
560            self.assertIsInstance(res[0], NewStyleCopy)
561            self.assertIsInstance(res[1], dict)
562            self.assertEqual(res[0].s, "value1")
563            self.assertIsNot(res[0], orig)
564            self.assertEqual(res[1], {"keyword1": "one", "keyword2": "two"})
565
566        d.addCallback(cb)
567        return d
568
569
570class ConnectionNotifyServerFactory(pb.PBServerFactory):
571    """
572    A server factory which stores the last connection and fires a
573    L{Deferred} on connection made. This factory can handle only one
574    client connection.
575
576    @ivar protocolInstance: the last protocol instance.
577    @type protocolInstance: C{pb.Broker}
578
579    @ivar connectionMade: the deferred fired upon connection.
580    @type connectionMade: C{Deferred}
581    """
582
583    protocolInstance = None
584
585    def __init__(self, root):
586        """
587        Initialize the factory.
588        """
589        pb.PBServerFactory.__init__(self, root)
590        self.connectionMade = Deferred()
591
592    def clientConnectionMade(self, protocol):
593        """
594        Store the protocol and fire the connection deferred.
595        """
596        self.protocolInstance = protocol
597        d, self.connectionMade = self.connectionMade, None
598        if d is not None:
599            d.callback(None)
600
601
602class NewStyleCachedTests(unittest.TestCase):
603    def setUp(self):
604        """
605        Create a pb server using L{CachedReturner} protocol and connect a
606        client to it.
607        """
608        self.orig = NewStyleCacheCopy()
609        self.orig.s = "value"
610        self.server = reactor.listenTCP(
611            0, ConnectionNotifyServerFactory(CachedReturner(self.orig))
612        )
613        clientFactory = pb.PBClientFactory()
614        reactor.connectTCP("localhost", self.server.getHost().port, clientFactory)
615
616        def gotRoot(ref):
617            self.ref = ref
618
619        d1 = clientFactory.getRootObject().addCallback(gotRoot)
620        d2 = self.server.factory.connectionMade
621        return gatherResults([d1, d2])
622
623    def tearDown(self):
624        """
625        Close client and server connections.
626        """
627        self.server.factory.protocolInstance.transport.loseConnection()
628        self.ref.broker.transport.loseConnection()
629        return self.server.stopListening()
630
631    def test_newStyleCache(self):
632        """
633        A new-style cacheable object can be retrieved and re-retrieved over a
634        single connection.  The value of an attribute of the cacheable can be
635        accessed on the receiving side.
636        """
637        d = self.ref.callRemote("giveMeCache", self.orig)
638
639        def cb(res, again):
640            self.assertIsInstance(res, NewStyleCacheCopy)
641            self.assertEqual("value", res.s)
642            # no cheating :)
643            self.assertIsNot(self.orig, res)
644
645            if again:
646                # Save a reference so it stays alive for the rest of this test
647                self.res = res
648                # And ask for it again to exercise the special re-jelly logic in
649                # Cacheable.
650                return self.ref.callRemote("giveMeCache", self.orig)
651
652        d.addCallback(cb, True)
653        d.addCallback(cb, False)
654        return d
655
656
657class BrokerTests(unittest.TestCase):
658    thunkResult = None
659
660    def tearDown(self):
661        try:
662            # from RemotePublished.getFileName
663            os.unlink("None-None-TESTING.pub")
664        except OSError:
665            pass
666
667    def thunkErrorBad(self, error):
668        self.fail(f"This should cause a return value, not {error}")
669
670    def thunkResultGood(self, result):
671        self.thunkResult = result
672
673    def thunkErrorGood(self, tb):
674        pass
675
676    def thunkResultBad(self, result):
677        self.fail(f"This should cause an error, not {result}")
678
679    def test_reference(self):
680        c, s, pump = connectedServerAndClient(test=self)
681
682        class X(pb.Referenceable):
683            def remote_catch(self, arg):
684                self.caught = arg
685
686        class Y(pb.Referenceable):
687            def remote_throw(self, a, b):
688                a.callRemote("catch", b)
689
690        s.setNameForLocal("y", Y())
691        y = c.remoteForName("y")
692        x = X()
693        z = X()
694        y.callRemote("throw", x, z)
695        pump.pump()
696        pump.pump()
697        pump.pump()
698        self.assertIs(x.caught, z, "X should have caught Z")
699
700        # make sure references to remote methods are equals
701        self.assertEqual(y.remoteMethod("throw"), y.remoteMethod("throw"))
702
703    def test_result(self):
704        c, s, pump = connectedServerAndClient(test=self)
705        for x, y in (c, s), (s, c):
706            # test reflexivity
707            foo = SimpleRemote()
708            x.setNameForLocal("foo", foo)
709            bar = y.remoteForName("foo")
710            self.expectedThunkResult = 8
711            bar.callRemote("thunk", self.expectedThunkResult - 1).addCallbacks(
712                self.thunkResultGood, self.thunkErrorBad
713            )
714            # Send question.
715            pump.pump()
716            # Send response.
717            pump.pump()
718            # Shouldn't require any more pumping than that...
719            self.assertEqual(
720                self.thunkResult, self.expectedThunkResult, "result wasn't received."
721            )
722
723    def refcountResult(self, result):
724        self.nestedRemote = result
725
726    def test_tooManyRefs(self):
727        l = []
728        e = []
729        c, s, pump = connectedServerAndClient(test=self)
730        foo = NestedRemote()
731        s.setNameForLocal("foo", foo)
732        x = c.remoteForName("foo")
733        for igno in range(pb.MAX_BROKER_REFS + 10):
734            if s.transport.closed or c.transport.closed:
735                break
736            x.callRemote("getSimple").addCallbacks(l.append, e.append)
737            pump.pump()
738        expected = pb.MAX_BROKER_REFS - 1
739        self.assertTrue(s.transport.closed, "transport was not closed")
740        self.assertEqual(len(l), expected, f"expected {expected} got {len(l)}")
741
742    def test_copy(self):
743        c, s, pump = connectedServerAndClient(test=self)
744        foo = NestedCopy()
745        s.setNameForLocal("foo", foo)
746        x = c.remoteForName("foo")
747        x.callRemote("getCopy").addCallbacks(self.thunkResultGood, self.thunkErrorBad)
748        pump.pump()
749        pump.pump()
750        self.assertEqual(self.thunkResult.x, 1)
751        self.assertEqual(self.thunkResult.y["Hello"], "World")
752        self.assertEqual(self.thunkResult.z[0], "test")
753
754    def test_observe(self):
755        c, s, pump = connectedServerAndClient(test=self)
756
757        # this is really testing the comparison between remote objects, to make
758        # sure that you can *UN*observe when you have an observer architecture.
759        a = Observable()
760        b = Observer()
761        s.setNameForLocal("a", a)
762        ra = c.remoteForName("a")
763        ra.callRemote("observe", b)
764        pump.pump()
765        a.notify(1)
766        pump.pump()
767        pump.pump()
768        a.notify(10)
769        pump.pump()
770        pump.pump()
771        self.assertIsNotNone(b.obj, "didn't notify")
772        self.assertEqual(b.obj, 1, "notified too much")
773
774    def test_defer(self):
775        c, s, pump = connectedServerAndClient(test=self)
776        d = DeferredRemote()
777        s.setNameForLocal("d", d)
778        e = c.remoteForName("d")
779        pump.pump()
780        pump.pump()
781        results = []
782        e.callRemote("doItLater").addCallback(results.append)
783        pump.pump()
784        pump.pump()
785        self.assertFalse(d.run, "Deferred method run too early.")
786        d.d.callback(5)
787        self.assertEqual(d.run, 5, "Deferred method run too late.")
788        pump.pump()
789        pump.pump()
790        self.assertEqual(results[0], 6, "Incorrect result.")
791
792    def test_refcount(self):
793        c, s, pump = connectedServerAndClient(test=self)
794        foo = NestedRemote()
795        s.setNameForLocal("foo", foo)
796        bar = c.remoteForName("foo")
797        bar.callRemote("getSimple").addCallbacks(
798            self.refcountResult, self.thunkErrorBad
799        )
800
801        # send question
802        pump.pump()
803        # send response
804        pump.pump()
805
806        # delving into internal structures here, because GC is sort of
807        # inherently internal.
808        rluid = self.nestedRemote.luid
809        self.assertIn(rluid, s.localObjects)
810        del self.nestedRemote
811        # nudge the gc
812        if sys.hexversion >= 0x2000000:
813            gc.collect()
814        # try to nudge the GC even if we can't really
815        pump.pump()
816        pump.pump()
817        pump.pump()
818        self.assertNotIn(rluid, s.localObjects)
819
820    def test_cache(self):
821        c, s, pump = connectedServerAndClient(test=self)
822        obj = NestedCache()
823        obj2 = NestedComplicatedCache()
824        vcc = obj2.c
825        s.setNameForLocal("obj", obj)
826        s.setNameForLocal("xxx", obj2)
827        o2 = c.remoteForName("obj")
828        o3 = c.remoteForName("xxx")
829        coll = []
830        o2.callRemote("getCache").addCallback(coll.append).addErrback(coll.append)
831        o2.callRemote("getCache").addCallback(coll.append).addErrback(coll.append)
832        complex = []
833        o3.callRemote("getCache").addCallback(complex.append)
834        o3.callRemote("getCache").addCallback(complex.append)
835        pump.flush()
836        # `worst things first'
837        self.assertEqual(complex[0].x, 1)
838        self.assertEqual(complex[0].y, 2)
839        self.assertEqual(complex[0].foo, 3)
840
841        vcc.setFoo4()
842        pump.flush()
843        self.assertEqual(complex[0].foo, 4)
844        self.assertEqual(len(coll), 2)
845        cp = coll[0][0]
846        self.assertIdentical(
847            cp.checkMethod().__self__, cp, "potential refcounting issue"
848        )
849        self.assertIdentical(cp.checkSelf(), cp, "other potential refcounting issue")
850        col2 = []
851        o2.callRemote("putCache", cp).addCallback(col2.append)
852        pump.flush()
853        # The objects were the same (testing lcache identity)
854        self.assertTrue(col2[0])
855        # test equality of references to methods
856        self.assertEqual(o2.remoteMethod("getCache"), o2.remoteMethod("getCache"))
857
858        # now, refcounting (similar to testRefCount)
859        luid = cp.luid
860        baroqueLuid = complex[0].luid
861        self.assertIn(luid, s.remotelyCachedObjects, "remote cache doesn't have it")
862        del coll
863        del cp
864        pump.flush()
865        del complex
866        del col2
867        # extra nudge...
868        pump.flush()
869        # del vcc.observer
870        # nudge the gc
871        if sys.hexversion >= 0x2000000:
872            gc.collect()
873        # try to nudge the GC even if we can't really
874        pump.flush()
875        # The GC is done with it.
876        self.assertNotIn(luid, s.remotelyCachedObjects, "Server still had it after GC")
877        self.assertNotIn(luid, c.locallyCachedObjects, "Client still had it after GC")
878        self.assertNotIn(
879            baroqueLuid, s.remotelyCachedObjects, "Server still had complex after GC"
880        )
881        self.assertNotIn(
882            baroqueLuid, c.locallyCachedObjects, "Client still had complex after GC"
883        )
884        self.assertIsNone(vcc.observer, "observer was not removed")
885
886    def test_publishable(self):
887        try:
888            os.unlink("None-None-TESTING.pub")  # from RemotePublished.getFileName
889        except OSError:
890            pass  # Sometimes it's not there.
891        c, s, pump = connectedServerAndClient(test=self)
892        foo = GetPublisher()
893        # foo.pub.timestamp = 1.0
894        s.setNameForLocal("foo", foo)
895        bar = c.remoteForName("foo")
896        accum = []
897        bar.callRemote("getPub").addCallbacks(accum.append, self.thunkErrorBad)
898        pump.flush()
899        obj = accum.pop()
900        self.assertEqual(obj.activateCalled, 1)
901        self.assertEqual(obj.isActivated, 1)
902        self.assertEqual(obj.yayIGotPublished, 1)
903        # timestamp's dirty, we don't have a cache file
904        self.assertEqual(obj._wasCleanWhenLoaded, 0)
905        c, s, pump = connectedServerAndClient(test=self)
906        s.setNameForLocal("foo", foo)
907        bar = c.remoteForName("foo")
908        bar.callRemote("getPub").addCallbacks(accum.append, self.thunkErrorBad)
909        pump.flush()
910        obj = accum.pop()
911        # timestamp's clean, our cache file is up-to-date
912        self.assertEqual(obj._wasCleanWhenLoaded, 1)
913
914    def gotCopy(self, val):
915        self.thunkResult = val.id
916
917    def test_factoryCopy(self):
918        c, s, pump = connectedServerAndClient(test=self)
919        ID = 99
920        obj = NestedCopy()
921        s.setNameForLocal("foo", obj)
922        x = c.remoteForName("foo")
923        x.callRemote("getFactory", ID).addCallbacks(self.gotCopy, self.thunkResultBad)
924        pump.pump()
925        pump.pump()
926        pump.pump()
927        self.assertEqual(
928            self.thunkResult,
929            ID,
930            f"ID not correct on factory object {self.thunkResult}",
931        )
932
933
934bigString = b"helloworld" * 50
935
936callbackArgs = None
937callbackKeyword = None
938
939
940def finishedCallback(*args, **kw):
941    global callbackArgs, callbackKeyword
942    callbackArgs = args
943    callbackKeyword = kw
944
945
946class Pagerizer(pb.Referenceable):
947    def __init__(self, callback, *args, **kw):
948        self.callback, self.args, self.kw = callback, args, kw
949
950    def remote_getPages(self, collector):
951        util.StringPager(
952            collector, bigString, 100, self.callback, *self.args, **self.kw
953        )
954        self.args = self.kw = None
955
956
957class FilePagerizer(pb.Referenceable):
958    pager = None
959
960    def __init__(self, filename, callback, *args, **kw):
961        self.filename = filename
962        self.callback, self.args, self.kw = callback, args, kw
963
964    def remote_getPages(self, collector):
965        self.pager = util.FilePager(
966            collector, open(self.filename, "rb"), self.callback, *self.args, **self.kw
967        )
968        self.args = self.kw = None
969
970
971class PagingTests(unittest.TestCase):
972    """
973    Test pb objects sending data by pages.
974    """
975
976    def setUp(self):
977        """
978        Create a file used to test L{util.FilePager}.
979        """
980        self.filename = self.mktemp()
981        with open(self.filename, "wb") as f:
982            f.write(bigString)
983
984    def test_pagingWithCallback(self):
985        """
986        Test L{util.StringPager}, passing a callback to fire when all pages
987        are sent.
988        """
989        c, s, pump = connectedServerAndClient(test=self)
990        s.setNameForLocal("foo", Pagerizer(finishedCallback, "hello", value=10))
991        x = c.remoteForName("foo")
992        l = []
993        util.getAllPages(x, "getPages").addCallback(l.append)
994        while not l:
995            pump.pump()
996        self.assertEqual(
997            b"".join(l[0]), bigString, "Pages received not equal to pages sent!"
998        )
999        self.assertEqual(callbackArgs, ("hello",), "Completed callback not invoked")
1000        self.assertEqual(
1001            callbackKeyword, {"value": 10}, "Completed callback not invoked"
1002        )
1003
1004    def test_pagingWithoutCallback(self):
1005        """
1006        Test L{util.StringPager} without a callback.
1007        """
1008        c, s, pump = connectedServerAndClient(test=self)
1009        s.setNameForLocal("foo", Pagerizer(None))
1010        x = c.remoteForName("foo")
1011        l = []
1012        util.getAllPages(x, "getPages").addCallback(l.append)
1013        while not l:
1014            pump.pump()
1015        self.assertEqual(
1016            b"".join(l[0]), bigString, "Pages received not equal to pages sent!"
1017        )
1018
1019    def test_emptyFilePaging(self):
1020        """
1021        Test L{util.FilePager}, sending an empty file.
1022        """
1023        filenameEmpty = self.mktemp()
1024        open(filenameEmpty, "w").close()
1025        c, s, pump = connectedServerAndClient(test=self)
1026        pagerizer = FilePagerizer(filenameEmpty, None)
1027        s.setNameForLocal("bar", pagerizer)
1028        x = c.remoteForName("bar")
1029        l = []
1030        util.getAllPages(x, "getPages").addCallback(l.append)
1031        ttl = 10
1032        while not l and ttl > 0:
1033            pump.pump()
1034            ttl -= 1
1035        if not ttl:
1036            self.fail("getAllPages timed out")
1037        self.assertEqual(b"".join(l[0]), b"", "Pages received not equal to pages sent!")
1038
1039    def test_filePagingWithCallback(self):
1040        """
1041        Test L{util.FilePager}, passing a callback to fire when all pages
1042        are sent, and verify that the pager doesn't keep chunks in memory.
1043        """
1044        c, s, pump = connectedServerAndClient(test=self)
1045        pagerizer = FilePagerizer(self.filename, finishedCallback, "frodo", value=9)
1046        s.setNameForLocal("bar", pagerizer)
1047        x = c.remoteForName("bar")
1048        l = []
1049        util.getAllPages(x, "getPages").addCallback(l.append)
1050        while not l:
1051            pump.pump()
1052        self.assertEqual(
1053            b"".join(l[0]), bigString, "Pages received not equal to pages sent!"
1054        )
1055        self.assertEqual(callbackArgs, ("frodo",), "Completed callback not invoked")
1056        self.assertEqual(
1057            callbackKeyword, {"value": 9}, "Completed callback not invoked"
1058        )
1059        self.assertEqual(pagerizer.pager.chunks, [])
1060
1061    def test_filePagingWithoutCallback(self):
1062        """
1063        Test L{util.FilePager} without a callback.
1064        """
1065        c, s, pump = connectedServerAndClient(test=self)
1066        pagerizer = FilePagerizer(self.filename, None)
1067        s.setNameForLocal("bar", pagerizer)
1068        x = c.remoteForName("bar")
1069        l = []
1070        util.getAllPages(x, "getPages").addCallback(l.append)
1071        while not l:
1072            pump.pump()
1073        self.assertEqual(
1074            b"".join(l[0]), bigString, "Pages received not equal to pages sent!"
1075        )
1076        self.assertEqual(pagerizer.pager.chunks, [])
1077
1078
1079class DumbPublishable(publish.Publishable):
1080    def getStateToPublish(self):
1081        return {"yayIGotPublished": 1}
1082
1083
1084class DumbPub(publish.RemotePublished):
1085    def activated(self):
1086        self.activateCalled = 1
1087
1088
1089class GetPublisher(pb.Referenceable):
1090    def __init__(self):
1091        self.pub = DumbPublishable("TESTING")
1092
1093    def remote_getPub(self):
1094        return self.pub
1095
1096
1097pb.setUnjellyableForClass(DumbPublishable, DumbPub)
1098
1099
1100class DisconnectionTests(unittest.TestCase):
1101    """
1102    Test disconnection callbacks.
1103    """
1104
1105    def error(self, *args):
1106        raise RuntimeError(f"I shouldn't have been called: {args}")
1107
1108    def gotDisconnected(self):
1109        """
1110        Called on broker disconnect.
1111        """
1112        self.gotCallback = 1
1113
1114    def objectDisconnected(self, o):
1115        """
1116        Called on RemoteReference disconnect.
1117        """
1118        self.assertEqual(o, self.remoteObject)
1119        self.objectCallback = 1
1120
1121    def test_badSerialization(self):
1122        c, s, pump = connectedServerAndClient(test=self)
1123        pump.pump()
1124        s.setNameForLocal("o", BadCopySet())
1125        g = c.remoteForName("o")
1126        l = []
1127        g.callRemote("setBadCopy", BadCopyable()).addErrback(l.append)
1128        pump.flush()
1129        self.assertEqual(len(l), 1)
1130
1131    def test_disconnection(self):
1132        c, s, pump = connectedServerAndClient(test=self)
1133        pump.pump()
1134        s.setNameForLocal("o", SimpleRemote())
1135
1136        # get a client reference to server object
1137        r = c.remoteForName("o")
1138        pump.pump()
1139        pump.pump()
1140        pump.pump()
1141
1142        # register and then unregister disconnect callbacks
1143        # making sure they get unregistered
1144        c.notifyOnDisconnect(self.error)
1145        self.assertIn(self.error, c.disconnects)
1146        c.dontNotifyOnDisconnect(self.error)
1147        self.assertNotIn(self.error, c.disconnects)
1148
1149        r.notifyOnDisconnect(self.error)
1150        self.assertIn(r._disconnected, c.disconnects)
1151        self.assertIn(self.error, r.disconnectCallbacks)
1152        r.dontNotifyOnDisconnect(self.error)
1153        self.assertNotIn(r._disconnected, c.disconnects)
1154        self.assertNotIn(self.error, r.disconnectCallbacks)
1155
1156        # register disconnect callbacks
1157        c.notifyOnDisconnect(self.gotDisconnected)
1158        r.notifyOnDisconnect(self.objectDisconnected)
1159        self.remoteObject = r
1160
1161        # disconnect
1162        c.connectionLost(failure.Failure(main.CONNECTION_DONE))
1163        self.assertTrue(self.gotCallback)
1164        self.assertTrue(self.objectCallback)
1165
1166
1167class FreakOut(Exception):
1168    pass
1169
1170
1171class BadCopyable(pb.Copyable):
1172    def getStateToCopyFor(self, p):
1173        raise FreakOut()
1174
1175
1176class BadCopySet(pb.Referenceable):
1177    def remote_setBadCopy(self, bc):
1178        return None
1179
1180
1181class LocalRemoteTest(util.LocalAsRemote):
1182    reportAllTracebacks = 0
1183
1184    def sync_add1(self, x):
1185        return x + 1
1186
1187    def async_add(self, x=0, y=1):
1188        return x + y
1189
1190    def async_fail(self):
1191        raise RuntimeError()
1192
1193
1194@implementer(pb.IPerspective)
1195class MyPerspective(pb.Avatar):
1196    """
1197    @ivar loggedIn: set to C{True} when the avatar is logged in.
1198    @type loggedIn: C{bool}
1199
1200    @ivar loggedOut: set to C{True} when the avatar is logged out.
1201    @type loggedOut: C{bool}
1202    """
1203
1204    loggedIn = loggedOut = False
1205
1206    def __init__(self, avatarId):
1207        self.avatarId = avatarId
1208
1209    def perspective_getAvatarId(self):
1210        """
1211        Return the avatar identifier which was used to access this avatar.
1212        """
1213        return self.avatarId
1214
1215    def perspective_getViewPoint(self):
1216        return MyView()
1217
1218    def perspective_add(self, a, b):
1219        """
1220        Add the given objects and return the result.  This is a method
1221        unavailable on L{Echoer}, so it can only be invoked by authenticated
1222        users who received their avatar from L{TestRealm}.
1223        """
1224        return a + b
1225
1226    def logout(self):
1227        self.loggedOut = True
1228
1229
1230class TestRealm:
1231    """
1232    A realm which repeatedly gives out a single instance of L{MyPerspective}
1233    for non-anonymous logins and which gives out a new instance of L{Echoer}
1234    for each anonymous login.
1235
1236    @ivar lastPerspective: The L{MyPerspective} most recently created and
1237        returned from C{requestAvatar}.
1238
1239    @ivar perspectiveFactory: A one-argument callable which will be used to
1240        create avatars to be returned from C{requestAvatar}.
1241    """
1242
1243    perspectiveFactory = MyPerspective
1244
1245    lastPerspective = None
1246
1247    def requestAvatar(self, avatarId, mind, interface):
1248        """
1249        Verify that the mind and interface supplied have the expected values
1250        (this should really be done somewhere else, like inside a test method)
1251        and return an avatar appropriate for the given identifier.
1252        """
1253        assert interface == pb.IPerspective
1254        assert mind == "BRAINS!"
1255        if avatarId is checkers.ANONYMOUS:
1256            return pb.IPerspective, Echoer(), lambda: None
1257        else:
1258            self.lastPerspective = self.perspectiveFactory(avatarId)
1259            self.lastPerspective.loggedIn = True
1260            return (pb.IPerspective, self.lastPerspective, self.lastPerspective.logout)
1261
1262
1263class MyView(pb.Viewable):
1264    def view_check(self, user):
1265        return isinstance(user, MyPerspective)
1266
1267
1268class LeakyRealm(TestRealm):
1269    """
1270    A realm which hangs onto a reference to the mind object in its logout
1271    function.
1272    """
1273
1274    def __init__(self, mindEater):
1275        """
1276        Create a L{LeakyRealm}.
1277
1278        @param mindEater: a callable that will be called with the C{mind}
1279        object when it is available
1280        """
1281        self._mindEater = mindEater
1282
1283    def requestAvatar(self, avatarId, mind, interface):
1284        self._mindEater(mind)
1285        persp = self.perspectiveFactory(avatarId)
1286        return (pb.IPerspective, persp, lambda: (mind, persp.logout()))
1287
1288
1289class NewCredLeakTests(unittest.TestCase):
1290    """
1291    Tests to try to trigger memory leaks.
1292    """
1293
1294    def test_logoutLeak(self):
1295        """
1296        The server does not leak a reference when the client disconnects
1297        suddenly, even if the cred logout function forms a reference cycle with
1298        the perspective.
1299        """
1300        # keep a weak reference to the mind object, which we can verify later
1301        # evaluates to None, thereby ensuring the reference leak is fixed.
1302        self.mindRef = None
1303
1304        def setMindRef(mind):
1305            self.mindRef = weakref.ref(mind)
1306
1307        clientBroker, serverBroker, pump = connectedServerAndClient(
1308            test=self, realm=LeakyRealm(setMindRef)
1309        )
1310
1311        # log in from the client
1312        connectionBroken = []
1313        root = clientBroker.remoteForName("root")
1314        d = root.callRemote("login", b"guest")
1315
1316        def cbResponse(x):
1317            challenge, challenger = x
1318            mind = SimpleRemote()
1319            return challenger.callRemote(
1320                "respond", pb.respond(challenge, b"guest"), mind
1321            )
1322
1323        d.addCallback(cbResponse)
1324
1325        def connectionLost(_):
1326            pump.stop()  # don't try to pump data anymore - it won't work
1327            connectionBroken.append(1)
1328            serverBroker.connectionLost(failure.Failure(RuntimeError("boom")))
1329
1330        d.addCallback(connectionLost)
1331
1332        # flush out the response and connectionLost
1333        pump.flush()
1334        self.assertEqual(connectionBroken, [1])
1335
1336        # and check for lingering references - requestAvatar sets mindRef
1337        # to a weakref to the mind; this object should be gc'd, and thus
1338        # the ref should return None
1339        gc.collect()
1340        self.assertIsNone(self.mindRef())
1341
1342
1343class NewCredTests(unittest.TestCase):
1344    """
1345    Tests related to the L{twisted.cred} support in PB.
1346    """
1347
1348    def setUp(self):
1349        """
1350        Create a portal with no checkers and wrap it around a simple test
1351        realm.  Set up a PB server on a TCP port which serves perspectives
1352        using that portal.
1353        """
1354        self.realm = TestRealm()
1355        self.portal = portal.Portal(self.realm)
1356        self.serverFactory = ConnectionNotifyServerFactory(self.portal)
1357        self.clientFactory = pb.PBClientFactory()
1358
1359    def establishClientAndServer(self, _ignored=None):
1360        """
1361        Connect a client obtained from C{clientFactory} and a server
1362        obtained from the current server factory via an L{IOPump},
1363        then assign them to the appropriate instance variables
1364
1365        @ivar clientFactory: the broker client factory
1366        @ivar clientFactory: L{pb.PBClientFactory} instance
1367
1368        @ivar client: the client broker
1369        @type client: L{pb.Broker}
1370
1371        @ivar server: the server broker
1372        @type server: L{pb.Broker}
1373
1374        @ivar pump: the IOPump connecting the client and server
1375        @type pump: L{IOPump}
1376
1377        @ivar connector: A connector whose connect method recreates
1378            the above instance variables
1379        @type connector: L{twisted.internet.base.IConnector}
1380        """
1381        self.client, self.server, self.pump = connectServerAndClient(
1382            self, self.clientFactory, self.serverFactory
1383        )
1384
1385        self.connectorState = _ReconnectingFakeConnectorState()
1386        self.connector = _ReconnectingFakeConnector(
1387            address.IPv4Address("TCP", "127.0.0.1", 4321), self.connectorState
1388        )
1389        self.connectorState.notifyOnConnect().addCallback(self.establishClientAndServer)
1390
1391    def completeClientLostConnection(
1392        self, reason=failure.Failure(main.CONNECTION_DONE)
1393    ):
1394        """
1395        Asserts that the client broker's transport was closed and then
1396        mimics the event loop by calling the broker's connectionLost
1397        callback with C{reason}, followed by C{self.clientFactory}'s
1398        C{clientConnectionLost}
1399
1400        @param reason: (optional) the reason to pass to the client
1401            broker's connectionLost callback
1402        @type reason: L{Failure}
1403        """
1404        self.assertTrue(self.client.transport.closed)
1405        # simulate the reactor calling back the client's
1406        # connectionLost after the loseConnection implied by
1407        # clientFactory.disconnect
1408        self.client.connectionLost(reason)
1409        self.clientFactory.clientConnectionLost(self.connector, reason)
1410
1411    def test_getRootObject(self):
1412        """
1413        Assert that L{PBClientFactory.getRootObject}'s Deferred fires with
1414        a L{RemoteReference}, and that disconnecting it runs its
1415        disconnection callbacks.
1416        """
1417        self.establishClientAndServer()
1418        rootObjDeferred = self.clientFactory.getRootObject()
1419
1420        def gotRootObject(rootObj):
1421            self.assertIsInstance(rootObj, pb.RemoteReference)
1422            return rootObj
1423
1424        def disconnect(rootObj):
1425            disconnectedDeferred = Deferred()
1426            rootObj.notifyOnDisconnect(disconnectedDeferred.callback)
1427            self.clientFactory.disconnect()
1428
1429            self.completeClientLostConnection()
1430
1431            return disconnectedDeferred
1432
1433        rootObjDeferred.addCallback(gotRootObject)
1434        rootObjDeferred.addCallback(disconnect)
1435
1436        return rootObjDeferred
1437
1438    def test_deadReferenceError(self):
1439        """
1440        Test that when a connection is lost, calling a method on a
1441        RemoteReference obtained from it raises L{DeadReferenceError}.
1442        """
1443        self.establishClientAndServer()
1444        rootObjDeferred = self.clientFactory.getRootObject()
1445
1446        def gotRootObject(rootObj):
1447            disconnectedDeferred = Deferred()
1448            rootObj.notifyOnDisconnect(disconnectedDeferred.callback)
1449
1450            def lostConnection(ign):
1451                self.assertRaises(pb.DeadReferenceError, rootObj.callRemote, "method")
1452
1453            disconnectedDeferred.addCallback(lostConnection)
1454            self.clientFactory.disconnect()
1455
1456            self.completeClientLostConnection()
1457
1458            return disconnectedDeferred
1459
1460        return rootObjDeferred.addCallback(gotRootObject)
1461
1462    def test_clientConnectionLost(self):
1463        """
1464        Test that if the L{reconnecting} flag is passed with a True value then
1465        a remote call made from a disconnection notification callback gets a
1466        result successfully.
1467        """
1468
1469        class ReconnectOnce(pb.PBClientFactory):
1470            reconnectedAlready = False
1471
1472            def clientConnectionLost(self, connector, reason):
1473                reconnecting = not self.reconnectedAlready
1474                self.reconnectedAlready = True
1475                result = pb.PBClientFactory.clientConnectionLost(
1476                    self, connector, reason, reconnecting
1477                )
1478                if reconnecting:
1479                    connector.connect()
1480                return result
1481
1482        self.clientFactory = ReconnectOnce()
1483        self.establishClientAndServer()
1484
1485        rootObjDeferred = self.clientFactory.getRootObject()
1486
1487        def gotRootObject(rootObj):
1488            self.assertIsInstance(rootObj, pb.RemoteReference)
1489
1490            d = Deferred()
1491            rootObj.notifyOnDisconnect(d.callback)
1492            # request a disconnection
1493            self.clientFactory.disconnect()
1494            self.completeClientLostConnection()
1495
1496            def disconnected(ign):
1497                d = self.clientFactory.getRootObject()
1498
1499                def gotAnotherRootObject(anotherRootObj):
1500                    self.assertIsInstance(anotherRootObj, pb.RemoteReference)
1501                    d = Deferred()
1502                    anotherRootObj.notifyOnDisconnect(d.callback)
1503                    self.clientFactory.disconnect()
1504                    self.completeClientLostConnection()
1505                    return d
1506
1507                return d.addCallback(gotAnotherRootObject)
1508
1509            return d.addCallback(disconnected)
1510
1511        return rootObjDeferred.addCallback(gotRootObject)
1512
1513    def test_immediateClose(self):
1514        """
1515        Test that if a Broker loses its connection without receiving any bytes,
1516        it doesn't raise any exceptions or log any errors.
1517        """
1518        self.establishClientAndServer()
1519        serverProto = self.serverFactory.buildProtocol(("127.0.0.1", 12345))
1520        serverProto.makeConnection(protocol.FileWrapper(StringIO()))
1521        serverProto.connectionLost(failure.Failure(main.CONNECTION_DONE))
1522
1523    def test_loginConnectionRefused(self):
1524        """
1525        L{PBClientFactory.login} returns a L{Deferred} which is errbacked
1526        with the L{ConnectionRefusedError} if the underlying connection is
1527        refused.
1528        """
1529        clientFactory = pb.PBClientFactory()
1530        loginDeferred = clientFactory.login(
1531            credentials.UsernamePassword(b"foo", b"bar")
1532        )
1533        clientFactory.clientConnectionFailed(
1534            None,
1535            failure.Failure(
1536                ConnectionRefusedError("Test simulated refused connection")
1537            ),
1538        )
1539        return self.assertFailure(loginDeferred, ConnectionRefusedError)
1540
1541    def test_loginLogout(self):
1542        """
1543        Test that login can be performed with IUsernamePassword credentials and
1544        that when the connection is dropped the avatar is logged out.
1545        """
1546        self.portal.registerChecker(
1547            checkers.InMemoryUsernamePasswordDatabaseDontUse(user=b"pass")
1548        )
1549        creds = credentials.UsernamePassword(b"user", b"pass")
1550
1551        # NOTE: real code probably won't need anything where we have the
1552        # "BRAINS!" argument, passing None is fine. We just do it here to
1553        # test that it is being passed. It is used to give additional info to
1554        # the realm to aid perspective creation, if you don't need that,
1555        # ignore it.
1556        mind = "BRAINS!"
1557
1558        loginCompleted = Deferred()
1559
1560        d = self.clientFactory.login(creds, mind)
1561
1562        def cbLogin(perspective):
1563            self.assertTrue(self.realm.lastPerspective.loggedIn)
1564            self.assertIsInstance(perspective, pb.RemoteReference)
1565            return loginCompleted
1566
1567        def cbDisconnect(ignored):
1568            self.clientFactory.disconnect()
1569            self.completeClientLostConnection()
1570
1571        d.addCallback(cbLogin)
1572        d.addCallback(cbDisconnect)
1573
1574        def cbLogout(ignored):
1575            self.assertTrue(self.realm.lastPerspective.loggedOut)
1576
1577        d.addCallback(cbLogout)
1578
1579        self.establishClientAndServer()
1580        self.pump.flush()
1581        # The perspective passed to cbLogin has gone out of scope.
1582        # Ensure its __del__ runs...
1583        gc.collect()
1584        # ...and send its decref message to the server
1585        self.pump.flush()
1586        # Now allow the client to disconnect.
1587        loginCompleted.callback(None)
1588        return d
1589
1590    def test_logoutAfterDecref(self):
1591        """
1592        If a L{RemoteReference} to an L{IPerspective} avatar is decrefed and
1593        there remain no other references to the avatar on the server, the
1594        avatar is garbage collected and the logout method called.
1595        """
1596        loggedOut = Deferred()
1597
1598        class EventPerspective(pb.Avatar):
1599            """
1600            An avatar which fires a Deferred when it is logged out.
1601            """
1602
1603            def __init__(self, avatarId):
1604                pass
1605
1606            def logout(self):
1607                loggedOut.callback(None)
1608
1609        self.realm.perspectiveFactory = EventPerspective
1610
1611        self.portal.registerChecker(
1612            checkers.InMemoryUsernamePasswordDatabaseDontUse(foo=b"bar")
1613        )
1614
1615        d = self.clientFactory.login(
1616            credentials.UsernamePassword(b"foo", b"bar"), "BRAINS!"
1617        )
1618
1619        def cbLoggedIn(avatar):
1620            # Just wait for the logout to happen, as it should since the
1621            # reference to the avatar will shortly no longer exists.
1622            return loggedOut
1623
1624        d.addCallback(cbLoggedIn)
1625
1626        def cbLoggedOut(ignored):
1627            # Verify that the server broker's _localCleanup dict isn't growing
1628            # without bound.
1629            self.assertEqual(self.serverFactory.protocolInstance._localCleanup, {})
1630
1631        d.addCallback(cbLoggedOut)
1632
1633        self.establishClientAndServer()
1634
1635        # complete authentication
1636        self.pump.flush()
1637        # _PortalAuthChallenger and our Avatar should be dead by now;
1638        # force a collection to trigger their __del__s
1639        gc.collect()
1640        # push their decref messages through
1641        self.pump.flush()
1642        return d
1643
1644    def test_concurrentLogin(self):
1645        """
1646        Two different correct login attempts can be made on the same root
1647        object at the same time and produce two different resulting avatars.
1648        """
1649        self.portal.registerChecker(
1650            checkers.InMemoryUsernamePasswordDatabaseDontUse(foo=b"bar", baz=b"quux")
1651        )
1652
1653        firstLogin = self.clientFactory.login(
1654            credentials.UsernamePassword(b"foo", b"bar"), "BRAINS!"
1655        )
1656        secondLogin = self.clientFactory.login(
1657            credentials.UsernamePassword(b"baz", b"quux"), "BRAINS!"
1658        )
1659        d = gatherResults([firstLogin, secondLogin])
1660
1661        def cbLoggedIn(result):
1662            (first, second) = result
1663            return gatherResults(
1664                [first.callRemote("getAvatarId"), second.callRemote("getAvatarId")]
1665            )
1666
1667        d.addCallback(cbLoggedIn)
1668
1669        def cbAvatarIds(x):
1670            first, second = x
1671            self.assertEqual(first, b"foo")
1672            self.assertEqual(second, b"baz")
1673
1674        d.addCallback(cbAvatarIds)
1675
1676        self.establishClientAndServer()
1677        self.pump.flush()
1678
1679        return d
1680
1681    def test_badUsernamePasswordLogin(self):
1682        """
1683        Test that a login attempt with an invalid user or invalid password
1684        fails in the appropriate way.
1685        """
1686        self.portal.registerChecker(
1687            checkers.InMemoryUsernamePasswordDatabaseDontUse(user=b"pass")
1688        )
1689
1690        firstLogin = self.clientFactory.login(
1691            credentials.UsernamePassword(b"nosuchuser", b"pass")
1692        )
1693        secondLogin = self.clientFactory.login(
1694            credentials.UsernamePassword(b"user", b"wrongpass")
1695        )
1696
1697        self.assertFailure(firstLogin, UnauthorizedLogin)
1698        self.assertFailure(secondLogin, UnauthorizedLogin)
1699        d = gatherResults([firstLogin, secondLogin])
1700
1701        def cleanup(ignore):
1702            errors = self.flushLoggedErrors(UnauthorizedLogin)
1703            self.assertEqual(len(errors), 2)
1704
1705        d.addCallback(cleanup)
1706
1707        self.establishClientAndServer()
1708        self.pump.flush()
1709
1710        return d
1711
1712    def test_anonymousLogin(self):
1713        """
1714        Verify that a PB server using a portal configured with a checker which
1715        allows IAnonymous credentials can be logged into using IAnonymous
1716        credentials.
1717        """
1718        self.portal.registerChecker(checkers.AllowAnonymousAccess())
1719        d = self.clientFactory.login(credentials.Anonymous(), "BRAINS!")
1720
1721        def cbLoggedIn(perspective):
1722            return perspective.callRemote("echo", 123)
1723
1724        d.addCallback(cbLoggedIn)
1725
1726        d.addCallback(self.assertEqual, 123)
1727
1728        self.establishClientAndServer()
1729        self.pump.flush()
1730        return d
1731
1732    def test_anonymousLoginNotPermitted(self):
1733        """
1734        Verify that without an anonymous checker set up, anonymous login is
1735        rejected.
1736        """
1737        self.portal.registerChecker(
1738            checkers.InMemoryUsernamePasswordDatabaseDontUse(user="pass")
1739        )
1740        d = self.clientFactory.login(credentials.Anonymous(), "BRAINS!")
1741        self.assertFailure(d, UnhandledCredentials)
1742
1743        def cleanup(ignore):
1744            errors = self.flushLoggedErrors(UnhandledCredentials)
1745            self.assertEqual(len(errors), 1)
1746
1747        d.addCallback(cleanup)
1748
1749        self.establishClientAndServer()
1750        self.pump.flush()
1751
1752        return d
1753
1754    def test_anonymousLoginWithMultipleCheckers(self):
1755        """
1756        Like L{test_anonymousLogin} but against a portal with a checker for
1757        both IAnonymous and IUsernamePassword.
1758        """
1759        self.portal.registerChecker(checkers.AllowAnonymousAccess())
1760        self.portal.registerChecker(
1761            checkers.InMemoryUsernamePasswordDatabaseDontUse(user=b"pass")
1762        )
1763        d = self.clientFactory.login(credentials.Anonymous(), "BRAINS!")
1764
1765        def cbLogin(perspective):
1766            return perspective.callRemote("echo", 123)
1767
1768        d.addCallback(cbLogin)
1769
1770        d.addCallback(self.assertEqual, 123)
1771
1772        self.establishClientAndServer()
1773        self.pump.flush()
1774
1775        return d
1776
1777    def test_authenticatedLoginWithMultipleCheckers(self):
1778        """
1779        Like L{test_anonymousLoginWithMultipleCheckers} but check that
1780        username/password authentication works.
1781        """
1782        self.portal.registerChecker(checkers.AllowAnonymousAccess())
1783        self.portal.registerChecker(
1784            checkers.InMemoryUsernamePasswordDatabaseDontUse(user=b"pass")
1785        )
1786        d = self.clientFactory.login(
1787            credentials.UsernamePassword(b"user", b"pass"), "BRAINS!"
1788        )
1789
1790        def cbLogin(perspective):
1791            return perspective.callRemote("add", 100, 23)
1792
1793        d.addCallback(cbLogin)
1794
1795        d.addCallback(self.assertEqual, 123)
1796
1797        self.establishClientAndServer()
1798        self.pump.flush()
1799
1800        return d
1801
1802    def test_view(self):
1803        """
1804        Verify that a viewpoint can be retrieved after authenticating with
1805        cred.
1806        """
1807        self.portal.registerChecker(
1808            checkers.InMemoryUsernamePasswordDatabaseDontUse(user=b"pass")
1809        )
1810        d = self.clientFactory.login(
1811            credentials.UsernamePassword(b"user", b"pass"), "BRAINS!"
1812        )
1813
1814        def cbLogin(perspective):
1815            return perspective.callRemote("getViewPoint")
1816
1817        d.addCallback(cbLogin)
1818
1819        def cbView(viewpoint):
1820            return viewpoint.callRemote("check")
1821
1822        d.addCallback(cbView)
1823
1824        d.addCallback(self.assertTrue)
1825
1826        self.establishClientAndServer()
1827        self.pump.flush()
1828
1829        return d
1830
1831
1832@implementer(pb.IPerspective)
1833class NonSubclassingPerspective:
1834    def __init__(self, avatarId):
1835        pass
1836
1837    # IPerspective implementation
1838    def perspectiveMessageReceived(self, broker, message, args, kwargs):
1839        args = broker.unserialize(args, self)
1840        kwargs = broker.unserialize(kwargs, self)
1841        return broker.serialize((message, args, kwargs))
1842
1843    # Methods required by TestRealm
1844    def logout(self):
1845        self.loggedOut = True
1846
1847
1848class NSPTests(unittest.TestCase):
1849    """
1850    Tests for authentication against a realm where the L{IPerspective}
1851    implementation is not a subclass of L{Avatar}.
1852    """
1853
1854    def setUp(self):
1855        self.realm = TestRealm()
1856        self.realm.perspectiveFactory = NonSubclassingPerspective
1857        self.portal = portal.Portal(self.realm)
1858        self.checker = checkers.InMemoryUsernamePasswordDatabaseDontUse()
1859        self.checker.addUser(b"user", b"pass")
1860        self.portal.registerChecker(self.checker)
1861        self.factory = WrappingFactory(pb.PBServerFactory(self.portal))
1862        self.port = reactor.listenTCP(0, self.factory, interface="127.0.0.1")
1863        self.addCleanup(self.port.stopListening)
1864        self.portno = self.port.getHost().port
1865
1866    def test_NSP(self):
1867        """
1868        An L{IPerspective} implementation which does not subclass
1869        L{Avatar} can expose remote methods for the client to call.
1870        """
1871        factory = pb.PBClientFactory()
1872        d = factory.login(credentials.UsernamePassword(b"user", b"pass"), "BRAINS!")
1873        reactor.connectTCP("127.0.0.1", self.portno, factory)
1874        d.addCallback(lambda p: p.callRemote("ANYTHING", "here", bar="baz"))
1875        d.addCallback(self.assertEqual, ("ANYTHING", ("here",), {"bar": "baz"}))
1876
1877        def cleanup(ignored):
1878            factory.disconnect()
1879            for p in self.factory.protocols:
1880                p.transport.loseConnection()
1881
1882        d.addCallback(cleanup)
1883        return d
1884
1885
1886class IForwarded(Interface):
1887    """
1888    Interface used for testing L{util.LocalAsyncForwarder}.
1889    """
1890
1891    def forwardMe():
1892        """
1893        Simple synchronous method.
1894        """
1895
1896    def forwardDeferred():
1897        """
1898        Simple asynchronous method.
1899        """
1900
1901
1902@implementer(IForwarded)
1903class Forwarded:
1904    """
1905    Test implementation of L{IForwarded}.
1906
1907    @ivar forwarded: set if C{forwardMe} is called.
1908    @type forwarded: C{bool}
1909    @ivar unforwarded: set if C{dontForwardMe} is called.
1910    @type unforwarded: C{bool}
1911    """
1912
1913    forwarded = False
1914    unforwarded = False
1915
1916    def forwardMe(self):
1917        """
1918        Set a local flag to test afterwards.
1919        """
1920        self.forwarded = True
1921
1922    def dontForwardMe(self):
1923        """
1924        Set a local flag to test afterwards. This should not be called as it's
1925        not in the interface.
1926        """
1927        self.unforwarded = True
1928
1929    def forwardDeferred(self):
1930        """
1931        Asynchronously return C{True}.
1932        """
1933        return succeed(True)
1934
1935
1936class SpreadUtilTests(unittest.TestCase):
1937    """
1938    Tests for L{twisted.spread.util}.
1939    """
1940
1941    def test_sync(self):
1942        """
1943        Call a synchronous method of a L{util.LocalAsRemote} object and check
1944        the result.
1945        """
1946        o = LocalRemoteTest()
1947        self.assertEqual(o.callRemote("add1", 2), 3)
1948
1949    def test_async(self):
1950        """
1951        Call an asynchronous method of a L{util.LocalAsRemote} object and check
1952        the result.
1953        """
1954        o = LocalRemoteTest()
1955        o = LocalRemoteTest()
1956        d = o.callRemote("add", 2, y=4)
1957        self.assertIsInstance(d, Deferred)
1958        d.addCallback(self.assertEqual, 6)
1959        return d
1960
1961    def test_asyncFail(self):
1962        """
1963        Test an asynchronous failure on a remote method call.
1964        """
1965        o = LocalRemoteTest()
1966        d = o.callRemote("fail")
1967
1968        def eb(f):
1969            self.assertIsInstance(f, failure.Failure)
1970            f.trap(RuntimeError)
1971
1972        d.addCallbacks(lambda res: self.fail("supposed to fail"), eb)
1973        return d
1974
1975    def test_remoteMethod(self):
1976        """
1977        Test the C{remoteMethod} facility of L{util.LocalAsRemote}.
1978        """
1979        o = LocalRemoteTest()
1980        m = o.remoteMethod("add1")
1981        self.assertEqual(m(3), 4)
1982
1983    def test_localAsyncForwarder(self):
1984        """
1985        Test a call to L{util.LocalAsyncForwarder} using L{Forwarded} local
1986        object.
1987        """
1988        f = Forwarded()
1989        lf = util.LocalAsyncForwarder(f, IForwarded)
1990        lf.callRemote("forwardMe")
1991        self.assertTrue(f.forwarded)
1992        lf.callRemote("dontForwardMe")
1993        self.assertFalse(f.unforwarded)
1994        rr = lf.callRemote("forwardDeferred")
1995        l = []
1996        rr.addCallback(l.append)
1997        self.assertEqual(l[0], 1)
1998
1999
2000class PBWithSecurityOptionsTests(unittest.TestCase):
2001    """
2002    Test security customization.
2003    """
2004
2005    def test_clientDefaultSecurityOptions(self):
2006        """
2007        By default, client broker should use C{jelly.globalSecurity} as
2008        security settings.
2009        """
2010        factory = pb.PBClientFactory()
2011        broker = factory.buildProtocol(None)
2012        self.assertIs(broker.security, jelly.globalSecurity)
2013
2014    def test_serverDefaultSecurityOptions(self):
2015        """
2016        By default, server broker should use C{jelly.globalSecurity} as
2017        security settings.
2018        """
2019        factory = pb.PBServerFactory(Echoer())
2020        broker = factory.buildProtocol(None)
2021        self.assertIs(broker.security, jelly.globalSecurity)
2022
2023    def test_clientSecurityCustomization(self):
2024        """
2025        Check that the security settings are passed from the client factory to
2026        the broker object.
2027        """
2028        security = jelly.SecurityOptions()
2029        factory = pb.PBClientFactory(security=security)
2030        broker = factory.buildProtocol(None)
2031        self.assertIs(broker.security, security)
2032
2033    def test_serverSecurityCustomization(self):
2034        """
2035        Check that the security settings are passed from the server factory to
2036        the broker object.
2037        """
2038        security = jelly.SecurityOptions()
2039        factory = pb.PBServerFactory(Echoer(), security=security)
2040        broker = factory.buildProtocol(None)
2041        self.assertIs(broker.security, security)
2042