1"""
2 session stuff for jabber connections
3
4"""
5from twisted.internet import defer,  reactor
6from twisted.python import failure, log
7from twisted.web import server
8from twisted.names.srvconnect import SRVConnector
9
10try:
11    from twisted.words.xish import domish, xmlstream
12    from twisted.words.protocols import jabber as jabber_protocol
13except ImportError:
14    from twisted.xish import domish, xmlstream
15
16
17import traceback
18import os
19import warnings
20from punjab import jabber
21from punjab.xmpp import ns
22
23import time
24import error
25
26try:
27    from twisted.internet import ssl
28except ImportError:
29    ssl = None
30if ssl and not ssl.supported:
31    ssl = None
32if not ssl:
33    log.msg("SSL ERROR: You do not have ssl support this may cause problems with tls client connections.")
34
35
36
37class XMPPClientConnector(SRVConnector):
38    """
39    A jabber connection to find srv records for xmpp client connections.
40    """
41    def __init__(self, client_reactor, domain, factory):
42        """ Init """
43        if isinstance(domain, unicode):
44            warnings.warn(
45                "Domain argument to XMPPClientConnector should be bytes, "
46                "not unicode",
47                stacklevel=2)
48            domain = domain.encode('ascii')
49        SRVConnector.__init__(self, client_reactor, 'xmpp-client', domain, factory)
50        self.timeout = [1,3]
51
52    def pickServer(self):
53        """
54        Pick a server and port to make the connection.
55        """
56        host, port = SRVConnector.pickServer(self)
57
58        if port == 5223 and ssl:
59            context = ssl.ClientContextFactory()
60            context.method = ssl.SSL.SSLv23_METHOD
61
62            self.connectFuncName = 'connectSSL'
63            self.connectFuncArgs = (context,)
64        return host, port
65
66def make_session(pint, attrs, session_type='BOSH'):
67    """
68    pint  - punjab session interface class
69    attrs - attributes sent from the body tag
70    """
71
72    s    = Session(pint, attrs)
73
74    if pint.v:
75        log.msg('================================== %s connect to %s:%s ==================================' % (str(time.time()),s.hostname,s.port))
76
77    connect_srv = s.connect_srv
78    if attrs.has_key('route'):
79        connect_srv = False
80    if s.hostname in ['localhost', '127.0.0.1']:
81        connect_srv = False
82    if not connect_srv:
83        reactor.connectTCP(s.hostname, s.port, s, bindAddress=pint.bindAddress)
84    else:
85        connector = XMPPClientConnector(reactor, s.hostname, s)
86        connector.connect()
87    # timeout
88    reactor.callLater(s.inactivity, s.checkExpired)
89
90    pint.sessions[s.sid] = s
91
92    return s, s.waiting_requests[0].deferred
93
94
95class WaitingRequest(object):
96    """A helper object for managing waiting requests."""
97
98    def __init__(self, deferred, delayedcall, timeout = 30, startup = False, rid = None):
99        """ """
100        self.deferred    = deferred
101        self.delayedcall = delayedcall
102        self.startup     = startup
103        self.timeout     = timeout
104        self.wait_start  = time.time()
105        self.rid         = rid
106
107    def doCallback(self, data):
108        """ """
109        self.deferred.callback(data)
110
111    def doErrback(self, data):
112        """ """
113        self.deferred.errback(data)
114
115
116class Session(jabber.JabberClientFactory, server.Session):
117    """ Jabber Client Session class for client XMPP connections. """
118    def __init__(self, pint, attrs):
119        """
120        Initialize the session
121        """
122        if attrs.has_key('charset'):
123            self.charset = str(attrs['charset'])
124        else:
125            self.charset = 'utf-8'
126
127        self.to    = attrs['to']
128        self.port  = 5222
129        self.inactivity = 900
130        if self.to != '' and self.to.find(":") != -1:
131            # Check if port is in the 'to' string
132            to, port = self.to.split(':')
133
134            if port:
135                self.to   = to
136                self.port = int(port)
137            else:
138                self.port = 5222
139
140        self.sid = "".join("%02x" % ord(i) for i in os.urandom(20))
141
142        jabber.JabberClientFactory.__init__(self, self.to, pint.v)
143        server.Session.__init__(self, pint, self.sid)
144        self.pint  = pint
145
146        self.attrs = attrs
147        self.s     = None
148
149        self.elems = []
150        rid        = int(attrs['rid'])
151
152        self.waiting_requests = []
153        self.use_raw = attrs.get('raw', False)
154
155        self.raw_buffer = u""
156        self.xmpp_node  = ''
157        self.success    = 0
158        self.mechanisms = []
159        self.xmlstream  = None
160        self.features   = None
161        self.session    = None
162
163        self.cache_data = {}
164        self.verbose    = self.pint.v
165        self.noisy      = self.verbose
166
167        self.version = attrs.get('version', 0.0)
168
169        self.key = attrs.get('newkey')
170
171        self.wait  = int(attrs.get('wait', 0))
172
173        self.hold  = int(attrs.get('hold', 0))
174        self.inactivity = int(attrs.get('inactivity', 900)) # default inactivity 15 mins
175
176        if attrs.has_key('window'):
177            self.window  = int(attrs['window'])
178        else:
179            self.window  = self.hold + 2
180
181        if attrs.has_key('polling'):
182            self.polling  = int(attrs['polling'])
183        else:
184            self.polling  = 0
185
186        if attrs.has_key('port'):
187            self.port = int(attrs['port'])
188
189        if attrs.has_key('hostname'):
190            self.hostname = attrs['hostname']
191        else:
192            self.hostname = self.to
193
194        self.use_raw = getattr(pint, 'use_raw', False) # use raw buffers
195
196        self.connect_srv = getattr(pint, 'connect_srv', True)
197
198        self.secure = attrs.has_key('secure') and attrs['secure'] == 'true'
199        self.authenticator.useTls = self.secure
200
201        if attrs.has_key('route'):
202            if attrs['route'].startswith("xmpp:"):
203                self.route = attrs['route'][5:]
204                if self.route.startswith("//"):
205                    self.route = self.route[2:]
206
207                # route format change, see http://www.xmpp.org/extensions/xep-0124.html#session-request
208                rhostname, rport = self.route.split(":")
209                self.port = int(rport)
210                self.hostname = rhostname
211                self.resource = ''
212            else:
213                raise error.Error('internal-server-error')
214
215
216        self.authid      = 0
217        self.rid         = rid + 1
218        self.connected   = 0 # number of clients connected on this session
219
220        self.notifyOnExpire(self.onExpire)
221        self.stream_error = None
222        if pint.v:
223            log.msg('Session Created : %s %s' % (str(self.sid),str(time.time()), ))
224        self.stream_error_called = False
225        self.addBootstrap(xmlstream.STREAM_START_EVENT, self.streamStart)
226        self.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT, self.connectEvent)
227        self.addBootstrap(xmlstream.STREAM_ERROR_EVENT, self.streamError)
228        self.addBootstrap(xmlstream.STREAM_END_EVENT, self.connectError)
229
230        # create the first waiting request
231        d = defer.Deferred()
232        timeout = 30
233        rid = self.rid - 1
234        self.appendWaitingRequest(d, rid,
235                                  timeout=timeout,
236                                  poll=self._startup_timeout,
237                                  startup=True,
238                                  )
239
240    def rawDataIn(self, buf):
241        """ Log incoming data on the xmlstream """
242        if self.pint and self.pint.v:
243            try:
244                log.msg("SID: %s => RECV: %r" % (self.sid, buf,))
245            except:
246                log.err()
247        if self.use_raw and self.authid:
248            if type(buf) == type(''):
249                buf = unicode(buf, 'utf-8')
250            # add some raw data
251            self.raw_buffer = self.raw_buffer + buf
252
253
254    def rawDataOut(self, buf):
255        """ Log outgoing data on the xmlstream """
256        try:
257            log.msg("SID: %s => SEND: %r" % (self.sid, buf,))
258        except:
259            log.err()
260
261    def _wrPop(self, data, i=0):
262        """Pop off a waiting requst, do callback, and cache request
263        """
264        wr = self.waiting_requests.pop(i)
265        wr.doCallback(data)
266        self._cacheData(wr.rid, data)
267
268    def clearWaitingRequests(self, hold = 0):
269        """clear number of requests given
270
271           hold - number of requests to clear, default is all
272        """
273        while len(self.waiting_requests) > hold:
274            self._wrPop([])
275
276    def _wrError(self, err, i = 0):
277        wr = self.waiting_requests.pop(i)
278        wr.doErrback(err)
279
280
281    def appendWaitingRequest(self, d, rid, timeout=None, poll=None, startup=False):
282        """append waiting request
283        """
284        if timeout is None:
285            timeout = self.wait
286        if poll is None:
287            poll = self._pollTimeout
288        self.waiting_requests.append(
289            WaitingRequest(d,
290                           poll,
291                           timeout = timeout,
292                           rid = rid,
293                           startup=startup))
294
295    def returnWaitingRequests(self):
296        """return a waiting request
297        """
298        while len(self.elems) > 0 and len(self.waiting_requests) > 0:
299            data = self.elems
300            self.elems = []
301            self._wrPop(data)
302
303
304    def onExpire(self):
305        """ When the session expires call this. """
306        if 'onExpire' in dir(self.pint):
307            self.pint.onExpire(self.sid)
308        if self.verbose and not getattr(self, 'terminated', False):
309            log.msg('SESSION -> We have expired', self.sid, self.rid, self.waiting_requests)
310        self.disconnect()
311
312    def terminate(self):
313        """Terminates the session."""
314        self.wait = 0
315        self.terminated = True
316        if self.verbose:
317            log.msg('SESSION -> Terminate')
318
319        # if there are any elements hanging around and waiting
320        # requests, send those off
321        self.returnWaitingRequests()
322
323        self.clearWaitingRequests()
324
325        try:
326            self.expire()
327        except:
328            self.onExpire()
329
330
331        return defer.succeed(self.elems)
332
333    def poll(self, d = None, rid = None):
334        """Handles the responses to requests.
335
336        This function is called for every request except session setup
337        and session termination.  It handles the reply portion of the
338        request by returning a deferred which will get called back
339        when there is data or when the wait timeout expires.
340        """
341        # queue this request
342        if d is None:
343            d = defer.Deferred()
344        if self.pint.error:
345            d.addErrback(self.pint.error)
346        if not rid:
347            rid = self.rid - 1
348        self.appendWaitingRequest(d, rid)
349        # check if there is any data to send back to a request
350        self.returnWaitingRequests()
351
352        # make sure we aren't queueing too many requests
353        self.clearWaitingRequests(self.hold)
354        return d
355
356    def _pollTimeout(self, d):
357        """Handle request timeouts.
358
359        Since the timeout function is called, we must return an empty
360        reply as there is no data to send back.
361        """
362        # find the request that timed out and reply
363        pop_eye = []
364        for i in range(len(self.waiting_requests)):
365            if self.waiting_requests[i].deferred == d:
366                pop_eye.append(i)
367                self.touch()
368
369        for i in pop_eye:
370            self._wrPop([],i)
371
372
373    def _pollForId(self, d):
374        if self.xmlstream.sid:
375            self.authid = self.xmlstream.sid
376        self._pollTimeout(d)
377
378
379
380    def connectEvent(self, xs):
381
382        self.version =  self.authenticator.version
383        self.xmlstream = xs
384        if self.pint.v:
385            # add logging for verbose output
386
387            self.xmlstream.rawDataOutFn = self.rawDataOut
388        self.xmlstream.rawDataInFn = self.rawDataIn
389
390        if self.version == '1.0':
391            self.xmlstream.addObserver("/features", self.featuresHandler)
392
393
394
395    def streamStart(self, xs):
396        """
397        A xmpp stream has started
398        """
399        # This is done to fix the stream id problem, I should submit a bug to twisted bugs
400
401        try:
402
403            self.authid    = self.xmlstream.sid
404
405            if not self.attrs.has_key('no_events'):
406
407                self.xmlstream.addOnetimeObserver("/auth", self.stanzaHandler)
408                self.xmlstream.addOnetimeObserver("/response", self.stanzaHandler)
409                self.xmlstream.addOnetimeObserver("/success", self._saslSuccess)
410                self.xmlstream.addOnetimeObserver("/failure", self._saslError)
411
412                self.xmlstream.addObserver("/iq/bind", self.bindHandler)
413                self.xmlstream.addObserver("/bind", self.stanzaHandler)
414
415                self.xmlstream.addObserver("/challenge", self.stanzaHandler)
416                self.xmlstream.addObserver("/message",  self.stanzaHandler)
417                self.xmlstream.addObserver("/iq",  self.stanzaHandler)
418                self.xmlstream.addObserver("/presence",  self.stanzaHandler)
419                # TODO - we should do something like this
420                # self.xmlstream.addObserver("/*",  self.stanzaHandler)
421
422        except:
423            log.err(traceback.print_exc())
424            self._wrError(error.Error("remote-connection-failed"))
425            self.disconnect()
426
427
428    def featuresHandler(self, f):
429        """
430        handle stream:features
431        """
432        f.prefixes   = ns.XMPP_PREFIXES.copy()
433
434        #check for tls
435        self.f = {}
436        for feature in f.elements():
437            self.f[(feature.uri, feature.name)] = feature
438
439        starttls = (ns.TLS_XMLNS, 'starttls') in self.f
440
441        initializers   = getattr(self.xmlstream, 'initializers', [])
442        self.features = f
443        self.xmlstream.features = f
444
445        # There is a tls initializer added by us, if it is available we need to try it
446        if len(initializers)>0 and starttls:
447            self.secure = True
448
449        if self.authid is None:
450            self.authid = self.xmlstream.sid
451
452
453        # If we get tls, then we should start tls, wait and then return
454        # Here we wait, the tls initializer will start it
455        if starttls and self.secure:
456            if self.verbose:
457                log.msg("Wait until starttls is completed.")
458                log.msg(initializers)
459            return
460        self.elems.append(f)
461        if len(self.waiting_requests) > 0:
462            self.returnWaitingRequests()
463            self.elems = [] # reset elems
464            self.raw_buffer = u"" # reset raw buffer, features should not be in it
465
466    def bindHandler(self, stz):
467        """bind debugger for punjab, this is temporary! """
468        if self.verbose:
469            try:
470                log.msg('BIND: %s %s' % (str(self.sid), str(stz.bind.jid)))
471            except:
472                log.err()
473        if self.use_raw:
474            self.raw_buffer = stz.toXml()
475
476    def stanzaHandler(self, stz):
477        """generic stanza handler for httpbind and httppoll"""
478        stz.prefixes = ns.XMPP_PREFIXES
479        if self.use_raw and self.authid:
480            stz = domish.SerializedXML(self.raw_buffer)
481            self.raw_buffer = u""
482
483        self.elems.append(stz)
484        if self.waiting_requests and len(self.waiting_requests) > 0:
485            # if there are any waiting requests, give them all the
486            # data so far, plus this new data
487            self.returnWaitingRequests()
488
489
490    def _startup_timeout(self, d):
491        # this can be called if connection failed, or if we connected
492        # but never got a stream features before the timeout
493        if self.pint.v:
494            log.msg('================================== %s %s startup timeout ==================================' % (str(self.sid), str(time.time()),))
495
496        for i in range(len(self.waiting_requests)):
497            if self.waiting_requests[i].deferred == d:
498                # check if we really failed or not
499                if self.authid:
500                    self._wrPop(self.elems, i=i)
501                else:
502                    self._wrError(error.Error("remote-connection-failed"), i=i)
503
504
505    def buildRemoteError(self, err_elem=None):
506        # This may not be a stream error, such as an XML parsing error.
507        # So expose it as remote-connection-failed.
508        err = 'remote-connection-failed'
509        if err_elem is not None:
510            # This is an actual stream:error.  Create a remote-stream-error to encapsulate it.
511            err = 'remote-stream-error'
512        e = error.Error(err)
513        e.error_stanza = err
514        e.children = []
515        if err_elem is not None:
516            e.children.append(err_elem)
517        return e
518
519    def streamError(self, streamerror):
520        """called when we get a stream:error stanza"""
521        self.stream_error_called = True
522        try:
523            err_elem = streamerror.value.getElement()
524        except AttributeError:
525            err_elem = None
526
527        e = self.buildRemoteError(err_elem)
528
529        do_expire = True
530
531        if len(self.waiting_requests) > 0:
532            wr = self.waiting_requests.pop(0)
533            wr.doErrback(e)
534        else: # need to wait for a new request and then expire
535            do_expire = False
536
537        if self.pint and self.pint.sessions.has_key(self.sid):
538            if do_expire:
539                try:
540                    self.expire()
541                except:
542                    self.onExpire()
543            else:
544                s = self.pint.sessions.get(self.sid)
545                s.stream_error = e
546
547    def connectError(self, reason):
548        """called when we get disconnected"""
549        if self.stream_error_called: return
550        # Before Twisted 11.x the xmlstream object was passed instead of the
551        # disconnect reason. See http://twistedmatrix.com/trac/ticket/2618
552        if not isinstance(reason, failure.Failure):
553            reason_str = 'Reason unknown'
554        else:
555            reason_str = str(reason)
556
557        # If the connection was established and lost, then we need to report
558        # the error back to the client, since he needs to reauthenticate.
559        # FIXME: If the connection was lost before anything happened, we could
560        # silently retry instead.
561        if self.verbose:
562            log.msg('connect ERROR: %s' % reason_str)
563
564        self.stopTrying()
565
566        e = error.Error('remote-connection-failed')
567
568        do_expire = True
569
570        if self.waiting_requests:
571            wr = self.waiting_requests.pop(0)
572            wr.doErrback(e)
573        else: # need to wait for a new request and then expire
574            do_expire = False
575
576        if self.pint and self.pint.sessions.has_key(self.sid):
577            if do_expire:
578                try:
579                    self.expire()
580                except:
581                    self.onExpire()
582            else:
583                s = self.pint.sessions.get(self.sid)
584                s.stream_error = e
585
586
587    def sendRawXml(self, obj):
588        """
589        Send a raw xml string, not a domish.Element
590        """
591        self.touch()
592        self._send(obj)
593
594
595    def _send(self, xml):
596        """
597        Send valid data over the xmlstream
598        """
599        if self.xmlstream: # FIXME this happens on an expired session and the post has something to send
600            if isinstance(xml, domish.Element):
601                xml.localPrefixes = {}
602            self.xmlstream.send(xml)
603
604    def _removeObservers(self, typ = ''):
605        if typ == 'event':
606            observers = self.xmlstream._eventObservers
607        else:
608            observers = self.xmlstream._xpathObservers
609        emptyLists = []
610        for priority, priorityObservers in observers.iteritems():
611            for query, callbacklist in priorityObservers.iteritems():
612                callbacklist.callbacks = []
613                emptyLists.append((priority, query))
614
615        for priority, query in emptyLists:
616            del observers[priority][query]
617
618    def disconnect(self):
619        """
620        Disconnect from the xmpp server.
621        """
622        if not getattr(self, 'xmlstream',None):
623            return
624
625        if self.xmlstream:
626            #sh = "<presence type='unavailable' xmlns='jabber:client'/>"
627            sh = "</stream:stream>"
628            self.xmlstream.send(sh)
629
630        self.stopTrying()
631        if self.xmlstream:
632            self.xmlstream.transport.loseConnection()
633
634            del self.xmlstream
635        self.connected = 0
636        self.pint      = None
637        self.elems     = []
638
639        if self.waiting_requests:
640            self.clearWaitingRequests()
641            del self.waiting_requests
642        self.mechanisms = None
643        self.features   = None
644
645
646
647    def checkExpired(self):
648        """
649        Check if the session or xmpp connection has expired
650        """
651        # send this so we do not timeout from servers
652        if getattr(self, 'xmlstream', None):
653            self.xmlstream.send(' ')
654        if self.inactivity is None:
655            wait = 900
656        elif self.inactivity == 0:
657            wait = time.time()
658
659        else:
660            wait = self.inactivity
661
662        if self.waiting_requests and len(self.waiting_requests)>0:
663            wait += self.wait # if we have pending requests we need to add the wait time
664
665        if time.time() - self.lastModified > wait+(0.1):
666            if self.site.sessions.has_key(self.uid):
667                self.terminate()
668            else:
669                pass
670
671        else:
672            reactor.callLater(wait, self.checkExpired)
673
674
675    def _cacheData(self, rid, data):
676        if len(self.cache_data.keys())>=3:
677            # remove the first one in
678            keys = self.cache_data.keys()
679            keys.sort()
680            del self.cache_data[keys[0]]
681
682        self.cache_data[int(rid)] = data
683
684# This stuff will leave when SASL and TLS are implemented correctly
685# session stuff
686
687    def _sessionResultEvent(self, iq):
688        """ """
689	if len(self.waiting_requests)>0:
690		wr = self.waiting_requests.pop(0)
691		d  = wr.deferred
692	else:
693		d = None
694
695        if iq["type"] == "result":
696            if d:
697                d.callback(self)
698        else:
699            if d:
700                d.errback(self)
701
702
703    def _saslSuccess(self, s):
704        """ """
705        self.success = 1
706        self.s = s
707        # return success to the client
708        if len(self.waiting_requests)>0:
709            self._wrPop([s])
710
711        self.authenticator._reset()
712        if self.use_raw:
713            self.raw_buffer = u""
714
715
716
717    def _saslError(self, sasl_error, d = None):
718        """ SASL error """
719
720        if d:
721            d.errback(self)
722        if len(self.waiting_requests)>0:
723            self._wrPop([sasl_error])
724