1''' 2Quick start 3----------- 4 5The goal of NDB is to provide an easy access to RTNL info and entities via 6Python objects, like `pyroute2.ndb.objects.interface` (see also: 7:ref:`ndbinterfaces`), `pyroute2.ndb.objects.route` (see also: 8:ref:`ndbroutes`) etc. These objects do not 9only reflect the system state for the time of their instantiation, but 10continuously monitor the system for relevant updates. The monitoring is 11done via netlink notifications, thus no polling. Also the objects allow 12to apply changes back to the system and rollback the changes. 13 14On the other hand it's too expensive to create Python objects for all the 15available RTNL entities, e.g. when there are hundreds of interfaces and 16thousands of routes. Thus NDB creates objects only upon request, when 17the user calls `.create()` to create new objects or runs 18`ndb.<view>[selector]` (e.g. `ndb.interfaces['eth0']`) to access an 19existing object. 20 21To list existing RTNL entities NDB uses objects of the class `RecordSet` 22that `yield` individual `Record` objects for every entity (see also: 23:ref:`ndbreports`). An object of the `Record` class is immutable, doesn't 24monitor any updates, doesn't contain any links to other objects and essentially 25behaves like a simple named tuple. 26 27.. aafig:: 28 :scale: 80 29 :textual: 30 31 32 +---------------------+ 33 | | 34 | | 35 | `NDB() instance` | 36 | | 37 | | 38 +---------------------+ 39 | 40 | 41 +-------------------+ 42 +-------------------+ | 43 +-------------------+ | |-----------+--------------------------+ 44 | | | | | | 45 | | | | | | 46 | `View()` | | | | | 47 | | |-+ | | 48 | |-+ | | 49 +-------------------+ | | 50 +------------------+ +------------------+ 51 | | | | 52 | | | | 53 | `.dump()` | | `.create()` | 54 | `.summary()` | | `.__getitem__()` | 55 | | | | 56 | | | | 57 +------------------+ +------------------+ 58 | | 59 | | 60 v v 61 +-------------------+ +------------------+ 62 | | +------------------+ | 63 | | +------------------+ | | 64 | `RecordSet()` | | `Interface()` | | | 65 | | | `Address()` | | | 66 | | | `Route()` | | | 67 +-------------------+ | `Neighbour()` | | | 68 | | `Rule()` | |-+ 69 | | ... |-+ 70 v +------------------+ 71 +-------------------+ 72 +-------------------+ | 73 +-------------------+ | | 74 | `filter()` | | | 75 | `select()` | | | 76 | `transform()` | | | 77 | `join()` | |-+ 78 | ... |-+ 79 +-------------------+ 80 | 81 v 82 +-------------------+ 83 +-------------------+ | 84 +-------------------+ | | 85 | | | | 86 | | | | 87 | `Record()` | | | 88 | | |-+ 89 | |-+ 90 +-------------------+ 91 92Here are some simple NDB usage examples. More info see in the reference 93documentation below. 94 95Print all the interface names on the system, assume we have an NDB 96instance `ndb`:: 97 98 for interface in ndb.interfaces.dump(): 99 print(interface.ifname) 100 101Print the routing information in the CSV format:: 102 103 for line in ndb.routes.summary().format('csv'): 104 print(record) 105 106.. note:: More on report filtering and formatting: :ref:`ndbreports` 107.. note:: Since 0.5.11; versions 0.5.10 and earlier used 108 syntax `summary(format='csv', match={...})` 109 110Print IP addresses of interfaces in several network namespaces as:: 111 112 nslist = ['netns01', 113 'netns02', 114 'netns03'] 115 116 for nsname in nslist: 117 ndb.sources.add(netns=nsname) 118 119 for line in ndb.addresses.summary().format('json'): 120 print(line) 121 122Add an IP address on an interface:: 123 124 (ndb 125 .interfaces['eth0'] 126 .add_ip('10.0.0.1/24') 127 .commit()) 128 # ---> <--- NDB waits until the address actually 129 130Change an interface property:: 131 132 (ndb 133 .interfaces['eth0'] 134 .set('state', 'up') 135 .set('address', '00:11:22:33:44:55') 136 .commit()) 137 # ---> <--- NDB waits here for the changes to be applied 138 139 # same as above, but using properties as argument names 140 (ndb 141 .interfaces['eth0'] 142 .set(state='up') 143 .set(address='00:11:22:33:44:55') 144 .commit()) 145 146 # ... or with another syntax 147 with ndb.interfaces['eth0'] as i: 148 i['state'] = 'up' 149 i['address'] = '00:11:22:33:44:55' 150 # ---> <--- the commit() is called authomatically by 151 # the context manager's __exit__() 152 153''' 154import gc 155import sys 156import json 157import time 158import errno 159import atexit 160import sqlite3 161import logging 162import logging.handlers 163import threading 164import traceback 165import ctypes 166import ctypes.util 167from functools import partial 168from collections import OrderedDict 169from pr2modules import config 170from pr2modules import cli 171from pr2modules.common import basestring 172from pr2modules.netlink import nlmsg_base 173## 174# NDB stuff 175from . import schema 176from .events import (DBMExitException, 177 ShutdownException, 178 InvalidateHandlerException, 179 RescheduleException) 180from .messages import (cmsg, 181 cmsg_event, 182 cmsg_failed, 183 cmsg_sstart) 184from .source import (Source, 185 SourceProxy) 186from .auth_manager import check_auth 187from .auth_manager import AuthManager 188from .objects import RSLV_DELETE 189from .objects.interface import Interface 190from .objects.interface import Vlan 191from .objects.address import Address 192from .objects.route import Route 193from .objects.neighbour import Neighbour 194from .objects.rule import Rule 195from .objects.netns import NetNS 196from .report import (RecordSet, 197 Record) 198try: 199 from urlparse import urlparse 200except ImportError: 201 from urllib.parse import urlparse 202 203try: 204 import queue 205except ImportError: 206 import Queue as queue 207 208try: 209 import psycopg2 210except ImportError: 211 psycopg2 = None 212 213log = logging.getLogger(__name__) 214_sql_adapters_lock = threading.Lock() 215_sql_adapters_psycopg2_registered = False 216_sql_adapters_sqlite3_registered = False 217 218 219def target_adapter(value): 220 # 221 # MPLS target adapter for SQLite3 222 # 223 return json.dumps(value) 224 225 226class PostgreSQLAdapter(object): 227 228 def __init__(self, obj): 229 self.obj = obj 230 231 def getquoted(self): 232 return "'%s'" % json.dumps(self.obj) 233 234 235def register_sqlite3_adapters(): 236 global _sql_adapters_lock 237 global _sql_adapters_sqlite3_registered 238 with _sql_adapters_lock: 239 if not _sql_adapters_sqlite3_registered: 240 _sql_adapters_sqlite3_registered = True 241 sqlite3.register_adapter(list, target_adapter) 242 sqlite3.register_adapter(dict, target_adapter) 243 244 245def regsiter_postgres_adapters(): 246 global _sql_adapters_lock 247 global _sql_adapters_psycopg2_registered 248 with _sql_adapters_lock: 249 if psycopg2 is not None and not _sql_adapters_psycopg2_registered: 250 _sql_adapters_psycopg2_registered = True 251 psycopg2.extensions.register_adapter(list, PostgreSQLAdapter) 252 psycopg2.extensions.register_adapter(dict, PostgreSQLAdapter) 253 254 255class Transaction(object): 256 257 def __init__(self, log): 258 self.queue = [] 259 self.event = threading.Event() 260 self.event.clear() 261 self.log = log.channel('transaction.%s' % id(self)) 262 self.log.debug('begin transaction') 263 264 def push(self, *argv): 265 for obj in argv: 266 self.log.debug('queue %s' % type(obj)) 267 self.queue.append(obj) 268 return self 269 270 def append(self, obj): 271 self.log.debug('queue %s' % type(obj)) 272 self.push(obj) 273 return self 274 275 def pop(self, index=-1): 276 self.log.debug('pop %s' % index) 277 self.queue.pop(index) 278 return self 279 280 def insert(self, index, obj): 281 self.log.debug('insert %i %s' % (index, type(obj))) 282 self.queue.insert(index, obj) 283 return self 284 285 def cancel(self): 286 self.log.debug('cancel transaction') 287 self.queue = [] 288 return self 289 290 def wait(self): 291 return self.event.wait() 292 293 def done(self): 294 return self.event.is_set() 295 296 def commit(self): 297 self.log.debug('commit') 298 rollbacks = [] 299 for obj in self.queue: 300 rollbacks.append(obj) 301 try: 302 obj.commit() 303 except Exception: 304 for rb in reversed(rollbacks): 305 try: 306 rb.rollback() 307 except Exception as e: 308 self.log.warning('ignore rollback exception: %s' % e) 309 raise 310 self.event.set() 311 return self 312 313 314class View(dict): 315 ''' 316 The View() object returns RTNL objects on demand:: 317 318 ifobj1 = ndb.interfaces['eth0'] 319 ifobj2 = ndb.interfaces['eth0'] 320 # ifobj1 != ifobj2 321 ''' 322 323 def __init__(self, 324 ndb, 325 table, 326 chain=None, 327 auth_managers=None): 328 self.ndb = ndb 329 self.log = ndb.log.channel('view.%s' % table) 330 self.table = table 331 self.event = table # FIXME 332 self.chain = chain 333 self.cache = {} 334 if auth_managers is None: 335 auth_managers = [] 336 if chain: 337 auth_managers += chain.auth_managers 338 self.auth_managers = auth_managers 339 self.constraints = {} 340 self.classes = OrderedDict() 341 self.classes['interfaces'] = Interface 342 self.classes['addresses'] = Address 343 self.classes['neighbours'] = Neighbour 344 self.classes['routes'] = Route 345 self.classes['rules'] = Rule 346 self.classes['netns'] = NetNS 347 self.classes['vlans'] = Vlan 348 349 def __enter__(self): 350 return self 351 352 def __exit__(self, exc_type, exc_value, traceback): 353 pass 354 355 @property 356 def default_target(self): 357 if self.table == 'netns': 358 return self.ndb.nsmanager 359 else: 360 return self.ndb.localhost 361 362 @property 363 def context(self): 364 if self.chain is not None: 365 return self.chain.context 366 else: 367 return {} 368 369 def getmany(self, spec, table=None): 370 return self.ndb.schema.get(table or self.table, spec) 371 372 def getone(self, spec, table=None): 373 for obj in self.getmany(spec, table): 374 return obj 375 376 def get(self, spec, table=None): 377 try: 378 return self.__getitem__(spec, table) 379 except KeyError: 380 return None 381 382 def template(self, key, table=None): 383 if self.chain: 384 context = self.chain.context 385 else: 386 context = {} 387 iclass = self.classes[table or self.table] 388 389 spec = (iclass 390 .new_spec(key, self.default_target) 391 .load_context(context) 392 .get_spec) 393 394 return iclass(self, 395 spec, 396 load=False, 397 master=self.chain, 398 auth_managers=self.auth_managers) 399 400 @cli.change_pointer 401 def create(self, *argspec, **kwspec): 402 iclass = self.classes[self.table] 403 404 if self.chain: 405 context = self.chain.context 406 else: 407 context = {} 408 409 spec = (iclass 410 .new_spec(kwspec or argspec[0], self.default_target) 411 .load_context(context) 412 .get_spec) 413 414 if self.chain: 415 spec['ndb_chain'] = self.chain 416 spec['create'] = True 417 return self[spec] 418 419 @cli.change_pointer 420 def add(self, *argspec, **kwspec): 421 self.log.warning('''\n 422 The name add() will be removed in future releases, use create() 423 instead. If you believe that the idea to rename is wrong, please 424 file your opinion to the project's bugtracker. 425 426 The reason behind the rename is not to confuse interfaces.add() with 427 bridge and bond port operations, that don't create any new interfaces 428 but work on existing ones. 429 ''') 430 return self.create(*argspec, **kwspec) 431 432 def wait(self, **spec): 433 ret = None 434 timeout = spec.pop('timeout', None) 435 ctime = time.time() 436 437 # install a limited events queue -- for a possible immediate reaction 438 evq = queue.Queue(maxsize=100) 439 440 def handler(evq, target, event): 441 # ignore the "queue full" exception 442 # 443 # if we miss some events here, nothing bad happens: we just 444 # load them from the DB after a timeout, falling back to 445 # the DB polling 446 # 447 # the most important here is not to allocate too much memory 448 try: 449 evq.put_nowait((target, event)) 450 except queue.Full: 451 pass 452 # 453 hdl = partial(handler, evq) 454 (self 455 .ndb 456 .register_handler(self 457 .ndb 458 .schema 459 .classes[self.event], hdl)) 460 # 461 try: 462 ret = self.__getitem__(spec) 463 for key in spec: 464 if ret[key] != spec[key]: 465 ret = None 466 break 467 except KeyError: 468 ret = None 469 470 while ret is None: 471 if timeout is not None: 472 if ctime + timeout < time.time(): 473 break 474 try: 475 target, msg = evq.get(timeout=1) 476 except queue.Empty: 477 try: 478 ret = self.__getitem__(spec) 479 for key in spec: 480 if ret[key] != spec[key]: 481 ret = None 482 raise KeyError() 483 break 484 except KeyError: 485 continue 486 487 # 488 for key, value in spec.items(): 489 if key == 'target' and value != target: 490 break 491 elif value not in (msg.get(key), 492 msg.get_attr(msg.name2nla(key))): 493 break 494 else: 495 while ret is None: 496 try: 497 ret = self.__getitem__(spec) 498 except KeyError: 499 time.sleep(0.1) 500 501 # 502 (self 503 .ndb 504 .unregister_handler(self 505 .ndb 506 .schema 507 .classes[self.event], hdl)) 508 509 del evq 510 del hdl 511 gc.collect() 512 if ret is None: 513 raise TimeoutError() 514 return ret 515 516 @check_auth('obj:read') 517 def locate(self, spec=None, table=None, **kwarg): 518 ''' 519 This method works like `__getitem__()`, but the important 520 difference is that it uses only key fields to locate the 521 object in the DB, ignoring all other keys. 522 523 It is useful to locate objects that may change attributes 524 during request, like an interface may come up/down, or an 525 address may become primary/secondary, so plain 526 `__getitem__()` will not match while the object still 527 exists. 528 ''' 529 if isinstance(spec, Record): 530 spec = spec._as_dict() 531 spec = spec or kwarg 532 if not spec: 533 raise TypeError('got an empty spec') 534 535 table = table or self.table 536 iclass = self.classes[table] 537 spec = iclass.spec_normalize(spec) 538 kspec = (self 539 .ndb 540 .schema 541 .compiled[table]['norm_idx']) 542 543 request = {} 544 for name in kspec: 545 name = iclass.nla2name(name) 546 if name in spec: 547 request[name] = spec[name] 548 549 if not request: 550 raise KeyError('got an empty key') 551 552 return self[request] 553 554 @check_auth('obj:read') 555 def __getitem__(self, key, table=None): 556 557 ret = self.template(key, table) 558 559 # rtnl_object.key() returns a dcitionary that can not 560 # be used as a cache key. Create here a tuple from it. 561 # The key order guaranteed by the dictionary. 562 cache_key = tuple(ret.key.items()) 563 564 rtime = time.time() 565 566 # Iterate all the cache to remove unused and clean 567 # (without any started transaction) objects. 568 for ckey in tuple(self.cache): 569 # Skip the current cache_key to avoid extra 570 # cache del/add records in the logs 571 if ckey == cache_key: 572 continue 573 # The number of referrers must be > 1, the first 574 # one is the cache itself 575 rcount = len(gc.get_referrers(self.cache[ckey])) 576 # Remove only expired items 577 expired = (rtime - self.cache[ckey].atime) > config.cache_expire 578 # The number of changed rtnl_object fields must 579 # be 0 which means that no transaction is started 580 if rcount == 1 and self.cache[ckey].clean and expired: 581 self.log.debug('cache del %s' % (ckey, )) 582 self.cache.pop(ckey, None) 583 584 if cache_key in self.cache: 585 self.log.debug('cache hit %s' % (cache_key, )) 586 # Explicitly get rid of the created object 587 del ret 588 # The object from the cache has already 589 # registered callbacks, simply return it 590 ret = self.cache[cache_key] 591 ret.atime = rtime 592 return ret 593 else: 594 # Cache only existing objects 595 if self.exists(key): 596 ret.load_sql() 597 self.log.debug('cache add %s' % (cache_key, )) 598 self.cache[cache_key] = ret 599 600 ret.register() 601 return ret 602 603 def exists(self, key, table=None): 604 ''' 605 Check if the specified object exists in the database:: 606 607 ndb.interfaces.exists('eth0') 608 ndb.interfaces.exists({'ifname': 'eth0', 'target': 'localhost'}) 609 ndb.addresses.exists('127.0.0.1/8') 610 ''' 611 if self.chain: 612 context = self.chain.context 613 else: 614 context = {} 615 616 iclass = self.classes[self.table] 617 key = iclass.new_spec(key, self.default_target).load_context(context).get_spec 618 619 iclass.resolve(view=self, 620 spec=key, 621 fields=iclass.resolve_fields, 622 policy=RSLV_DELETE) 623 624 table = table or self.table 625 schema = self.ndb.schema 626 names = schema.compiled[self.table]['all_names'] 627 628 self.log.debug('check if the key %s exists in table %s' % 629 (key, table)) 630 keys = [] 631 values = [] 632 for name, value in key.items(): 633 nla_name = iclass.name2nla(name) 634 if nla_name in names: 635 name = nla_name 636 if value is not None and name in names: 637 keys.append('f_%s = %s' % (name, schema.plch)) 638 values.append(value) 639 spec = (schema 640 .fetchone('SELECT * FROM %s WHERE %s' % 641 (self.table, ' AND '.join(keys)), values)) 642 if spec is not None: 643 self.log.debug('exists') 644 return True 645 else: 646 self.log.debug('not exists') 647 return False 648 649 def __setitem__(self, key, value): 650 raise NotImplementedError() 651 652 def __delitem__(self, key): 653 raise NotImplementedError() 654 655 def __iter__(self): 656 return self.keys() 657 658 def __contains__(self, key): 659 return key in self.dump() 660 661 @check_auth('obj:list') 662 def keys(self): 663 for record in self.dump(): 664 yield record 665 666 @check_auth('obj:list') 667 def values(self): 668 for key in self.keys(): 669 yield self[key] 670 671 @check_auth('obj:list') 672 def items(self): 673 for key in self.keys(): 674 yield (key, self[key]) 675 676 @cli.show_result 677 def count(self): 678 return (self 679 .ndb 680 .schema 681 .fetchone('SELECT count(*) FROM %s' % self.table))[0] 682 683 def __len__(self): 684 return self.count() 685 686 def _keys(self, iclass): 687 return (['target', 'tflags'] + 688 self.ndb.schema.compiled[iclass.view or iclass.table]['names']) 689 690 def _native(self, dump): 691 fnames = next(dump) 692 for record in dump: 693 yield Record(fnames, record, self.classes[self.table]) 694 695 @cli.show_result 696 @check_auth('obj:list') 697 def dump(self): 698 iclass = self.classes[self.table] 699 return RecordSet(self._native(iclass.dump(self))) 700 701 @cli.show_result 702 @check_auth('obj:list') 703 def summary(self): 704 iclass = self.classes[self.table] 705 return RecordSet(self._native(iclass.summary(self))) 706 707 708class SourcesView(View): 709 710 def __init__(self, ndb, auth_managers=None): 711 super(SourcesView, self).__init__(ndb, 'sources') 712 self.classes['sources'] = Source 713 self.cache = {} 714 self.proxy = {} 715 self.lock = threading.Lock() 716 if auth_managers is None: 717 auth_managers = [] 718 self.auth_managers = auth_managers 719 720 def async_add(self, **spec): 721 spec = dict(Source.defaults(spec)) 722 self.cache[spec['target']] = Source(self.ndb, **spec).start() 723 return self.cache[spec['target']] 724 725 def add(self, **spec): 726 spec = dict(Source.defaults(spec)) 727 if 'event' not in spec: 728 sync = True 729 spec['event'] = threading.Event() 730 else: 731 sync = False 732 self.cache[spec['target']] = Source(self.ndb, **spec).start() 733 if sync: 734 self.cache[spec['target']].event.wait() 735 return self.cache[spec['target']] 736 737 def remove(self, target, code=errno.ECONNRESET, sync=True): 738 with self.lock: 739 if target in self.cache: 740 source = self.cache[target] 741 source.close(code=code, sync=sync) 742 return self.cache.pop(target) 743 744 @check_auth('obj:list') 745 def keys(self): 746 for key in self.cache: 747 yield key 748 749 def _keys(self, iclass): 750 return ['target', 'kind'] 751 752 def wait(self, **spec): 753 raise NotImplementedError() 754 755 def _summary(self, *argv, **kwarg): 756 return self._dump(*argv, **kwarg) 757 758 def __getitem__(self, key, table=None): 759 if isinstance(key, basestring): 760 target = key 761 elif isinstance(key, dict) and 'target' in key.keys(): 762 target = key['target'] 763 else: 764 raise ValueError('key format not supported') 765 766 if target in self.cache: 767 return self.cache[target] 768 elif target in self.proxy: 769 return self.proxy[target] 770 else: 771 proxy = SourceProxy(self.ndb, target) 772 self.proxy[target] = proxy 773 return proxy 774 775 776class Log(object): 777 778 def __init__(self, log_id=None): 779 self.logger = None 780 self.state = False 781 self.log_id = log_id or id(self) 782 self.logger = logging.getLogger('pyroute2.ndb.%s' % self.log_id) 783 self.main = self.channel('main') 784 785 def __call__(self, target=None, level=logging.INFO): 786 if target is None: 787 return self.logger is not None 788 789 if self.logger is not None: 790 for handler in tuple(self.logger.handlers): 791 self.logger.removeHandler(handler) 792 793 if target in ('off', False): 794 if self.state: 795 self.logger.setLevel(0) 796 self.logger.addHandler(logging.NullHandler()) 797 return 798 799 if target in ('on', 'stderr'): 800 handler = logging.StreamHandler() 801 elif target == 'debug': 802 handler = logging.StreamHandler() 803 level = logging.DEBUG 804 elif isinstance(target, basestring): 805 url = urlparse(target) 806 if not url.scheme and url.path: 807 handler = logging.FileHandler(url.path) 808 elif url.scheme == 'syslog': 809 handler = (logging 810 .handlers 811 .SysLogHandler(address=url.netloc.split(':'))) 812 else: 813 raise ValueError('logging scheme not supported') 814 else: 815 handler = target 816 817 # set formatting only for new created logging handlers 818 if handler is not target: 819 fmt = '%(asctime)s %(levelname)8s %(name)s: %(message)s' 820 formatter = logging.Formatter(fmt) 821 handler.setFormatter(formatter) 822 823 self.logger.addHandler(handler) 824 self.logger.setLevel(level) 825 826 @property 827 def on(self): 828 self.__call__(target='on') 829 830 @property 831 def off(self): 832 self.__call__(target='off') 833 834 def close(self): 835 manager = self.logger.manager 836 name = self.logger.name 837 # the loggerDict can be huge, so don't 838 # cache all the keys -- cache only the 839 # needed ones 840 purge_list = [] 841 for logger in manager.loggerDict.keys(): 842 if logger.startswith(name): 843 purge_list.append(logger) 844 # now shoot them one by one 845 for logger in purge_list: 846 del manager.loggerDict[logger] 847 # don't force GC, leave it to the user 848 del manager 849 del name 850 del purge_list 851 852 def channel(self, name): 853 return logging.getLogger('pyroute2.ndb.%s.%s' % (self.log_id, name)) 854 855 def debug(self, *argv, **kwarg): 856 return self.main.debug(*argv, **kwarg) 857 858 def info(self, *argv, **kwarg): 859 return self.main.info(*argv, **kwarg) 860 861 def warning(self, *argv, **kwarg): 862 return self.main.warning(*argv, **kwarg) 863 864 def error(self, *argv, **kwarg): 865 return self.main.error(*argv, **kwarg) 866 867 def critical(self, *argv, **kwarg): 868 return self.main.critical(*argv, **kwarg) 869 870 871class ReadOnly(object): 872 873 def __init__(self, ndb): 874 self.ndb = ndb 875 876 def __enter__(self): 877 self.ndb.schema.allow_write(False) 878 return self 879 880 def __exit__(self, exc_type, exc_value, traceback): 881 self.ndb.schema.allow_write(True) 882 883 884class DeadEnd(object): 885 886 def put(self, *argv, **kwarg): 887 raise ShutdownException('shutdown in progress') 888 889 890class EventQueue(object): 891 892 def __init__(self, *argv, **kwarg): 893 self._bypass = self._queue = queue.Queue(*argv, **kwarg) 894 895 def put(self, *argv, **kwarg): 896 return self._queue.put(*argv, **kwarg) 897 898 def shutdown(self): 899 self._queue = DeadEnd() 900 901 def bypass(self, *argv, **kwarg): 902 return self._bypass.put(*argv, **kwarg) 903 904 def get(self, *argv, **kwarg): 905 return self._bypass.get(*argv, **kwarg) 906 907 def qsize(self): 908 return self._bypass.qsize() 909 910 911def Events(*argv): 912 for sequence in argv: 913 if sequence is not None: 914 for item in sequence: 915 yield item 916 917 918class AuthProxy(object): 919 920 def __init__(self, ndb, auth_managers): 921 self._ndb = ndb 922 self._auth_managers = auth_managers 923 924 for spec in ('interfaces', 925 'addresses', 926 'routes', 927 'neighbours', 928 'rules', 929 'netns', 930 'vlans'): 931 view = View(self._ndb, spec, 932 auth_managers=self._auth_managers) 933 setattr(self, spec, view) 934 935 936class NDB(object): 937 938 @property 939 def nsmanager(self): 940 return '%s/nsmanager' % self.localhost 941 942 def __init__(self, 943 sources=None, 944 localhost='localhost', 945 db_provider='sqlite3', 946 db_spec=':memory:', 947 db_cleanup=True, 948 rtnl_debug=False, 949 log=False, 950 auto_netns=False, 951 libc=None, 952 messenger=None): 953 954 if db_provider == 'postgres': 955 db_provider = 'psycopg2' 956 957 if db_provider == 'sqlite3': 958 register_sqlite3_adapters() 959 elif db_provider == 'psycopg2': 960 regsiter_postgres_adapters() 961 962 self.localhost = localhost 963 self.ctime = self.gctime = time.time() 964 self.schema = None 965 self.config = {} 966 self.libc = libc or ctypes.CDLL(ctypes.util.find_library('c'), 967 use_errno=True) 968 self.log = Log(log_id=id(self)) 969 self.readonly = ReadOnly(self) 970 self._auto_netns = auto_netns 971 self._db = None 972 self._dbm_thread = None 973 self._dbm_ready = threading.Event() 974 self._dbm_shutdown = threading.Event() 975 self._db_cleanup = db_cleanup 976 self._global_lock = threading.Lock() 977 self._event_map = None 978 self._event_queue = EventQueue(maxsize=100) 979 self.messenger = messenger 980 if messenger is not None: 981 self._mm_thread = threading.Thread(target=self.__mm__, 982 name='Messenger') 983 self._mm_thread.setDaemon(True) 984 self._mm_thread.start() 985 # 986 if log: 987 if isinstance(log, basestring): 988 self.log(log) 989 elif isinstance(log, (tuple, list)): 990 self.log(*log) 991 elif isinstance(log, dict): 992 self.log(**log) 993 else: 994 raise TypeError('wrong log spec format') 995 # 996 # fix sources prime 997 if sources is None: 998 sources = [{'target': self.localhost, 999 'kind': 'local', 1000 'nlm_generator': 1}] 1001 if sys.platform.startswith('linux'): 1002 sources.append({'target': self.nsmanager, 1003 'kind': 'nsmanager'}) 1004 elif not isinstance(sources, (list, tuple)): 1005 raise ValueError('sources format not supported') 1006 1007 for spec in sources: 1008 if 'target' not in spec: 1009 spec['target'] = self.localhost 1010 break 1011 1012 am = AuthManager({'obj:list': True, 1013 'obj:read': True, 1014 'obj:modify': True}, 1015 self.log.channel('auth')) 1016 self.sources = SourcesView(self, auth_managers=[am]) 1017 self._call_registry = {} 1018 self._nl = sources 1019 self._db_provider = db_provider 1020 self._db_spec = db_spec 1021 self._db_rtnl_log = rtnl_debug 1022 atexit.register(self.close) 1023 self._dbm_ready.clear() 1024 self._dbm_error = None 1025 self._dbm_autoload = set() 1026 self._dbm_thread = threading.Thread(target=self.__dbm__, 1027 name='NDB main loop') 1028 self._dbm_thread.setDaemon(True) 1029 self._dbm_thread.start() 1030 self._dbm_ready.wait() 1031 if self._dbm_error is not None: 1032 raise self._dbm_error 1033 for event in tuple(self._dbm_autoload): 1034 event.wait() 1035 self._dbm_autoload = None 1036 for spec in ('interfaces', 1037 'addresses', 1038 'routes', 1039 'neighbours', 1040 'rules', 1041 'netns', 1042 'vlans'): 1043 view = View(self, 1044 spec, 1045 auth_managers=[am]) 1046 setattr(self, spec, view) 1047 # self.query = Query(self.schema) 1048 1049 def _get_view(self, name, chain=None): 1050 return View(self, name, chain) 1051 1052 def __enter__(self): 1053 return self 1054 1055 def __exit__(self, exc_type, exc_value, traceback): 1056 self.close() 1057 1058 def begin(self): 1059 return Transaction(self.log) 1060 1061 def auth_proxy(self, auth_manager): 1062 return AuthProxy(self, [auth_manager, ]) 1063 1064 def register_handler(self, event, handler): 1065 if event not in self._event_map: 1066 self._event_map[event] = [] 1067 self._event_map[event].append(handler) 1068 1069 def unregister_handler(self, event, handler): 1070 self._event_map[event].remove(handler) 1071 1072 def execute(self, *argv, **kwarg): 1073 return self.schema.execute(*argv, **kwarg) 1074 1075 def close(self): 1076 with self._global_lock: 1077 if self._dbm_shutdown.is_set(): 1078 return 1079 else: 1080 self._dbm_shutdown.set() 1081 if hasattr(atexit, 'unregister'): 1082 atexit.unregister(self.close) 1083 else: 1084 try: 1085 atexit._exithandlers.remove((self.close, (), {})) 1086 except ValueError: 1087 pass 1088 # shutdown the _dbm_thread 1089 self._event_queue.shutdown() 1090 self._event_queue.bypass((cmsg(None, ShutdownException()), )) 1091 self._dbm_thread.join() 1092 # shutdown the logger -- free the resources 1093 self.log.close() 1094 1095 def reload(self, kinds=None): 1096 for source in self.sources.values(): 1097 if kinds is not None and source.kind in kinds: 1098 source.restart() 1099 1100 def __mm__(self): 1101 # notify neighbours by sending hello 1102 for peer in self.messenger.transport.peers: 1103 peer.hello() 1104 # receive events 1105 for msg in self.messenger: 1106 if msg['type'] == 'system' and msg['data'] == 'HELLO': 1107 for peer in self.messenger.transport.peers: 1108 peer.last_exception_time = 0 1109 self.reload(kinds=['local', 'netns', 'remote']) 1110 elif msg['type'] == 'transport': 1111 message = msg['data'][0](data=msg['data'][1]) 1112 message.decode() 1113 message['header']['target'] = msg['target'] 1114 self._event_queue.put((message, )) 1115 elif msg['type'] == 'response': 1116 if msg['call_id'] in self._call_registry: 1117 event = self._call_registry.pop(msg['call_id']) 1118 self._call_registry[msg['call_id']] = msg 1119 event.set() 1120 elif msg['type'] == 'api': 1121 if msg['target'] in self.messenger.targets: 1122 try: 1123 ret = self.sources[msg['target']].api(msg['name'], 1124 *msg['argv'], 1125 **msg['kwarg']) 1126 self.messenger.emit({'type': 'response', 1127 'call_id': msg['call_id'], 1128 'return': ret}) 1129 except Exception as e: 1130 self.messenger.emit({'type': 'response', 1131 'call_id': msg['call_id'], 1132 'exception': e}) 1133 else: 1134 self.log.warning('unknown protocol via messenger') 1135 1136 def __dbm__(self): 1137 1138 def default_handler(target, event): 1139 if isinstance(getattr(event, 'payload', None), Exception): 1140 raise event.payload 1141 log.warning('unsupported event ignored: %s' % type(event)) 1142 1143 def check_sources_started(self, _locals, target, event): 1144 _locals['countdown'] -= 1 1145 if _locals['countdown'] == 0: 1146 self._dbm_ready.set() 1147 1148 _locals = {'countdown': len(self._nl)} 1149 1150 # init the events map 1151 event_map = {cmsg_event: [lambda t, x: x.payload.set()], 1152 cmsg_failed: [lambda t, x: (self 1153 .schema 1154 .mark(t, 1))], 1155 cmsg_sstart: [partial(check_sources_started, 1156 self, _locals)]} 1157 self._event_map = event_map 1158 1159 event_queue = self._event_queue 1160 1161 try: 1162 with _sql_adapters_lock: 1163 if self._db_provider == 'sqlite3': 1164 self._db = sqlite3.connect(self._db_spec) 1165 elif self._db_provider == 'psycopg2': 1166 self._db = psycopg2.connect(**self._db_spec) 1167 1168 self.schema = schema.init(self, 1169 self._db, 1170 self._db_provider, 1171 self._db_rtnl_log, 1172 id(threading.current_thread())) 1173 except Exception as e: 1174 self._dbm_error = e 1175 self._dbm_ready.set() 1176 return 1177 1178 for spec in self._nl: 1179 spec['event'] = None 1180 self.sources.add(**spec) 1181 1182 for (event, handlers) in self.schema.event_map.items(): 1183 for handler in handlers: 1184 self.register_handler(event, handler) 1185 1186 stop = False 1187 reschedule = [] 1188 while not stop: 1189 events = Events(event_queue.get(), reschedule) 1190 reschedule = [] 1191 try: 1192 for event in events: 1193 handlers = event_map.get(event.__class__, 1194 [default_handler, ]) 1195 if self.messenger is not None and\ 1196 (event 1197 .get('header', {}) 1198 .get('target', None) in self.messenger.targets): 1199 if isinstance(event, nlmsg_base): 1200 if event.data is not None: 1201 data = event.data[event.offset: 1202 event.offset + event.length] 1203 else: 1204 event.reset() 1205 event.encode() 1206 data = event.data 1207 data = (type(event), data) 1208 tgt = event['header']['target'] 1209 self.messenger.emit({'type': 'transport', 1210 'target': tgt, 1211 'data': data}) 1212 1213 for handler in tuple(handlers): 1214 try: 1215 target = event['header']['target'] 1216 handler(target, event) 1217 except RescheduleException: 1218 if 'rcounter' not in event['header']: 1219 event['header']['rcounter'] = 0 1220 if event['header']['rcounter'] < 3: 1221 event['header']['rcounter'] += 1 1222 self.log.debug('reschedule %s' % (event, )) 1223 reschedule.append(event) 1224 else: 1225 self.log.error('drop %s' % (event, )) 1226 except InvalidateHandlerException: 1227 try: 1228 handlers.remove(handler) 1229 except Exception: 1230 self.log.error('could not invalidate ' 1231 'event handler:\n%s' 1232 % traceback.format_exc()) 1233 except ShutdownException: 1234 stop = True 1235 break 1236 except DBMExitException: 1237 return 1238 except Exception: 1239 self.log.error('could not load event:\n%s\n%s' 1240 % (event, traceback.format_exc())) 1241 if time.time() - self.gctime > config.gc_timeout: 1242 self.gctime = time.time() 1243 except Exception as e: 1244 self.log.error('exception <%s> in source %s' % (e, target)) 1245 # restart the target 1246 try: 1247 self.sources[target].restart(reason=e) 1248 except KeyError: 1249 pass 1250 1251 # release all the sources 1252 for target in tuple(self.sources.cache): 1253 source = self.sources.remove(target, sync=False) 1254 if source is not None and source.th is not None: 1255 source.shutdown.set() 1256 source.th.join() 1257 if self._db_cleanup: 1258 self.log.debug('flush DB for the target %s' % target) 1259 self.schema.flush(target) 1260 else: 1261 self.log.debug('leave DB for debug') 1262 1263 # close the database 1264 self.schema.commit() 1265 self.schema.close() 1266 1267 # close the logging 1268 for handler in self.log.logger.handlers: 1269 handler.close() 1270