1'''
2Quick start
3-----------
4
5The goal of NDB is to provide an easy access to RTNL info and entities via
6Python objects, like `pyroute2.ndb.objects.interface` (see also:
7:ref:`ndbinterfaces`), `pyroute2.ndb.objects.route` (see also:
8:ref:`ndbroutes`) etc. These objects do not
9only reflect the system state for the time of their instantiation, but
10continuously monitor the system for relevant updates. The monitoring is
11done via netlink notifications, thus no polling. Also the objects allow
12to apply changes back to the system and rollback the changes.
13
14On the other hand it's too expensive to create Python objects for all the
15available RTNL entities, e.g. when there are hundreds of interfaces and
16thousands of routes. Thus NDB creates objects only upon request, when
17the user calls `.create()` to create new objects or runs
18`ndb.<view>[selector]` (e.g. `ndb.interfaces['eth0']`) to access an
19existing object.
20
21To list existing RTNL entities NDB uses objects of the class `RecordSet`
22that `yield` individual `Record` objects for every entity (see also:
23:ref:`ndbreports`). An object of the `Record` class is immutable, doesn't
24monitor any updates, doesn't contain any links to other objects and essentially
25behaves like a simple named tuple.
26
27.. aafig::
28    :scale: 80
29    :textual:
30
31
32      +---------------------+
33      |                     |
34      |                     |
35      | `NDB() instance`    |
36      |                     |
37      |                     |
38      +---------------------+
39                 |
40                 |
41        +-------------------+
42      +-------------------+ |
43    +-------------------+ | |-----------+--------------------------+
44    |                   | | |           |                          |
45    |                   | | |           |                          |
46    | `View()`          | | |           |                          |
47    |                   | |-+           |                          |
48    |                   |-+             |                          |
49    +-------------------+               |                          |
50                               +------------------+       +------------------+
51                               |                  |       |                  |
52                               |                  |       |                  |
53                               | `.dump()`        |       | `.create()`      |
54                               | `.summary()`     |       | `.__getitem__()` |
55                               |                  |       |                  |
56                               |                  |       |                  |
57                               +------------------+       +------------------+
58                                        |                           |
59                                        |                           |
60                                        v                           v
61                              +-------------------+        +------------------+
62                              |                   |      +------------------+ |
63                              |                   |    +------------------+ | |
64                              | `RecordSet()`     |    | `Interface()`    | | |
65                              |                   |    | `Address()`      | | |
66                              |                   |    | `Route()`        | | |
67                              +-------------------+    | `Neighbour()`    | | |
68                                        |              | `Rule()`         | |-+
69                                        |              |  ...             |-+
70                                        v              +------------------+
71                                +-------------------+
72                              +-------------------+ |
73                            +-------------------+ | |
74                            | `filter()`        | | |
75                            | `select()`        | | |
76                            | `transform()`     | | |
77                            | `join()`          | |-+
78                            |  ...              |-+
79                            +-------------------+
80                                        |
81                                        v
82                                +-------------------+
83                              +-------------------+ |
84                            +-------------------+ | |
85                            |                   | | |
86                            |                   | | |
87                            | `Record()`        | | |
88                            |                   | |-+
89                            |                   |-+
90                            +-------------------+
91
92Here are some simple NDB usage examples. More info see in the reference
93documentation below.
94
95Print all the interface names on the system, assume we have an NDB
96instance `ndb`::
97
98    for interface in ndb.interfaces.dump():
99        print(interface.ifname)
100
101Print the routing information in the CSV format::
102
103    for line in ndb.routes.summary().format('csv'):
104        print(record)
105
106.. note:: More on report filtering and formatting: :ref:`ndbreports`
107.. note:: Since 0.5.11; versions 0.5.10 and earlier used
108          syntax `summary(format='csv', match={...})`
109
110Print IP addresses of interfaces in several network namespaces as::
111
112    nslist = ['netns01',
113              'netns02',
114              'netns03']
115
116    for nsname in nslist:
117        ndb.sources.add(netns=nsname)
118
119    for line in ndb.addresses.summary().format('json'):
120        print(line)
121
122Add an IP address on an interface::
123
124    (ndb
125     .interfaces['eth0']
126     .add_ip('10.0.0.1/24')
127     .commit())
128    # ---> <---  NDB waits until the address actually
129
130Change an interface property::
131
132    (ndb
133     .interfaces['eth0']
134     .set('state', 'up')
135     .set('address', '00:11:22:33:44:55')
136     .commit())
137    # ---> <---  NDB waits here for the changes to be applied
138
139    # same as above, but using properties as argument names
140    (ndb
141     .interfaces['eth0']
142     .set(state='up')
143     .set(address='00:11:22:33:44:55')
144     .commit())
145
146    # ... or with another syntax
147    with ndb.interfaces['eth0'] as i:
148        i['state'] = 'up'
149        i['address'] = '00:11:22:33:44:55'
150    # ---> <---  the commit() is called authomatically by
151    #            the context manager's __exit__()
152
153'''
154import gc
155import sys
156import json
157import time
158import errno
159import atexit
160import sqlite3
161import logging
162import logging.handlers
163import threading
164import traceback
165import ctypes
166import ctypes.util
167from functools import partial
168from collections import OrderedDict
169from pr2modules import config
170from pr2modules import cli
171from pr2modules.common import basestring
172from pr2modules.netlink import nlmsg_base
173##
174# NDB stuff
175from . import schema
176from .events import (DBMExitException,
177                     ShutdownException,
178                     InvalidateHandlerException,
179                     RescheduleException)
180from .messages import (cmsg,
181                       cmsg_event,
182                       cmsg_failed,
183                       cmsg_sstart)
184from .source import (Source,
185                     SourceProxy)
186from .auth_manager import check_auth
187from .auth_manager import AuthManager
188from .objects import RSLV_DELETE
189from .objects.interface import Interface
190from .objects.interface import Vlan
191from .objects.address import Address
192from .objects.route import Route
193from .objects.neighbour import Neighbour
194from .objects.rule import Rule
195from .objects.netns import NetNS
196from .report import (RecordSet,
197                     Record)
198try:
199    from urlparse import urlparse
200except ImportError:
201    from urllib.parse import urlparse
202
203try:
204    import queue
205except ImportError:
206    import Queue as queue
207
208try:
209    import psycopg2
210except ImportError:
211    psycopg2 = None
212
213log = logging.getLogger(__name__)
214_sql_adapters_lock = threading.Lock()
215_sql_adapters_psycopg2_registered = False
216_sql_adapters_sqlite3_registered = False
217
218
219def target_adapter(value):
220    #
221    # MPLS target adapter for SQLite3
222    #
223    return json.dumps(value)
224
225
226class PostgreSQLAdapter(object):
227
228    def __init__(self, obj):
229        self.obj = obj
230
231    def getquoted(self):
232        return "'%s'" % json.dumps(self.obj)
233
234
235def register_sqlite3_adapters():
236    global _sql_adapters_lock
237    global _sql_adapters_sqlite3_registered
238    with _sql_adapters_lock:
239        if not _sql_adapters_sqlite3_registered:
240            _sql_adapters_sqlite3_registered = True
241            sqlite3.register_adapter(list, target_adapter)
242            sqlite3.register_adapter(dict, target_adapter)
243
244
245def regsiter_postgres_adapters():
246    global _sql_adapters_lock
247    global _sql_adapters_psycopg2_registered
248    with _sql_adapters_lock:
249        if psycopg2 is not None and not _sql_adapters_psycopg2_registered:
250            _sql_adapters_psycopg2_registered = True
251            psycopg2.extensions.register_adapter(list, PostgreSQLAdapter)
252            psycopg2.extensions.register_adapter(dict, PostgreSQLAdapter)
253
254
255class Transaction(object):
256
257    def __init__(self, log):
258        self.queue = []
259        self.event = threading.Event()
260        self.event.clear()
261        self.log = log.channel('transaction.%s' % id(self))
262        self.log.debug('begin transaction')
263
264    def push(self, *argv):
265        for obj in argv:
266            self.log.debug('queue %s' % type(obj))
267            self.queue.append(obj)
268        return self
269
270    def append(self, obj):
271        self.log.debug('queue %s' % type(obj))
272        self.push(obj)
273        return self
274
275    def pop(self, index=-1):
276        self.log.debug('pop %s' % index)
277        self.queue.pop(index)
278        return self
279
280    def insert(self, index, obj):
281        self.log.debug('insert %i %s' % (index, type(obj)))
282        self.queue.insert(index, obj)
283        return self
284
285    def cancel(self):
286        self.log.debug('cancel transaction')
287        self.queue = []
288        return self
289
290    def wait(self):
291        return self.event.wait()
292
293    def done(self):
294        return self.event.is_set()
295
296    def commit(self):
297        self.log.debug('commit')
298        rollbacks = []
299        for obj in self.queue:
300            rollbacks.append(obj)
301            try:
302                obj.commit()
303            except Exception:
304                for rb in reversed(rollbacks):
305                    try:
306                        rb.rollback()
307                    except Exception as e:
308                        self.log.warning('ignore rollback exception: %s' % e)
309                raise
310        self.event.set()
311        return self
312
313
314class View(dict):
315    '''
316    The View() object returns RTNL objects on demand::
317
318        ifobj1 = ndb.interfaces['eth0']
319        ifobj2 = ndb.interfaces['eth0']
320        # ifobj1 != ifobj2
321    '''
322
323    def __init__(self,
324                 ndb,
325                 table,
326                 chain=None,
327                 auth_managers=None):
328        self.ndb = ndb
329        self.log = ndb.log.channel('view.%s' % table)
330        self.table = table
331        self.event = table  # FIXME
332        self.chain = chain
333        self.cache = {}
334        if auth_managers is None:
335            auth_managers = []
336        if chain:
337            auth_managers += chain.auth_managers
338        self.auth_managers = auth_managers
339        self.constraints = {}
340        self.classes = OrderedDict()
341        self.classes['interfaces'] = Interface
342        self.classes['addresses'] = Address
343        self.classes['neighbours'] = Neighbour
344        self.classes['routes'] = Route
345        self.classes['rules'] = Rule
346        self.classes['netns'] = NetNS
347        self.classes['vlans'] = Vlan
348
349    def __enter__(self):
350        return self
351
352    def __exit__(self, exc_type, exc_value, traceback):
353        pass
354
355    @property
356    def default_target(self):
357        if self.table == 'netns':
358            return self.ndb.nsmanager
359        else:
360            return self.ndb.localhost
361
362    @property
363    def context(self):
364        if self.chain is not None:
365            return self.chain.context
366        else:
367            return {}
368
369    def getmany(self, spec, table=None):
370        return self.ndb.schema.get(table or self.table, spec)
371
372    def getone(self, spec, table=None):
373        for obj in self.getmany(spec, table):
374            return obj
375
376    def get(self, spec, table=None):
377        try:
378            return self.__getitem__(spec, table)
379        except KeyError:
380            return None
381
382    def template(self, key, table=None):
383        if self.chain:
384            context = self.chain.context
385        else:
386            context = {}
387        iclass = self.classes[table or self.table]
388
389        spec = (iclass
390                .new_spec(key, self.default_target)
391                .load_context(context)
392                .get_spec)
393
394        return iclass(self,
395                      spec,
396                      load=False,
397                      master=self.chain,
398                      auth_managers=self.auth_managers)
399
400    @cli.change_pointer
401    def create(self, *argspec, **kwspec):
402        iclass = self.classes[self.table]
403
404        if self.chain:
405            context = self.chain.context
406        else:
407            context = {}
408
409        spec = (iclass
410                .new_spec(kwspec or argspec[0], self.default_target)
411                .load_context(context)
412                .get_spec)
413
414        if self.chain:
415            spec['ndb_chain'] = self.chain
416        spec['create'] = True
417        return self[spec]
418
419    @cli.change_pointer
420    def add(self, *argspec, **kwspec):
421        self.log.warning('''\n
422        The name add() will be removed in future releases, use create()
423        instead. If you believe that the idea to rename is wrong, please
424        file your opinion to the project's bugtracker.
425
426        The reason behind the rename is not to confuse interfaces.add() with
427        bridge and bond port operations, that don't create any new interfaces
428        but work on existing ones.
429        ''')
430        return self.create(*argspec, **kwspec)
431
432    def wait(self, **spec):
433        ret = None
434        timeout = spec.pop('timeout', None)
435        ctime = time.time()
436
437        # install a limited events queue -- for a possible immediate reaction
438        evq = queue.Queue(maxsize=100)
439
440        def handler(evq, target, event):
441            # ignore the "queue full" exception
442            #
443            # if we miss some events here, nothing bad happens: we just
444            # load them from the DB after a timeout, falling back to
445            # the DB polling
446            #
447            # the most important here is not to allocate too much memory
448            try:
449                evq.put_nowait((target, event))
450            except queue.Full:
451                pass
452        #
453        hdl = partial(handler, evq)
454        (self
455         .ndb
456         .register_handler(self
457                           .ndb
458                           .schema
459                           .classes[self.event], hdl))
460        #
461        try:
462            ret = self.__getitem__(spec)
463            for key in spec:
464                if ret[key] != spec[key]:
465                    ret = None
466                    break
467        except KeyError:
468            ret = None
469
470        while ret is None:
471            if timeout is not None:
472                if ctime + timeout < time.time():
473                    break
474            try:
475                target, msg = evq.get(timeout=1)
476            except queue.Empty:
477                try:
478                    ret = self.__getitem__(spec)
479                    for key in spec:
480                        if ret[key] != spec[key]:
481                            ret = None
482                            raise KeyError()
483                    break
484                except KeyError:
485                    continue
486
487            #
488            for key, value in spec.items():
489                if key == 'target' and value != target:
490                    break
491                elif value not in (msg.get(key),
492                                   msg.get_attr(msg.name2nla(key))):
493                    break
494            else:
495                while ret is None:
496                    try:
497                        ret = self.__getitem__(spec)
498                    except KeyError:
499                        time.sleep(0.1)
500
501        #
502        (self
503         .ndb
504         .unregister_handler(self
505                             .ndb
506                             .schema
507                             .classes[self.event], hdl))
508
509        del evq
510        del hdl
511        gc.collect()
512        if ret is None:
513            raise TimeoutError()
514        return ret
515
516    @check_auth('obj:read')
517    def locate(self, spec=None, table=None, **kwarg):
518        '''
519        This method works like `__getitem__()`, but the important
520        difference is that it uses only key fields to locate the
521        object in the DB, ignoring all other keys.
522
523        It is useful to locate objects that may change attributes
524        during request, like an interface may come up/down, or an
525        address may become primary/secondary, so plain
526        `__getitem__()` will not match while the object still
527        exists.
528        '''
529        if isinstance(spec, Record):
530            spec = spec._as_dict()
531        spec = spec or kwarg
532        if not spec:
533            raise TypeError('got an empty spec')
534
535        table = table or self.table
536        iclass = self.classes[table]
537        spec = iclass.spec_normalize(spec)
538        kspec = (self
539                 .ndb
540                 .schema
541                 .compiled[table]['norm_idx'])
542
543        request = {}
544        for name in kspec:
545            name = iclass.nla2name(name)
546            if name in spec:
547                request[name] = spec[name]
548
549        if not request:
550            raise KeyError('got an empty key')
551
552        return self[request]
553
554    @check_auth('obj:read')
555    def __getitem__(self, key, table=None):
556
557        ret = self.template(key, table)
558
559        # rtnl_object.key() returns a dcitionary that can not
560        # be used as a cache key. Create here a tuple from it.
561        # The key order guaranteed by the dictionary.
562        cache_key = tuple(ret.key.items())
563
564        rtime = time.time()
565
566        # Iterate all the cache to remove unused and clean
567        # (without any started transaction) objects.
568        for ckey in tuple(self.cache):
569            # Skip the current cache_key to avoid extra
570            # cache del/add records in the logs
571            if ckey == cache_key:
572                continue
573            # The number of referrers must be > 1, the first
574            # one is the cache itself
575            rcount = len(gc.get_referrers(self.cache[ckey]))
576            # Remove only expired items
577            expired = (rtime - self.cache[ckey].atime) > config.cache_expire
578            # The number of changed rtnl_object fields must
579            # be 0 which means that no transaction is started
580            if rcount == 1 and self.cache[ckey].clean and expired:
581                self.log.debug('cache del %s' % (ckey, ))
582                self.cache.pop(ckey, None)
583
584        if cache_key in self.cache:
585            self.log.debug('cache hit %s' % (cache_key, ))
586            # Explicitly get rid of the created object
587            del ret
588            # The object from the cache has already
589            # registered callbacks, simply return it
590            ret = self.cache[cache_key]
591            ret.atime = rtime
592            return ret
593        else:
594            # Cache only existing objects
595            if self.exists(key):
596                ret.load_sql()
597                self.log.debug('cache add %s' % (cache_key, ))
598                self.cache[cache_key] = ret
599
600        ret.register()
601        return ret
602
603    def exists(self, key, table=None):
604        '''
605        Check if the specified object exists in the database::
606
607            ndb.interfaces.exists('eth0')
608            ndb.interfaces.exists({'ifname': 'eth0', 'target': 'localhost'})
609            ndb.addresses.exists('127.0.0.1/8')
610        '''
611        if self.chain:
612            context = self.chain.context
613        else:
614            context = {}
615
616        iclass = self.classes[self.table]
617        key = iclass.new_spec(key, self.default_target).load_context(context).get_spec
618
619        iclass.resolve(view=self,
620                       spec=key,
621                       fields=iclass.resolve_fields,
622                       policy=RSLV_DELETE)
623
624        table = table or self.table
625        schema = self.ndb.schema
626        names = schema.compiled[self.table]['all_names']
627
628        self.log.debug('check if the key %s exists in table %s' %
629                       (key, table))
630        keys = []
631        values = []
632        for name, value in key.items():
633            nla_name = iclass.name2nla(name)
634            if nla_name in names:
635                name = nla_name
636            if value is not None and name in names:
637                keys.append('f_%s = %s' % (name, schema.plch))
638                values.append(value)
639        spec = (schema
640                .fetchone('SELECT * FROM %s WHERE %s' %
641                          (self.table, ' AND '.join(keys)), values))
642        if spec is not None:
643            self.log.debug('exists')
644            return True
645        else:
646            self.log.debug('not exists')
647            return False
648
649    def __setitem__(self, key, value):
650        raise NotImplementedError()
651
652    def __delitem__(self, key):
653        raise NotImplementedError()
654
655    def __iter__(self):
656        return self.keys()
657
658    def __contains__(self, key):
659        return key in self.dump()
660
661    @check_auth('obj:list')
662    def keys(self):
663        for record in self.dump():
664            yield record
665
666    @check_auth('obj:list')
667    def values(self):
668        for key in self.keys():
669            yield self[key]
670
671    @check_auth('obj:list')
672    def items(self):
673        for key in self.keys():
674            yield (key, self[key])
675
676    @cli.show_result
677    def count(self):
678        return (self
679                .ndb
680                .schema
681                .fetchone('SELECT count(*) FROM %s' % self.table))[0]
682
683    def __len__(self):
684        return self.count()
685
686    def _keys(self, iclass):
687        return (['target', 'tflags'] +
688                self.ndb.schema.compiled[iclass.view or iclass.table]['names'])
689
690    def _native(self, dump):
691        fnames = next(dump)
692        for record in dump:
693            yield Record(fnames, record, self.classes[self.table])
694
695    @cli.show_result
696    @check_auth('obj:list')
697    def dump(self):
698        iclass = self.classes[self.table]
699        return RecordSet(self._native(iclass.dump(self)))
700
701    @cli.show_result
702    @check_auth('obj:list')
703    def summary(self):
704        iclass = self.classes[self.table]
705        return RecordSet(self._native(iclass.summary(self)))
706
707
708class SourcesView(View):
709
710    def __init__(self, ndb, auth_managers=None):
711        super(SourcesView, self).__init__(ndb, 'sources')
712        self.classes['sources'] = Source
713        self.cache = {}
714        self.proxy = {}
715        self.lock = threading.Lock()
716        if auth_managers is None:
717            auth_managers = []
718        self.auth_managers = auth_managers
719
720    def async_add(self, **spec):
721        spec = dict(Source.defaults(spec))
722        self.cache[spec['target']] = Source(self.ndb, **spec).start()
723        return self.cache[spec['target']]
724
725    def add(self, **spec):
726        spec = dict(Source.defaults(spec))
727        if 'event' not in spec:
728            sync = True
729            spec['event'] = threading.Event()
730        else:
731            sync = False
732        self.cache[spec['target']] = Source(self.ndb, **spec).start()
733        if sync:
734            self.cache[spec['target']].event.wait()
735        return self.cache[spec['target']]
736
737    def remove(self, target, code=errno.ECONNRESET, sync=True):
738        with self.lock:
739            if target in self.cache:
740                source = self.cache[target]
741                source.close(code=code, sync=sync)
742                return self.cache.pop(target)
743
744    @check_auth('obj:list')
745    def keys(self):
746        for key in self.cache:
747            yield key
748
749    def _keys(self, iclass):
750        return ['target', 'kind']
751
752    def wait(self, **spec):
753        raise NotImplementedError()
754
755    def _summary(self, *argv, **kwarg):
756        return self._dump(*argv, **kwarg)
757
758    def __getitem__(self, key, table=None):
759        if isinstance(key, basestring):
760            target = key
761        elif isinstance(key, dict) and 'target' in key.keys():
762            target = key['target']
763        else:
764            raise ValueError('key format not supported')
765
766        if target in self.cache:
767            return self.cache[target]
768        elif target in self.proxy:
769            return self.proxy[target]
770        else:
771            proxy = SourceProxy(self.ndb, target)
772            self.proxy[target] = proxy
773            return proxy
774
775
776class Log(object):
777
778    def __init__(self, log_id=None):
779        self.logger = None
780        self.state = False
781        self.log_id = log_id or id(self)
782        self.logger = logging.getLogger('pyroute2.ndb.%s' % self.log_id)
783        self.main = self.channel('main')
784
785    def __call__(self, target=None, level=logging.INFO):
786        if target is None:
787            return self.logger is not None
788
789        if self.logger is not None:
790            for handler in tuple(self.logger.handlers):
791                self.logger.removeHandler(handler)
792
793        if target in ('off', False):
794            if self.state:
795                self.logger.setLevel(0)
796                self.logger.addHandler(logging.NullHandler())
797            return
798
799        if target in ('on', 'stderr'):
800            handler = logging.StreamHandler()
801        elif target == 'debug':
802            handler = logging.StreamHandler()
803            level = logging.DEBUG
804        elif isinstance(target, basestring):
805            url = urlparse(target)
806            if not url.scheme and url.path:
807                handler = logging.FileHandler(url.path)
808            elif url.scheme == 'syslog':
809                handler = (logging
810                           .handlers
811                           .SysLogHandler(address=url.netloc.split(':')))
812            else:
813                raise ValueError('logging scheme not supported')
814        else:
815            handler = target
816
817        # set formatting only for new created logging handlers
818        if handler is not target:
819            fmt = '%(asctime)s %(levelname)8s %(name)s: %(message)s'
820            formatter = logging.Formatter(fmt)
821            handler.setFormatter(formatter)
822
823        self.logger.addHandler(handler)
824        self.logger.setLevel(level)
825
826    @property
827    def on(self):
828        self.__call__(target='on')
829
830    @property
831    def off(self):
832        self.__call__(target='off')
833
834    def close(self):
835        manager = self.logger.manager
836        name = self.logger.name
837        # the loggerDict can be huge, so don't
838        # cache all the keys -- cache only the
839        # needed ones
840        purge_list = []
841        for logger in manager.loggerDict.keys():
842            if logger.startswith(name):
843                purge_list.append(logger)
844        # now shoot them one by one
845        for logger in purge_list:
846            del manager.loggerDict[logger]
847        # don't force GC, leave it to the user
848        del manager
849        del name
850        del purge_list
851
852    def channel(self, name):
853        return logging.getLogger('pyroute2.ndb.%s.%s' % (self.log_id, name))
854
855    def debug(self, *argv, **kwarg):
856        return self.main.debug(*argv, **kwarg)
857
858    def info(self, *argv, **kwarg):
859        return self.main.info(*argv, **kwarg)
860
861    def warning(self, *argv, **kwarg):
862        return self.main.warning(*argv, **kwarg)
863
864    def error(self, *argv, **kwarg):
865        return self.main.error(*argv, **kwarg)
866
867    def critical(self, *argv, **kwarg):
868        return self.main.critical(*argv, **kwarg)
869
870
871class ReadOnly(object):
872
873    def __init__(self, ndb):
874        self.ndb = ndb
875
876    def __enter__(self):
877        self.ndb.schema.allow_write(False)
878        return self
879
880    def __exit__(self, exc_type, exc_value, traceback):
881        self.ndb.schema.allow_write(True)
882
883
884class DeadEnd(object):
885
886    def put(self, *argv, **kwarg):
887        raise ShutdownException('shutdown in progress')
888
889
890class EventQueue(object):
891
892    def __init__(self, *argv, **kwarg):
893        self._bypass = self._queue = queue.Queue(*argv, **kwarg)
894
895    def put(self, *argv, **kwarg):
896        return self._queue.put(*argv, **kwarg)
897
898    def shutdown(self):
899        self._queue = DeadEnd()
900
901    def bypass(self, *argv, **kwarg):
902        return self._bypass.put(*argv, **kwarg)
903
904    def get(self, *argv, **kwarg):
905        return self._bypass.get(*argv, **kwarg)
906
907    def qsize(self):
908        return self._bypass.qsize()
909
910
911def Events(*argv):
912    for sequence in argv:
913        if sequence is not None:
914            for item in sequence:
915                yield item
916
917
918class AuthProxy(object):
919
920    def __init__(self, ndb, auth_managers):
921        self._ndb = ndb
922        self._auth_managers = auth_managers
923
924        for spec in ('interfaces',
925                     'addresses',
926                     'routes',
927                     'neighbours',
928                     'rules',
929                     'netns',
930                     'vlans'):
931            view = View(self._ndb, spec,
932                        auth_managers=self._auth_managers)
933            setattr(self, spec, view)
934
935
936class NDB(object):
937
938    @property
939    def nsmanager(self):
940        return '%s/nsmanager' % self.localhost
941
942    def __init__(self,
943                 sources=None,
944                 localhost='localhost',
945                 db_provider='sqlite3',
946                 db_spec=':memory:',
947                 db_cleanup=True,
948                 rtnl_debug=False,
949                 log=False,
950                 auto_netns=False,
951                 libc=None,
952                 messenger=None):
953
954        if db_provider == 'postgres':
955            db_provider = 'psycopg2'
956
957        if db_provider == 'sqlite3':
958            register_sqlite3_adapters()
959        elif db_provider == 'psycopg2':
960            regsiter_postgres_adapters()
961
962        self.localhost = localhost
963        self.ctime = self.gctime = time.time()
964        self.schema = None
965        self.config = {}
966        self.libc = libc or ctypes.CDLL(ctypes.util.find_library('c'),
967                                        use_errno=True)
968        self.log = Log(log_id=id(self))
969        self.readonly = ReadOnly(self)
970        self._auto_netns = auto_netns
971        self._db = None
972        self._dbm_thread = None
973        self._dbm_ready = threading.Event()
974        self._dbm_shutdown = threading.Event()
975        self._db_cleanup = db_cleanup
976        self._global_lock = threading.Lock()
977        self._event_map = None
978        self._event_queue = EventQueue(maxsize=100)
979        self.messenger = messenger
980        if messenger is not None:
981            self._mm_thread = threading.Thread(target=self.__mm__,
982                                               name='Messenger')
983            self._mm_thread.setDaemon(True)
984            self._mm_thread.start()
985        #
986        if log:
987            if isinstance(log, basestring):
988                self.log(log)
989            elif isinstance(log, (tuple, list)):
990                self.log(*log)
991            elif isinstance(log, dict):
992                self.log(**log)
993            else:
994                raise TypeError('wrong log spec format')
995        #
996        # fix sources prime
997        if sources is None:
998            sources = [{'target': self.localhost,
999                        'kind': 'local',
1000                        'nlm_generator': 1}]
1001            if sys.platform.startswith('linux'):
1002                sources.append({'target': self.nsmanager,
1003                                'kind': 'nsmanager'})
1004        elif not isinstance(sources, (list, tuple)):
1005            raise ValueError('sources format not supported')
1006
1007        for spec in sources:
1008            if 'target' not in spec:
1009                spec['target'] = self.localhost
1010                break
1011
1012        am = AuthManager({'obj:list': True,
1013                          'obj:read': True,
1014                          'obj:modify': True},
1015                         self.log.channel('auth'))
1016        self.sources = SourcesView(self, auth_managers=[am])
1017        self._call_registry = {}
1018        self._nl = sources
1019        self._db_provider = db_provider
1020        self._db_spec = db_spec
1021        self._db_rtnl_log = rtnl_debug
1022        atexit.register(self.close)
1023        self._dbm_ready.clear()
1024        self._dbm_error = None
1025        self._dbm_autoload = set()
1026        self._dbm_thread = threading.Thread(target=self.__dbm__,
1027                                            name='NDB main loop')
1028        self._dbm_thread.setDaemon(True)
1029        self._dbm_thread.start()
1030        self._dbm_ready.wait()
1031        if self._dbm_error is not None:
1032            raise self._dbm_error
1033        for event in tuple(self._dbm_autoload):
1034            event.wait()
1035        self._dbm_autoload = None
1036        for spec in ('interfaces',
1037                     'addresses',
1038                     'routes',
1039                     'neighbours',
1040                     'rules',
1041                     'netns',
1042                     'vlans'):
1043            view = View(self,
1044                        spec,
1045                        auth_managers=[am])
1046            setattr(self, spec, view)
1047        # self.query = Query(self.schema)
1048
1049    def _get_view(self, name, chain=None):
1050        return View(self, name, chain)
1051
1052    def __enter__(self):
1053        return self
1054
1055    def __exit__(self, exc_type, exc_value, traceback):
1056        self.close()
1057
1058    def begin(self):
1059        return Transaction(self.log)
1060
1061    def auth_proxy(self, auth_manager):
1062        return AuthProxy(self, [auth_manager, ])
1063
1064    def register_handler(self, event, handler):
1065        if event not in self._event_map:
1066            self._event_map[event] = []
1067        self._event_map[event].append(handler)
1068
1069    def unregister_handler(self, event, handler):
1070        self._event_map[event].remove(handler)
1071
1072    def execute(self, *argv, **kwarg):
1073        return self.schema.execute(*argv, **kwarg)
1074
1075    def close(self):
1076        with self._global_lock:
1077            if self._dbm_shutdown.is_set():
1078                return
1079            else:
1080                self._dbm_shutdown.set()
1081            if hasattr(atexit, 'unregister'):
1082                atexit.unregister(self.close)
1083            else:
1084                try:
1085                    atexit._exithandlers.remove((self.close, (), {}))
1086                except ValueError:
1087                    pass
1088            # shutdown the _dbm_thread
1089            self._event_queue.shutdown()
1090            self._event_queue.bypass((cmsg(None, ShutdownException()), ))
1091            self._dbm_thread.join()
1092            # shutdown the logger -- free the resources
1093            self.log.close()
1094
1095    def reload(self, kinds=None):
1096        for source in self.sources.values():
1097            if kinds is not None and source.kind in kinds:
1098                source.restart()
1099
1100    def __mm__(self):
1101        # notify neighbours by sending hello
1102        for peer in self.messenger.transport.peers:
1103            peer.hello()
1104        # receive events
1105        for msg in self.messenger:
1106            if msg['type'] == 'system' and msg['data'] == 'HELLO':
1107                for peer in self.messenger.transport.peers:
1108                    peer.last_exception_time = 0
1109                self.reload(kinds=['local', 'netns', 'remote'])
1110            elif msg['type'] == 'transport':
1111                message = msg['data'][0](data=msg['data'][1])
1112                message.decode()
1113                message['header']['target'] = msg['target']
1114                self._event_queue.put((message, ))
1115            elif msg['type'] == 'response':
1116                if msg['call_id'] in self._call_registry:
1117                    event = self._call_registry.pop(msg['call_id'])
1118                    self._call_registry[msg['call_id']] = msg
1119                    event.set()
1120            elif msg['type'] == 'api':
1121                if msg['target'] in self.messenger.targets:
1122                    try:
1123                        ret = self.sources[msg['target']].api(msg['name'],
1124                                                              *msg['argv'],
1125                                                              **msg['kwarg'])
1126                        self.messenger.emit({'type': 'response',
1127                                             'call_id': msg['call_id'],
1128                                             'return': ret})
1129                    except Exception as e:
1130                        self.messenger.emit({'type': 'response',
1131                                             'call_id': msg['call_id'],
1132                                             'exception': e})
1133            else:
1134                self.log.warning('unknown protocol via messenger')
1135
1136    def __dbm__(self):
1137
1138        def default_handler(target, event):
1139            if isinstance(getattr(event, 'payload', None), Exception):
1140                raise event.payload
1141            log.warning('unsupported event ignored: %s' % type(event))
1142
1143        def check_sources_started(self, _locals, target, event):
1144            _locals['countdown'] -= 1
1145            if _locals['countdown'] == 0:
1146                self._dbm_ready.set()
1147
1148        _locals = {'countdown': len(self._nl)}
1149
1150        # init the events map
1151        event_map = {cmsg_event: [lambda t, x: x.payload.set()],
1152                     cmsg_failed: [lambda t, x: (self
1153                                                 .schema
1154                                                 .mark(t, 1))],
1155                     cmsg_sstart: [partial(check_sources_started,
1156                                           self, _locals)]}
1157        self._event_map = event_map
1158
1159        event_queue = self._event_queue
1160
1161        try:
1162            with _sql_adapters_lock:
1163                if self._db_provider == 'sqlite3':
1164                    self._db = sqlite3.connect(self._db_spec)
1165                elif self._db_provider == 'psycopg2':
1166                    self._db = psycopg2.connect(**self._db_spec)
1167
1168            self.schema = schema.init(self,
1169                                      self._db,
1170                                      self._db_provider,
1171                                      self._db_rtnl_log,
1172                                      id(threading.current_thread()))
1173        except Exception as e:
1174            self._dbm_error = e
1175            self._dbm_ready.set()
1176            return
1177
1178        for spec in self._nl:
1179            spec['event'] = None
1180            self.sources.add(**spec)
1181
1182        for (event, handlers) in self.schema.event_map.items():
1183            for handler in handlers:
1184                self.register_handler(event, handler)
1185
1186        stop = False
1187        reschedule = []
1188        while not stop:
1189            events = Events(event_queue.get(), reschedule)
1190            reschedule = []
1191            try:
1192                for event in events:
1193                    handlers = event_map.get(event.__class__,
1194                                             [default_handler, ])
1195                    if self.messenger is not None and\
1196                            (event
1197                             .get('header', {})
1198                             .get('target', None) in self.messenger.targets):
1199                        if isinstance(event, nlmsg_base):
1200                            if event.data is not None:
1201                                data = event.data[event.offset:
1202                                                  event.offset + event.length]
1203                            else:
1204                                event.reset()
1205                                event.encode()
1206                                data = event.data
1207                            data = (type(event), data)
1208                            tgt = event['header']['target']
1209                            self.messenger.emit({'type': 'transport',
1210                                                 'target': tgt,
1211                                                 'data': data})
1212
1213                    for handler in tuple(handlers):
1214                        try:
1215                            target = event['header']['target']
1216                            handler(target, event)
1217                        except RescheduleException:
1218                            if 'rcounter' not in event['header']:
1219                                event['header']['rcounter'] = 0
1220                            if event['header']['rcounter'] < 3:
1221                                event['header']['rcounter'] += 1
1222                                self.log.debug('reschedule %s' % (event, ))
1223                                reschedule.append(event)
1224                            else:
1225                                self.log.error('drop %s' % (event, ))
1226                        except InvalidateHandlerException:
1227                            try:
1228                                handlers.remove(handler)
1229                            except Exception:
1230                                self.log.error('could not invalidate '
1231                                               'event handler:\n%s'
1232                                               % traceback.format_exc())
1233                        except ShutdownException:
1234                            stop = True
1235                            break
1236                        except DBMExitException:
1237                            return
1238                        except Exception:
1239                            self.log.error('could not load event:\n%s\n%s'
1240                                           % (event, traceback.format_exc()))
1241                    if time.time() - self.gctime > config.gc_timeout:
1242                        self.gctime = time.time()
1243            except Exception as e:
1244                self.log.error('exception <%s> in source %s' % (e, target))
1245                # restart the target
1246                try:
1247                    self.sources[target].restart(reason=e)
1248                except KeyError:
1249                    pass
1250
1251        # release all the sources
1252        for target in tuple(self.sources.cache):
1253            source = self.sources.remove(target, sync=False)
1254            if source is not None and source.th is not None:
1255                source.shutdown.set()
1256                source.th.join()
1257                if self._db_cleanup:
1258                    self.log.debug('flush DB for the target %s' % target)
1259                    self.schema.flush(target)
1260                else:
1261                    self.log.debug('leave DB for debug')
1262
1263        # close the database
1264        self.schema.commit()
1265        self.schema.close()
1266
1267        # close the logging
1268        for handler in self.log.logger.handlers:
1269            handler.close()
1270