1#
2#  SelfTest/Util/test_strxor.py: Self-test for XORing
3#
4# ===================================================================
5#
6# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
7# All rights reserved.
8#
9# Redistribution and use in source and binary forms, with or without
10# modification, are permitted provided that the following conditions
11# are met:
12#
13# 1. Redistributions of source code must retain the above copyright
14#    notice, this list of conditions and the following disclaimer.
15# 2. Redistributions in binary form must reproduce the above copyright
16#    notice, this list of conditions and the following disclaimer in
17#    the documentation and/or other materials provided with the
18#    distribution.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31# POSSIBILITY OF SUCH DAMAGE.
32# ===================================================================
33
34import unittest
35from binascii import unhexlify, hexlify
36
37from Crypto.SelfTest.st_common import list_test_cases
38from Crypto.Util.strxor import strxor, strxor_c
39
40
41class StrxorTests(unittest.TestCase):
42
43    def test1(self):
44        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
45        term2 = unhexlify(b"383d4ba020573314395b")
46        result = unhexlify(b"c70ed123c59a7fcb6f12")
47        self.assertEqual(strxor(term1, term2), result)
48        self.assertEqual(strxor(term2, term1), result)
49
50    def test2(self):
51        es = b""
52        self.assertEqual(strxor(es, es), es)
53
54    def test3(self):
55        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
56        all_zeros = b"\x00" * len(term1)
57        self.assertEqual(strxor(term1, term1), all_zeros)
58
59    def test_wrong_length(self):
60        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
61        term2 = unhexlify(b"ff339a83e5cd4cdf564990")
62        self.assertRaises(ValueError, strxor, term1, term2)
63
64    def test_bytearray(self):
65        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
66        term1_ba = bytearray(term1)
67        term2 = unhexlify(b"383d4ba020573314395b")
68        result = unhexlify(b"c70ed123c59a7fcb6f12")
69
70        self.assertEqual(strxor(term1_ba, term2), result)
71
72    def test_memoryview(self):
73        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
74        term1_mv = memoryview(term1)
75        term2 = unhexlify(b"383d4ba020573314395b")
76        result = unhexlify(b"c70ed123c59a7fcb6f12")
77
78        self.assertEqual(strxor(term1_mv, term2), result)
79
80    def test_output_bytearray(self):
81        """Verify result can be stored in pre-allocated memory"""
82
83        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
84        term2 = unhexlify(b"383d4ba020573314395b")
85        original_term1 = term1[:]
86        original_term2 = term2[:]
87        expected_xor = unhexlify(b"c70ed123c59a7fcb6f12")
88        output = bytearray(len(term1))
89
90        result = strxor(term1, term2, output=output)
91
92        self.assertEqual(result, None)
93        self.assertEqual(output, expected_xor)
94        self.assertEqual(term1, original_term1)
95        self.assertEqual(term2, original_term2)
96
97    def test_output_memoryview(self):
98        """Verify result can be stored in pre-allocated memory"""
99
100        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
101        term2 = unhexlify(b"383d4ba020573314395b")
102        original_term1 = term1[:]
103        original_term2 = term2[:]
104        expected_xor = unhexlify(b"c70ed123c59a7fcb6f12")
105        output = memoryview(bytearray(len(term1)))
106
107        result = strxor(term1, term2, output=output)
108
109        self.assertEqual(result, None)
110        self.assertEqual(output, expected_xor)
111        self.assertEqual(term1, original_term1)
112        self.assertEqual(term2, original_term2)
113
114    def test_output_overlapping_bytearray(self):
115        """Verify result can be stored in overlapping memory"""
116
117        term1 = bytearray(unhexlify(b"ff339a83e5cd4cdf5649"))
118        term2 = unhexlify(b"383d4ba020573314395b")
119        original_term2 = term2[:]
120        expected_xor = unhexlify(b"c70ed123c59a7fcb6f12")
121
122        result = strxor(term1, term2, output=term1)
123
124        self.assertEqual(result, None)
125        self.assertEqual(term1, expected_xor)
126        self.assertEqual(term2, original_term2)
127
128    def test_output_overlapping_memoryview(self):
129        """Verify result can be stored in overlapping memory"""
130
131        term1 = memoryview(bytearray(unhexlify(b"ff339a83e5cd4cdf5649")))
132        term2 = unhexlify(b"383d4ba020573314395b")
133        original_term2 = term2[:]
134        expected_xor = unhexlify(b"c70ed123c59a7fcb6f12")
135
136        result = strxor(term1, term2, output=term1)
137
138        self.assertEqual(result, None)
139        self.assertEqual(term1, expected_xor)
140        self.assertEqual(term2, original_term2)
141
142    def test_output_ro_bytes(self):
143        """Verify result cannot be stored in read-only memory"""
144
145        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
146        term2 = unhexlify(b"383d4ba020573314395b")
147
148        self.assertRaises(TypeError, strxor, term1, term2, output=term1)
149
150    def test_output_ro_memoryview(self):
151        """Verify result cannot be stored in read-only memory"""
152
153        term1 = memoryview(unhexlify(b"ff339a83e5cd4cdf5649"))
154        term2 = unhexlify(b"383d4ba020573314395b")
155
156        self.assertRaises(TypeError, strxor, term1, term2, output=term1)
157
158    def test_output_incorrect_length(self):
159        """Verify result cannot be stored in memory of incorrect length"""
160
161        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
162        term2 = unhexlify(b"383d4ba020573314395b")
163        output = bytearray(len(term1) - 1)
164
165        self.assertRaises(ValueError, strxor, term1, term2, output=output)
166
167
168class Strxor_cTests(unittest.TestCase):
169
170    def test1(self):
171        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
172        result = unhexlify(b"be72dbc2a48c0d9e1708")
173        self.assertEqual(strxor_c(term1, 65), result)
174
175    def test2(self):
176        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
177        self.assertEqual(strxor_c(term1, 0), term1)
178
179    def test3(self):
180        self.assertEqual(strxor_c(b"", 90), b"")
181
182    def test_wrong_range(self):
183        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
184        self.assertRaises(ValueError, strxor_c, term1, -1)
185        self.assertRaises(ValueError, strxor_c, term1, 256)
186
187    def test_bytearray(self):
188        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
189        term1_ba = bytearray(term1)
190        result = unhexlify(b"be72dbc2a48c0d9e1708")
191
192        self.assertEqual(strxor_c(term1_ba, 65), result)
193
194    def test_memoryview(self):
195        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
196        term1_mv = memoryview(term1)
197        result = unhexlify(b"be72dbc2a48c0d9e1708")
198
199        self.assertEqual(strxor_c(term1_mv, 65), result)
200
201    def test_output_bytearray(self):
202        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
203        original_term1 = term1[:]
204        expected_result = unhexlify(b"be72dbc2a48c0d9e1708")
205        output = bytearray(len(term1))
206
207        result = strxor_c(term1, 65, output=output)
208
209        self.assertEqual(result, None)
210        self.assertEqual(output, expected_result)
211        self.assertEqual(term1, original_term1)
212
213    def test_output_memoryview(self):
214        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
215        original_term1 = term1[:]
216        expected_result = unhexlify(b"be72dbc2a48c0d9e1708")
217        output = memoryview(bytearray(len(term1)))
218
219        result = strxor_c(term1, 65, output=output)
220
221        self.assertEqual(result, None)
222        self.assertEqual(output, expected_result)
223        self.assertEqual(term1, original_term1)
224
225    def test_output_overlapping_bytearray(self):
226        """Verify result can be stored in overlapping memory"""
227
228        term1 = bytearray(unhexlify(b"ff339a83e5cd4cdf5649"))
229        expected_xor = unhexlify(b"be72dbc2a48c0d9e1708")
230
231        result = strxor_c(term1, 65, output=term1)
232
233        self.assertEqual(result, None)
234        self.assertEqual(term1, expected_xor)
235
236    def test_output_overlapping_memoryview(self):
237        """Verify result can be stored in overlapping memory"""
238
239        term1 = memoryview(bytearray(unhexlify(b"ff339a83e5cd4cdf5649")))
240        expected_xor = unhexlify(b"be72dbc2a48c0d9e1708")
241
242        result = strxor_c(term1, 65, output=term1)
243
244        self.assertEqual(result, None)
245        self.assertEqual(term1, expected_xor)
246
247    def test_output_ro_bytes(self):
248        """Verify result cannot be stored in read-only memory"""
249
250        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
251
252        self.assertRaises(TypeError, strxor_c, term1, 65, output=term1)
253
254    def test_output_ro_memoryview(self):
255        """Verify result cannot be stored in read-only memory"""
256
257        term1 = memoryview(unhexlify(b"ff339a83e5cd4cdf5649"))
258        term2 = unhexlify(b"383d4ba020573314395b")
259
260        self.assertRaises(TypeError, strxor_c, term1, 65, output=term1)
261
262    def test_output_incorrect_length(self):
263        """Verify result cannot be stored in memory of incorrect length"""
264
265        term1 = unhexlify(b"ff339a83e5cd4cdf5649")
266        output = bytearray(len(term1) - 1)
267
268        self.assertRaises(ValueError, strxor_c, term1, 65, output=output)
269
270
271def get_tests(config={}):
272    tests = []
273    tests += list_test_cases(StrxorTests)
274    tests += list_test_cases(Strxor_cTests)
275    return tests
276
277
278if __name__ == '__main__':
279    suite = lambda: unittest.TestSuite(get_tests())
280    unittest.main(defaultTest='suite')
281