1# Copyright (C) 2013 by the Massachusetts Institute of Technology.
2# All rights reserved.
3#
4# Redistribution and use in source and binary forms, with or without
5# modification, are permitted provided that the following conditions
6# are met:
7#
8# * Redistributions of source code must retain the above copyright
9#   notice, this list of conditions and the following disclaimer.
10#
11# * Redistributions in binary form must reproduce the above copyright
12#   notice, this list of conditions and the following disclaimer in
13#   the documentation and/or other materials provided with the
14#   distribution.
15#
16# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
19# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
20# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
21# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
25# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
27# OF THE POSSIBILITY OF SUCH DAMAGE.
28
29# XXX current status:
30# * Done and tested
31#   - AES encryption, checksum, string2key, prf
32#   - cf2 (needed for FAST)
33# * Still to do:
34#   - DES enctypes and cksumtypes
35#   - RC4 exported enctype (if we need it for anything)
36#   - Unkeyed checksums
37#   - Special RC4, raw DES/DES3 operations for GSSAPI
38# * Difficult or low priority:
39#   - Camellia not supported by PyCrypto
40#   - Cipher state only needed for kcmd suite
41#   - Nonstandard enctypes and cksumtypes like des-hmac-sha1
42
43from struct import pack, unpack
44from binascii import unhexlify
45import string
46import random
47
48from Crypto.Util.number import GCD as gcd
49from Crypto.Cipher import AES, DES3, ARC4, DES
50from Crypto.Hash import HMAC, MD4, MD5, SHA
51from Crypto.Protocol.KDF import PBKDF2
52
53
54def get_random_bytes(lenBytes):
55    # We don't really need super strong randomness here to use PyCrypto.Random
56    return "".join([random.choice(string.digits+string.letters) for i in xrange(lenBytes)])
57
58
59class Enctype(object):
60    DES_CRC = 1
61    DES_MD4 = 2
62    DES_MD5 = 3
63    DES3 = 16
64    AES128 = 17
65    AES256 = 18
66    RC4 = 23
67
68
69class Cksumtype(object):
70    CRC32 = 1
71    MD4 = 2
72    MD4_DES = 3
73    MD5 = 7
74    MD5_DES = 8
75    SHA1 = 9
76    SHA1_DES3 = 12
77    SHA1_AES128 = 15
78    SHA1_AES256 = 16
79    HMAC_MD5 = -138
80
81
82class InvalidChecksum(ValueError):
83    pass
84
85
86def _zeropad(s, padsize):
87    # Return s padded with 0 bytes to a multiple of padsize.
88    padlen = (padsize - (len(s) % padsize)) % padsize
89    return s + '\0'*padlen
90
91
92def _xorbytes(b1, b2):
93    # xor two strings together and return the resulting string.
94    assert len(b1) == len(b2)
95    return ''.join(chr(ord(x) ^ ord(y)) for x, y in zip(b1, b2))
96
97
98def _mac_equal(mac1, mac2):
99    # Constant-time comparison function.  (We can't use HMAC.verify
100    # since we use truncated macs.)
101    assert len(mac1) == len(mac2)
102    res = 0
103    for x, y in zip(mac1, mac2):
104        res |= ord(x) ^ ord(y)
105    return res == 0
106
107
108def _nfold(str, nbytes):
109    # Convert str to a string of length nbytes using the RFC 3961 nfold
110    # operation.
111
112    # Rotate the bytes in str to the right by nbits bits.
113    def rotate_right(str, nbits):
114        nbytes, remain = (nbits//8) % len(str), nbits % 8
115        return ''.join(chr((ord(str[i-nbytes]) >> remain) |
116                           ((ord(str[i-nbytes-1]) << (8-remain)) & 0xff))
117                       for i in xrange(len(str)))
118
119    # Add equal-length strings together with end-around carry.
120    def add_ones_complement(str1, str2):
121        n = len(str1)
122        v = [ord(a) + ord(b) for a, b in zip(str1, str2)]
123        # Propagate carry bits to the left until there aren't any left.
124        while any(x & ~0xff for x in v):
125            v = [(v[i-n+1]>>8) + (v[i]&0xff) for i in xrange(n)]
126        return ''.join(chr(x) for x in v)
127
128    # Concatenate copies of str to produce the least common multiple
129    # of len(str) and nbytes, rotating each copy of str to the right
130    # by 13 bits times its list position.  Decompose the concatenation
131    # into slices of length nbytes, and add them together as
132    # big-endian ones' complement integers.
133    slen = len(str)
134    lcm = nbytes * slen / gcd(nbytes, slen)
135    bigstr = ''.join((rotate_right(str, 13 * i) for i in xrange(lcm / slen)))
136    slices = (bigstr[p:p+nbytes] for p in xrange(0, lcm, nbytes))
137    return reduce(add_ones_complement, slices)
138
139
140def _is_weak_des_key(keybytes):
141    return keybytes in ('\x01\x01\x01\x01\x01\x01\x01\x01',
142                        '\xFE\xFE\xFE\xFE\xFE\xFE\xFE\xFE',
143                        '\x1F\x1F\x1F\x1F\x0E\x0E\x0E\x0E',
144                        '\xE0\xE0\xE0\xE0\xF1\xF1\xF1\xF1',
145                        '\x01\xFE\x01\xFE\x01\xFE\x01\xFE',
146                        '\xFE\x01\xFE\x01\xFE\x01\xFE\x01',
147                        '\x1F\xE0\x1F\xE0\x0E\xF1\x0E\xF1',
148                        '\xE0\x1F\xE0\x1F\xF1\x0E\xF1\x0E',
149                        '\x01\xE0\x01\xE0\x01\xF1\x01\xF1',
150                        '\xE0\x01\xE0\x01\xF1\x01\xF1\x01',
151                        '\x1F\xFE\x1F\xFE\x0E\xFE\x0E\xFE',
152                        '\xFE\x1F\xFE\x1F\xFE\x0E\xFE\x0E',
153                        '\x01\x1F\x01\x1F\x01\x0E\x01\x0E',
154                        '\x1F\x01\x1F\x01\x0E\x01\x0E\x01',
155                        '\xE0\xFE\xE0\xFE\xF1\xFE\xF1\xFE',
156                        '\xFE\xE0\xFE\xE0\xFE\xF1\xFE\xF1')
157
158
159class _EnctypeProfile(object):
160    # Base class for enctype profiles.  Usable enctype classes must define:
161    #   * enctype: enctype number
162    #   * keysize: protocol size of key in bytes
163    #   * seedsize: random_to_key input size in bytes
164    #   * random_to_key (if the keyspace is not dense)
165    #   * string_to_key
166    #   * encrypt
167    #   * decrypt
168    #   * prf
169
170    @classmethod
171    def random_to_key(cls, seed):
172        if len(seed) != cls.seedsize:
173            raise ValueError('Wrong seed length')
174        return Key(cls.enctype, seed)
175
176
177class _SimplifiedEnctype(_EnctypeProfile):
178    # Base class for enctypes using the RFC 3961 simplified profile.
179    # Defines the encrypt, decrypt, and prf methods.  Subclasses must
180    # define:
181    #   * blocksize: Underlying cipher block size in bytes
182    #   * padsize: Underlying cipher padding multiple (1 or blocksize)
183    #   * macsize: Size of integrity MAC in bytes
184    #   * hashmod: PyCrypto hash module for underlying hash function
185    #   * basic_encrypt, basic_decrypt: Underlying CBC/CTS cipher
186
187    @classmethod
188    def derive(cls, key, constant):
189        # RFC 3961 only says to n-fold the constant only if it is
190        # shorter than the cipher block size.  But all Unix
191        # implementations n-fold constants if their length is larger
192        # than the block size as well, and n-folding when the length
193        # is equal to the block size is a no-op.
194        plaintext = _nfold(constant, cls.blocksize)
195        rndseed = ''
196        while len(rndseed) < cls.seedsize:
197            ciphertext = cls.basic_encrypt(key, plaintext)
198            rndseed += ciphertext
199            plaintext = ciphertext
200        return cls.random_to_key(rndseed[0:cls.seedsize])
201
202    @classmethod
203    def encrypt(cls, key, keyusage, plaintext, confounder):
204        ki = cls.derive(key, pack('>IB', keyusage, 0x55))
205        ke = cls.derive(key, pack('>IB', keyusage, 0xAA))
206        if confounder is None:
207            confounder = get_random_bytes(cls.blocksize)
208        basic_plaintext = confounder + _zeropad(plaintext, cls.padsize)
209        hmac = HMAC.new(ki.contents, basic_plaintext, cls.hashmod).digest()
210        return cls.basic_encrypt(ke, basic_plaintext) + hmac[:cls.macsize]
211
212    @classmethod
213    def decrypt(cls, key, keyusage, ciphertext):
214        ki = cls.derive(key, pack('>IB', keyusage, 0x55))
215        ke = cls.derive(key, pack('>IB', keyusage, 0xAA))
216        if len(ciphertext) < cls.blocksize + cls.macsize:
217            raise ValueError('ciphertext too short')
218        basic_ctext, mac = ciphertext[:-cls.macsize], ciphertext[-cls.macsize:]
219        if len(basic_ctext) % cls.padsize != 0:
220            raise ValueError('ciphertext does not meet padding requirement')
221        basic_plaintext = cls.basic_decrypt(ke, basic_ctext)
222        hmac = HMAC.new(ki.contents, basic_plaintext, cls.hashmod).digest()
223        expmac = hmac[:cls.macsize]
224        if not _mac_equal(mac, expmac):
225            raise InvalidChecksum('ciphertext integrity failure')
226        # Discard the confounder.
227        return basic_plaintext[cls.blocksize:]
228
229    @classmethod
230    def prf(cls, key, string):
231        # Hash the input.  RFC 3961 says to truncate to the padding
232        # size, but implementations truncate to the block size.
233        hashval = cls.hashmod.new(string).digest()
234        truncated = hashval[:-(len(hashval) % cls.blocksize)]
235        # Encrypt the hash with a derived key.
236        kp = cls.derive(key, 'prf')
237        return cls.basic_encrypt(kp, truncated)
238
239class _DESCBC(_SimplifiedEnctype):
240    enctype = Enctype.DES_MD5
241    keysize = 8
242    seedsize = 8
243    blocksize = 8
244    padsize = 8
245    macsize = 16
246    hashmod = MD5
247
248    @classmethod
249    def encrypt(cls, key, keyusage, plaintext, confounder):
250        if confounder is None:
251            confounder = get_random_bytes(cls.blocksize)
252        basic_plaintext = confounder + '\x00'*cls.macsize + _zeropad(plaintext, cls.padsize)
253        checksum = cls.hashmod.new(basic_plaintext).digest()
254        basic_plaintext = basic_plaintext[:len(confounder)] + checksum + basic_plaintext[len(confounder)+len(checksum):]
255        return cls.basic_encrypt(key, basic_plaintext)
256
257
258    @classmethod
259    def decrypt(cls, key, keyusage, ciphertext):
260        if len(ciphertext) < cls.blocksize + cls.macsize:
261            raise ValueError('ciphertext too short')
262
263        complex_plaintext = cls.basic_decrypt(key, ciphertext)
264        cofounder = complex_plaintext[:cls.padsize]
265        mac = complex_plaintext[cls.padsize:cls.padsize+cls.macsize]
266        message = complex_plaintext[cls.padsize+cls.macsize:]
267
268        expmac = cls.hashmod.new(cofounder+'\x00'*cls.macsize+message).digest()
269        if not _mac_equal(mac, expmac):
270            raise InvalidChecksum('ciphertext integrity failure')
271        return message
272
273    @classmethod
274    def mit_des_string_to_key(cls,string,salt):
275
276        def fixparity(deskey):
277            temp = ''
278            for byte in deskey:
279                t = (bin(ord(byte))[2:]).rjust(8,'0')
280                if t[:7].count('1') %2 == 0:
281                    temp+= chr(int(t[:7]+'1',2))
282                else:
283                    temp+= chr(int(t[:7]+'0',2))
284            return temp
285
286        def addparity(l1):
287            temp = list()
288            for byte in l1:
289                if (bin(byte).count('1') % 2) == 0:
290                    byte = (byte << 1)|0b00000001
291                else:
292                    byte = (byte << 1)&0b11111110
293                temp.append(byte)
294            return temp
295
296        def XOR(l1,l2):
297            temp = list()
298            for b1,b2 in zip(l1,l2):
299                temp.append((b1^b2)&0b01111111)
300
301            return temp
302
303        odd = True
304        s = string + salt
305        tempstring = [0,0,0,0,0,0,0,0]
306        s = s + '\x00'*( 8- (len(s)%8)) #pad(s); /* with nulls to 8 byte boundary */
307
308        for block in [s[i:i+8] for i in range(0, len(s), 8)]:
309            temp56 = list()
310            #removeMSBits
311            for byte in block:
312                temp56.append(ord(byte)&0b01111111)
313
314            #reverse
315            if odd == False:
316                bintemp = ''
317                for byte in temp56:
318                    bintemp += (bin(byte)[2:]).rjust(7,'0')
319                bintemp = bintemp[::-1]
320
321                temp56 = list()
322                for bits7 in [bintemp[i:i+7] for i in range(0, len(bintemp), 7)]:
323                    temp56.append(int(bits7,2))
324
325            odd = not odd
326
327            tempstring = XOR(tempstring,temp56)
328
329        tempkey = ''.join(chr(byte) for byte in addparity(tempstring))
330        if _is_weak_des_key(tempkey):
331            tempkey[7] = chr(ord(tempkey[7]) ^ 0xF0)
332
333        cipher = DES.new(tempkey, DES.MODE_CBC, tempkey)
334        chekcsumkey = cipher.encrypt(s)[-8:]
335        chekcsumkey = fixparity(chekcsumkey)
336        if _is_weak_des_key(chekcsumkey):
337            chekcsumkey[7] = chr(ord(chekcsumkey[7]) ^ 0xF0)
338
339        return Key(cls.enctype, chekcsumkey)
340
341    @classmethod
342    def basic_encrypt(cls, key, plaintext):
343        assert len(plaintext) % 8 == 0
344        des = DES.new(key.contents, DES.MODE_CBC, '\0' * 8)
345        return des.encrypt(plaintext)
346
347    @classmethod
348    def basic_decrypt(cls, key, ciphertext):
349        assert len(ciphertext) % 8 == 0
350        des = DES.new(key.contents, DES.MODE_CBC, '\0' * 8)
351        return des.decrypt(ciphertext)
352
353    @classmethod
354    def string_to_key(cls, string, salt, params):
355        if params is not None and params != '':
356            raise ValueError('Invalid DES string-to-key parameters')
357        key = cls.mit_des_string_to_key(string, salt)
358        return key
359
360
361
362class _DES3CBC(_SimplifiedEnctype):
363    enctype = Enctype.DES3
364    keysize = 24
365    seedsize = 21
366    blocksize = 8
367    padsize = 8
368    macsize = 20
369    hashmod = SHA
370
371    @classmethod
372    def random_to_key(cls, seed):
373        # XXX Maybe reframe as _DESEnctype.random_to_key and use that
374        # way from DES3 random-to-key when DES is implemented, since
375        # MIT does this instead of the RFC 3961 random-to-key.
376        def expand(seed):
377            def parity(b):
378                # Return b with the low-order bit set to yield odd parity.
379                b &= ~1
380                return b if bin(b & ~1).count('1') % 2 else b | 1
381            assert len(seed) == 7
382            firstbytes = [parity(ord(b) & ~1) for b in seed]
383            lastbyte = parity(sum((ord(seed[i])&1) << i+1 for i in xrange(7)))
384            keybytes = ''.join(chr(b) for b in firstbytes + [lastbyte])
385            if _is_weak_des_key(keybytes):
386                keybytes[7] = chr(ord(keybytes[7]) ^ 0xF0)
387            return keybytes
388
389        if len(seed) != 21:
390            raise ValueError('Wrong seed length')
391        k1, k2, k3 = expand(seed[:7]), expand(seed[7:14]), expand(seed[14:])
392        return Key(cls.enctype, k1 + k2 + k3)
393
394    @classmethod
395    def string_to_key(cls, string, salt, params):
396        if params is not None and params != '':
397            raise ValueError('Invalid DES3 string-to-key parameters')
398        k = cls.random_to_key(_nfold(string + salt, 21))
399        return cls.derive(k, 'kerberos')
400
401    @classmethod
402    def basic_encrypt(cls, key, plaintext):
403        assert len(plaintext) % 8 == 0
404        des3 = DES3.new(key.contents, AES.MODE_CBC, '\0' * 8)
405        return des3.encrypt(plaintext)
406
407    @classmethod
408    def basic_decrypt(cls, key, ciphertext):
409        assert len(ciphertext) % 8 == 0
410        des3 = DES3.new(key.contents, AES.MODE_CBC, '\0' * 8)
411        return des3.decrypt(ciphertext)
412
413
414class _AESEnctype(_SimplifiedEnctype):
415    # Base class for aes128-cts and aes256-cts.
416    blocksize = 16
417    padsize = 1
418    macsize = 12
419    hashmod = SHA
420
421    @classmethod
422    def string_to_key(cls, string, salt, params):
423        (iterations,) = unpack('>L', params or '\x00\x00\x10\x00')
424        prf = lambda p, s: HMAC.new(p, s, SHA).digest()
425        seed = PBKDF2(string, salt, cls.seedsize, iterations, prf)
426        tkey = cls.random_to_key(seed)
427        return cls.derive(tkey, 'kerberos')
428
429    @classmethod
430    def basic_encrypt(cls, key, plaintext):
431        assert len(plaintext) >= 16
432        aes = AES.new(key.contents, AES.MODE_CBC, '\0' * 16)
433        ctext = aes.encrypt(_zeropad(plaintext, 16))
434        if len(plaintext) > 16:
435            # Swap the last two ciphertext blocks and truncate the
436            # final block to match the plaintext length.
437            lastlen = len(plaintext) % 16 or 16
438            ctext = ctext[:-32] + ctext[-16:] + ctext[-32:-16][:lastlen]
439        return ctext
440
441    @classmethod
442    def basic_decrypt(cls, key, ciphertext):
443        assert len(ciphertext) >= 16
444        aes = AES.new(key.contents, AES.MODE_ECB)
445        if len(ciphertext) == 16:
446            return aes.decrypt(ciphertext)
447        # Split the ciphertext into blocks.  The last block may be partial.
448        cblocks = [ciphertext[p:p+16] for p in xrange(0, len(ciphertext), 16)]
449        lastlen = len(cblocks[-1])
450        # CBC-decrypt all but the last two blocks.
451        prev_cblock = '\0' * 16
452        plaintext = ''
453        for b in cblocks[:-2]:
454            plaintext += _xorbytes(aes.decrypt(b), prev_cblock)
455            prev_cblock = b
456        # Decrypt the second-to-last cipher block.  The left side of
457        # the decrypted block will be the final block of plaintext
458        # xor'd with the final partial cipher block; the right side
459        # will be the omitted bytes of ciphertext from the final
460        # block.
461        b = aes.decrypt(cblocks[-2])
462        lastplaintext =_xorbytes(b[:lastlen], cblocks[-1])
463        omitted = b[lastlen:]
464        # Decrypt the final cipher block plus the omitted bytes to get
465        # the second-to-last plaintext block.
466        plaintext += _xorbytes(aes.decrypt(cblocks[-1] + omitted), prev_cblock)
467        return plaintext + lastplaintext
468
469
470class _AES128CTS(_AESEnctype):
471    enctype = Enctype.AES128
472    keysize = 16
473    seedsize = 16
474
475
476class _AES256CTS(_AESEnctype):
477    enctype = Enctype.AES256
478    keysize = 32
479    seedsize = 32
480
481
482class _RC4(_EnctypeProfile):
483    enctype = Enctype.RC4
484    keysize = 16
485    seedsize = 16
486
487    @staticmethod
488    def usage_str(keyusage):
489        # Return a four-byte string for an RFC 3961 keyusage, using
490        # the RFC 4757 rules.  Per the errata, do not map 9 to 8.
491        table = {3: 8, 23: 13}
492        msusage = table[keyusage] if keyusage in table else keyusage
493        return pack('<I', msusage)
494
495    @classmethod
496    def string_to_key(cls, string, salt, params):
497        utf16string = string.decode('UTF-8').encode('UTF-16LE')
498        return Key(cls.enctype, MD4.new(utf16string).digest())
499
500    @classmethod
501    def encrypt(cls, key, keyusage, plaintext, confounder):
502        if confounder is None:
503            confounder = get_random_bytes(8)
504        ki = HMAC.new(key.contents, cls.usage_str(keyusage), MD5).digest()
505        cksum = HMAC.new(ki, confounder + plaintext, MD5).digest()
506        ke = HMAC.new(ki, cksum, MD5).digest()
507        return cksum + ARC4.new(ke).encrypt(confounder + plaintext)
508
509    @classmethod
510    def decrypt(cls, key, keyusage, ciphertext):
511        if len(ciphertext) < 24:
512            raise ValueError('ciphertext too short')
513        cksum, basic_ctext = ciphertext[:16], ciphertext[16:]
514        ki = HMAC.new(key.contents, cls.usage_str(keyusage), MD5).digest()
515        ke = HMAC.new(ki, cksum, MD5).digest()
516        basic_plaintext = ARC4.new(ke).decrypt(basic_ctext)
517        exp_cksum = HMAC.new(ki, basic_plaintext, MD5).digest()
518        ok = _mac_equal(cksum, exp_cksum)
519        if not ok and keyusage == 9:
520            # Try again with usage 8, due to RFC 4757 errata.
521            ki = HMAC.new(key.contents, pack('<I', 8), MD5).digest()
522            exp_cksum = HMAC.new(ki, basic_plaintext, MD5).digest()
523            ok = _mac_equal(cksum, exp_cksum)
524        if not ok:
525            raise InvalidChecksum('ciphertext integrity failure')
526        # Discard the confounder.
527        return basic_plaintext[8:]
528
529    @classmethod
530    def prf(cls, key, string):
531        return HMAC.new(key.contents, string, SHA).digest()
532
533
534class _ChecksumProfile(object):
535    # Base class for checksum profiles.  Usable checksum classes must
536    # define:
537    #   * checksum
538    #   * verify (if verification is not just checksum-and-compare)
539    @classmethod
540    def verify(cls, key, keyusage, text, cksum):
541        expected = cls.checksum(key, keyusage, text)
542        if not _mac_equal(cksum, expected):
543            raise InvalidChecksum('checksum verification failure')
544
545
546class _SimplifiedChecksum(_ChecksumProfile):
547    # Base class for checksums using the RFC 3961 simplified profile.
548    # Defines the checksum and verify methods.  Subclasses must
549    # define:
550    #   * macsize: Size of checksum in bytes
551    #   * enc: Profile of associated enctype
552
553    @classmethod
554    def checksum(cls, key, keyusage, text):
555        kc = cls.enc.derive(key, pack('>IB', keyusage, 0x99))
556        hmac = HMAC.new(kc.contents, text, cls.enc.hashmod).digest()
557        return hmac[:cls.macsize]
558
559    @classmethod
560    def verify(cls, key, keyusage, text, cksum):
561        if key.enctype != cls.enc.enctype:
562            raise ValueError('Wrong key type for checksum')
563        super(_SimplifiedChecksum, cls).verify(key, keyusage, text, cksum)
564
565
566class _SHA1AES128(_SimplifiedChecksum):
567    macsize = 12
568    enc = _AES128CTS
569
570
571class _SHA1AES256(_SimplifiedChecksum):
572    macsize = 12
573    enc = _AES256CTS
574
575
576class _SHA1DES3(_SimplifiedChecksum):
577    macsize = 20
578    enc = _DES3CBC
579
580
581class _HMACMD5(_ChecksumProfile):
582    @classmethod
583    def checksum(cls, key, keyusage, text):
584        ksign = HMAC.new(key.contents, 'signaturekey\0', MD5).digest()
585        md5hash = MD5.new(_RC4.usage_str(keyusage) + text).digest()
586        return HMAC.new(ksign, md5hash, MD5).digest()
587
588    @classmethod
589    def verify(cls, key, keyusage, text, cksum):
590        if key.enctype != Enctype.RC4:
591            raise ValueError('Wrong key type for checksum')
592        super(_HMACMD5, cls).verify(key, keyusage, text, cksum)
593
594
595_enctype_table = {
596    Enctype.DES_MD5: _DESCBC,
597    Enctype.DES3: _DES3CBC,
598    Enctype.AES128: _AES128CTS,
599    Enctype.AES256: _AES256CTS,
600    Enctype.RC4: _RC4
601}
602
603
604_checksum_table = {
605    Cksumtype.SHA1_DES3: _SHA1DES3,
606    Cksumtype.SHA1_AES128: _SHA1AES128,
607    Cksumtype.SHA1_AES256: _SHA1AES256,
608    Cksumtype.HMAC_MD5: _HMACMD5,
609    0xffffff76: _HMACMD5
610}
611
612
613def _get_enctype_profile(enctype):
614    if enctype not in _enctype_table:
615        raise ValueError('Invalid enctype %d' % enctype)
616    return _enctype_table[enctype]
617
618
619def _get_checksum_profile(cksumtype):
620    if cksumtype not in _checksum_table:
621        raise ValueError('Invalid cksumtype %d' % cksumtype)
622    return _checksum_table[cksumtype]
623
624
625class Key(object):
626    def __init__(self, enctype, contents):
627        e = _get_enctype_profile(enctype)
628        if len(contents) != e.keysize:
629            raise ValueError('Wrong key length')
630        self.enctype = enctype
631        self.contents = contents
632
633
634def random_to_key(enctype, seed):
635    e = _get_enctype_profile(enctype)
636    if len(seed) != e.seedsize:
637        raise ValueError('Wrong crypto seed length')
638    return e.random_to_key(seed)
639
640
641def string_to_key(enctype, string, salt, params=None):
642    e = _get_enctype_profile(enctype)
643    return e.string_to_key(string, salt, params)
644
645
646def encrypt(key, keyusage, plaintext, confounder=None):
647    e = _get_enctype_profile(key.enctype)
648    return e.encrypt(key, keyusage, plaintext, confounder)
649
650
651def decrypt(key, keyusage, ciphertext):
652    # Throw InvalidChecksum on checksum failure.  Throw ValueError on
653    # invalid key enctype or malformed ciphertext.
654    e = _get_enctype_profile(key.enctype)
655    return e.decrypt(key, keyusage, ciphertext)
656
657
658def prf(key, string):
659    e = _get_enctype_profile(key.enctype)
660    return e.prf(key, string)
661
662
663def make_checksum(cksumtype, key, keyusage, text):
664    c = _get_checksum_profile(cksumtype)
665    return c.checksum(key, keyusage, text)
666
667
668def verify_checksum(cksumtype, key, keyusage, text, cksum):
669    # Throw InvalidChecksum exception on checksum failure.  Throw
670    # ValueError on invalid cksumtype, invalid key enctype, or
671    # malformed checksum.
672    c = _get_checksum_profile(cksumtype)
673    c.verify(key, keyusage, text, cksum)
674
675
676def cf2(enctype, key1, key2, pepper1, pepper2):
677    # Combine two keys and two pepper strings to produce a result key
678    # of type enctype, using the RFC 6113 KRB-FX-CF2 function.
679    def prfplus(key, pepper, l):
680        # Produce l bytes of output using the RFC 6113 PRF+ function.
681        out = ''
682        count = 1
683        while len(out) < l:
684            out += prf(key, chr(count) + pepper)
685            count += 1
686        return out[:l]
687
688    e = _get_enctype_profile(enctype)
689    return e.random_to_key(_xorbytes(prfplus(key1, pepper1, e.seedsize),
690                                     prfplus(key2, pepper2, e.seedsize)))
691
692
693if __name__ == '__main__':
694    def h(hexstr):
695        return unhexlify(hexstr)
696
697    # AES128 encrypt and decrypt
698    kb = h('9062430C8CDA3388922E6D6A509F5B7A')
699    conf = h('94B491F481485B9A0678CD3C4EA386AD')
700    keyusage = 2
701    plain = '9 bytesss'
702    ctxt = h('68FB9679601F45C78857B2BF820FD6E53ECA8D42FD4B1D7024A09205ABB7CD2E'
703             'C26C355D2F')
704    k = Key(Enctype.AES128, kb)
705    assert(encrypt(k, keyusage, plain, conf) == ctxt)
706    assert(decrypt(k, keyusage, ctxt) == plain)
707
708    # AES256 encrypt and decrypt
709    kb = h('F1C795E9248A09338D82C3F8D5B567040B0110736845041347235B1404231398')
710    conf = h('E45CA518B42E266AD98E165E706FFB60')
711    keyusage = 4
712    plain = '30 bytes bytes bytes bytes byt'
713    ctxt = h('D1137A4D634CFECE924DBC3BF6790648BD5CFF7DE0E7B99460211D0DAEF3D79A'
714             '295C688858F3B34B9CBD6EEBAE81DAF6B734D4D498B6714F1C1D')
715    k = Key(Enctype.AES256, kb)
716    assert(encrypt(k, keyusage, plain, conf) == ctxt)
717    assert(decrypt(k, keyusage, ctxt) == plain)
718
719    # AES128 checksum
720    kb = h('9062430C8CDA3388922E6D6A509F5B7A')
721    keyusage = 3
722    plain = 'eight nine ten eleven twelve thirteen'
723    cksum = h('01A4B088D45628F6946614E3')
724    k = Key(Enctype.AES128, kb)
725    verify_checksum(Cksumtype.SHA1_AES128, k, keyusage, plain, cksum)
726
727    # AES256 checksum
728    kb = h('B1AE4CD8462AFF1677053CC9279AAC30B796FB81CE21474DD3DDBCFEA4EC76D7')
729    keyusage = 4
730    plain = 'fourteen'
731    cksum = h('E08739E3279E2903EC8E3836')
732    k = Key(Enctype.AES256, kb)
733    verify_checksum(Cksumtype.SHA1_AES256, k, keyusage, plain, cksum)
734
735    # AES128 string-to-key
736    string = 'password'
737    salt = 'ATHENA.MIT.EDUraeburn'
738    params = h('00000002')
739    kb = h('C651BF29E2300AC27FA469D693BDDA13')
740    k = string_to_key(Enctype.AES128, string, salt, params)
741    assert(k.contents == kb)
742
743    # AES256 string-to-key
744    string = 'X' * 64
745    salt = 'pass phrase equals block size'
746    params = h('000004B0')
747    kb = h('89ADEE3608DB8BC71F1BFBFE459486B05618B70CBAE22092534E56C553BA4B34')
748    k = string_to_key(Enctype.AES256, string, salt, params)
749    assert(k.contents == kb)
750
751    # AES128 prf
752    kb = h('77B39A37A868920F2A51F9DD150C5717')
753    k = string_to_key(Enctype.AES128, 'key1', 'key1')
754    assert(prf(k, '\x01\x61') == kb)
755
756    # AES256 prf
757    kb = h('0D674DD0F9A6806525A4D92E828BD15A')
758    k = string_to_key(Enctype.AES256, 'key2', 'key2')
759    assert(prf(k, '\x02\x62') == kb)
760
761    # AES128 cf2
762    kb = h('97DF97E4B798B29EB31ED7280287A92A')
763    k1 = string_to_key(Enctype.AES128, 'key1', 'key1')
764    k2 = string_to_key(Enctype.AES128, 'key2', 'key2')
765    k = cf2(Enctype.AES128, k1, k2, 'a', 'b')
766    assert(k.contents == kb)
767
768    # AES256 cf2
769    kb = h('4D6CA4E629785C1F01BAF55E2E548566B9617AE3A96868C337CB93B5E72B1C7B')
770    k1 = string_to_key(Enctype.AES256, 'key1', 'key1')
771    k2 = string_to_key(Enctype.AES256, 'key2', 'key2')
772    k = cf2(Enctype.AES256, k1, k2, 'a', 'b')
773    assert(k.contents == kb)
774
775    # DES3 encrypt and decrypt
776    kb = h('0DD52094E0F41CECCB5BE510A764B35176E3981332F1E598')
777    conf = h('94690A17B2DA3C9B')
778    keyusage = 3
779    plain = '13 bytes byte'
780    ctxt = h('839A17081ECBAFBCDC91B88C6955DD3C4514023CF177B77BF0D0177A16F705E8'
781             '49CB7781D76A316B193F8D30')
782    k = Key(Enctype.DES3, kb)
783    assert(encrypt(k, keyusage, plain, conf) == ctxt)
784    assert(decrypt(k, keyusage, ctxt) == _zeropad(plain, 8))
785
786    # DES3 string-to-key
787    string = 'password'
788    salt = 'ATHENA.MIT.EDUraeburn'
789    kb = h('850BB51358548CD05E86768C313E3BFEF7511937DCF72C3E')
790    k = string_to_key(Enctype.DES3, string, salt)
791    assert(k.contents == kb)
792
793    # DES3 checksum
794    kb = h('7A25DF8992296DCEDA0E135BC4046E2375B3C14C98FBC162')
795    keyusage = 2
796    plain = 'six seven'
797    cksum = h('0EEFC9C3E049AABC1BA5C401677D9AB699082BB4')
798    k = Key(Enctype.DES3, kb)
799    verify_checksum(Cksumtype.SHA1_DES3, k, keyusage, plain, cksum)
800
801    # DES3 cf2
802    kb = h('E58F9EB643862C13AD38E529313462A7F73E62834FE54A01')
803    k1 = string_to_key(Enctype.DES3, 'key1', 'key1')
804    k2 = string_to_key(Enctype.DES3, 'key2', 'key2')
805    k = cf2(Enctype.DES3, k1, k2, 'a', 'b')
806    assert(k.contents == kb)
807
808    # RC4 encrypt and decrypt
809    kb = h('68F263DB3FCE15D031C9EAB02D67107A')
810    conf = h('37245E73A45FBF72')
811    keyusage = 4
812    plain = '30 bytes bytes bytes bytes byt'
813    ctxt = h('95F9047C3AD75891C2E9B04B16566DC8B6EB9CE4231AFB2542EF87A7B5A0F260'
814             'A99F0460508DE0CECC632D07C354124E46C5D2234EB8')
815    k = Key(Enctype.RC4, kb)
816    assert(encrypt(k, keyusage, plain, conf) == ctxt)
817    assert(decrypt(k, keyusage, ctxt) == plain)
818
819    # RC4 string-to-key
820    string = 'foo'
821    kb = h('AC8E657F83DF82BEEA5D43BDAF7800CC')
822    k = string_to_key(Enctype.RC4, string, None)
823    assert(k.contents == kb)
824
825    # RC4 checksum
826    kb = h('F7D3A155AF5E238A0B7A871A96BA2AB2')
827    keyusage = 6
828    plain = 'seventeen eighteen nineteen twenty'
829    cksum = h('EB38CC97E2230F59DA4117DC5859D7EC')
830    k = Key(Enctype.RC4, kb)
831    verify_checksum(Cksumtype.HMAC_MD5, k, keyusage, plain, cksum)
832
833    # RC4 cf2
834    kb = h('24D7F6B6BAE4E5C00D2082C5EBAB3672')
835    k1 = string_to_key(Enctype.RC4, 'key1', 'key1')
836    k2 = string_to_key(Enctype.RC4, 'key2', 'key2')
837    k = cf2(Enctype.RC4, k1, k2, 'a', 'b')
838    assert(k.contents == kb)
839
840    # DES string-to-key
841    string = 'password'
842    salt = 'ATHENA.MIT.EDUraeburn'
843    kb = h('cbc22fae235298e3')
844    k = string_to_key(Enctype.DES_MD5, string, salt)
845    assert(k.contents == kb)
846
847    # DES string-to-key
848    string = 'potatoe'
849    salt = 'WHITEHOUSE.GOVdanny'
850    kb = h('df3d32a74fd92a01')
851    k = string_to_key(Enctype.DES_MD5, string, salt)
852    assert(k.contents == kb)
853
854
855