1from future import standard_library
2standard_library.install_aliases()
3from builtins import str
4import sys
5import unittest
6import re
7import os.path
8sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
9
10import time
11from functools import partial
12from configparser import RawConfigParser
13from Exscript import Account, PrivateKey
14from Exscript.emulators import VirtualDevice
15from Exscript.protocols.exception import TimeoutException, \
16    InvalidCommandException, ExpectCancelledException
17from Exscript.protocols import drivers
18from Exscript.protocols.protocol import Protocol
19
20
21class ProtocolTest(unittest.TestCase):
22
23    """
24    Since protocols.Protocol is abstract, this test is only a base class
25    for other protocols. It does not do anything fancy on its own.
26    """
27    CORRELATE = Protocol
28
29    def setUp(self):
30        self.hostname = '127.0.0.1'
31        self.port = 1236
32        self.user = 'user'
33        self.password = 'password'
34        self.account = Account(self.user, password=self.password)
35        self.daemon = None
36
37        self.createVirtualDevice()
38        self.createDaemon()
39        if self.daemon is not None:
40            self.daemon.start()
41            time.sleep(.2)
42        self.createProtocol()
43
44    def tearDown(self):
45        if self.protocol.__class__ != Protocol:
46            self.protocol.close(True)
47        if self.daemon is not None:
48            self.daemon.exit()
49            self.daemon.join()
50
51    def createVirtualDevice(self):
52        self.banner = 'Welcome to %s!\n' % self.hostname
53        self.prompt = self.hostname + '> '
54        self.device = VirtualDevice(self.hostname, echo=True)
55        ls_response = '-rw-r--r--  1 sab  nmc    1628 Aug 18 10:02 file'
56        self.device.add_command('ls',   ls_response)
57        self.device.add_command('df',   'foobar')
58        self.device.add_command('exit', '')
59        self.device.add_command('this-command-causes-an-error',
60                                '\ncommand not found')
61
62    def createDaemon(self):
63        pass
64
65    def createProtocol(self):
66        self.protocol = Protocol(timeout=1)
67
68    def disableDriver(self):
69        self.assertIn(drivers.driver_map, 'ios')
70        drivers.disable_driver('ios')
71        self.assertNotIn(drivers.driver_map, 'ios')
72
73    def doConnect(self):
74        self.protocol.connect(self.hostname, self.port)
75
76    def doLogin(self, flush=True):
77        self.doConnect()
78        self.protocol.login(self.account, flush=flush)
79
80    def doProtocolAuthenticate(self, flush=True):
81        self.doConnect()
82        self.protocol.protocol_authenticate(self.account)
83
84    def doAppAuthenticate(self, flush=True):
85        self.protocol.app_authenticate(self.account, flush)
86
87    def doAppAuthorize(self, flush=True):
88        self.protocol.app_authorize(self.account, flush)
89
90    def _trymatch(self, prompts, string):
91        for regex in prompts:
92            match = regex.search(string)
93            if match:
94                return match
95        return None
96
97    def testPrompts(self):
98        prompts = ('[sam123@home ~]$',
99                   '[MyHost-A1]',
100                   '<MyHost-A1>',
101                   'sam@knip:~/Code/exscript$',
102                   'sam@MyHost-X123>',
103                   'sam@MyHost-X123#',
104                   'MyHost-ABC-CDE123>',
105                   'MyHost-A1#',
106                   'S-ABC#',
107                   '0123456-1-1-abc#',
108                   '0123456-1-1-a>',
109                   'MyHost-A1(config)#',
110                   'MyHost-A1(config)>',
111                   'RP/0/RP0/CPU0:A-BC2#',
112                   'FA/0/1/2/3>',
113                   'FA/0/1/2/3(config)>',
114                   'FA/0/1/2/3(config)#',
115                   'ec-c3-c27s99(su)->',
116                   'foobar:0>',
117                   'admin@s-x-a6.a.bc.de.fg:/# ',
118                   'admin@s-x-a6.a.bc.de.fg:/% ')
119        notprompts = ('one two',
120                      ' [MyHost-A1]',
121                      '[edit]\r',
122                      '[edit]\n',
123                      '[edit foo]\r',
124                      '[edit foo]\n',
125                      '[edit foo]\r\n',
126                      '[edit one two]')
127        prompt_re = self.protocol.get_prompt()
128        for prompt in prompts:
129            if not self._trymatch(prompt_re, '\n' + prompt):
130                self.fail('Prompt %s does not match exactly.' % prompt)
131            if not self._trymatch(prompt_re, 'this is a test\r\n' + prompt):
132                self.fail('Prompt %s does not match.' % prompt)
133            if self._trymatch(prompt_re, 'some text ' + prompt):
134                self.fail('Prompt %s matches incorrectly.' % repr(prompt))
135        for prompt in notprompts:
136            if self._trymatch(prompt_re, prompt):
137                self.fail('Prompt %s matches incorrecly.' % repr(prompt))
138            if self._trymatch(prompt_re, prompt + ' '):
139                self.fail('Prompt %s matches incorrecly.' % repr(prompt))
140            if self._trymatch(prompt_re, '\n' + prompt):
141                self.fail('Prompt %s matches incorrecly.' % repr(prompt))
142
143    def testConstructor(self):
144        self.assertIsInstance(self.protocol, Protocol)
145
146    def testCopy(self):
147        self.assertEqual(self.protocol, self.protocol.__copy__())
148
149    def testDeepcopy(self):
150        self.assertEqual(self.protocol, self.protocol.__deepcopy__({}))
151
152    def testIsDummy(self):
153        self.assertEqual(self.protocol.is_dummy(), False)
154
155    def testSetDriver(self):
156        self.assertTrue(self.protocol.get_driver() is not None)
157        self.assertEqual(self.protocol.get_driver().name, 'generic')
158
159        self.protocol.set_driver()
160        self.assertTrue(self.protocol.get_driver() is not None)
161        self.assertEqual(self.protocol.get_driver().name, 'generic')
162
163        self.protocol.set_driver('ios')
164        self.assertTrue(self.protocol.get_driver() is not None)
165        self.assertEqual(self.protocol.get_driver().name, 'ios')
166
167        self.protocol.set_driver()
168        self.assertTrue(self.protocol.get_driver() is not None)
169        self.assertEqual(self.protocol.get_driver().name, 'generic')
170
171    def testGetDriver(self):
172        pass  # Already tested in testSetDriver()
173
174    def testGetBanner(self):
175        self.assertEqual(self.protocol.get_banner(), None)
176        if self.protocol.__class__ == Protocol:
177            self.assertRaises(Exception, self.protocol.connect)
178            return
179        self.doConnect()
180        self.assertEqual(self.protocol.get_banner(), None)
181
182    def testGetRemoteVersion(self):
183        self.assertEqual(self.protocol.get_remote_version(), None)
184        if self.protocol.__class__ == Protocol:
185            self.assertRaises(Exception, self.protocol.connect)
186            return
187        self.doConnect()
188        self.assertEqual(self.protocol.get_remote_version(), None)
189
190    def testAutoinit(self):
191        self.protocol.autoinit()
192
193    def _test_prompt_setter(self, getter, setter):
194        initial_regex = getter()
195        self.assertIsInstance(initial_regex, list)
196        self.assertTrue(hasattr(initial_regex[0], 'groups'))
197
198        my_re = re.compile(r'% username')
199        setter(my_re)
200        regex = getter()
201        self.assertIsInstance(regex, list)
202        self.assertTrue(hasattr(regex[0], 'groups'))
203        self.assertEqual(regex[0], my_re)
204
205        setter()
206        regex = getter()
207        self.assertEqual(regex, initial_regex)
208
209    def testSetUsernamePrompt(self):
210        self._test_prompt_setter(self.protocol.get_username_prompt,
211                                 self.protocol.set_username_prompt)
212
213    def testGetUsernamePrompt(self):
214        pass  # Already tested in testSetUsernamePrompt()
215
216    def testSetPasswordPrompt(self):
217        self._test_prompt_setter(self.protocol.get_password_prompt,
218                                 self.protocol.set_password_prompt)
219
220    def testGetPasswordPrompt(self):
221        pass  # Already tested in testSetPasswordPrompt()
222
223    def testSetPrompt(self):
224        self._test_prompt_setter(self.protocol.get_prompt,
225                                 self.protocol.set_prompt)
226
227    def testGetPrompt(self):
228        pass  # Already tested in testSetPrompt()
229
230    def testSetErrorPrompt(self):
231        self._test_prompt_setter(self.protocol.get_error_prompt,
232                                 self.protocol.set_error_prompt)
233
234    def testGetErrorPrompt(self):
235        pass  # Already tested in testSetErrorPrompt()
236
237    def testSetLoginErrorPrompt(self):
238        self._test_prompt_setter(self.protocol.get_login_error_prompt,
239                                 self.protocol.set_login_error_prompt)
240
241    def testGetLoginErrorPrompt(self):
242        pass  # Already tested in testSetLoginErrorPrompt()
243
244    def testSetConnectTimeout(self):
245        self.assertEqual(self.protocol.get_connect_timeout(), 30)
246        self.protocol.set_connect_timeout(60)
247        self.assertEqual(self.protocol.get_connect_timeout(), 60)
248
249    def testGetConnectTimeout(self):
250        pass  # Already tested in testSetConnectTimeout()
251
252    def testSetTimeout(self):
253        self.assertEqual(self.protocol.get_timeout(), 1)
254        self.protocol.set_timeout(60)
255        self.assertEqual(self.protocol.get_timeout(), 60)
256
257    def testGetTimeout(self):
258        pass  # Already tested in testSetTimeout()
259
260    def testConnect(self):
261        # Test can not work on the abstract base.
262        if self.protocol.__class__ == Protocol:
263            self.assertRaises(Exception, self.protocol.connect)
264            return
265        self.assertEqual(self.protocol.response, None)
266        self.doConnect()
267        self.assertEqual(self.protocol.response, None)
268        self.assertEqual(self.protocol.get_host(), self.hostname)
269
270    def testLogin(self):
271        # Test can not work on the abstract base.
272        if self.protocol.__class__ == Protocol:
273            self.assertRaises(Exception,
274                              self.protocol.login,
275                              self.account)
276            return
277        # Password login.
278        self.doLogin(flush=False)
279        self.assertTrue(self.protocol.response is not None)
280        self.assertTrue(len(self.protocol.response) > 0)
281        self.assertTrue(self.protocol.is_protocol_authenticated())
282        self.assertTrue(self.protocol.is_app_authenticated())
283        self.assertTrue(self.protocol.is_app_authorized())
284
285        # Key login.
286        self.tearDown()
287        self.setUp()
288        key = PrivateKey.from_file('foo', keytype='rsa')
289        account = Account(self.user, self.password, key=key)
290        self.doConnect()
291        self.assertFalse(self.protocol.is_protocol_authenticated())
292        self.assertFalse(self.protocol.is_app_authenticated())
293        self.assertFalse(self.protocol.is_app_authorized())
294        self.protocol.login(account, flush=False)
295        self.assertTrue(self.protocol.is_protocol_authenticated())
296        self.assertTrue(self.protocol.is_app_authenticated())
297        self.assertTrue(self.protocol.is_app_authorized())
298
299    def testAuthenticate(self):
300        # Test can not work on the abstract base.
301        if self.protocol.__class__ == Protocol:
302            self.assertRaises(Exception,
303                              self.protocol.authenticate,
304                              self.account)
305            return
306        self.doConnect()
307
308        # Password login.
309        self.assertFalse(self.protocol.is_protocol_authenticated())
310        self.assertFalse(self.protocol.is_app_authenticated())
311        self.assertFalse(self.protocol.is_app_authorized())
312        self.protocol.authenticate(self.account, flush=False)
313        self.assertTrue(self.protocol.response is not None)
314        self.assertTrue(len(self.protocol.response) > 0)
315        self.assertTrue(self.protocol.is_protocol_authenticated())
316        self.assertTrue(self.protocol.is_app_authenticated())
317        self.assertFalse(self.protocol.is_app_authorized())
318
319        # Key login.
320        self.tearDown()
321        self.setUp()
322        key = PrivateKey.from_file('foo', keytype='rsa')
323        account = Account(self.user, self.password, key=key)
324        self.doConnect()
325        self.assertFalse(self.protocol.is_protocol_authenticated())
326        self.assertFalse(self.protocol.is_app_authenticated())
327        self.assertFalse(self.protocol.is_app_authorized())
328        self.protocol.authenticate(account, flush=False)
329        self.assertTrue(self.protocol.is_protocol_authenticated())
330        self.assertTrue(self.protocol.is_app_authenticated())
331        self.assertFalse(self.protocol.is_app_authorized())
332
333    def testProtocolAuthenticate(self):
334        # Test can not work on the abstract base.
335        if self.protocol.__class__ == Protocol:
336            self.protocol.protocol_authenticate(self.account)
337            return
338        # There is no guarantee that the device provided any response
339        # during protocol level authentification.
340        self.doProtocolAuthenticate(flush=False)
341        self.assertTrue(self.protocol.is_protocol_authenticated())
342        self.assertFalse(self.protocol.is_app_authenticated())
343        self.assertFalse(self.protocol.is_app_authorized())
344
345    def testIsProtocolAuthenticated(self):
346        pass  # See testProtocolAuthenticate()
347
348    def testAppAuthenticate(self):
349        # Test can not work on the abstract base.
350        if self.protocol.__class__ == Protocol:
351            self.assertRaises(Exception,
352                              self.protocol.app_authenticate,
353                              self.account)
354            return
355        self.testProtocolAuthenticate()
356        self.doAppAuthenticate(flush=False)
357        self.assertTrue(self.protocol.is_protocol_authenticated())
358        self.assertTrue(self.protocol.is_app_authenticated())
359        self.assertFalse(self.protocol.is_app_authorized())
360
361    def testIsAppAuthenticated(self):
362        pass  # See testAppAuthenticate()
363
364    def testAppAuthorize(self):
365        # Test can not work on the abstract base.
366        if self.protocol.__class__ == Protocol:
367            self.assertRaises(Exception, self.protocol.app_authorize)
368            return
369        self.doProtocolAuthenticate(flush=False)
370        self.doAppAuthenticate(flush=False)
371        response = self.protocol.response
372
373        # Authorize should see that a prompt is still in the buffer,
374        # and do nothing.
375        self.doAppAuthorize(flush=False)
376        self.assertEqual(self.protocol.response, response)
377        self.assertTrue(self.protocol.is_protocol_authenticated())
378        self.assertTrue(self.protocol.is_app_authenticated())
379        self.assertTrue(self.protocol.is_app_authorized())
380
381        self.doAppAuthorize(flush=True)
382        self.assertEqual(self.protocol.response, response)
383        self.assertTrue(self.protocol.is_protocol_authenticated())
384        self.assertTrue(self.protocol.is_app_authenticated())
385        self.assertTrue(self.protocol.is_app_authorized())
386
387    def testAppAuthorize2(self):
388        # Same test as above, but using flush=True all the way.
389        # Test can not work on the abstract base.
390        if self.protocol.__class__ == Protocol:
391            self.assertRaises(Exception, self.protocol.app_authorize)
392            return
393        self.doProtocolAuthenticate(flush=True)
394        self.doAppAuthenticate(flush=True)
395        response = self.protocol.response
396
397        # At this point app_authorize should fail because the buffer is
398        # empty due to flush=True above. In other words, app_authorize
399        # will wait for a prompt until a timeout happens.
400        self.assertRaises(TimeoutException, self.doAppAuthorize)
401        self.assertTrue(self.protocol.is_protocol_authenticated())
402        self.assertTrue(self.protocol.is_app_authenticated())
403        self.assertFalse(self.protocol.is_app_authorized())
404
405    def testAutoAppAuthorize(self):
406        # Test can not work on the abstract base.
407        if self.protocol.__class__ == Protocol:
408            self.assertRaises(TypeError, self.protocol.auto_app_authorize)
409            return
410
411        self.testAppAuthenticate()
412        response = self.protocol.response
413
414        # This should do nothing, because our test host does not
415        # support AAA. Can't think of a way to test against a
416        # device using AAA.
417        self.protocol.auto_app_authorize(self.account, flush=False)
418        self.assertEqual(self.protocol.response, response)
419        self.assertTrue(self.protocol.is_protocol_authenticated())
420        self.assertTrue(self.protocol.is_app_authenticated())
421        self.assertTrue(self.protocol.is_app_authorized())
422
423        self.protocol.auto_app_authorize(self.account, flush=True)
424        self.assertEqual(self.protocol.response, response)
425        self.assertTrue(self.protocol.is_protocol_authenticated())
426        self.assertTrue(self.protocol.is_app_authenticated())
427        self.assertTrue(self.protocol.is_app_authorized())
428
429    def testIsAppAuthorized(self):
430        pass  # see testAppAuthorize()
431
432    def testSend(self):
433        # Test can not work on the abstract base.
434        if self.protocol.__class__ == Protocol:
435            self.assertRaises(Exception, self.protocol.send, 'ls')
436            return
437        self.doLogin()
438        self.protocol.execute('ls')
439
440        self.protocol.send('df\r')
441        self.assertTrue(self.protocol.response is not None)
442        self.assertTrue(self.protocol.response.startswith('ls'))
443
444        self.protocol.send('exit\r')
445        self.assertTrue(self.protocol.response is not None)
446        self.assertTrue(self.protocol.response.startswith('ls'))
447
448    def testExecute(self):
449        # Test can not work on the abstract base.
450        if self.protocol.__class__ == Protocol:
451            self.assertRaises(Exception, self.protocol.execute, 'ls')
452            return
453        self.doLogin()
454        self.protocol.execute('ls')
455        self.assertTrue(self.protocol.response is not None)
456        self.assertTrue(self.protocol.response.startswith('ls'))
457
458        # Make sure that we raise an error if the device responds
459        # with something that matches any of the error prompts.
460        self.protocol.set_error_prompt('.')
461        self.assertRaises(InvalidCommandException,
462                          self.protocol.execute,
463                          'this-command-causes-an-error')
464
465    def testWaitfor(self):
466        # Test can not work on the abstract base.
467        if self.protocol.__class__ == Protocol:
468            self.assertRaises(Exception, self.protocol.waitfor, 'ls')
469            return
470        self.doLogin()
471        oldresponse = self.protocol.response
472        self.protocol.send('ls\r')
473        self.assertEqual(oldresponse, self.protocol.response)
474        self.protocol.waitfor(re.compile(r'[\r\n]'))
475        self.assertNotEqual(oldresponse, self.protocol.response)
476        oldresponse = self.protocol.response
477        self.protocol.waitfor(re.compile(r'[\r\n]'))
478        self.assertEqual(oldresponse, self.protocol.response)
479
480    def testExpect(self):
481        # Test can not work on the abstract base.
482        if self.protocol.__class__ == Protocol:
483            self.assertRaises(Exception, self.protocol.expect, 'ls')
484            return
485        self.doLogin()
486        oldresponse = self.protocol.response
487        self.protocol.send('ls\r')
488        self.assertEqual(oldresponse, self.protocol.response)
489        self.protocol.expect(re.compile(r'[\r\n]'))
490        self.assertNotEqual(oldresponse, self.protocol.response)
491
492    def testExpectPrompt(self):
493        # Test can not work on the abstract base.
494        if self.protocol.__class__ == Protocol:
495            self.assertRaises(Exception, self.protocol.expect, 'ls')
496            return
497        self.doLogin()
498        oldresponse = self.protocol.response
499        self.protocol.send('ls\r')
500        self.assertEqual(oldresponse, self.protocol.response)
501        self.protocol.expect_prompt()
502        self.assertNotEqual(oldresponse, self.protocol.response)
503
504    def testAddMonitor(self):
505        # Set the monitor callback up.
506        def monitor_cb(thedata, *args, **kwargs):
507            thedata['args'] = args
508            thedata['kwargs'] = kwargs
509        data = {}
510        self.protocol.add_monitor('abc', partial(monitor_cb, data))
511
512        # Simulate some non-matching data.
513        self.protocol.buffer.append(u'aaa')
514        self.assertEqual(data, {})
515
516        # Simulate some matching data.
517        self.protocol.buffer.append(u'abc')
518        self.assertEqual(len(data.get('args')), 3)
519        self.assertEqual(data.get('args')[0], self.protocol)
520        self.assertEqual(data.get('args')[1], 0)
521        self.assertEqual(data.get('args')[2].group(0), 'abc')
522        self.assertEqual(data.get('kwargs'), {})
523
524    def testGetBuffer(self):
525        # Test can not work on the abstract base.
526        if self.protocol.__class__ == Protocol:
527            return
528        self.assertEqual(str(self.protocol.buffer), '')
529        self.doLogin()
530        # Depending on whether the connected host sends a banner,
531        # the buffer may or may not contain anything now.
532
533        before = str(self.protocol.buffer)
534        self.protocol.send('ls\r')
535        self.protocol.waitfor(self.protocol.get_prompt())
536        self.assertNotEqual(str(self.protocol.buffer), before)
537
538    def _cancel_cb(self, data):
539        self.protocol.cancel_expect()
540
541    def testCancelExpect(self):
542        # Test can not work on the abstract base.
543        if self.protocol.__class__ == Protocol:
544            return
545        self.doLogin()
546        oldresponse = self.protocol.response
547        self.protocol.data_received_event.connect(self._cancel_cb)
548        self.protocol.send('ls\r')
549        self.assertEqual(oldresponse, self.protocol.response)
550        self.assertRaises(ExpectCancelledException,
551                          self.protocol.expect,
552                          'notgoingtohappen')
553
554    def testInteract(self):
555        # Test can not work on the abstract base.
556        if self.protocol.__class__ == Protocol:
557            self.assertRaises(Exception, self.protocol.interact)
558            return
559        # Can't really be tested.
560
561    def testClose(self):
562        if self.protocol.__class__ != Protocol:
563            self.doConnect()
564        self.protocol.close(True)
565
566    def testGetHost(self):
567        self.assertTrue(self.protocol.get_host() is None)
568        if self.protocol.__class__ == Protocol:
569            return
570        self.doConnect()
571        self.assertEqual(self.protocol.get_host(), self.hostname)
572
573    def testGuessOs(self):
574        self.assertEqual('unknown', self.protocol.guess_os())
575        # Other tests can not work on the abstract base.
576        if self.protocol.__class__ == Protocol:
577            return
578        self.doConnect()
579        self.assertEqual('unknown', self.protocol.guess_os())
580        self.protocol.login(self.account)
581        self.assertTrue(self.protocol.is_protocol_authenticated())
582        self.assertTrue(self.protocol.is_app_authenticated())
583        self.assertTrue(self.protocol.is_app_authorized())
584        self.assertEqual('shell', self.protocol.guess_os())
585
586
587def suite():
588    return unittest.TestLoader().loadTestsFromTestCase(ProtocolTest)
589if __name__ == '__main__':
590    unittest.TextTestRunner(verbosity=2).run(suite())
591