1import email.utils
2import json
3import os
4import tempfile
5import unittest
6
7from mox3 import mox
8import six
9
10import duo_openvpn
11
12def mock_client_factory(mock):
13    """
14    Return a Client-alike that uses a mock instead of an HTTP
15    connection. Special case: Client.__init__() and set_proxy()
16    arguments are verified by calling mock.duo_client_init() and
17    mock.duo_client_set_proxy(), respectively.
18    """
19    class MockClient(duo_openvpn.Client):
20        def __init__(self, *args, **kwargs):
21            mock.duo_client_init(*args, **kwargs)
22            super(MockClient, self).__init__(*args, **kwargs)
23
24        def set_proxy(self, *args, **kwargs):
25            mock.duo_client_set_proxy(*args, **kwargs)
26            return super(MockClient, self).set_proxy(*args, **kwargs)
27
28        def _connect(self):
29            return mock
30
31    return MockClient
32
33class MockResponse(six.StringIO, object):
34    def __init__(self, status, body, reason='some reason'):
35        self.status = status
36        self.reason = reason
37        super(MockResponse, self).__init__(body)
38
39class TestIntegration(unittest.TestCase):
40    IKEY = 'expected ikey'
41    SKEY = 'expected skey'
42    HOST = 'expected hostname'
43    USERNAME = 'expected username'
44    PASSCODE = 'expected passcode'
45    IPADDR = 'expected_ipaddr'
46    PROXY_HOST = 'expected proxy host'
47    PROXY_PORT = 'expected proxy port'
48    EXPECTED_USER_AGENT = 'duo_openvpn/' + duo_openvpn.__version__
49    EXPECTED_PREAUTH_PARAMS = (
50        'ipaddr=expected_ipaddr'
51        '&user=expected+username'
52    )
53    EXPECTED_AUTH_PATH = '/rest/v1/auth'
54    EXPECTED_PREAUTH_PATH = '/rest/v1/preauth'
55    EXPECTED_AUTH_PARAMS = (
56        'auto=expected+passcode'
57        '&ipaddr=expected_ipaddr'
58        '&user=expected+username'
59        '&factor=auto'
60    )
61
62    def setUp(self):
63        self.mox = mox.Mox()
64        self.expected_calls = self.mox.CreateMockAnything()
65
66    def assert_auth(self, environ, expected_control,
67                    send_control=True):
68        self.mox.ReplayAll()
69
70        with tempfile.NamedTemporaryFile() as control:
71            if send_control:
72                environ['control'] = control.name
73
74            with self.assertRaises(SystemExit) as cm:
75                duo_openvpn.main(
76                    environ=environ,
77                    Client=mock_client_factory(self.expected_calls),
78                )
79            self.mox.VerifyAll()
80
81            control.seek(0, os.SEEK_SET)
82            output = control.read()
83            if not isinstance(output, six.text_type):
84                output = output.decode('utf-8')
85            self.assertEqual(expected_control, output)
86            if expected_control == '1':
87                self.assertEqual(0, cm.exception.args[0])
88            else:
89                self.assertEqual(1, cm.exception.args[0])
90
91    def normal_environ(self):
92        environ = {
93            'ikey': self.IKEY,
94            'skey': self.SKEY,
95            'host': self.HOST,
96            'username': self.USERNAME,
97            'password': self.PASSCODE,
98            'ipaddr': self.IPADDR,
99        }
100        self.expected_calls.duo_client_init(
101            ikey=self.IKEY,
102            skey=self.SKEY,
103            host=self.HOST,
104            user_agent=('duo_openvpn/' + duo_openvpn.__version__),
105        )
106        self.expected_calls.duo_client_set_proxy(
107            host=None,
108            proxy_type=None,
109        )
110        return environ
111
112    def compare_params(self, recv_params, sent_params):
113        stanzas = sent_params.split('&')
114        return len(recv_params.split('&')) == len(stanzas) and all([s in recv_params for s in stanzas])
115
116    def expect_request(self, method, path, params, params_func=None, response=None, raises=None):
117        if params_func == None:
118            params_func = lambda p: self.compare_params(p, self.EXPECTED_PREAUTH_PARAMS)
119        self.expected_calls.request(method, path, mox.Func(params_func), {
120                'User-Agent': self.EXPECTED_USER_AGENT,
121                'Host': self.HOST,
122                'Content-type': 'application/x-www-form-urlencoded',
123                'Authorization': mox.Func((lambda s: s.startswith('Basic ') and not s.startswith('Basic b\''))),
124                'Date': mox.Func((lambda s: bool(email.utils.parsedate_tz(s))))
125            },
126        )
127        meth = self.expected_calls.getresponse()
128        if raises is not None:
129            meth.AndRaise(raises)
130        else:
131            meth.AndReturn(response)
132            self.expected_calls.close()
133
134    def expect_preauth(self, result, path=EXPECTED_PREAUTH_PATH, factor='push1'):
135        self.expect_request(
136            method='POST',
137            path=path,
138            params=self.EXPECTED_PREAUTH_PARAMS,
139            response=MockResponse(
140                status=200,
141                body=json.dumps({
142                        'stat': 'OK',
143                        'response': {
144                            'result': result,
145                            'status': 'expected status',
146                            'factors': {'default': factor},
147                        },
148                }),
149            ),
150        )
151
152    def expect_auth(self, result, path=EXPECTED_AUTH_PATH):
153        self.expect_request(
154            method='POST',
155            path=path,
156            params=self.EXPECTED_AUTH_PARAMS,
157            params_func = lambda p: self.compare_params(p, self.EXPECTED_AUTH_PARAMS),
158            response=MockResponse(
159                status=200,
160                body=json.dumps({
161                        'stat': 'OK',
162                        'response': {
163                            'result': result,
164                            'status': 'expected status',
165                        },
166                }),
167            ),
168        )
169
170    def test_preauth_allow(self):
171        environ = self.normal_environ()
172        self.expect_preauth('allow')
173        self.assert_auth(
174            environ=environ,
175            expected_control='1',
176        )
177
178    def test_preauth_deny(self):
179        environ = self.normal_environ()
180        self.expect_preauth('deny')
181        self.assert_auth(
182            environ=environ,
183            expected_control='0',
184        )
185
186    def test_preauth_enroll(self):
187        environ = self.normal_environ()
188        self.expect_preauth('enroll')
189        self.assert_auth(
190            environ=environ,
191            expected_control='0',
192        )
193
194    def test_preauth_bogus(self):
195        environ = self.normal_environ()
196        self.expect_preauth('bogus')
197        self.assert_auth(
198            environ=environ,
199            expected_control='0',
200        )
201
202    def test_preauth_missing_result(self):
203        environ = self.normal_environ()
204        self.expect_request(
205            method='POST',
206            path=self.EXPECTED_PREAUTH_PATH,
207            params=self.EXPECTED_PREAUTH_PARAMS,
208            response=MockResponse(
209                status=200,
210                body=json.dumps({
211                        'stat': 'OK',
212                        'response': {
213                            'status': 'expected status',
214                        },
215                }),
216            ),
217        )
218        self.assert_auth(
219            environ=environ,
220            expected_control='0',
221        )
222
223    def test_preauth_missing_status(self):
224        environ = self.normal_environ()
225        self.expect_request(
226            method='POST',
227            path=self.EXPECTED_PREAUTH_PATH,
228            params=self.EXPECTED_PREAUTH_PARAMS,
229            response=MockResponse(
230                status=200,
231                body=json.dumps({
232                        'stat': 'OK',
233                        'response': {
234                            'result': 'deny',
235                        },
236                }),
237            ),
238        )
239        self.assert_auth(
240            environ=environ,
241            expected_control='0',
242        )
243
244    def test_preauth_exception(self):
245        environ = self.normal_environ()
246        self.expect_request(
247            method='POST',
248            path=self.EXPECTED_PREAUTH_PATH,
249            params=self.EXPECTED_PREAUTH_PARAMS,
250            raises=Exception('whoops'),
251        )
252        self.assert_auth(
253            environ=environ,
254            expected_control='0',
255        )
256
257    def test_auth_allow(self):
258        environ = self.normal_environ()
259        self.expect_preauth('auth')
260        self.expect_auth('allow')
261        self.assert_auth(
262            environ=environ,
263            expected_control='1',
264        )
265
266    def test_auth_deny(self):
267        environ = self.normal_environ()
268        self.expect_preauth('auth')
269        self.expect_auth('deny')
270        self.assert_auth(
271            environ=environ,
272            expected_control='0',
273        )
274
275    def test_auth_bogus(self):
276        environ = self.normal_environ()
277        self.expect_preauth('auth')
278        self.expect_auth('bogus')
279        self.assert_auth(
280            environ=environ,
281            expected_control='0',
282        )
283
284    def test_auth_missing_reason(self):
285        environ = self.normal_environ()
286        self.expect_preauth('auth')
287        self.expect_request(
288            method='POST',
289            path=self.EXPECTED_AUTH_PATH,
290            params=self.EXPECTED_AUTH_PARAMS,
291            params_func = lambda p: self.compare_params(p, self.EXPECTED_AUTH_PARAMS),
292            response=MockResponse(
293                status=200,
294                body=json.dumps({
295                        'stat': 'OK',
296                        'response': {
297                            'status': 'expected status',
298                        },
299                }),
300            ),
301        )
302        self.assert_auth(
303            environ=environ,
304            expected_control='0',
305        )
306
307    def test_auth_missing_status(self):
308        environ = self.normal_environ()
309        self.expect_preauth('auth')
310        self.expect_request(
311            method='POST',
312            path=self.EXPECTED_AUTH_PATH,
313            params=self.EXPECTED_AUTH_PARAMS,
314            params_func = lambda p: self.compare_params(p, self.EXPECTED_AUTH_PARAMS),
315            response=MockResponse(
316                status=200,
317                body=json.dumps({
318                        'stat': 'OK',
319                        'response': {
320                            'result': 'allow',
321                        },
322                }),
323            ),
324        )
325        self.assert_auth(
326            environ=environ,
327            expected_control='0',
328        )
329
330    def test_auth_exception(self):
331        environ = self.normal_environ()
332        self.expect_preauth('auth')
333        self.expect_request(
334            method='POST',
335            path=self.EXPECTED_AUTH_PATH,
336            params=self.EXPECTED_AUTH_PARAMS,
337            params_func = lambda p: self.compare_params(p, self.EXPECTED_AUTH_PARAMS),
338            raises=Exception('whoops'),
339        )
340        self.assert_auth(
341            environ=environ,
342            expected_control='0',
343        )
344
345    def test_auth_no_ipaddr(self):
346        preauth_noip_params='ipaddr=0.0.0.0' \
347            '&user=expected+username'
348        environ = self.normal_environ()
349        environ.pop('ipaddr')
350        self.expect_request(
351            method='POST',
352            path=self.EXPECTED_PREAUTH_PATH,
353            params=preauth_noip_params,
354            params_func = lambda p: self.compare_params(p, preauth_noip_params),
355            response=MockResponse(
356                status=200,
357                body=json.dumps({
358                        'stat': 'OK',
359                        'response': {
360                            'result': 'auth',
361                            'status': 'expected status',
362                            'factors': {'default': 'push1'},
363                        },
364                }),
365            ),
366        )
367        auth_noip_params='auto=expected+passcode' \
368            '&ipaddr=0.0.0.0' \
369            '&user=expected+username' \
370            '&factor=auto'
371        self.expect_request(
372            method='POST',
373            path=self.EXPECTED_AUTH_PATH,
374            params=auth_noip_params,
375            params_func = lambda p: self.compare_params(p, auth_noip_params),
376            response=MockResponse(
377                status=200,
378                body=json.dumps({
379                        'stat': 'OK',
380                        'response': {
381                            'result': 'allow',
382                            'status': 'expected status',
383                        },
384                }),
385            ),
386        )
387        self.assert_auth(
388            environ=environ,
389            expected_control='1',
390        )
391
392    def test_missing_control(self):
393        environ = {
394            'ikey': self.IKEY,
395            'skey': self.SKEY,
396            'host': self.HOST,
397            'password': self.PASSCODE,
398            'username': self.USERNAME,
399            'ipaddr': self.IPADDR,
400        }
401        self.assert_auth(
402            environ=environ,
403            send_control=False,
404            expected_control='',
405        )
406
407    def test_missing_username(self):
408        environ = {
409            'ikey': self.IKEY,
410            'skey': self.SKEY,
411            'host': self.HOST,
412            'password': self.PASSCODE,
413            'ipaddr': self.IPADDR,
414        }
415        self.assert_auth(
416            environ=environ,
417            expected_control='',
418        )
419
420    def test_missing_password(self):
421        environ = self.normal_environ()
422        del environ['password']
423        self.expect_preauth('auth', factor=None)
424        self.assert_auth(
425            environ=environ,
426            expected_control='0',
427        )
428
429    def test_missing_ikey(self):
430        environ = {
431            'skey': self.SKEY,
432            'host': self.HOST,
433            'password': self.PASSCODE,
434            'username': self.USERNAME,
435            'ipaddr': self.IPADDR,
436        }
437        self.assert_auth(
438            environ=environ,
439            expected_control='0',
440        )
441
442    def test_missing_skey(self):
443        environ = {
444            'ikey': self.IKEY,
445            'host': self.HOST,
446            'password': self.PASSCODE,
447            'username': self.USERNAME,
448            'ipaddr': self.IPADDR,
449        }
450        self.assert_auth(
451            environ=environ,
452            expected_control='0',
453        )
454
455    def test_missing_host(self):
456        environ = {
457            'ikey': self.IKEY,
458            'skey': self.SKEY,
459            'password': self.PASSCODE,
460            'username': self.USERNAME,
461            'ipaddr': self.IPADDR,
462        }
463        self.assert_auth(
464            environ=environ,
465            expected_control='0',
466        )
467
468    def test_proxy_success(self):
469        environ = self.normal_environ()
470        environ['proxy_host'] = self.PROXY_HOST
471        environ['proxy_port'] = self.PROXY_PORT
472        self.expected_calls.duo_client_set_proxy(
473            host=self.PROXY_HOST,
474            port=self.PROXY_PORT,
475        )
476        self.expect_preauth(
477            result='auth',
478            path=('https://' + self.HOST + self.EXPECTED_PREAUTH_PATH),
479        )
480        self.expect_auth(
481            result='allow',
482            path=('https://' + self.HOST + self.EXPECTED_AUTH_PATH),
483        )
484        self.assert_auth(
485            environ=environ,
486            expected_control='1',
487        )
488
489    def test_proxy_missing_port(self):
490        environ = self.normal_environ()
491        environ['proxy_host'] = self.PROXY_HOST
492        self.assert_auth(
493            environ=environ,
494            expected_control='0',
495        )
496
497    def test_proxy_missing_host(self):
498        environ = self.normal_environ()
499        # proxy_port is ignored if proxy_host isn't present.
500        environ['proxy_port'] = self.PROXY_PORT
501        self.expect_preauth('auth')
502        self.expect_auth('allow')
503        self.assert_auth(
504            environ=environ,
505            expected_control='1',
506        )
507
508if __name__ == '__main__':
509    unittest.main()
510