1# -*- coding: utf-8 -*- 2 3from . import storageprotos_pb2 as storageprotos 4from ..identitykeypair import IdentityKey, IdentityKeyPair 5from ..ratchet.rootkey import RootKey 6from ..kdf.hkdf import HKDF 7from ..ecc.curve import Curve 8from ..ecc.eckeypair import ECKeyPair 9from ..ratchet.chainkey import ChainKey 10from ..kdf.messagekeys import MessageKeys 11 12 13class SessionState: 14 def __init__(self, session=None): 15 if session is None: 16 self.sessionStructure = storageprotos.SessionStructure() 17 elif session.__class__ == SessionState: 18 self.sessionStructure = storageprotos.SessionStructure() 19 self.sessionStructure.CopyFrom(session.sessionStructure) 20 else: 21 self.sessionStructure = session 22 23 def getStructure(self): 24 return self.sessionStructure 25 26 def getAliceBaseKey(self): 27 return self.sessionStructure.aliceBaseKey 28 29 def setAliceBaseKey(self, aliceBaseKey): 30 self.sessionStructure.aliceBaseKey = aliceBaseKey 31 32 def setSessionVersion(self, version): 33 self.sessionStructure.sessionVersion = version 34 35 def getSessionVersion(self): 36 sessionVersion = self.sessionStructure.sessionVersion 37 return 2 if sessionVersion == 0 else sessionVersion 38 39 def setRemoteIdentityKey(self, identityKey): 40 self.sessionStructure.remoteIdentityPublic = identityKey.serialize() 41 42 def setLocalIdentityKey(self, identityKey): 43 self.sessionStructure.localIdentityPublic = identityKey.serialize() 44 45 def getRemoteIdentityKey(self): 46 if self.sessionStructure.remoteIdentityPublic is None: 47 return None 48 return IdentityKey(self.sessionStructure.remoteIdentityPublic, 0) 49 50 def getLocalIdentityKey(self): 51 return IdentityKey(self.sessionStructure.localIdentityPublic, 0) 52 53 def getPreviousCounter(self): 54 return self.sessionStructure.previousCounter 55 56 def setPreviousCounter(self, previousCounter): 57 self.sessionStructure.previousCounter = previousCounter 58 59 def getRootKey(self): 60 return RootKey(HKDF.createFor(self.getSessionVersion()), self.sessionStructure.rootKey) 61 62 def setRootKey(self, rootKey): 63 self.sessionStructure.rootKey = rootKey.getKeyBytes() 64 65 def getSenderRatchetKey(self): 66 return Curve.decodePoint(bytearray(self.sessionStructure.senderChain.senderRatchetKey), 0) 67 68 def getSenderRatchetKeyPair(self): 69 publicKey = self.getSenderRatchetKey() 70 privateKey = Curve.decodePrivatePoint(self.sessionStructure.senderChain.senderRatchetKeyPrivate) 71 72 return ECKeyPair(publicKey, privateKey) 73 74 def hasReceiverChain(self, ECPublickKey_senderEphemeral): 75 return self.getReceiverChain(ECPublickKey_senderEphemeral) is not None 76 77 def hasSenderChain(self): 78 return self.sessionStructure.HasField("senderChain") 79 80 def getReceiverChain(self, ECPublickKey_senderEphemeral): 81 receiverChains = self.sessionStructure.receiverChains 82 index = 0 83 for receiverChain in receiverChains: 84 chainSenderRatchetKey = Curve.decodePoint(bytearray(receiverChain.senderRatchetKey), 0) 85 if chainSenderRatchetKey == ECPublickKey_senderEphemeral: 86 return (receiverChain, index) 87 88 index += 1 89 90 def getReceiverChainKey(self, ECPublicKey_senderEphemeral): 91 receiverChainAndIndex = self.getReceiverChain(ECPublicKey_senderEphemeral) 92 receiverChain = receiverChainAndIndex[0] 93 if receiverChain is None: 94 return None 95 96 return ChainKey(HKDF.createFor(self.getSessionVersion()), 97 receiverChain.chainKey.key, 98 receiverChain.chainKey.index) 99 100 def addReceiverChain(self, ECPublickKey_senderRatchetKey, chainKey): 101 senderRatchetKey = ECPublickKey_senderRatchetKey 102 103 chain = storageprotos.SessionStructure.Chain() 104 chain.senderRatchetKey = senderRatchetKey.serialize() 105 chain.chainKey.key = chainKey.getKey() 106 chain.chainKey.index = chainKey.getIndex() 107 108 self.sessionStructure.receiverChains.extend([chain]) 109 110 if len(self.sessionStructure.receiverChains) > 5: 111 del self.sessionStructure.receiverChains[0] 112 113 def setSenderChain(self, ECKeyPair_senderRatchetKeyPair, chainKey): 114 senderRatchetKeyPair = ECKeyPair_senderRatchetKeyPair 115 116 # TODO: This is never used, maybe a bug? 117 senderChain = storageprotos.SessionStructure.Chain() 118 119 self.sessionStructure.senderChain.senderRatchetKey = senderRatchetKeyPair.getPublicKey().serialize() 120 self.sessionStructure.senderChain.senderRatchetKeyPrivate = senderRatchetKeyPair.getPrivateKey().serialize() 121 self.sessionStructure.senderChain.chainKey.key = chainKey.key 122 self.sessionStructure.senderChain.chainKey.index = chainKey.index 123 124 def getSenderChainKey(self): 125 chainKeyStructure = self.sessionStructure.senderChain.chainKey 126 return ChainKey(HKDF.createFor(self.getSessionVersion()), 127 chainKeyStructure.key, chainKeyStructure.index) 128 129 def setSenderChainKey(self, ChainKey_nextChainKey): 130 nextChainKey = ChainKey_nextChainKey 131 132 self.sessionStructure.senderChain.chainKey.key = nextChainKey.getKey() 133 self.sessionStructure.senderChain.chainKey.index = nextChainKey.getIndex() 134 135 def hasMessageKeys(self, ECPublickKey_senderEphemeral, counter): 136 senderEphemeral = ECPublickKey_senderEphemeral 137 chainAndIndex = self.getReceiverChain(senderEphemeral) 138 chain = chainAndIndex[0] 139 if chain is None: 140 return False 141 142 messageKeyList = chain.messageKeys 143 for messageKey in messageKeyList: 144 if messageKey.index == counter: 145 return True 146 147 return False 148 149 def removeMessageKeys(self, ECPublicKey_senderEphemeral, counter): 150 senderEphemeral = ECPublicKey_senderEphemeral 151 chainAndIndex = self.getReceiverChain(senderEphemeral) 152 chain = chainAndIndex[0] 153 if chain is None: 154 return None 155 156 messageKeyList = chain.messageKeys 157 result = None 158 159 for i in range(0, len(messageKeyList)): 160 messageKey = messageKeyList[i] 161 if messageKey.index == counter: 162 result = MessageKeys(messageKey.cipherKey, messageKey.macKey, messageKey.iv, messageKey.index) 163 del messageKeyList[i] 164 break 165 166 self.sessionStructure.receiverChains[chainAndIndex[1]].CopyFrom(chain) 167 168 return result 169 170 def setMessageKeys(self, ECPublicKey_senderEphemeral, messageKeys): 171 senderEphemeral = ECPublicKey_senderEphemeral 172 chainAndIndex = self.getReceiverChain(senderEphemeral) 173 chain = chainAndIndex[0] 174 messageKeyStructure = chain.messageKeys.add() # storageprotos.SessionStructure.Chain.MessageKey() 175 messageKeyStructure.cipherKey = messageKeys.getCipherKey() 176 messageKeyStructure.macKey = messageKeys.getMacKey() 177 messageKeyStructure.index = messageKeys.getCounter() 178 messageKeyStructure.iv = messageKeys.getIv() 179 180 # chain.messageKeys.append(messageKeyStructure) 181 182 self.sessionStructure.receiverChains[chainAndIndex[1]].CopyFrom(chain) 183 184 def setReceiverChainKey(self, ECPublicKey_senderEphemeral, chainKey): 185 senderEphemeral = ECPublicKey_senderEphemeral 186 chainAndIndex = self.getReceiverChain(senderEphemeral) 187 chain = chainAndIndex[0] 188 chain.chainKey.key = chainKey.getKey() 189 chain.chainKey.index = chainKey.getIndex() 190 191 # self.sessionStructure.receiverChains[chainAndIndex[1]].ClearField() 192 self.sessionStructure.receiverChains[chainAndIndex[1]].CopyFrom(chain) 193 194 def setPendingKeyExchange(self, sequence, ourBaseKey, ourRatchetKey, ourIdentityKey): 195 """ 196 :type sequence: int 197 :type ourBaseKey: ECKeyPair 198 :type ourRatchetKey: ECKeyPair 199 :type ourIdentityKey: IdentityKeyPair 200 """ 201 structure = self.sessionStructure.PendingKeyExchange() 202 structure.sequence = sequence 203 structure.localBaseKey = ourBaseKey.getPublicKey().serialize() 204 structure.localBaseKeyPrivate = ourBaseKey.getPrivateKey().serialize() 205 structure.localRatchetKey = ourRatchetKey.getPublicKey().serialize() 206 structure.localRatchetKeyPrivate = ourRatchetKey.getPrivateKey().serialize() 207 structure.localIdentityKey = ourIdentityKey.getPublicKey().serialize() 208 structure.localIdentityKeyPrivate = ourIdentityKey.getPrivateKey().serialize() 209 210 self.sessionStructure.pendingKeyExchange.MergeFrom(structure) 211 212 def getPendingKeyExchangeSequence(self): 213 return self.sessionStructure.pendingKeyExchange.sequence 214 215 def getPendingKeyExchangeBaseKey(self): 216 publicKey = Curve.decodePoint(bytearray(self.sessionStructure.pendingKeyExchange.localBaseKey), 0) 217 privateKey = Curve.decodePrivatePoint(self.sessionStructure.pendingKeyExchange.localBaseKeyPrivate) 218 return ECKeyPair(publicKey, privateKey) 219 220 def getPendingKeyExchangeRatchetKey(self): 221 publicKey = Curve.decodePoint(bytearray(self.sessionStructure.pendingKeyExchange.localRatchetKey), 0) 222 privateKey = Curve.decodePrivatePoint(self.sessionStructure.pendingKeyExchange.localRatchetKeyPrivate) 223 return ECKeyPair(publicKey, privateKey) 224 225 def getPendingKeyExchangeIdentityKey(self): 226 publicKey = IdentityKey(bytearray(self.sessionStructure.pendingKeyExchange.localIdentityKey), 0) 227 228 privateKey = Curve.decodePrivatePoint(self.sessionStructure.pendingKeyExchange.localIdentityKeyPrivate) 229 return IdentityKeyPair(publicKey, privateKey) 230 231 def hasPendingKeyExchange(self): 232 return self.sessionStructure.HasField("pendingKeyExchange") 233 234 def setUnacknowledgedPreKeyMessage(self, preKeyId, signedPreKeyId, baseKey): 235 """ 236 :type preKeyId: int 237 :type signedPreKeyId: int 238 :type baseKey: ECPublicKey 239 """ 240 self.sessionStructure.pendingPreKey.signedPreKeyId = signedPreKeyId 241 self.sessionStructure.pendingPreKey.baseKey = baseKey.serialize() 242 243 if preKeyId is not None: 244 self.sessionStructure.pendingPreKey.preKeyId = preKeyId 245 246 def hasUnacknowledgedPreKeyMessage(self): 247 return self.sessionStructure.HasField("pendingPreKey") 248 249 def getUnacknowledgedPreKeyMessageItems(self): 250 preKeyId = None 251 if self.sessionStructure.pendingPreKey.HasField("preKeyId"): 252 preKeyId = self.sessionStructure.pendingPreKey.preKeyId 253 254 return SessionState.UnacknowledgedPreKeyMessageItems(preKeyId, 255 self.sessionStructure.pendingPreKey.signedPreKeyId, 256 Curve.decodePoint(bytearray(self.sessionStructure.pendingPreKey.baseKey), 0)) 257 258 def clearUnacknowledgedPreKeyMessage(self): 259 self.sessionStructure.ClearField("pendingPreKey") 260 261 def setRemoteRegistrationId(self, registrationId): 262 self.sessionStructure.remoteRegistrationId = registrationId 263 264 def getRemoteRegistrationId(self, registrationId): 265 return self.sessionStructure.remoteRegistrationId 266 267 def setLocalRegistrationId(self, registrationId): 268 self.sessionStructure.localRegistrationId = registrationId 269 270 def getLocalRegistrationId(self): 271 return self.sessionStructure.localRegistrationId 272 273 def serialize(self): 274 return self.sessionStructure.SerializeToString() 275 276 class UnacknowledgedPreKeyMessageItems: 277 def __init__(self, preKeyId, signedPreKeyId, baseKey): 278 """ 279 :type preKeyId: int 280 :type signedPreKeyId: int 281 :type baseKey: ECPublicKey 282 """ 283 self.preKeyId = preKeyId 284 self.signedPreKeyId = signedPreKeyId 285 self.baseKey = baseKey 286 287 def getPreKeyId(self): 288 return self.preKeyId 289 290 def getSignedPreKeyId(self): 291 return self.signedPreKeyId 292 293 def getBaseKey(self): 294 return self.baseKey 295