1from future import standard_library 2standard_library.install_aliases() 3from builtins import str 4import sys 5import unittest 6import re 7import os.path 8sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..')) 9 10import time 11from functools import partial 12from configparser import RawConfigParser 13from Exscript import Account, PrivateKey 14from Exscript.emulators import VirtualDevice 15from Exscript.protocols.exception import TimeoutException, \ 16 InvalidCommandException, ExpectCancelledException 17from Exscript.protocols import drivers 18from Exscript.protocols.protocol import Protocol 19 20 21class ProtocolTest(unittest.TestCase): 22 23 """ 24 Since protocols.Protocol is abstract, this test is only a base class 25 for other protocols. It does not do anything fancy on its own. 26 """ 27 CORRELATE = Protocol 28 29 def setUp(self): 30 self.hostname = '127.0.0.1' 31 self.port = 1236 32 self.user = 'user' 33 self.password = 'password' 34 self.account = Account(self.user, password=self.password) 35 self.daemon = None 36 37 self.createVirtualDevice() 38 self.createDaemon() 39 if self.daemon is not None: 40 self.daemon.start() 41 time.sleep(.2) 42 self.createProtocol() 43 44 def tearDown(self): 45 if self.protocol.__class__ != Protocol: 46 self.protocol.close(True) 47 if self.daemon is not None: 48 self.daemon.exit() 49 self.daemon.join() 50 51 def createVirtualDevice(self): 52 self.banner = 'Welcome to %s!\n' % self.hostname 53 self.prompt = self.hostname + '> ' 54 self.device = VirtualDevice(self.hostname, echo=True) 55 ls_response = '-rw-r--r-- 1 sab nmc 1628 Aug 18 10:02 file' 56 self.device.add_command('ls', ls_response) 57 self.device.add_command('df', 'foobar') 58 self.device.add_command('exit', '') 59 self.device.add_command('this-command-causes-an-error', 60 '\ncommand not found') 61 62 def createDaemon(self): 63 pass 64 65 def createProtocol(self): 66 self.protocol = Protocol(timeout=1) 67 68 def disableDriver(self): 69 self.assertIn(drivers.driver_map, 'ios') 70 drivers.disable_driver('ios') 71 self.assertNotIn(drivers.driver_map, 'ios') 72 73 def doConnect(self): 74 self.protocol.connect(self.hostname, self.port) 75 76 def doLogin(self, flush=True): 77 self.doConnect() 78 self.protocol.login(self.account, flush=flush) 79 80 def doProtocolAuthenticate(self, flush=True): 81 self.doConnect() 82 self.protocol.protocol_authenticate(self.account) 83 84 def doAppAuthenticate(self, flush=True): 85 self.protocol.app_authenticate(self.account, flush) 86 87 def doAppAuthorize(self, flush=True): 88 self.protocol.app_authorize(self.account, flush) 89 90 def _trymatch(self, prompts, string): 91 for regex in prompts: 92 match = regex.search(string) 93 if match: 94 return match 95 return None 96 97 def testPrompts(self): 98 prompts = ('[sam123@home ~]$', 99 '[MyHost-A1]', 100 '<MyHost-A1>', 101 'sam@knip:~/Code/exscript$', 102 'sam@MyHost-X123>', 103 'sam@MyHost-X123#', 104 'MyHost-ABC-CDE123>', 105 'MyHost-A1#', 106 'S-ABC#', 107 '0123456-1-1-abc#', 108 '0123456-1-1-a>', 109 'MyHost-A1(config)#', 110 'MyHost-A1(config)>', 111 'RP/0/RP0/CPU0:A-BC2#', 112 'FA/0/1/2/3>', 113 'FA/0/1/2/3(config)>', 114 'FA/0/1/2/3(config)#', 115 'ec-c3-c27s99(su)->', 116 'foobar:0>', 117 'admin@s-x-a6.a.bc.de.fg:/# ', 118 'admin@s-x-a6.a.bc.de.fg:/% ') 119 notprompts = ('one two', 120 ' [MyHost-A1]', 121 '[edit]\r', 122 '[edit]\n', 123 '[edit foo]\r', 124 '[edit foo]\n', 125 '[edit foo]\r\n', 126 '[edit one two]') 127 prompt_re = self.protocol.get_prompt() 128 for prompt in prompts: 129 if not self._trymatch(prompt_re, '\n' + prompt): 130 self.fail('Prompt %s does not match exactly.' % prompt) 131 if not self._trymatch(prompt_re, 'this is a test\r\n' + prompt): 132 self.fail('Prompt %s does not match.' % prompt) 133 if self._trymatch(prompt_re, 'some text ' + prompt): 134 self.fail('Prompt %s matches incorrectly.' % repr(prompt)) 135 for prompt in notprompts: 136 if self._trymatch(prompt_re, prompt): 137 self.fail('Prompt %s matches incorrecly.' % repr(prompt)) 138 if self._trymatch(prompt_re, prompt + ' '): 139 self.fail('Prompt %s matches incorrecly.' % repr(prompt)) 140 if self._trymatch(prompt_re, '\n' + prompt): 141 self.fail('Prompt %s matches incorrecly.' % repr(prompt)) 142 143 def testConstructor(self): 144 self.assertIsInstance(self.protocol, Protocol) 145 146 def testCopy(self): 147 self.assertEqual(self.protocol, self.protocol.__copy__()) 148 149 def testDeepcopy(self): 150 self.assertEqual(self.protocol, self.protocol.__deepcopy__({})) 151 152 def testIsDummy(self): 153 self.assertEqual(self.protocol.is_dummy(), False) 154 155 def testSetDriver(self): 156 self.assertTrue(self.protocol.get_driver() is not None) 157 self.assertEqual(self.protocol.get_driver().name, 'generic') 158 159 self.protocol.set_driver() 160 self.assertTrue(self.protocol.get_driver() is not None) 161 self.assertEqual(self.protocol.get_driver().name, 'generic') 162 163 self.protocol.set_driver('ios') 164 self.assertTrue(self.protocol.get_driver() is not None) 165 self.assertEqual(self.protocol.get_driver().name, 'ios') 166 167 self.protocol.set_driver() 168 self.assertTrue(self.protocol.get_driver() is not None) 169 self.assertEqual(self.protocol.get_driver().name, 'generic') 170 171 def testGetDriver(self): 172 pass # Already tested in testSetDriver() 173 174 def testGetBanner(self): 175 self.assertEqual(self.protocol.get_banner(), None) 176 if self.protocol.__class__ == Protocol: 177 self.assertRaises(Exception, self.protocol.connect) 178 return 179 self.doConnect() 180 self.assertEqual(self.protocol.get_banner(), None) 181 182 def testGetRemoteVersion(self): 183 self.assertEqual(self.protocol.get_remote_version(), None) 184 if self.protocol.__class__ == Protocol: 185 self.assertRaises(Exception, self.protocol.connect) 186 return 187 self.doConnect() 188 self.assertEqual(self.protocol.get_remote_version(), None) 189 190 def testAutoinit(self): 191 self.protocol.autoinit() 192 193 def _test_prompt_setter(self, getter, setter): 194 initial_regex = getter() 195 self.assertIsInstance(initial_regex, list) 196 self.assertTrue(hasattr(initial_regex[0], 'groups')) 197 198 my_re = re.compile(r'% username') 199 setter(my_re) 200 regex = getter() 201 self.assertIsInstance(regex, list) 202 self.assertTrue(hasattr(regex[0], 'groups')) 203 self.assertEqual(regex[0], my_re) 204 205 setter() 206 regex = getter() 207 self.assertEqual(regex, initial_regex) 208 209 def testSetUsernamePrompt(self): 210 self._test_prompt_setter(self.protocol.get_username_prompt, 211 self.protocol.set_username_prompt) 212 213 def testGetUsernamePrompt(self): 214 pass # Already tested in testSetUsernamePrompt() 215 216 def testSetPasswordPrompt(self): 217 self._test_prompt_setter(self.protocol.get_password_prompt, 218 self.protocol.set_password_prompt) 219 220 def testGetPasswordPrompt(self): 221 pass # Already tested in testSetPasswordPrompt() 222 223 def testSetPrompt(self): 224 self._test_prompt_setter(self.protocol.get_prompt, 225 self.protocol.set_prompt) 226 227 def testGetPrompt(self): 228 pass # Already tested in testSetPrompt() 229 230 def testSetErrorPrompt(self): 231 self._test_prompt_setter(self.protocol.get_error_prompt, 232 self.protocol.set_error_prompt) 233 234 def testGetErrorPrompt(self): 235 pass # Already tested in testSetErrorPrompt() 236 237 def testSetLoginErrorPrompt(self): 238 self._test_prompt_setter(self.protocol.get_login_error_prompt, 239 self.protocol.set_login_error_prompt) 240 241 def testGetLoginErrorPrompt(self): 242 pass # Already tested in testSetLoginErrorPrompt() 243 244 def testSetConnectTimeout(self): 245 self.assertEqual(self.protocol.get_connect_timeout(), 30) 246 self.protocol.set_connect_timeout(60) 247 self.assertEqual(self.protocol.get_connect_timeout(), 60) 248 249 def testGetConnectTimeout(self): 250 pass # Already tested in testSetConnectTimeout() 251 252 def testSetTimeout(self): 253 self.assertEqual(self.protocol.get_timeout(), 1) 254 self.protocol.set_timeout(60) 255 self.assertEqual(self.protocol.get_timeout(), 60) 256 257 def testGetTimeout(self): 258 pass # Already tested in testSetTimeout() 259 260 def testConnect(self): 261 # Test can not work on the abstract base. 262 if self.protocol.__class__ == Protocol: 263 self.assertRaises(Exception, self.protocol.connect) 264 return 265 self.assertEqual(self.protocol.response, None) 266 self.doConnect() 267 self.assertEqual(self.protocol.response, None) 268 self.assertEqual(self.protocol.get_host(), self.hostname) 269 270 def testLogin(self): 271 # Test can not work on the abstract base. 272 if self.protocol.__class__ == Protocol: 273 self.assertRaises(Exception, 274 self.protocol.login, 275 self.account) 276 return 277 # Password login. 278 self.doLogin(flush=False) 279 self.assertTrue(self.protocol.response is not None) 280 self.assertTrue(len(self.protocol.response) > 0) 281 self.assertTrue(self.protocol.is_protocol_authenticated()) 282 self.assertTrue(self.protocol.is_app_authenticated()) 283 self.assertTrue(self.protocol.is_app_authorized()) 284 285 # Key login. 286 self.tearDown() 287 self.setUp() 288 key = PrivateKey.from_file('foo', keytype='rsa') 289 account = Account(self.user, self.password, key=key) 290 self.doConnect() 291 self.assertFalse(self.protocol.is_protocol_authenticated()) 292 self.assertFalse(self.protocol.is_app_authenticated()) 293 self.assertFalse(self.protocol.is_app_authorized()) 294 self.protocol.login(account, flush=False) 295 self.assertTrue(self.protocol.is_protocol_authenticated()) 296 self.assertTrue(self.protocol.is_app_authenticated()) 297 self.assertTrue(self.protocol.is_app_authorized()) 298 299 def testAuthenticate(self): 300 # Test can not work on the abstract base. 301 if self.protocol.__class__ == Protocol: 302 self.assertRaises(Exception, 303 self.protocol.authenticate, 304 self.account) 305 return 306 self.doConnect() 307 308 # Password login. 309 self.assertFalse(self.protocol.is_protocol_authenticated()) 310 self.assertFalse(self.protocol.is_app_authenticated()) 311 self.assertFalse(self.protocol.is_app_authorized()) 312 self.protocol.authenticate(self.account, flush=False) 313 self.assertTrue(self.protocol.response is not None) 314 self.assertTrue(len(self.protocol.response) > 0) 315 self.assertTrue(self.protocol.is_protocol_authenticated()) 316 self.assertTrue(self.protocol.is_app_authenticated()) 317 self.assertFalse(self.protocol.is_app_authorized()) 318 319 # Key login. 320 self.tearDown() 321 self.setUp() 322 key = PrivateKey.from_file('foo', keytype='rsa') 323 account = Account(self.user, self.password, key=key) 324 self.doConnect() 325 self.assertFalse(self.protocol.is_protocol_authenticated()) 326 self.assertFalse(self.protocol.is_app_authenticated()) 327 self.assertFalse(self.protocol.is_app_authorized()) 328 self.protocol.authenticate(account, flush=False) 329 self.assertTrue(self.protocol.is_protocol_authenticated()) 330 self.assertTrue(self.protocol.is_app_authenticated()) 331 self.assertFalse(self.protocol.is_app_authorized()) 332 333 def testProtocolAuthenticate(self): 334 # Test can not work on the abstract base. 335 if self.protocol.__class__ == Protocol: 336 self.protocol.protocol_authenticate(self.account) 337 return 338 # There is no guarantee that the device provided any response 339 # during protocol level authentification. 340 self.doProtocolAuthenticate(flush=False) 341 self.assertTrue(self.protocol.is_protocol_authenticated()) 342 self.assertFalse(self.protocol.is_app_authenticated()) 343 self.assertFalse(self.protocol.is_app_authorized()) 344 345 def testIsProtocolAuthenticated(self): 346 pass # See testProtocolAuthenticate() 347 348 def testAppAuthenticate(self): 349 # Test can not work on the abstract base. 350 if self.protocol.__class__ == Protocol: 351 self.assertRaises(Exception, 352 self.protocol.app_authenticate, 353 self.account) 354 return 355 self.testProtocolAuthenticate() 356 self.doAppAuthenticate(flush=False) 357 self.assertTrue(self.protocol.is_protocol_authenticated()) 358 self.assertTrue(self.protocol.is_app_authenticated()) 359 self.assertFalse(self.protocol.is_app_authorized()) 360 361 def testIsAppAuthenticated(self): 362 pass # See testAppAuthenticate() 363 364 def testAppAuthorize(self): 365 # Test can not work on the abstract base. 366 if self.protocol.__class__ == Protocol: 367 self.assertRaises(Exception, self.protocol.app_authorize) 368 return 369 self.doProtocolAuthenticate(flush=False) 370 self.doAppAuthenticate(flush=False) 371 response = self.protocol.response 372 373 # Authorize should see that a prompt is still in the buffer, 374 # and do nothing. 375 self.doAppAuthorize(flush=False) 376 self.assertEqual(self.protocol.response, response) 377 self.assertTrue(self.protocol.is_protocol_authenticated()) 378 self.assertTrue(self.protocol.is_app_authenticated()) 379 self.assertTrue(self.protocol.is_app_authorized()) 380 381 self.doAppAuthorize(flush=True) 382 self.assertEqual(self.protocol.response, response) 383 self.assertTrue(self.protocol.is_protocol_authenticated()) 384 self.assertTrue(self.protocol.is_app_authenticated()) 385 self.assertTrue(self.protocol.is_app_authorized()) 386 387 def testAppAuthorize2(self): 388 # Same test as above, but using flush=True all the way. 389 # Test can not work on the abstract base. 390 if self.protocol.__class__ == Protocol: 391 self.assertRaises(Exception, self.protocol.app_authorize) 392 return 393 self.doProtocolAuthenticate(flush=True) 394 self.doAppAuthenticate(flush=True) 395 response = self.protocol.response 396 397 # At this point app_authorize should fail because the buffer is 398 # empty due to flush=True above. In other words, app_authorize 399 # will wait for a prompt until a timeout happens. 400 self.assertRaises(TimeoutException, self.doAppAuthorize) 401 self.assertTrue(self.protocol.is_protocol_authenticated()) 402 self.assertTrue(self.protocol.is_app_authenticated()) 403 self.assertFalse(self.protocol.is_app_authorized()) 404 405 def testAutoAppAuthorize(self): 406 # Test can not work on the abstract base. 407 if self.protocol.__class__ == Protocol: 408 self.assertRaises(TypeError, self.protocol.auto_app_authorize) 409 return 410 411 self.testAppAuthenticate() 412 response = self.protocol.response 413 414 # This should do nothing, because our test host does not 415 # support AAA. Can't think of a way to test against a 416 # device using AAA. 417 self.protocol.auto_app_authorize(self.account, flush=False) 418 self.assertEqual(self.protocol.response, response) 419 self.assertTrue(self.protocol.is_protocol_authenticated()) 420 self.assertTrue(self.protocol.is_app_authenticated()) 421 self.assertTrue(self.protocol.is_app_authorized()) 422 423 self.protocol.auto_app_authorize(self.account, flush=True) 424 self.assertEqual(self.protocol.response, response) 425 self.assertTrue(self.protocol.is_protocol_authenticated()) 426 self.assertTrue(self.protocol.is_app_authenticated()) 427 self.assertTrue(self.protocol.is_app_authorized()) 428 429 def testIsAppAuthorized(self): 430 pass # see testAppAuthorize() 431 432 def testSend(self): 433 # Test can not work on the abstract base. 434 if self.protocol.__class__ == Protocol: 435 self.assertRaises(Exception, self.protocol.send, 'ls') 436 return 437 self.doLogin() 438 self.protocol.execute('ls') 439 440 self.protocol.send('df\r') 441 self.assertTrue(self.protocol.response is not None) 442 self.assertTrue(self.protocol.response.startswith('ls')) 443 444 self.protocol.send('exit\r') 445 self.assertTrue(self.protocol.response is not None) 446 self.assertTrue(self.protocol.response.startswith('ls')) 447 448 def testExecute(self): 449 # Test can not work on the abstract base. 450 if self.protocol.__class__ == Protocol: 451 self.assertRaises(Exception, self.protocol.execute, 'ls') 452 return 453 self.doLogin() 454 self.protocol.execute('ls') 455 self.assertTrue(self.protocol.response is not None) 456 self.assertTrue(self.protocol.response.startswith('ls')) 457 458 # Make sure that we raise an error if the device responds 459 # with something that matches any of the error prompts. 460 self.protocol.set_error_prompt('.') 461 self.assertRaises(InvalidCommandException, 462 self.protocol.execute, 463 'this-command-causes-an-error') 464 465 def testWaitfor(self): 466 # Test can not work on the abstract base. 467 if self.protocol.__class__ == Protocol: 468 self.assertRaises(Exception, self.protocol.waitfor, 'ls') 469 return 470 self.doLogin() 471 oldresponse = self.protocol.response 472 self.protocol.send('ls\r') 473 self.assertEqual(oldresponse, self.protocol.response) 474 self.protocol.waitfor(re.compile(r'[\r\n]')) 475 self.assertNotEqual(oldresponse, self.protocol.response) 476 oldresponse = self.protocol.response 477 self.protocol.waitfor(re.compile(r'[\r\n]')) 478 self.assertEqual(oldresponse, self.protocol.response) 479 480 def testExpect(self): 481 # Test can not work on the abstract base. 482 if self.protocol.__class__ == Protocol: 483 self.assertRaises(Exception, self.protocol.expect, 'ls') 484 return 485 self.doLogin() 486 oldresponse = self.protocol.response 487 self.protocol.send('ls\r') 488 self.assertEqual(oldresponse, self.protocol.response) 489 self.protocol.expect(re.compile(r'[\r\n]')) 490 self.assertNotEqual(oldresponse, self.protocol.response) 491 492 def testExpectPrompt(self): 493 # Test can not work on the abstract base. 494 if self.protocol.__class__ == Protocol: 495 self.assertRaises(Exception, self.protocol.expect, 'ls') 496 return 497 self.doLogin() 498 oldresponse = self.protocol.response 499 self.protocol.send('ls\r') 500 self.assertEqual(oldresponse, self.protocol.response) 501 self.protocol.expect_prompt() 502 self.assertNotEqual(oldresponse, self.protocol.response) 503 504 def testAddMonitor(self): 505 # Set the monitor callback up. 506 def monitor_cb(thedata, *args, **kwargs): 507 thedata['args'] = args 508 thedata['kwargs'] = kwargs 509 data = {} 510 self.protocol.add_monitor('abc', partial(monitor_cb, data)) 511 512 # Simulate some non-matching data. 513 self.protocol.buffer.append(u'aaa') 514 self.assertEqual(data, {}) 515 516 # Simulate some matching data. 517 self.protocol.buffer.append(u'abc') 518 self.assertEqual(len(data.get('args')), 3) 519 self.assertEqual(data.get('args')[0], self.protocol) 520 self.assertEqual(data.get('args')[1], 0) 521 self.assertEqual(data.get('args')[2].group(0), 'abc') 522 self.assertEqual(data.get('kwargs'), {}) 523 524 def testGetBuffer(self): 525 # Test can not work on the abstract base. 526 if self.protocol.__class__ == Protocol: 527 return 528 self.assertEqual(str(self.protocol.buffer), '') 529 self.doLogin() 530 # Depending on whether the connected host sends a banner, 531 # the buffer may or may not contain anything now. 532 533 before = str(self.protocol.buffer) 534 self.protocol.send('ls\r') 535 self.protocol.waitfor(self.protocol.get_prompt()) 536 self.assertNotEqual(str(self.protocol.buffer), before) 537 538 def _cancel_cb(self, data): 539 self.protocol.cancel_expect() 540 541 def testCancelExpect(self): 542 # Test can not work on the abstract base. 543 if self.protocol.__class__ == Protocol: 544 return 545 self.doLogin() 546 oldresponse = self.protocol.response 547 self.protocol.data_received_event.connect(self._cancel_cb) 548 self.protocol.send('ls\r') 549 self.assertEqual(oldresponse, self.protocol.response) 550 self.assertRaises(ExpectCancelledException, 551 self.protocol.expect, 552 'notgoingtohappen') 553 554 def testInteract(self): 555 # Test can not work on the abstract base. 556 if self.protocol.__class__ == Protocol: 557 self.assertRaises(Exception, self.protocol.interact) 558 return 559 # Can't really be tested. 560 561 def testClose(self): 562 if self.protocol.__class__ != Protocol: 563 self.doConnect() 564 self.protocol.close(True) 565 566 def testGetHost(self): 567 self.assertTrue(self.protocol.get_host() is None) 568 if self.protocol.__class__ == Protocol: 569 return 570 self.doConnect() 571 self.assertEqual(self.protocol.get_host(), self.hostname) 572 573 def testGuessOs(self): 574 self.assertEqual('unknown', self.protocol.guess_os()) 575 # Other tests can not work on the abstract base. 576 if self.protocol.__class__ == Protocol: 577 return 578 self.doConnect() 579 self.assertEqual('unknown', self.protocol.guess_os()) 580 self.protocol.login(self.account) 581 self.assertTrue(self.protocol.is_protocol_authenticated()) 582 self.assertTrue(self.protocol.is_app_authenticated()) 583 self.assertTrue(self.protocol.is_app_authorized()) 584 self.assertEqual('shell', self.protocol.guess_os()) 585 586 587def suite(): 588 return unittest.TestLoader().loadTestsFromTestCase(ProtocolTest) 589if __name__ == '__main__': 590 unittest.TextTestRunner(verbosity=2).run(suite()) 591