1# -*- coding: utf-8 -*-
2
3from . import storageprotos_pb2 as storageprotos
4from ..identitykeypair import IdentityKey, IdentityKeyPair
5from ..ratchet.rootkey import RootKey
6from ..kdf.hkdf import HKDF
7from ..ecc.curve import Curve
8from ..ecc.eckeypair import ECKeyPair
9from ..ratchet.chainkey import ChainKey
10from ..kdf.messagekeys import MessageKeys
11
12
13class SessionState:
14    def __init__(self, session=None):
15        if session is None:
16            self.sessionStructure = storageprotos.SessionStructure()
17        elif session.__class__ == SessionState:
18            self.sessionStructure = storageprotos.SessionStructure()
19            self.sessionStructure.CopyFrom(session.sessionStructure)
20        else:
21            self.sessionStructure = session
22
23    def getStructure(self):
24        return self.sessionStructure
25
26    def getAliceBaseKey(self):
27        return self.sessionStructure.aliceBaseKey
28
29    def setAliceBaseKey(self, aliceBaseKey):
30        self.sessionStructure.aliceBaseKey = aliceBaseKey
31
32    def setSessionVersion(self, version):
33        self.sessionStructure.sessionVersion = version
34
35    def getSessionVersion(self):
36        sessionVersion = self.sessionStructure.sessionVersion
37        return 2 if sessionVersion == 0 else sessionVersion
38
39    def setRemoteIdentityKey(self, identityKey):
40        self.sessionStructure.remoteIdentityPublic = identityKey.serialize()
41
42    def setLocalIdentityKey(self, identityKey):
43        self.sessionStructure.localIdentityPublic = identityKey.serialize()
44
45    def getRemoteIdentityKey(self):
46        if self.sessionStructure.remoteIdentityPublic is None:
47            return None
48        return IdentityKey(self.sessionStructure.remoteIdentityPublic, 0)
49
50    def getLocalIdentityKey(self):
51        return IdentityKey(self.sessionStructure.localIdentityPublic, 0)
52
53    def getPreviousCounter(self):
54        return self.sessionStructure.previousCounter
55
56    def setPreviousCounter(self, previousCounter):
57        self.sessionStructure.previousCounter = previousCounter
58
59    def getRootKey(self):
60        return RootKey(HKDF.createFor(self.getSessionVersion()), self.sessionStructure.rootKey)
61
62    def setRootKey(self, rootKey):
63        self.sessionStructure.rootKey = rootKey.getKeyBytes()
64
65    def getSenderRatchetKey(self):
66        return Curve.decodePoint(bytearray(self.sessionStructure.senderChain.senderRatchetKey), 0)
67
68    def getSenderRatchetKeyPair(self):
69        publicKey = self.getSenderRatchetKey()
70        privateKey = Curve.decodePrivatePoint(self.sessionStructure.senderChain.senderRatchetKeyPrivate)
71
72        return ECKeyPair(publicKey, privateKey)
73
74    def hasReceiverChain(self, ECPublickKey_senderEphemeral):
75        return self.getReceiverChain(ECPublickKey_senderEphemeral) is not None
76
77    def hasSenderChain(self):
78        return self.sessionStructure.HasField("senderChain")
79
80    def getReceiverChain(self, ECPublickKey_senderEphemeral):
81        receiverChains = self.sessionStructure.receiverChains
82        index = 0
83        for receiverChain in receiverChains:
84            chainSenderRatchetKey = Curve.decodePoint(bytearray(receiverChain.senderRatchetKey), 0)
85            if chainSenderRatchetKey == ECPublickKey_senderEphemeral:
86                return (receiverChain, index)
87
88            index += 1
89
90    def getReceiverChainKey(self, ECPublicKey_senderEphemeral):
91        receiverChainAndIndex = self.getReceiverChain(ECPublicKey_senderEphemeral)
92        receiverChain = receiverChainAndIndex[0]
93        if receiverChain is None:
94            return None
95
96        return ChainKey(HKDF.createFor(self.getSessionVersion()),
97                        receiverChain.chainKey.key,
98                        receiverChain.chainKey.index)
99
100    def addReceiverChain(self, ECPublickKey_senderRatchetKey, chainKey):
101        senderRatchetKey = ECPublickKey_senderRatchetKey
102
103        chain = storageprotos.SessionStructure.Chain()
104        chain.senderRatchetKey = senderRatchetKey.serialize()
105        chain.chainKey.key = chainKey.getKey()
106        chain.chainKey.index = chainKey.getIndex()
107
108        self.sessionStructure.receiverChains.extend([chain])
109
110        if len(self.sessionStructure.receiverChains) > 5:
111            del self.sessionStructure.receiverChains[0]
112
113    def setSenderChain(self, ECKeyPair_senderRatchetKeyPair, chainKey):
114        senderRatchetKeyPair = ECKeyPair_senderRatchetKeyPair
115
116        # TODO: This is never used, maybe a bug?
117        senderChain = storageprotos.SessionStructure.Chain()
118
119        self.sessionStructure.senderChain.senderRatchetKey = senderRatchetKeyPair.getPublicKey().serialize()
120        self.sessionStructure.senderChain.senderRatchetKeyPrivate = senderRatchetKeyPair.getPrivateKey().serialize()
121        self.sessionStructure.senderChain.chainKey.key = chainKey.key
122        self.sessionStructure.senderChain.chainKey.index = chainKey.index
123
124    def getSenderChainKey(self):
125        chainKeyStructure = self.sessionStructure.senderChain.chainKey
126        return ChainKey(HKDF.createFor(self.getSessionVersion()),
127                        chainKeyStructure.key, chainKeyStructure.index)
128
129    def setSenderChainKey(self, ChainKey_nextChainKey):
130        nextChainKey = ChainKey_nextChainKey
131
132        self.sessionStructure.senderChain.chainKey.key = nextChainKey.getKey()
133        self.sessionStructure.senderChain.chainKey.index = nextChainKey.getIndex()
134
135    def hasMessageKeys(self, ECPublickKey_senderEphemeral, counter):
136        senderEphemeral = ECPublickKey_senderEphemeral
137        chainAndIndex = self.getReceiverChain(senderEphemeral)
138        chain = chainAndIndex[0]
139        if chain is None:
140            return False
141
142        messageKeyList = chain.messageKeys
143        for messageKey in messageKeyList:
144            if messageKey.index == counter:
145                return True
146
147        return False
148
149    def removeMessageKeys(self, ECPublicKey_senderEphemeral, counter):
150        senderEphemeral = ECPublicKey_senderEphemeral
151        chainAndIndex = self.getReceiverChain(senderEphemeral)
152        chain = chainAndIndex[0]
153        if chain is None:
154            return None
155
156        messageKeyList = chain.messageKeys
157        result = None
158
159        for i in range(0, len(messageKeyList)):
160            messageKey = messageKeyList[i]
161            if messageKey.index == counter:
162                result = MessageKeys(messageKey.cipherKey, messageKey.macKey, messageKey.iv, messageKey.index)
163                del messageKeyList[i]
164                break
165
166        self.sessionStructure.receiverChains[chainAndIndex[1]].CopyFrom(chain)
167
168        return result
169
170    def setMessageKeys(self, ECPublicKey_senderEphemeral, messageKeys):
171        senderEphemeral = ECPublicKey_senderEphemeral
172        chainAndIndex = self.getReceiverChain(senderEphemeral)
173        chain = chainAndIndex[0]
174        messageKeyStructure = chain.messageKeys.add()  # storageprotos.SessionStructure.Chain.MessageKey()
175        messageKeyStructure.cipherKey = messageKeys.getCipherKey()
176        messageKeyStructure.macKey = messageKeys.getMacKey()
177        messageKeyStructure.index = messageKeys.getCounter()
178        messageKeyStructure.iv = messageKeys.getIv()
179
180        # chain.messageKeys.append(messageKeyStructure)
181
182        self.sessionStructure.receiverChains[chainAndIndex[1]].CopyFrom(chain)
183
184    def setReceiverChainKey(self, ECPublicKey_senderEphemeral, chainKey):
185        senderEphemeral = ECPublicKey_senderEphemeral
186        chainAndIndex = self.getReceiverChain(senderEphemeral)
187        chain = chainAndIndex[0]
188        chain.chainKey.key = chainKey.getKey()
189        chain.chainKey.index = chainKey.getIndex()
190
191        # self.sessionStructure.receiverChains[chainAndIndex[1]].ClearField()
192        self.sessionStructure.receiverChains[chainAndIndex[1]].CopyFrom(chain)
193
194    def setPendingKeyExchange(self, sequence, ourBaseKey, ourRatchetKey, ourIdentityKey):
195        """
196        :type sequence: int
197        :type ourBaseKey: ECKeyPair
198        :type ourRatchetKey: ECKeyPair
199        :type  ourIdentityKey: IdentityKeyPair
200        """
201        structure = self.sessionStructure.PendingKeyExchange()
202        structure.sequence = sequence
203        structure.localBaseKey = ourBaseKey.getPublicKey().serialize()
204        structure.localBaseKeyPrivate = ourBaseKey.getPrivateKey().serialize()
205        structure.localRatchetKey = ourRatchetKey.getPublicKey().serialize()
206        structure.localRatchetKeyPrivate = ourRatchetKey.getPrivateKey().serialize()
207        structure.localIdentityKey = ourIdentityKey.getPublicKey().serialize()
208        structure.localIdentityKeyPrivate = ourIdentityKey.getPrivateKey().serialize()
209
210        self.sessionStructure.pendingKeyExchange.MergeFrom(structure)
211
212    def getPendingKeyExchangeSequence(self):
213        return self.sessionStructure.pendingKeyExchange.sequence
214
215    def getPendingKeyExchangeBaseKey(self):
216        publicKey = Curve.decodePoint(bytearray(self.sessionStructure.pendingKeyExchange.localBaseKey), 0)
217        privateKey = Curve.decodePrivatePoint(self.sessionStructure.pendingKeyExchange.localBaseKeyPrivate)
218        return ECKeyPair(publicKey, privateKey)
219
220    def getPendingKeyExchangeRatchetKey(self):
221        publicKey = Curve.decodePoint(bytearray(self.sessionStructure.pendingKeyExchange.localRatchetKey), 0)
222        privateKey = Curve.decodePrivatePoint(self.sessionStructure.pendingKeyExchange.localRatchetKeyPrivate)
223        return ECKeyPair(publicKey, privateKey)
224
225    def getPendingKeyExchangeIdentityKey(self):
226        publicKey = IdentityKey(bytearray(self.sessionStructure.pendingKeyExchange.localIdentityKey), 0)
227
228        privateKey = Curve.decodePrivatePoint(self.sessionStructure.pendingKeyExchange.localIdentityKeyPrivate)
229        return IdentityKeyPair(publicKey, privateKey)
230
231    def hasPendingKeyExchange(self):
232        return self.sessionStructure.HasField("pendingKeyExchange")
233
234    def setUnacknowledgedPreKeyMessage(self, preKeyId, signedPreKeyId, baseKey):
235        """
236        :type preKeyId: int
237        :type signedPreKeyId: int
238        :type baseKey: ECPublicKey
239        """
240        self.sessionStructure.pendingPreKey.signedPreKeyId = signedPreKeyId
241        self.sessionStructure.pendingPreKey.baseKey = baseKey.serialize()
242
243        if preKeyId is not None:
244            self.sessionStructure.pendingPreKey.preKeyId = preKeyId
245
246    def hasUnacknowledgedPreKeyMessage(self):
247        return self.sessionStructure.HasField("pendingPreKey")
248
249    def getUnacknowledgedPreKeyMessageItems(self):
250        preKeyId = None
251        if self.sessionStructure.pendingPreKey.HasField("preKeyId"):
252            preKeyId = self.sessionStructure.pendingPreKey.preKeyId
253
254        return SessionState.UnacknowledgedPreKeyMessageItems(preKeyId,
255                                                             self.sessionStructure.pendingPreKey.signedPreKeyId,
256                                                             Curve.decodePoint(bytearray(self.sessionStructure.pendingPreKey.baseKey), 0))
257
258    def clearUnacknowledgedPreKeyMessage(self):
259        self.sessionStructure.ClearField("pendingPreKey")
260
261    def setRemoteRegistrationId(self, registrationId):
262        self.sessionStructure.remoteRegistrationId = registrationId
263
264    def getRemoteRegistrationId(self, registrationId):
265        return self.sessionStructure.remoteRegistrationId
266
267    def setLocalRegistrationId(self, registrationId):
268        self.sessionStructure.localRegistrationId = registrationId
269
270    def getLocalRegistrationId(self):
271        return self.sessionStructure.localRegistrationId
272
273    def serialize(self):
274        return self.sessionStructure.SerializeToString()
275
276    class UnacknowledgedPreKeyMessageItems:
277        def __init__(self, preKeyId, signedPreKeyId, baseKey):
278            """
279            :type preKeyId: int
280            :type signedPreKeyId: int
281            :type baseKey: ECPublicKey
282            """
283            self.preKeyId = preKeyId
284            self.signedPreKeyId = signedPreKeyId
285            self.baseKey = baseKey
286
287        def getPreKeyId(self):
288            return self.preKeyId
289
290        def getSignedPreKeyId(self):
291            return self.signedPreKeyId
292
293        def getBaseKey(self):
294            return self.baseKey
295