1#!/usr/bin/env python 2 3import binascii 4import argparse 5import re 6import subprocess 7import sys 8import os.path 9import logging 10import time 11from collections import OrderedDict 12import multiprocessing 13from multiprocessing.pool import ThreadPool 14 15SUPPORTED_ALGORITHMS = { 16 "AES-128/CFB": "aes-128-cfb", 17 "AES-192/CFB": "aes-192-cfb", 18 "AES-256/CFB": "aes-256-cfb", 19 "AES-128/GCM": "aes-128-gcm", 20 "AES-192/GCM": "aes-192-gcm", 21 "AES-256/GCM": "aes-256-gcm", 22 "AES-128/OCB": "aes-128-ocb", 23 "AES-128/XTS": "aes-128-xts", 24 "AES-256/XTS": "aes-256-xts", 25 "ChaCha20Poly1305": "chacha20poly1305", 26} 27 28class VecDocument: 29 def __init__(self, filepath): 30 self.data = OrderedDict() 31 last_testcase_number = 1 32 current_testcase_number = 1 33 current_group_name = "" 34 last_group_name = "" 35 current_testcase = {} 36 37 PATTERN_GROUPHEADER = "^\[(.+)\]$" 38 PATTERN_KEYVALUE = "^\s*([a-zA-Z]+)\s*=(.*)$" 39 40 with open(filepath, 'r') as f: 41 # Append one empty line to simplify parsing 42 lines = f.read().splitlines() + ["\n"] 43 44 for line in lines: 45 line = line.strip() 46 if line.startswith("#"): 47 pass # Skip 48 elif line == "": 49 current_testcase_number += 1 50 elif re.match(PATTERN_GROUPHEADER, line): 51 match = re.match(PATTERN_GROUPHEADER, line) 52 current_group_name = match.group(1) 53 elif re.match(PATTERN_KEYVALUE, line): 54 match = re.match(PATTERN_KEYVALUE, line) 55 key = match.group(1) 56 value = match.group(2).strip() 57 current_testcase[key] = value 58 59 if current_testcase_number != last_testcase_number: 60 if not current_group_name in self.data: 61 self.data[current_group_name] = [] 62 if len(current_testcase) != 0: 63 self.data[current_group_name].append(current_testcase) 64 current_testcase = {} 65 last_testcase_number = current_testcase_number 66 67 if current_group_name != last_group_name: 68 last_group_name = current_group_name 69 # Reset testcase number 70 last_testcase_number = 1 71 current_testcase_number = 1 72 73 def get_data(self): 74 return self.data 75 76TESTS_RUN = 0 77TESTS_FAILED = 0 78 79class TestLogHandler(logging.StreamHandler, object): 80 def emit(self, record): 81 # Do the default stuff first 82 super(TestLogHandler, self).emit(record) 83 if record.levelno >= logging.ERROR: 84 global TESTS_FAILED 85 TESTS_FAILED += 1 86 87def setup_logging(options): 88 if options.verbose: 89 log_level = logging.DEBUG 90 elif options.quiet: 91 log_level = logging.WARNING 92 else: 93 log_level = logging.INFO 94 95 lh = TestLogHandler(sys.stdout) 96 lh.setFormatter(logging.Formatter('%(levelname) 7s: %(message)s')) 97 logging.getLogger().addHandler(lh) 98 logging.getLogger().setLevel(log_level) 99 100def test_cipher_kat(cli_binary, data): 101 iv = data['Nonce'] 102 key = data['Key'] 103 ad = data['AD'] if 'AD' in data else "" 104 plaintext = data['In'].lower() 105 ciphertext = data['Out'].lower() 106 algorithm = data['Algorithm'] 107 direction = data['Direction'] 108 109 mode = SUPPORTED_ALGORITHMS.get(algorithm) 110 if mode is None: 111 raise Exception("Unknown algorithm: '" + algorithm + "'") 112 113 cmd = [ 114 cli_binary, 115 "encryption", 116 "--mode=%s" % mode, 117 "--iv=%s" % iv, 118 "--ad=%s" % ad, 119 "--key=%s" % key] 120 if direction == "decrypt": 121 cmd += ['--decrypt'] 122 123 if direction == "decrypt": 124 invalue = ciphertext 125 else: 126 invalue = plaintext 127 128 p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.PIPE) 129 out_raw = p.communicate(input=binascii.unhexlify(invalue))[0] 130 output = binascii.hexlify(out_raw).decode("UTF-8").lower() 131 132 expected = plaintext if direction == "decrypt" else ciphertext 133 if expected != output: 134 logging.error("For test %s got %s expected %s" % (data['testname'], output, expected)) 135 136def get_testdata(document, max_tests): 137 out = [] 138 for algorithm in document: 139 if algorithm in SUPPORTED_ALGORITHMS: 140 testcase_number = 0 141 for testcase in document[algorithm]: 142 testcase_number += 1 143 for direction in ['encrypt', 'decrypt']: 144 testname = "{} no {:0>3} ({})".format( 145 algorithm.lower(), testcase_number, direction) 146 testname = re.sub("[^a-z0-9-]", "_", testname) 147 testname = re.sub("_+", "_", testname) 148 testname = testname.strip("_") 149 test = {'testname': testname} 150 for key in testcase: 151 value = testcase[key] 152 test[key] = value 153 test['Algorithm'] = algorithm 154 test['Direction'] = direction 155 156 out.append(test) 157 158 if max_tests > 0 and testcase_number > max_tests: 159 break 160 return out 161 162def main(args=None): 163 if args is None: 164 args = sys.argv 165 166 parser = argparse.ArgumentParser(description="") 167 parser.add_argument('cli_binary', help='path to the botan cli binary') 168 parser.add_argument('--max-tests', type=int, default=50, metavar="M") 169 parser.add_argument('--threads', type=int, default=0, metavar="T") 170 parser.add_argument('--verbose', action='store_true', default=False) 171 parser.add_argument('--quiet', action='store_true', default=False) 172 args = parser.parse_args() 173 174 setup_logging(args) 175 176 cli_binary = args.cli_binary 177 max_tests = args.max_tests 178 threads = args.threads 179 180 if threads == 0: 181 threads = multiprocessing.cpu_count() 182 183 test_data_dir = os.path.join('src', 'tests', 'data') 184 185 mode_test_data = [os.path.join(test_data_dir, 'modes', 'cfb.vec'), 186 os.path.join(test_data_dir, 'aead', 'gcm.vec'), 187 os.path.join(test_data_dir, 'aead', 'ocb.vec'), 188 os.path.join(test_data_dir, 'modes', 'xts.vec'), 189 os.path.join(test_data_dir, 'aead', 'chacha20poly1305.vec')] 190 191 kats = [] 192 for f in mode_test_data: 193 vecfile = VecDocument(f) 194 kats += get_testdata(vecfile.get_data(), max_tests) 195 196 start_time = time.time() 197 198 if threads > 1: 199 pool = ThreadPool(processes=threads) 200 results = [] 201 for test in kats: 202 results.append(pool.apply_async(test_cipher_kat, (cli_binary, test))) 203 204 for result in results: 205 result.get() 206 else: 207 for test in kats: 208 test_cipher_kat(test) 209 210 end_time = time.time() 211 212 print("Ran %d tests with %d failures in %.02f seconds" % ( 213 len(kats), TESTS_FAILED, end_time - start_time)) 214 215 if TESTS_FAILED > 0: 216 return 1 217 return 0 218 219if __name__ == '__main__': 220 sys.exit(main()) 221