1#!/usr/bin/env python
2
3import binascii
4import argparse
5import re
6import subprocess
7import sys
8import os.path
9import logging
10import time
11from collections import OrderedDict
12import multiprocessing
13from multiprocessing.pool import ThreadPool
14
15SUPPORTED_ALGORITHMS = {
16    "AES-128/CFB": "aes-128-cfb",
17    "AES-192/CFB": "aes-192-cfb",
18    "AES-256/CFB": "aes-256-cfb",
19    "AES-128/GCM": "aes-128-gcm",
20    "AES-192/GCM": "aes-192-gcm",
21    "AES-256/GCM": "aes-256-gcm",
22    "AES-128/OCB": "aes-128-ocb",
23    "AES-128/XTS": "aes-128-xts",
24    "AES-256/XTS": "aes-256-xts",
25    "ChaCha20Poly1305": "chacha20poly1305",
26}
27
28class VecDocument:
29    def __init__(self, filepath):
30        self.data = OrderedDict()
31        last_testcase_number = 1
32        current_testcase_number = 1
33        current_group_name = ""
34        last_group_name = ""
35        current_testcase = {}
36
37        PATTERN_GROUPHEADER = "^\[(.+)\]$"
38        PATTERN_KEYVALUE = "^\s*([a-zA-Z]+)\s*=(.*)$"
39
40        with open(filepath, 'r') as f:
41            # Append one empty line to simplify parsing
42            lines = f.read().splitlines() + ["\n"]
43
44            for line in lines:
45                line = line.strip()
46                if line.startswith("#"):
47                    pass # Skip
48                elif line == "":
49                    current_testcase_number += 1
50                elif re.match(PATTERN_GROUPHEADER, line):
51                    match = re.match(PATTERN_GROUPHEADER, line)
52                    current_group_name = match.group(1)
53                elif re.match(PATTERN_KEYVALUE, line):
54                    match = re.match(PATTERN_KEYVALUE, line)
55                    key = match.group(1)
56                    value = match.group(2).strip()
57                    current_testcase[key] = value
58
59                if current_testcase_number != last_testcase_number:
60                    if not current_group_name in self.data:
61                        self.data[current_group_name] = []
62                    if len(current_testcase) != 0:
63                        self.data[current_group_name].append(current_testcase)
64                    current_testcase = {}
65                    last_testcase_number = current_testcase_number
66
67                if current_group_name != last_group_name:
68                    last_group_name = current_group_name
69                    # Reset testcase number
70                    last_testcase_number = 1
71                    current_testcase_number = 1
72
73    def get_data(self):
74        return self.data
75
76TESTS_RUN = 0
77TESTS_FAILED = 0
78
79class TestLogHandler(logging.StreamHandler, object):
80    def emit(self, record):
81        # Do the default stuff first
82        super(TestLogHandler, self).emit(record)
83        if record.levelno >= logging.ERROR:
84            global TESTS_FAILED
85            TESTS_FAILED += 1
86
87def setup_logging(options):
88    if options.verbose:
89        log_level = logging.DEBUG
90    elif options.quiet:
91        log_level = logging.WARNING
92    else:
93        log_level = logging.INFO
94
95    lh = TestLogHandler(sys.stdout)
96    lh.setFormatter(logging.Formatter('%(levelname) 7s: %(message)s'))
97    logging.getLogger().addHandler(lh)
98    logging.getLogger().setLevel(log_level)
99
100def test_cipher_kat(cli_binary, data):
101    iv = data['Nonce']
102    key = data['Key']
103    ad = data['AD'] if 'AD' in data else ""
104    plaintext = data['In'].lower()
105    ciphertext = data['Out'].lower()
106    algorithm = data['Algorithm']
107    direction = data['Direction']
108
109    mode = SUPPORTED_ALGORITHMS.get(algorithm)
110    if mode is None:
111        raise Exception("Unknown algorithm: '" + algorithm + "'")
112
113    cmd = [
114        cli_binary,
115        "encryption",
116        "--mode=%s" % mode,
117        "--iv=%s" % iv,
118        "--ad=%s" % ad,
119        "--key=%s" % key]
120    if direction == "decrypt":
121        cmd += ['--decrypt']
122
123    if direction == "decrypt":
124        invalue = ciphertext
125    else:
126        invalue = plaintext
127
128    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.PIPE)
129    out_raw = p.communicate(input=binascii.unhexlify(invalue))[0]
130    output = binascii.hexlify(out_raw).decode("UTF-8").lower()
131
132    expected = plaintext if direction == "decrypt" else ciphertext
133    if expected != output:
134        logging.error("For test %s got %s expected %s" % (data['testname'], output, expected))
135
136def get_testdata(document, max_tests):
137    out = []
138    for algorithm in document:
139        if algorithm in SUPPORTED_ALGORITHMS:
140            testcase_number = 0
141            for testcase in document[algorithm]:
142                testcase_number += 1
143                for direction in ['encrypt', 'decrypt']:
144                    testname = "{} no {:0>3} ({})".format(
145                        algorithm.lower(), testcase_number, direction)
146                    testname = re.sub("[^a-z0-9-]", "_", testname)
147                    testname = re.sub("_+", "_", testname)
148                    testname = testname.strip("_")
149                    test = {'testname': testname}
150                    for key in testcase:
151                        value = testcase[key]
152                        test[key] = value
153                    test['Algorithm'] = algorithm
154                    test['Direction'] = direction
155
156                    out.append(test)
157
158                if max_tests > 0 and testcase_number > max_tests:
159                    break
160    return out
161
162def main(args=None):
163    if args is None:
164        args = sys.argv
165
166    parser = argparse.ArgumentParser(description="")
167    parser.add_argument('cli_binary', help='path to the botan cli binary')
168    parser.add_argument('--max-tests', type=int, default=50, metavar="M")
169    parser.add_argument('--threads', type=int, default=0, metavar="T")
170    parser.add_argument('--verbose', action='store_true', default=False)
171    parser.add_argument('--quiet', action='store_true', default=False)
172    args = parser.parse_args()
173
174    setup_logging(args)
175
176    cli_binary = args.cli_binary
177    max_tests = args.max_tests
178    threads = args.threads
179
180    if threads == 0:
181        threads = multiprocessing.cpu_count()
182
183    test_data_dir = os.path.join('src', 'tests', 'data')
184
185    mode_test_data = [os.path.join(test_data_dir, 'modes', 'cfb.vec'),
186                      os.path.join(test_data_dir, 'aead', 'gcm.vec'),
187                      os.path.join(test_data_dir, 'aead', 'ocb.vec'),
188                      os.path.join(test_data_dir, 'modes', 'xts.vec'),
189                      os.path.join(test_data_dir, 'aead', 'chacha20poly1305.vec')]
190
191    kats = []
192    for f in mode_test_data:
193        vecfile = VecDocument(f)
194        kats += get_testdata(vecfile.get_data(), max_tests)
195
196    start_time = time.time()
197
198    if threads > 1:
199        pool = ThreadPool(processes=threads)
200        results = []
201        for test in kats:
202            results.append(pool.apply_async(test_cipher_kat, (cli_binary, test)))
203
204        for result in results:
205            result.get()
206    else:
207        for test in kats:
208            test_cipher_kat(test)
209
210    end_time = time.time()
211
212    print("Ran %d tests with %d failures in %.02f seconds" % (
213        len(kats), TESTS_FAILED, end_time - start_time))
214
215    if TESTS_FAILED > 0:
216        return 1
217    return 0
218
219if __name__ == '__main__':
220    sys.exit(main())
221