1#!/usr/bin/env python2
2#
3# eapol_test controller
4# Copyright (c) 2015, Jouni Malinen <j@w1.fi>
5#
6# This software may be distributed under the terms of the BSD license.
7# See README for more details.
8
9import argparse
10import logging
11import os
12import Queue
13import sys
14import threading
15
16logger = logging.getLogger()
17dir = os.path.dirname(os.path.realpath(sys.modules[__name__].__file__))
18sys.path.append(os.path.join(dir, '..', 'wpaspy'))
19import wpaspy
20wpas_ctrl = '/tmp/eapol_test'
21
22class eapol_test:
23    def __init__(self, ifname):
24        self.ifname = ifname
25        self.ctrl = wpaspy.Ctrl(os.path.join(wpas_ctrl, ifname))
26        if "PONG" not in self.ctrl.request("PING"):
27            raise Exception("Failed to connect to eapol_test (%s)" % ifname)
28        self.mon = wpaspy.Ctrl(os.path.join(wpas_ctrl, ifname))
29        self.mon.attach()
30
31    def add_network(self):
32        id = self.request("ADD_NETWORK")
33        if "FAIL" in id:
34            raise Exception("ADD_NETWORK failed")
35        return int(id)
36
37    def remove_network(self, id):
38        id = self.request("REMOVE_NETWORK " + str(id))
39        if "FAIL" in id:
40            raise Exception("REMOVE_NETWORK failed")
41        return None
42
43    def set_network(self, id, field, value):
44        res = self.request("SET_NETWORK " + str(id) + " " + field + " " + value)
45        if "FAIL" in res:
46            raise Exception("SET_NETWORK failed")
47        return None
48
49    def set_network_quoted(self, id, field, value):
50        res = self.request("SET_NETWORK " + str(id) + " " + field + ' "' + value + '"')
51        if "FAIL" in res:
52            raise Exception("SET_NETWORK failed")
53        return None
54
55    def request(self, cmd, timeout=10):
56        return self.ctrl.request(cmd, timeout=timeout)
57
58    def wait_event(self, events, timeout=10):
59        start = os.times()[4]
60        while True:
61            while self.mon.pending():
62                ev = self.mon.recv()
63                logger.debug(self.ifname + ": " + ev)
64                for event in events:
65                    if event in ev:
66                        return ev
67            now = os.times()[4]
68            remaining = start + timeout - now
69            if remaining <= 0:
70                break
71            if not self.mon.pending(timeout=remaining):
72                break
73        return None
74
75def run(ifname, count, no_fast_reauth, res):
76    et = eapol_test(ifname)
77
78    et.request("AP_SCAN 0")
79    if no_fast_reauth:
80        et.request("SET fast_reauth 0")
81    else:
82        et.request("SET fast_reauth 1")
83    id = et.add_network()
84    et.set_network(id, "key_mgmt", "IEEE8021X")
85    et.set_network(id, "eapol_flags", "0")
86    et.set_network(id, "eap", "TLS")
87    et.set_network_quoted(id, "identity", "user")
88    et.set_network_quoted(id, "ca_cert", 'ca.pem')
89    et.set_network_quoted(id, "client_cert", 'client.pem')
90    et.set_network_quoted(id, "private_key", 'client.key')
91    et.set_network_quoted(id, "private_key_passwd", 'whatever')
92    et.set_network(id, "disabled", "0")
93
94    fail = False
95    for i in range(count):
96        et.request("REASSOCIATE")
97        ev = et.wait_event(["CTRL-EVENT-CONNECTED", "CTRL-EVENT-EAP-FAILURE"])
98        if ev is None or "CTRL-EVENT-CONNECTED" not in ev:
99            fail = True
100            break
101
102    et.remove_network(id)
103
104    if fail:
105        res.put("FAIL (%d OK)" % i)
106    else:
107        res.put("PASS %d" % (i + 1))
108
109def main():
110    parser = argparse.ArgumentParser(description='eapol_test controller')
111    parser.add_argument('--ctrl', help='control interface directory')
112    parser.add_argument('--num', help='number of processes')
113    parser.add_argument('--iter', help='number of iterations')
114    parser.add_argument('--no-fast-reauth', action='store_true',
115                        dest='no_fast_reauth',
116                        help='disable TLS session resumption')
117    args = parser.parse_args()
118
119    num = int(args.num)
120    iter = int(args.iter)
121    if args.ctrl:
122        global wpas_ctrl
123        wpas_ctrl = args.ctrl
124
125    t = {}
126    res = {}
127    for i in range(num):
128        res[i] = Queue.Queue()
129        t[i] = threading.Thread(target=run, args=(str(i), iter,
130                                                  args.no_fast_reauth, res[i]))
131    for i in range(num):
132        t[i].start()
133    for i in range(num):
134        t[i].join()
135        try:
136            results = res[i].get(False)
137        except:
138            results = "N/A"
139        print("%d: %s" % (i, results))
140
141if __name__ == "__main__":
142    main()
143