1# -*- coding: utf-8 -*-
2# libolm python bindings
3# Copyright © 2015-2017 OpenMarket Ltd
4# Copyright © 2018 Damir Jelić <poljar@termina.org.uk>
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17"""libolm Group session module.
18
19This module contains the group session part of the Olm library. It contains two
20classes for creating inbound and outbound group sessions.
21
22Examples:
23    >>> outbound = OutboundGroupSession()
24    >>> InboundGroupSession(outbound.session_key)
25"""
26
27# pylint: disable=redefined-builtin,unused-import
28from builtins import bytes, super
29from typing import AnyStr, Optional, Tuple, Type
30
31from future.utils import bytes_to_native_str
32
33# pylint: disable=no-name-in-module
34from _libolm import ffi, lib  # type: ignore
35
36from ._compat import URANDOM, to_bytearray, to_bytes, to_unicode_str
37from ._finalize import track_for_finalization
38
39
40def _clear_inbound_group_session(session):
41    # type: (ffi.cdata) -> None
42    lib.olm_clear_inbound_group_session(session)
43
44
45def _clear_outbound_group_session(session):
46    # type: (ffi.cdata) -> None
47    lib.olm_clear_outbound_group_session(session)
48
49
50class OlmGroupSessionError(Exception):
51    """libolm Group session error exception."""
52
53
54class InboundGroupSession(object):
55    """Inbound group session for encrypted multiuser communication."""
56
57    def __new__(
58        cls,              # type: Type[InboundGroupSession]
59        session_key=None  # type: Optional[str]
60    ):
61        # type: (...) -> InboundGroupSession
62        obj = super().__new__(cls)
63        obj._buf = ffi.new("char[]", lib.olm_inbound_group_session_size())
64        obj._session = lib.olm_inbound_group_session(obj._buf)
65        track_for_finalization(obj, obj._session, _clear_inbound_group_session)
66        return obj
67
68    def __init__(self, session_key):
69        # type: (AnyStr) -> None
70        """Create a new inbound group session.
71        Start a new inbound group session, from a key exported from
72        an outbound group session.
73
74        Raises OlmGroupSessionError on failure. The error message of the
75        exception will be "OLM_INVALID_BASE64" if the session key is not valid
76        base64 and "OLM_BAD_SESSION_KEY" if the session key is invalid.
77        """
78        if False:  # pragma: no cover
79            self._session = self._session  # type: ffi.cdata
80
81        byte_session_key = to_bytearray(session_key)
82
83        try:
84            ret = lib.olm_init_inbound_group_session(
85                self._session,
86                ffi.from_buffer(byte_session_key), len(byte_session_key)
87            )
88        finally:
89            if byte_session_key is not session_key:
90                for i in range(0, len(byte_session_key)):
91                    byte_session_key[i] = 0
92        self._check_error(ret)
93
94    def pickle(self, passphrase=""):
95        # type: (Optional[str]) -> bytes
96        """Store an inbound group session.
97
98        Stores a group session as a base64 string. Encrypts the session using
99        the supplied passphrase. Returns a byte object containing the base64
100        encoded string of the pickled session.
101
102        Args:
103            passphrase(str, optional): The passphrase to be used to encrypt
104                the session.
105        """
106        byte_passphrase = bytearray(passphrase, "utf-8") if passphrase else b""
107
108        pickle_length = lib.olm_pickle_inbound_group_session_length(
109            self._session)
110        pickle_buffer = ffi.new("char[]", pickle_length)
111
112        try:
113            ret = lib.olm_pickle_inbound_group_session(
114                self._session,
115                ffi.from_buffer(byte_passphrase), len(byte_passphrase),
116                pickle_buffer, pickle_length
117            )
118            self._check_error(ret)
119        finally:
120            # clear out copies of the passphrase
121            for i in range(0, len(byte_passphrase)):
122                    byte_passphrase[i] = 0
123
124        return ffi.unpack(pickle_buffer, pickle_length)
125
126    @classmethod
127    def from_pickle(cls, pickle, passphrase=""):
128        # type: (bytes, Optional[str]) -> InboundGroupSession
129        """Load a previously stored inbound group session.
130
131        Loads an inbound group session from a pickled base64 string and returns
132        an InboundGroupSession object. Decrypts the session using the supplied
133        passphrase. Raises OlmSessionError on failure. If the passphrase
134        doesn't match the one used to encrypt the session then the error
135        message for the exception will be "BAD_ACCOUNT_KEY". If the base64
136        couldn't be decoded then the error message will be "INVALID_BASE64".
137
138        Args:
139            pickle(bytes): Base64 encoded byte string containing the pickled
140                session
141            passphrase(str, optional): The passphrase used to encrypt the
142                session
143        """
144        if not pickle:
145            raise ValueError("Pickle can't be empty")
146
147        byte_passphrase = bytearray(passphrase, "utf-8") if passphrase else b""
148        # copy because unpickle will destroy the buffer
149        pickle_buffer = ffi.new("char[]", pickle)
150
151        obj = cls.__new__(cls)
152
153        try:
154            ret = lib.olm_unpickle_inbound_group_session(
155                obj._session,
156                ffi.from_buffer(byte_passphrase),
157                len(byte_passphrase),
158                pickle_buffer,
159                len(pickle)
160            )
161            obj._check_error(ret)
162        finally:
163            # clear out copies of the passphrase
164            for i in range(0, len(byte_passphrase)):
165                    byte_passphrase[i] = 0
166
167        return obj
168
169    def _check_error(self, ret):
170        # type: (int) -> None
171        if ret != lib.olm_error():
172            return
173
174        last_error = bytes_to_native_str(ffi.string(
175            lib.olm_inbound_group_session_last_error(self._session)))
176
177        raise OlmGroupSessionError(last_error)
178
179    def decrypt(self, ciphertext, unicode_errors="replace"):
180        # type: (AnyStr, str) -> Tuple[str, int]
181        """Decrypt a message
182
183        Returns a tuple of the decrypted plain-text and the message index of
184        the decrypted message or raises OlmGroupSessionError on failure.
185        On failure the error message of the exception  will be:
186
187        * OLM_INVALID_BASE64         if the message is not valid base64
188        * OLM_BAD_MESSAGE_VERSION    if the message was encrypted with an
189            unsupported version of the protocol
190        * OLM_BAD_MESSAGE_FORMAT     if the message headers could not be
191            decoded
192        * OLM_BAD_MESSAGE_MAC        if the message could not be verified
193        * OLM_UNKNOWN_MESSAGE_INDEX  if we do not have a session key
194            corresponding to the message's index (i.e., it was sent before
195            the session key was shared with us)
196
197        Args:
198            ciphertext(str): Base64 encoded ciphertext containing the encrypted
199                message
200            unicode_errors(str, optional): The error handling scheme to use for
201                unicode decoding errors. The default is "replace" meaning that
202                the character that was unable to decode will be replaced with
203                the unicode replacement character (U+FFFD). Other possible
204                values are "strict", "ignore" and "xmlcharrefreplace" as well
205                as any other name registered with codecs.register_error that
206                can handle UnicodeEncodeErrors.
207        """
208        if not ciphertext:
209            raise ValueError("Ciphertext can't be empty.")
210
211        byte_ciphertext = to_bytes(ciphertext)
212
213        # copy because max_plaintext_length will destroy the buffer
214        ciphertext_buffer = ffi.new("char[]", byte_ciphertext)
215
216        max_plaintext_length = lib.olm_group_decrypt_max_plaintext_length(
217            self._session, ciphertext_buffer, len(byte_ciphertext)
218        )
219        self._check_error(max_plaintext_length)
220        plaintext_buffer = ffi.new("char[]", max_plaintext_length)
221        # copy because max_plaintext_length will destroy the buffer
222        ciphertext_buffer = ffi.new("char[]", byte_ciphertext)
223
224        message_index = ffi.new("uint32_t*")
225        plaintext_length = lib.olm_group_decrypt(
226            self._session, ciphertext_buffer, len(byte_ciphertext),
227            plaintext_buffer, max_plaintext_length,
228            message_index
229        )
230
231        self._check_error(plaintext_length)
232
233        plaintext = to_unicode_str(
234            ffi.unpack(plaintext_buffer, plaintext_length),
235            errors=unicode_errors
236        )
237
238        # clear out copies of the plaintext
239        lib.memset(plaintext_buffer, 0, max_plaintext_length)
240
241        return plaintext, message_index[0]
242
243    @property
244    def id(self):
245        # type: () -> str
246        """str: A base64 encoded identifier for this session."""
247        id_length = lib.olm_inbound_group_session_id_length(self._session)
248        id_buffer = ffi.new("char[]", id_length)
249        ret = lib.olm_inbound_group_session_id(
250            self._session,
251            id_buffer,
252            id_length
253        )
254        self._check_error(ret)
255        return bytes_to_native_str(ffi.unpack(id_buffer, id_length))
256
257    @property
258    def first_known_index(self):
259        # type: () -> int
260        """int: The first message index we know how to decrypt."""
261        return lib.olm_inbound_group_session_first_known_index(self._session)
262
263    def export_session(self, message_index):
264        # type: (int) -> str
265        """Export an inbound group session
266
267        Export the base64-encoded ratchet key for this session, at the given
268        index, in a format which can be used by import_session().
269
270        Raises OlmGroupSessionError on failure. The error message for the
271        exception will be:
272
273        * OLM_UNKNOWN_MESSAGE_INDEX if we do not have a session key
274            corresponding to the given index (ie, it was sent before the
275            session key was shared with us)
276
277        Args:
278            message_index(int): The message index at which the session should
279                be exported.
280        """
281
282        export_length = lib.olm_export_inbound_group_session_length(
283            self._session)
284
285        export_buffer = ffi.new("char[]", export_length)
286        ret = lib.olm_export_inbound_group_session(
287            self._session,
288            export_buffer,
289            export_length,
290            message_index
291        )
292        self._check_error(ret)
293        export_str = bytes_to_native_str(ffi.unpack(export_buffer, export_length))
294
295        # clear out copies of the key
296        lib.memset(export_buffer, 0, export_length)
297
298        return export_str
299
300    @classmethod
301    def import_session(cls, session_key):
302        # type: (AnyStr) -> InboundGroupSession
303        """Create an InboundGroupSession from an exported session key.
304
305        Creates an InboundGroupSession with an previously exported session key,
306        raises OlmGroupSessionError on failure. The error message for the
307        exception will be:
308
309        * OLM_INVALID_BASE64  if the session_key is not valid base64
310        * OLM_BAD_SESSION_KEY if the session_key is invalid
311
312        Args:
313            session_key(str): The exported session key with which the inbound
314                group session will be created
315        """
316        obj = cls.__new__(cls)
317
318        byte_session_key = to_bytearray(session_key)
319
320        try:
321            ret = lib.olm_import_inbound_group_session(
322                obj._session,
323                ffi.from_buffer(byte_session_key),
324                len(byte_session_key)
325            )
326            obj._check_error(ret)
327        finally:
328            # clear out copies of the key
329            if byte_session_key is not session_key:
330                for i in range(0, len(byte_session_key)):
331                    byte_session_key[i] = 0
332
333        return obj
334
335
336class OutboundGroupSession(object):
337    """Outbound group session for encrypted multiuser communication."""
338
339    def __new__(cls):
340        # type: (Type[OutboundGroupSession]) -> OutboundGroupSession
341        obj = super().__new__(cls)
342        obj._buf = ffi.new("char[]", lib.olm_outbound_group_session_size())
343        obj._session = lib.olm_outbound_group_session(obj._buf)
344        track_for_finalization(
345            obj,
346            obj._session,
347            _clear_outbound_group_session
348        )
349        return obj
350
351    def __init__(self):
352        # type: () -> None
353        """Create a new outbound group session.
354
355        Start a new outbound group session. Raises OlmGroupSessionError on
356        failure.
357        """
358        if False:  # pragma: no cover
359            self._session = self._session  # type: ffi.cdata
360
361        random_length = lib.olm_init_outbound_group_session_random_length(
362            self._session
363        )
364        random = URANDOM(random_length)
365
366        ret = lib.olm_init_outbound_group_session(
367            self._session, ffi.from_buffer(random), random_length
368        )
369        self._check_error(ret)
370
371    def _check_error(self, ret):
372        # type: (int) -> None
373        if ret != lib.olm_error():
374            return
375
376        last_error = bytes_to_native_str(ffi.string(
377            lib.olm_outbound_group_session_last_error(self._session)
378        ))
379
380        raise OlmGroupSessionError(last_error)
381
382    def pickle(self, passphrase=""):
383        # type: (Optional[str]) -> bytes
384        """Store an outbound group session.
385
386        Stores a group session as a base64 string. Encrypts the session using
387        the supplied passphrase. Returns a byte object containing the base64
388        encoded string of the pickled session.
389
390        Args:
391            passphrase(str, optional): The passphrase to be used to encrypt
392                the session.
393        """
394        byte_passphrase = bytearray(passphrase, "utf-8") if passphrase else b""
395        pickle_length = lib.olm_pickle_outbound_group_session_length(
396            self._session)
397        pickle_buffer = ffi.new("char[]", pickle_length)
398
399        try:
400            ret = lib.olm_pickle_outbound_group_session(
401                self._session,
402                ffi.from_buffer(byte_passphrase), len(byte_passphrase),
403                pickle_buffer, pickle_length
404            )
405            self._check_error(ret)
406        finally:
407            # clear out copies of the passphrase
408            for i in range(0, len(byte_passphrase)):
409                    byte_passphrase[i] = 0
410
411        return ffi.unpack(pickle_buffer, pickle_length)
412
413    @classmethod
414    def from_pickle(cls, pickle, passphrase=""):
415        # type: (bytes, Optional[str]) -> OutboundGroupSession
416        """Load a previously stored outbound group session.
417
418        Loads an outbound group session from a pickled base64 string and
419        returns an OutboundGroupSession object. Decrypts the session using the
420        supplied passphrase. Raises OlmSessionError on failure. If the
421        passphrase doesn't match the one used to encrypt the session then the
422        error message for the exception will be "BAD_ACCOUNT_KEY". If the
423        base64 couldn't be decoded then the error message will be
424        "INVALID_BASE64".
425
426        Args:
427            pickle(bytes): Base64 encoded byte string containing the pickled
428                session
429            passphrase(str, optional): The passphrase used to encrypt the
430        """
431        if not pickle:
432            raise ValueError("Pickle can't be empty")
433
434        byte_passphrase = bytearray(passphrase, "utf-8") if passphrase else b""
435        # copy because unpickle will destroy the buffer
436        pickle_buffer = ffi.new("char[]", pickle)
437
438        obj = cls.__new__(cls)
439
440        try:
441            ret = lib.olm_unpickle_outbound_group_session(
442                obj._session,
443                ffi.from_buffer(byte_passphrase),
444                len(byte_passphrase),
445                pickle_buffer,
446                len(pickle)
447            )
448            obj._check_error(ret)
449        finally:
450            # clear out copies of the passphrase
451            for i in range(0, len(byte_passphrase)):
452                    byte_passphrase[i] = 0
453
454        return obj
455
456    def encrypt(self, plaintext):
457        # type: (AnyStr) -> str
458        """Encrypt a message.
459
460        Returns the encrypted ciphertext.
461
462        Args:
463            plaintext(str): A string that will be encrypted using the group
464                session.
465        """
466        byte_plaintext = to_bytearray(plaintext)
467        message_length = lib.olm_group_encrypt_message_length(
468            self._session, len(byte_plaintext)
469        )
470
471        message_buffer = ffi.new("char[]", message_length)
472
473        try:
474            ret = lib.olm_group_encrypt(
475                self._session,
476                ffi.from_buffer(byte_plaintext), len(byte_plaintext),
477                message_buffer, message_length,
478            )
479            self._check_error(ret)
480        finally:
481            # clear out copies of plaintext
482            if byte_plaintext is not plaintext:
483                for i in range(0, len(byte_plaintext)):
484                    byte_plaintext[i] = 0
485
486        return bytes_to_native_str(ffi.unpack(message_buffer, message_length))
487
488    @property
489    def id(self):
490        # type: () -> str
491        """str: A base64 encoded identifier for this session."""
492        id_length = lib.olm_outbound_group_session_id_length(self._session)
493        id_buffer = ffi.new("char[]", id_length)
494
495        ret = lib.olm_outbound_group_session_id(
496            self._session,
497            id_buffer,
498            id_length
499        )
500        self._check_error(ret)
501
502        return bytes_to_native_str(ffi.unpack(id_buffer, id_length))
503
504    @property
505    def message_index(self):
506        # type: () -> int
507        """int: The current message index of the session.
508
509        Each message is encrypted with an increasing index. This is the index
510        for the next message.
511        """
512        return lib.olm_outbound_group_session_message_index(self._session)
513
514    @property
515    def session_key(self):
516        # type: () -> str
517        """The base64-encoded current ratchet key for this session.
518
519        Each message is encrypted with a different ratchet key. This function
520        returns the ratchet key that will be used for the next message.
521        """
522        key_length = lib.olm_outbound_group_session_key_length(self._session)
523        key_buffer = ffi.new("char[]", key_length)
524
525        ret = lib.olm_outbound_group_session_key(
526            self._session,
527            key_buffer,
528            key_length
529        )
530        self._check_error(ret)
531
532        return bytes_to_native_str(ffi.unpack(key_buffer, key_length))
533