1# -*- coding: utf-8 -*-
2'''
3Generic netlink
4===============
5
6Describe
7'''
8import errno
9import logging
10from pr2modules.netlink import CTRL_CMD_GETFAMILY
11from pr2modules.netlink import GENL_ID_CTRL
12from pr2modules.netlink import NLM_F_REQUEST
13from pr2modules.netlink import SOL_NETLINK
14from pr2modules.netlink import NETLINK_ADD_MEMBERSHIP
15from pr2modules.netlink import NETLINK_DROP_MEMBERSHIP
16from pr2modules.netlink import ctrlmsg
17from pr2modules.netlink.nlsocket import NetlinkSocket
18
19
20class GenericNetlinkSocket(NetlinkSocket):
21    '''
22    Low-level socket interface. Provides all the
23    usual socket does, can be used in poll/select,
24    doesn't create any implicit threads.
25    '''
26
27    mcast_groups = {}
28    module_err_message = None
29    module_err_level = 'error'
30
31    def bind(self, proto, msg_class, groups=0, pid=None, **kwarg):
32        '''
33        Bind the socket and performs generic netlink
34        proto lookup. The `proto` parameter is a string,
35        like "TASKSTATS", `msg_class` is a class to
36        parse messages with.
37        '''
38        NetlinkSocket.bind(self, groups, pid, **kwarg)
39        self.marshal.msg_map[GENL_ID_CTRL] = ctrlmsg
40        msg = self.discovery(proto)
41        self.prid = msg.get_attr('CTRL_ATTR_FAMILY_ID')
42        self.mcast_groups = \
43            dict([(x.get_attr('CTRL_ATTR_MCAST_GRP_NAME'),
44                   x.get_attr('CTRL_ATTR_MCAST_GRP_ID')) for x
45                  in msg.get_attr('CTRL_ATTR_MCAST_GROUPS', [])])
46        self.marshal.msg_map[self.prid] = msg_class
47
48    def add_membership(self, group):
49        self.setsockopt(SOL_NETLINK,
50                        NETLINK_ADD_MEMBERSHIP,
51                        self.mcast_groups[group])
52
53    def drop_membership(self, group):
54        self.setsockopt(SOL_NETLINK,
55                        NETLINK_DROP_MEMBERSHIP,
56                        self.mcast_groups[group])
57
58    def discovery(self, proto):
59        '''
60        Resolve generic netlink protocol -- takes a string
61        as the only parameter, return protocol description
62        '''
63        msg = ctrlmsg()
64        msg['cmd'] = CTRL_CMD_GETFAMILY
65        msg['version'] = 1
66        msg['attrs'].append(['CTRL_ATTR_FAMILY_NAME', proto])
67        msg['header']['type'] = GENL_ID_CTRL
68        msg['header']['flags'] = NLM_F_REQUEST
69        msg['header']['pid'] = self.pid
70        msg.encode()
71        self.sendto(msg.data, (0, 0))
72        msg = self.get()[0]
73        err = msg['header'].get('error', None)
74        if err is not None:
75            if hasattr(err, 'code') and err.code == errno.ENOENT:
76                err.extra_code = errno.ENOTSUP
77                logger = getattr(logging, self.module_err_level)
78                logger('Generic netlink protocol %s not found' % proto)
79                logger('Please check if the protocol module is loaded')
80                if self.module_err_message is not None:
81                    logger(self.module_err_message)
82            raise err
83        return msg
84