1# The MIT License (MIT)
2#
3# Copyright (c) 2014 Richard Moore
4#
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21# THE SOFTWARE.
22
23
24import sys
25sys.path.append('../pyaes')
26
27import os
28import random
29
30try:
31    from StringIO import StringIO
32except:
33    import io
34    StringIO = io.BytesIO
35
36import pyaes
37from pyaes.blockfeeder import Decrypter, Encrypter
38from pyaes import decrypt_stream, encrypt_stream
39from pyaes.util import to_bufferable
40
41
42key = os.urandom(32)
43
44plaintext = os.urandom(1000)
45
46for mode_name in pyaes.AESModesOfOperation:
47    mode = pyaes.AESModesOfOperation[mode_name]
48    print(mode.name)
49
50    kw = dict(key = key)
51    if mode_name in ('cbc', 'cfb', 'ofb'):
52        kw['iv'] = os.urandom(16)
53
54    encrypter = Encrypter(mode(**kw))
55    ciphertext = to_bufferable('')
56
57    # Feed the encrypter random number of bytes at a time
58    index = 0
59    while index < len(plaintext):
60        length = random.randint(1, 128)
61        if index + length > len(plaintext): length = len(plaintext) - index
62        ciphertext += encrypter.feed(plaintext[index: index + length])
63        index += length
64    ciphertext += encrypter.feed(None)
65
66    decrypter = Decrypter(mode(**kw))
67    decrypted = to_bufferable('')
68
69    # Feed the decrypter random number of bytes at a time
70    index = 0
71    while index < len(ciphertext):
72        length = random.randint(1, 128)
73        if index + length > len(ciphertext): length = len(ciphertext) - index
74        decrypted += decrypter.feed(ciphertext[index: index + length])
75        index += length
76    decrypted += decrypter.feed(None)
77
78    passed = decrypted == plaintext
79    cipher_length = len(ciphertext)
80    print("  cipher-length=%(cipher_length)s passed=%(passed)s" % locals())
81
82# Test block modes of operation with no padding
83plaintext = os.urandom(1024)
84
85for mode_name in ['ecb', 'cbc']:
86    mode = pyaes.AESModesOfOperation[mode_name]
87    print(mode.name + ' (no padding)')
88
89    kw = dict(key = key)
90    if mode_name == 'cbc':
91        kw['iv'] = os.urandom(16)
92
93    encrypter = Encrypter(mode(**kw), padding = pyaes.PADDING_NONE)
94    ciphertext = to_bufferable('')
95
96    # Feed the encrypter random number of bytes at a time
97    index = 0
98    while index < len(plaintext):
99        length = random.randint(1, 128)
100        if index + length > len(plaintext): length = len(plaintext) - index
101        ciphertext += encrypter.feed(plaintext[index: index + length])
102        index += length
103    ciphertext += encrypter.feed(None)
104
105    if len(ciphertext) != len(plaintext):
106        print('  failed to encrypt with correct padding')
107
108    decrypter = Decrypter(mode(**kw), padding = pyaes.PADDING_NONE)
109    decrypted = to_bufferable('')
110
111    # Feed the decrypter random number of bytes at a time
112    index = 0
113    while index < len(ciphertext):
114        length = random.randint(1, 128)
115        if index + length > len(ciphertext): length = len(ciphertext) - index
116        decrypted += decrypter.feed(ciphertext[index: index + length])
117        index += length
118    decrypted += decrypter.feed(None)
119
120    passed = decrypted == plaintext
121    cipher_length = len(ciphertext)
122    print("  cipher-length=%(cipher_length)s passed=%(passed)s" % locals())
123
124plaintext = os.urandom(1000)
125
126for mode_name in pyaes.AESModesOfOperation:
127    mode = pyaes.AESModesOfOperation[mode_name]
128    print(mode.name + ' (stream operations)')
129
130    kw = dict(key = key)
131    if mode_name in ('cbc', 'cfb', 'ofb'):
132        kw['iv'] = os.urandom(16)
133
134    moo = mode(**kw)
135    output = StringIO()
136    pyaes.encrypt_stream(moo, StringIO(plaintext), output)
137    output.seek(0)
138    ciphertext = output.read()
139
140    moo = mode(**kw)
141    output = StringIO()
142    pyaes.decrypt_stream(moo, StringIO(ciphertext), output)
143    output.seek(0)
144    decrypted = output.read()
145
146    passed = decrypted == plaintext
147    cipher_length = len(ciphertext)
148    print("  cipher-length=%(cipher_length)s passed=%(passed)s" % locals())
149
150# Issue #15
151# https://github.com/ricmoo/pyaes/issues/15
152# @TODO: These tests need a severe overhaul; they should use deterministic input, keys and IVs...
153def TestIssue15():
154    print('Issue #15')
155
156    key = b"abcdefghijklmnop"
157    iv = b"0123456789012345"
158    encrypter = pyaes.Encrypter(pyaes.AESModeOfOperationCBC(key, iv))
159
160    plaintext = b"Hello World!!!!!"
161
162    ciphertext = to_bufferable('')
163
164    ciphertext += encrypter.feed(plaintext)
165    ciphertext += encrypter.feed('')
166    ciphertext += encrypter.feed(plaintext)
167    ciphertext += encrypter.feed(None)
168    expected = b'(Ob\xe5\xae"\xdc\xb0\x84\xc5\x04\x04GQ\xd8.\x0e4\xd2b\xc1\x15\xe5\x11M\xfc\x9a\xd2\xd5\xc8xP\x00[\xd57\x92\x01\xbb\xc42\x18\xbc\xbf\x1ay\x19P'
169
170    decrypter = pyaes.Decrypter(pyaes.AESModeOfOperationCBC(key, iv))
171
172    output = to_bufferable('')
173
174    output += decrypter.feed('')
175    output += decrypter.feed(ciphertext)
176    output += decrypter.feed('')
177    output += decrypter.feed(None)
178
179    print("  passed=%(passed)s" % dict(passed = (ciphertext == expected and output == (plaintext + plaintext))))
180
181TestIssue15()
182