1# client_async.py
2#
3# Copyright 2018-2020 Geaaru <geaaru<@>gmail.com>
4
5__docformat__ = "epytext en"
6
7from datetime import datetime
8import asyncio
9import six
10import logging
11import random
12
13from pyrad.packet import Packet, AuthPacket, AcctPacket, CoAPacket
14
15
16class DatagramProtocolClient(asyncio.Protocol):
17
18    def __init__(self, server, port, logger,
19                 client, retries=3, timeout=30):
20        self.transport = None
21        self.port = port
22        self.server = server
23        self.logger = logger
24        self.retries = retries
25        self.timeout = timeout
26        self.client = client
27
28        # Map of pending requests
29        self.pending_requests = {}
30
31        # Use cryptographic-safe random generator as provided by the OS.
32        random_generator = random.SystemRandom()
33        self.packet_id = random_generator.randrange(0, 256)
34
35        self.timeout_future = None
36
37    async def __timeout_handler__(self):
38
39        try:
40
41            while True:
42
43                req2delete = []
44                now = datetime.now()
45                next_weak_up = self.timeout
46                # noinspection PyShadowingBuiltins
47                for id, req in self.pending_requests.items():
48
49                    secs = (req['send_date'] - now).seconds
50                    if secs > self.timeout:
51                        if req['retries'] == self.retries:
52                            self.logger.debug('[%s:%d] For request %d execute all retries', self.server, self.port, id)
53                            req['future'].set_exception(
54                                TimeoutError('Timeout on Reply')
55                            )
56                            req2delete.append(id)
57                        else:
58                            # Send again packet
59                            req['send_date'] = now
60                            req['retries'] += 1
61                            self.logger.debug('[%s:%d] For request %d execute retry %d', self.server, self.port, id, req['retries'])
62                            self.transport.sendto(req['packet'].RequestPacket())
63                    elif next_weak_up > secs:
64                        next_weak_up = secs
65
66                # noinspection PyShadowingBuiltins
67                for id in req2delete:
68                    # Remove request for map
69                    del self.pending_requests[id]
70
71                await asyncio.sleep(next_weak_up)
72
73        except asyncio.CancelledError:
74            pass
75
76    def send_packet(self, packet, future):
77        if packet.id in self.pending_requests:
78            raise Exception('Packet with id %d already present' % packet.id)
79
80        # Store packet on pending requests map
81        self.pending_requests[packet.id] = {
82            'packet': packet,
83            'creation_date': datetime.now(),
84            'retries': 0,
85            'future': future,
86            'send_date': datetime.now()
87        }
88
89        # In queue packet raw on socket buffer
90        self.transport.sendto(packet.RequestPacket())
91
92    def connection_made(self, transport):
93        self.transport = transport
94        socket = transport.get_extra_info('socket')
95        self.logger.info(
96            '[%s:%d] Transport created with binding in %s:%d',
97                self.server, self.port,
98                socket.getsockname()[0],
99                socket.getsockname()[1]
100        )
101
102        pre_loop = asyncio.get_event_loop()
103        asyncio.set_event_loop(loop=self.client.loop)
104        # Start asynchronous timer handler
105        self.timeout_future = asyncio.ensure_future(
106            self.__timeout_handler__()
107        )
108        asyncio.set_event_loop(loop=pre_loop)
109
110    def error_received(self, exc):
111        self.logger.error('[%s:%d] Error received: %s', self.server, self.port, exc)
112
113    def connection_lost(self, exc):
114        if exc:
115            self.logger.warn('[%s:%d] Connection lost: %s', self.server, self.port, str(exc))
116        else:
117            self.logger.info('[%s:%d] Transport closed', self.server, self.port)
118
119    # noinspection PyUnusedLocal
120    def datagram_received(self, data, addr):
121        try:
122            reply = Packet(packet=data, dict=self.client.dict)
123
124            if reply and reply.id in self.pending_requests:
125                req = self.pending_requests[reply.id]
126                packet = req['packet']
127
128                reply.dict = packet.dict
129                reply.secret = packet.secret
130
131                if packet.VerifyReply(reply, data):
132                    req['future'].set_result(reply)
133                    # Remove request for map
134                    del self.pending_requests[reply.id]
135                else:
136                    self.logger.warn('[%s:%d] Ignore invalid reply for id %d. %s', self.server, self.port, reply.id)
137            else:
138                self.logger.warn('[%s:%d] Ignore invalid reply: %s', self.server, self.port, data)
139
140        except Exception as exc:
141            self.logger.error('[%s:%d] Error on decode packet: %s', self.server, self.port, exc)
142
143    async def close_transport(self):
144        if self.transport:
145            self.logger.debug('[%s:%d] Closing transport...', self.server, self.port)
146            self.transport.close()
147            self.transport = None
148        if self.timeout_future:
149            self.timeout_future.cancel()
150            await self.timeout_future
151            self.timeout_future = None
152
153    def create_id(self):
154        self.packet_id = (self.packet_id + 1) % 256
155        return self.packet_id
156
157    def __str__(self):
158        return 'DatagramProtocolClient(server?=%s, port=%d)' % (self.server, self.port)
159
160    # Used as protocol_factory
161    def __call__(self):
162        return self
163
164
165class ClientAsync:
166    """Basic RADIUS client.
167    This class implements a basic RADIUS client. It can send requests
168    to a RADIUS server, taking care of timeouts and retries, and
169    validate its replies.
170
171    :ivar retries: number of times to retry sending a RADIUS request
172    :type retries: integer
173    :ivar timeout: number of seconds to wait for an answer
174    :type timeout: integer
175    """
176    # noinspection PyShadowingBuiltins
177    def __init__(self, server, auth_port=1812, acct_port=1813,
178                 coa_port=3799, secret=six.b(''), dict=None,
179                 loop=None, retries=3, timeout=30,
180                 logger_name='pyrad'):
181
182        """Constructor.
183
184        :param    server: hostname or IP address of RADIUS server
185        :type     server: string
186        :param auth_port: port to use for authentication packets
187        :type  auth_port: integer
188        :param acct_port: port to use for accounting packets
189        :type  acct_port: integer
190        :param  coa_port: port to use for CoA packets
191        :type   coa_port: integer
192        :param    secret: RADIUS secret
193        :type     secret: string
194        :param      dict: RADIUS dictionary
195        :type       dict: pyrad.dictionary.Dictionary
196        :param      loop: Python loop handler
197        :type       loop:  asyncio event loop
198        """
199        if not loop:
200            self.loop = asyncio.get_event_loop()
201        else:
202            self.loop = loop
203        self.logger = logging.getLogger(logger_name)
204
205        self.server = server
206        self.secret = secret
207        self.retries = retries
208        self.timeout = timeout
209        self.dict = dict
210
211        self.auth_port = auth_port
212        self.protocol_auth = None
213
214        self.acct_port = acct_port
215        self.protocol_acct = None
216
217        self.protocol_coa = None
218        self.coa_port = coa_port
219
220    async def initialize_transports(self, enable_acct=False,
221                                    enable_auth=False, enable_coa=False,
222                                    local_addr=None, local_auth_port=None,
223                                    local_acct_port=None, local_coa_port=None):
224
225        task_list = []
226
227        if not enable_acct and not enable_auth and not enable_coa:
228            raise Exception('No transports selected')
229
230        if enable_acct and not self.protocol_acct:
231            self.protocol_acct = DatagramProtocolClient(
232                self.server,
233                self.acct_port,
234                self.logger, self,
235                retries=self.retries,
236                timeout=self.timeout
237            )
238            bind_addr = None
239            if local_addr and local_acct_port:
240                bind_addr = (local_addr, local_acct_port)
241
242            acct_connect = self.loop.create_datagram_endpoint(
243                self.protocol_acct,
244                reuse_port=True,
245                remote_addr=(self.server, self.acct_port),
246                local_addr=bind_addr
247            )
248            task_list.append(acct_connect)
249
250        if enable_auth and not self.protocol_auth:
251            self.protocol_auth = DatagramProtocolClient(
252                self.server,
253                self.auth_port,
254                self.logger, self,
255                retries=self.retries,
256                timeout=self.timeout
257            )
258            bind_addr = None
259            if local_addr and local_auth_port:
260                bind_addr = (local_addr, local_auth_port)
261
262            auth_connect = self.loop.create_datagram_endpoint(
263                self.protocol_auth,
264                reuse_port=True,
265                remote_addr=(self.server, self.auth_port),
266                local_addr=bind_addr
267            )
268            task_list.append(auth_connect)
269
270        if enable_coa and not self.protocol_coa:
271            self.protocol_coa = DatagramProtocolClient(
272                self.server,
273                self.coa_port,
274                self.logger, self,
275                retries=self.retries,
276                timeout=self.timeout
277            )
278            bind_addr = None
279            if local_addr and local_coa_port:
280                bind_addr = (local_addr, local_coa_port)
281
282            coa_connect = self.loop.create_datagram_endpoint(
283                self.protocol_coa,
284                reuse_port=True,
285                remote_addr=(self.server, self.coa_port),
286                local_addr=bind_addr
287            )
288            task_list.append(coa_connect)
289
290        await asyncio.ensure_future(
291            asyncio.gather(
292                *task_list,
293                return_exceptions=False,
294            ),
295            loop=self.loop
296        )
297
298    # noinspection SpellCheckingInspection
299    async def deinitialize_transports(self, deinit_coa=True,
300                                      deinit_auth=True,
301                                      deinit_acct=True):
302        if self.protocol_coa and deinit_coa:
303            await self.protocol_coa.close_transport()
304            del self.protocol_coa
305            self.protocol_coa = None
306        if self.protocol_auth and deinit_auth:
307            await self.protocol_auth.close_transport()
308            del self.protocol_auth
309            self.protocol_auth = None
310        if self.protocol_acct and deinit_acct:
311            await self.protocol_acct.close_transport()
312            del self.protocol_acct
313            self.protocol_acct = None
314
315    # noinspection PyPep8Naming
316    def CreateAuthPacket(self, **args):
317        """Create a new RADIUS packet.
318        This utility function creates a new RADIUS packet which can
319        be used to communicate with the RADIUS server this client
320        talks to. This is initializing the new packet with the
321        dictionary and secret used for the client.
322
323        :return: a new empty packet instance
324        :rtype:  pyrad.packet.Packet
325        """
326        if not self.protocol_auth:
327            raise Exception('Transport not initialized')
328
329        return AuthPacket(dict=self.dict,
330                          id=self.protocol_auth.create_id(),
331                          secret=self.secret, **args)
332
333    # noinspection PyPep8Naming
334    def CreateAcctPacket(self, **args):
335        """Create a new RADIUS packet.
336        This utility function creates a new RADIUS packet which can
337        be used to communicate with the RADIUS server this client
338        talks to. This is initializing the new packet with the
339        dictionary and secret used for the client.
340
341        :return: a new empty packet instance
342        :rtype:  pyrad.packet.Packet
343        """
344        if not self.protocol_acct:
345            raise Exception('Transport not initialized')
346
347        return AcctPacket(id=self.protocol_acct.create_id(),
348                          dict=self.dict,
349                          secret=self.secret, **args)
350
351    # noinspection PyPep8Naming
352    def CreateCoAPacket(self, **args):
353        """Create a new RADIUS packet.
354        This utility function creates a new RADIUS packet which can
355        be used to communicate with the RADIUS server this client
356        talks to. This is initializing the new packet with the
357        dictionary and secret used for the client.
358
359        :return: a new empty packet instance
360        :rtype:  pyrad.packet.Packet
361        """
362
363        if not self.protocol_acct:
364            raise Exception('Transport not initialized')
365
366        return CoAPacket(id=self.protocol_coa.create_id(),
367                         dict=self.dict,
368                         secret=self.secret, **args)
369
370    # noinspection PyPep8Naming
371    # noinspection PyShadowingBuiltins
372    def CreatePacket(self, id, **args):
373        if not id:
374            raise Exception('Missing mandatory packet id')
375
376        return Packet(id=id, dict=self.dict,
377                      secret=self.secret, **args)
378
379    # noinspection PyPep8Naming
380    def SendPacket(self, pkt):
381        """Send a packet to a RADIUS server.
382
383        :param pkt: the packet to send
384        :type  pkt: pyrad.packet.Packet
385        :return:    Future related with packet to send
386        :rtype:     asyncio.Future
387        """
388
389        ans = asyncio.Future(loop=self.loop)
390
391        if isinstance(pkt, AuthPacket):
392            if not self.protocol_auth:
393                raise Exception('Transport not initialized')
394
395            self.protocol_auth.send_packet(pkt, ans)
396
397        elif isinstance(pkt, AcctPacket):
398            if not self.protocol_acct:
399                raise Exception('Transport not initialized')
400
401            self.protocol_acct.send_packet(pkt, ans)
402
403        elif isinstance(pkt, CoAPacket):
404            if not self.protocol_coa:
405                raise Exception('Transport not initialized')
406
407            self.protocol_coa.send_packet(pkt, ans)
408
409        else:
410            raise Exception('Unsupported packet')
411
412        return ans
413