1"""
2    Slixmpp: The Slick XMPP Library
3    Copyright (C) 2011 Nathanael C. Fritz, Lance J.T. Stout
4    This file is part of Slixmpp.
5
6    See the file LICENSE for copying permission.
7"""
8
9import logging
10import hashlib
11import base64
12
13from slixmpp import __version__
14from slixmpp.stanza import StreamFeatures, Presence, Iq
15from slixmpp.xmlstream import register_stanza_plugin, JID
16from slixmpp.xmlstream.handler import Callback
17from slixmpp.xmlstream.matcher import StanzaPath
18from slixmpp.util import MemoryCache
19from slixmpp import asyncio
20from slixmpp.exceptions import XMPPError, IqError, IqTimeout
21from slixmpp.plugins import BasePlugin
22from slixmpp.plugins.xep_0115 import stanza, StaticCaps
23
24
25log = logging.getLogger(__name__)
26
27
28class XEP_0115(BasePlugin):
29
30    """
31    XEP-0115: Entity Capabilities
32    """
33
34    name = 'xep_0115'
35    description = 'XEP-0115: Entity Capabilities'
36    dependencies = {'xep_0030', 'xep_0128', 'xep_0004'}
37    stanza = stanza
38    default_config = {
39        'hash': 'sha-1',
40        'caps_node': None,
41        'broadcast': True,
42        'cache': None,
43    }
44
45    def plugin_init(self):
46        self.hashes = {'sha-1': hashlib.sha1,
47                       'sha1': hashlib.sha1,
48                       'md5': hashlib.md5}
49
50        if self.caps_node is None:
51            self.caps_node = 'http://slixmpp.com/ver/%s' % __version__
52
53        if self.cache is None:
54            self.cache = MemoryCache()
55
56        register_stanza_plugin(Presence, stanza.Capabilities)
57        register_stanza_plugin(StreamFeatures, stanza.Capabilities)
58
59        self._disco_ops = ['cache_caps',
60                           'get_caps',
61                           'assign_verstring',
62                           'get_verstring',
63                           'supports',
64                           'has_identity']
65
66        self.xmpp.register_handler(
67                Callback('Entity Capabilites',
68                         StanzaPath('presence/caps'),
69                         self._handle_caps))
70
71        self.xmpp.add_filter('out', self._filter_add_caps)
72
73        self.xmpp.add_event_handler('entity_caps', self._process_caps)
74
75        if not self.xmpp.is_component:
76            self.xmpp.register_feature('caps',
77                    self._handle_caps_feature,
78                    restart=False,
79                    order=10010)
80
81        disco = self.xmpp['xep_0030']
82        self.static = StaticCaps(self.xmpp, disco.static)
83
84        for op in self._disco_ops:
85            self.api.register(getattr(self.static, op), op, default=True)
86
87        for op in ('supports', 'has_identity'):
88            self.xmpp['xep_0030'].api.register(getattr(self.static, op), op)
89
90        self._run_node_handler = disco._run_node_handler
91
92        disco.cache_caps = self.cache_caps
93        disco.update_caps = self.update_caps
94        disco.assign_verstring = self.assign_verstring
95        disco.get_verstring = self.get_verstring
96
97    def plugin_end(self):
98        self.xmpp['xep_0030'].del_feature(feature=stanza.Capabilities.namespace)
99        self.xmpp.del_filter('out', self._filter_add_caps)
100        self.xmpp.del_event_handler('entity_caps', self._process_caps)
101        self.xmpp.remove_handler('Entity Capabilities')
102        if not self.xmpp.is_component:
103            self.xmpp.unregister_feature('caps', 10010)
104        for op in ('supports', 'has_identity'):
105            self.xmpp['xep_0030'].restore_defaults(op)
106
107    def session_bind(self, jid):
108        self.xmpp['xep_0030'].add_feature(stanza.Capabilities.namespace)
109
110    def _filter_add_caps(self, stanza):
111        if not isinstance(stanza, Presence) or not self.broadcast:
112            return stanza
113
114        if stanza['type'] not in ('available', 'chat', 'away', 'dnd', 'xa'):
115            return stanza
116
117        ver = self.get_verstring(stanza['from'])
118        if ver:
119            stanza['caps']['node'] = self.caps_node
120            stanza['caps']['hash'] = self.hash
121            stanza['caps']['ver'] = ver
122        return stanza
123
124    def _handle_caps(self, presence):
125        if not self.xmpp.is_component:
126            if presence['from'] == self.xmpp.boundjid:
127                return
128        self.xmpp.event('entity_caps', presence)
129
130    def _handle_caps_feature(self, features):
131        # We already have a method to process presence with
132        # caps, so wrap things up and use that.
133        p = Presence()
134        p['from'] = self.xmpp.boundjid.domain
135        p.append(features['caps'])
136        self.xmpp.features.add('caps')
137
138        self.xmpp.event('entity_caps', p)
139
140    async def _process_caps(self, pres):
141        if not pres['caps']['hash']:
142            log.debug("Received unsupported legacy caps: %s, %s, %s",
143                    pres['caps']['node'],
144                    pres['caps']['ver'],
145                    pres['caps']['ext'])
146            self.xmpp.event('entity_caps_legacy', pres)
147            return
148
149        ver = pres['caps']['ver']
150
151        existing_verstring = self.get_verstring(pres['from'].full)
152        if str(existing_verstring) == str(ver):
153            return
154
155        existing_caps = self.get_caps(verstring=ver)
156        if existing_caps is not None:
157            self.assign_verstring(pres['from'], ver)
158            return
159
160        ifrom = pres['to'] if self.xmpp.is_component else None
161
162        if pres['caps']['hash'] not in self.hashes:
163            try:
164                log.debug("Unknown caps hash: %s", pres['caps']['hash'])
165                self.xmpp['xep_0030'].get_info(jid=pres['from'], ifrom=ifrom)
166                return
167            except XMPPError:
168                return
169
170        log.debug("New caps verification string: %s", ver)
171        try:
172            node = '%s#%s' % (pres['caps']['node'], ver)
173            caps = await self.xmpp['xep_0030'].get_info(pres['from'], node,
174                                                             coroutine=True,
175                                                             ifrom=ifrom)
176
177            if isinstance(caps, Iq):
178                caps = caps['disco_info']
179
180            if self._validate_caps(caps, pres['caps']['hash'],
181                                         pres['caps']['ver']):
182                self.assign_verstring(pres['from'], pres['caps']['ver'])
183        except XMPPError:
184            log.debug("Could not retrieve disco#info results for caps for %s", node)
185
186    def _validate_caps(self, caps, hash, check_verstring):
187        # Check Identities
188        full_ids = caps.get_identities(dedupe=False)
189        deduped_ids = caps.get_identities()
190        if len(full_ids) != len(deduped_ids):
191            log.debug("Duplicate disco identities found, invalid for caps")
192            return False
193
194        # Check Features
195        full_features = caps.get_features(dedupe=False)
196        deduped_features = caps.get_features()
197        if len(full_features) != len(deduped_features):
198            log.debug("Duplicate disco features found, invalid for caps")
199            return False
200
201        # Check Forms
202        form_types = []
203        deduped_form_types = set()
204        for stanza in caps['substanzas']:
205            if not isinstance(stanza, self.xmpp['xep_0004'].stanza.Form):
206                log.debug("Non form extension found, ignoring for caps")
207                caps.xml.remove(stanza.xml)
208                continue
209            if 'FORM_TYPE' in stanza.get_fields():
210                f_type = tuple(stanza.get_fields()['FORM_TYPE']['value'])
211                form_types.append(f_type)
212                deduped_form_types.add(f_type)
213                if len(form_types) != len(deduped_form_types):
214                    log.debug("Duplicated FORM_TYPE values, " + \
215                              "invalid for caps")
216                    return False
217
218                if len(f_type) > 1:
219                    deduped_type = set(f_type)
220                    if len(f_type) != len(deduped_type):
221                        log.debug("Extra FORM_TYPE data, invalid for caps")
222                        return False
223
224                if stanza.get_fields()['FORM_TYPE']['type'] != 'hidden':
225                    log.debug("Field FORM_TYPE type not 'hidden', " + \
226                              "ignoring form for caps")
227                    caps.xml.remove(stanza.xml)
228            else:
229                log.debug("No FORM_TYPE found, ignoring form for caps")
230                caps.xml.remove(stanza.xml)
231
232        verstring = self.generate_verstring(caps, hash)
233        if verstring != check_verstring:
234            log.debug("Verification strings do not match: %s, %s" % (
235                verstring, check_verstring))
236            return False
237
238        self.cache_caps(verstring, caps)
239        return True
240
241    def generate_verstring(self, info, hash):
242        hash = self.hashes.get(hash, None)
243        if hash is None:
244            return None
245
246        S = ''
247
248        # Convert None to '' in the identities
249        def clean_identity(id):
250            return map(lambda i: i or '', id)
251        identities = map(clean_identity, info['identities'])
252
253        identities = sorted(('/'.join(i) for i in identities))
254        features = sorted(info['features'])
255
256        S += '<'.join(identities) + '<'
257        S += '<'.join(features) + '<'
258
259        form_types = {}
260
261        for stanza in info['substanzas']:
262            if isinstance(stanza, self.xmpp['xep_0004'].stanza.Form):
263                if 'FORM_TYPE' in stanza.get_fields():
264                    f_type = stanza['values']['FORM_TYPE']
265                    if len(f_type):
266                        f_type = f_type[0]
267                    if f_type not in form_types:
268                        form_types[f_type] = []
269                    form_types[f_type].append(stanza)
270
271        sorted_forms = sorted(form_types.keys())
272        for f_type in sorted_forms:
273            for form in form_types[f_type]:
274                S += '%s<' % f_type
275                fields = sorted(form.get_fields().keys())
276                fields.remove('FORM_TYPE')
277                for field in fields:
278                    S += '%s<' % field
279                    vals = form.get_fields()[field].get_value(convert=False)
280                    if vals is None:
281                        S += '<'
282                    else:
283                        if not isinstance(vals, list):
284                            vals = [vals]
285                        S += '<'.join(sorted(vals)) + '<'
286
287        binary = hash(S.encode('utf8')).digest()
288        return base64.b64encode(binary).decode('utf-8')
289
290    async def update_caps(self, jid=None, node=None, preserve=False):
291        try:
292            info = await self.xmpp['xep_0030'].get_info(jid, node, local=True)
293            if isinstance(info, Iq):
294                info = info['disco_info']
295            ver = self.generate_verstring(info, self.hash)
296            self.xmpp['xep_0030'].set_info(
297                    jid=jid,
298                    node='%s#%s' % (self.caps_node, ver),
299                    info=info)
300            self.cache_caps(ver, info)
301            self.assign_verstring(jid, ver)
302
303            if self.xmpp.sessionstarted and self.broadcast:
304                if self.xmpp.is_component or preserve:
305                    for contact in self.xmpp.roster[jid]:
306                        self.xmpp.roster[jid][contact].send_last_presence()
307                else:
308                    self.xmpp.roster[jid].send_last_presence()
309        except XMPPError:
310            return
311
312    def get_verstring(self, jid=None):
313        if jid in ('', None):
314            jid = self.xmpp.boundjid.full
315        if isinstance(jid, JID):
316            jid = jid.full
317        return self.api['get_verstring'](jid)
318
319    def assign_verstring(self, jid=None, verstring=None):
320        if jid in (None, ''):
321            jid = self.xmpp.boundjid.full
322        if isinstance(jid, JID):
323            jid = jid.full
324        return self.api['assign_verstring'](jid, args={
325            'verstring': verstring})
326
327    def cache_caps(self, verstring=None, info=None):
328        data = {'verstring': verstring, 'info': info}
329        return self.api['cache_caps'](args=data)
330
331    def get_caps(self, jid=None, verstring=None):
332        if verstring is None:
333            if jid is not None:
334                verstring = self.get_verstring(jid)
335            else:
336                return None
337        if isinstance(jid, JID):
338            jid = jid.full
339        data = {'verstring': verstring}
340        return self.api['get_caps'](jid, args=data)
341