1#
2# libvirtaio -- asyncio adapter for libvirt
3# Copyright (C) 2017  Wojtek Porczyk <woju@invisiblethingslab.com>
4#
5# This library is free software; you can redistribute it and/or
6# modify it under the terms of the GNU Lesser General Public
7# License as published by the Free Software Foundation; either
8# version 2.1 of the License, or (at your option) any later version.
9#
10# This library is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13# Lesser General Public License for more details.
14#
15# You should have received a copy of the GNU Lesser General Public
16# License along with this library; if not, see
17# <http://www.gnu.org/licenses/>.
18#
19
20'''Libvirt event loop implementation using asyncio
21
22Register the implementation of default loop:
23
24    >>> import libvirtaio
25    >>> libvirtaio.virEventRegisterAsyncIOImpl()
26
27.. seealso::
28    https://libvirt.org/html/libvirt-libvirt-event.html
29'''
30
31import asyncio
32import itertools
33import logging
34import warnings
35
36import libvirt
37
38from typing import Any, Callable, Dict, Generator, Optional, TypeVar  # noqa F401
39_T = TypeVar('_T')
40
41__author__ = 'Wojtek Porczyk <woju@invisiblethingslab.com>'
42__license__ = 'LGPL-2.1+'
43__all__ = [
44    'getCurrentImpl',
45    'virEventAsyncIOImpl',
46    'virEventRegisterAsyncIOImpl',
47]
48
49# Python < 3.4.4 doesn't have 'ensure_future', so we have to fall
50# back to 'async'; however, since 'async' is a reserved keyword
51# in Python >= 3.7, we can't perform a straightforward import and
52# we have to resort to getattr() instead
53ensure_future = getattr(asyncio, "ensure_future", None)
54if not ensure_future:
55    ensure_future = getattr(asyncio, "async")
56
57
58class Callback(object):
59    '''Base class for holding callback
60
61    :param virEventAsyncIOImpl impl: the implementation in which we run
62    :param cb: the callback itself
63    :param opaque: the opaque tuple passed by libvirt
64    '''
65    # pylint: disable=too-few-public-methods
66
67    _iden_counter = itertools.count()
68
69    def __init__(self, impl: "virEventAsyncIOImpl", cb: Callable[[int, _T], None], opaque: _T, *args: Any, **kwargs: Any) -> None:
70        super().__init__(*args, **kwargs)  # type: ignore
71        self.iden = next(self._iden_counter)
72        self.impl = impl
73        self.cb = cb
74        self.opaque = opaque
75
76    def __repr__(self) -> str:
77        return '<{} iden={}>'.format(self.__class__.__name__, self.iden)
78
79    def close(self) -> None:
80        '''Schedule *ff* callback'''
81        self.impl.log.debug('callback %d close(), scheduling ff', self.iden)
82        self.impl.schedule_ff_callback(self.iden, self.opaque)
83
84
85#
86# file descriptors
87#
88
89class Descriptor(object):
90    '''Manager of one file descriptor
91
92    :param virEventAsyncIOImpl impl: the implementation in which we run
93    :param int fd: the file descriptor
94    '''
95    def __init__(self, impl: "virEventAsyncIOImpl", fd: int) -> None:
96        self.impl = impl
97        self.fd = fd
98        self.callbacks = {}  # type: Dict
99
100    def _handle(self, event: int) -> None:
101        '''Dispatch the event to the descriptors
102
103        :param int event: The event (from libvirt's constants) being dispatched
104        '''
105        for callback in list(self.callbacks.values()):
106            if callback.event is not None and callback.event & event:
107                callback.cb(callback.iden, self.fd, event, callback.opaque)
108
109    def update(self) -> None:
110        '''Register or unregister callbacks at event loop
111
112        This should be called after change of any ``.event`` in callbacks.
113        '''
114        # It seems like loop.add_{reader,writer} can be run multiple times
115        # and will still register the callback only once. Likewise,
116        # remove_{reader,writer} may be run even if the reader/writer
117        # is not registered (and will just return False).
118
119        # For the edge case of empty callbacks, any() returns False.
120        if any(callback.event & ~(
121            libvirt.VIR_EVENT_HANDLE_READABLE |
122            libvirt.VIR_EVENT_HANDLE_WRITABLE)
123                for callback in self.callbacks.values()):
124            warnings.warn(
125                'The only event supported are VIR_EVENT_HANDLE_READABLE '
126                'and VIR_EVENT_HANDLE_WRITABLE',
127                UserWarning)
128
129        if any(callback.event & libvirt.VIR_EVENT_HANDLE_READABLE
130                for callback in self.callbacks.values()):
131            self.impl.loop.add_reader(
132                self.fd, self._handle, libvirt.VIR_EVENT_HANDLE_READABLE)
133        else:
134            self.impl.loop.remove_reader(self.fd)
135
136        if any(callback.event & libvirt.VIR_EVENT_HANDLE_WRITABLE
137                for callback in self.callbacks.values()):
138            self.impl.loop.add_writer(
139                self.fd, self._handle, libvirt.VIR_EVENT_HANDLE_WRITABLE)
140        else:
141            self.impl.loop.remove_writer(self.fd)
142
143    def add_handle(self, callback: "FDCallback") -> None:
144        '''Add a callback to the descriptor
145
146        :param FDCallback callback: the callback to add
147        :rtype: None
148
149        After adding the callback, it is immediately watched.
150        '''
151        self.callbacks[callback.iden] = callback
152        self.update()
153
154    def remove_handle(self, iden: int) -> None:
155        '''Remove a callback from the descriptor
156
157        :param int iden: the identifier of the callback
158        :returns: the callback
159        :rtype: FDCallback
160
161        After removing the callback, the descriptor may be unwatched, if there
162        are no more handles for it.
163        '''
164        callback = self.callbacks.pop(iden)
165        self.update()
166        return callback
167
168
169class DescriptorDict(dict):
170    '''Descriptors collection
171
172    This is used internally by virEventAsyncIOImpl to hold descriptors.
173    '''
174    def __init__(self, impl: "virEventAsyncIOImpl") -> None:
175        super().__init__()
176        self.impl = impl
177
178    def __missing__(self, fd: int) -> Descriptor:
179        descriptor = Descriptor(self.impl, fd)
180        self[fd] = descriptor
181        return descriptor
182
183
184class FDCallback(Callback):
185    '''Callback for file descriptor (watcher)
186
187    :param Descriptor descriptor: the descriptor manager
188    :param int event: bitset of events on which to fire the callback
189    '''
190    # pylint: disable=too-few-public-methods
191
192    def __init__(self, *args: Any, descriptor: Descriptor, event: int, **kwargs: Any) -> None:
193        super().__init__(*args, **kwargs)
194        self.descriptor = descriptor
195        self.event = event
196
197    def __repr__(self) -> str:
198        return '<{} iden={} fd={} event={}>'.format(
199            self.__class__.__name__, self.iden, self.descriptor.fd, self.event)
200
201    def update(self, event: int) -> None:
202        '''Update the callback and fix descriptor's watchers'''
203        self.event = event
204        self.descriptor.update()
205
206
207#
208# timeouts
209#
210
211class TimeoutCallback(Callback):
212    '''Callback for timer'''
213    def __init__(self, *args: Any, **kwargs: Any) -> None:
214        super().__init__(*args, **kwargs)
215        self.timeout = -1
216        self._task = None
217
218    def __repr__(self) -> str:
219        return '<{} iden={} timeout={}>'.format(
220            self.__class__.__name__, self.iden, self.timeout)
221
222    @asyncio.coroutine
223    def _timer(self) -> Generator[Any, None, None]:
224        '''An actual timer running on the event loop.
225
226        This is a coroutine.
227        '''
228        while True:
229            try:
230                if self.timeout > 0:
231                    timeout = self.timeout * 1e-3
232                    self.impl.log.debug('sleeping %r', timeout)
233                    yield from asyncio.sleep(timeout)
234                else:
235                    # scheduling timeout for next loop iteration
236                    yield
237
238            except asyncio.CancelledError:
239                self.impl.log.debug('timer %d cancelled', self.iden)
240                break
241
242            self.cb(self.iden, self.opaque)
243            self.impl.log.debug('timer %r callback ended', self.iden)
244
245    def update(self, timeout: int) -> None:
246        '''Start or the timer, possibly updating timeout'''
247        self.timeout = timeout
248
249        if self.timeout >= 0 and self._task is None:
250            self.impl.log.debug('timer %r start', self.iden)
251            self._task = ensure_future(self._timer(),
252                                       loop=self.impl.loop)
253
254        elif self.timeout < 0 and self._task is not None:
255            self.impl.log.debug('timer %r stop', self.iden)
256            self._task.cancel()  # pylint: disable=no-member
257            self._task = None
258
259    def close(self) -> None:
260        '''Stop the timer and call ff callback'''
261        self.update(timeout=-1)
262        super(TimeoutCallback, self).close()
263
264
265#
266# main implementation
267#
268
269class virEventAsyncIOImpl(object):
270    '''Libvirt event adapter to asyncio.
271
272    :param loop: asyncio's event loop
273
274    If *loop* is not specified, the current (or default) event loop is used.
275    '''
276
277    def __init__(self, loop: asyncio.AbstractEventLoop = None) -> None:
278        self.loop = loop or asyncio.get_event_loop()
279        self.callbacks = {}  # type: Dict[int, Callback]
280        self.descriptors = DescriptorDict(self)
281        self.log = logging.getLogger(self.__class__.__name__)
282
283        # NOTE invariant: _finished.is_set() iff _pending == 0
284        self._pending = 0
285        self._finished = asyncio.Event(loop=loop)
286        self._finished.set()
287
288    def __repr__(self) -> str:
289        return '<{} callbacks={} descriptors={}>'.format(
290            type(self).__name__, self.callbacks, self.descriptors)
291
292    def _pending_inc(self) -> None:
293        '''Increase the count of pending affairs. Do not use directly.'''
294        self._pending += 1
295        self._finished.clear()
296
297    def _pending_dec(self) -> None:
298        '''Decrease the count of pending affairs. Do not use directly.'''
299        assert self._pending > 0
300        self._pending -= 1
301        if self._pending == 0:
302            self._finished.set()
303
304    def register(self) -> "virEventAsyncIOImpl":
305        '''Register this instance as event loop implementation'''
306        # pylint: disable=bad-whitespace
307        self.log.debug('register()')
308        libvirt.virEventRegisterImpl(
309            self._add_handle, self._update_handle, self._remove_handle,
310            self._add_timeout, self._update_timeout, self._remove_timeout)
311        return self
312
313    def schedule_ff_callback(self, iden: int, opaque: _T) -> None:
314        '''Schedule a ff callback from one of the handles or timers'''
315        ensure_future(self._ff_callback(iden, opaque), loop=self.loop)
316
317    @asyncio.coroutine
318    def _ff_callback(self, iden: int, opaque: _T) -> None:
319        '''Directly free the opaque object
320
321        This is a coroutine.
322        '''
323        self.log.debug('ff_callback(iden=%d, opaque=...)', iden)
324        libvirt.virEventInvokeFreeCallback(opaque)
325        self._pending_dec()
326
327    @asyncio.coroutine
328    def drain(self) -> Generator[Any, None, None]:
329        '''Wait for the implementation to become idle.
330
331        This is a coroutine.
332        '''
333        self.log.debug('drain()')
334        if self._pending:
335            yield from self._finished.wait()
336        self.log.debug('drain ended')
337
338    def is_idle(self) -> bool:
339        '''Returns False if there are leftovers from a connection
340
341        Those may happen if there are sematical problems while closing
342        a connection. For example, not deregistered events before .close().
343        '''
344        return not self.callbacks and not self._pending
345
346    def _add_handle(self, fd: int, event: int, cb: libvirt._EventCB, opaque: _T) -> int:
347        '''Register a callback for monitoring file handle events
348
349        :param int fd: file descriptor to listen on
350        :param int event: bitset of events on which to fire the callback
351        :param cb: the callback to be called when an event occurrs
352        :param opaque: user data to pass to the callback
353        :rtype: int
354        :returns: handle watch number to be used for updating and unregistering for events
355
356        .. seealso::
357            https://libvirt.org/html/libvirt-libvirt-event.html#virEventAddHandleFuncFunc
358        '''
359        callback = FDCallback(self, cb, opaque,
360                              descriptor=self.descriptors[fd], event=event)
361        assert callback.iden not in self.callbacks
362
363        self.log.debug('add_handle(fd=%d, event=%d, cb=..., opaque=...) = %d',
364                       fd, event, callback.iden)
365        self.callbacks[callback.iden] = callback
366        self.descriptors[fd].add_handle(callback)
367        self._pending_inc()
368        return callback.iden
369
370    def _update_handle(self, watch: int, event: int) -> None:
371        '''Change event set for a monitored file handle
372
373        :param int watch: file descriptor watch to modify
374        :param int event: new events to listen on
375
376        .. seealso::
377            https://libvirt.org/html/libvirt-libvirt-event.html#virEventUpdateHandleFunc
378        '''
379        self.log.debug('update_handle(watch=%d, event=%d)', watch, event)
380        callback = self.callbacks[watch]
381        assert isinstance(callback, FDCallback)
382        callback.update(event=event)
383
384    def _remove_handle(self, watch: int) -> int:
385        '''Unregister a callback from a file handle.
386
387        :param int watch: file descriptor watch to stop listening on
388        :returns: -1 on error, 0 on success
389
390        .. seealso::
391            https://libvirt.org/html/libvirt-libvirt-event.html#virEventRemoveHandleFunc
392        '''
393        self.log.debug('remove_handle(watch=%d)', watch)
394        try:
395            callback = self.callbacks.pop(watch)
396        except KeyError as err:
397            self.log.warning('remove_handle(): no such handle: %r', err.args[0])
398            return -1
399        assert isinstance(callback, FDCallback)
400        fd = callback.descriptor.fd
401        assert callback is self.descriptors[fd].remove_handle(watch)
402        if len(self.descriptors[fd].callbacks) == 0:
403            del self.descriptors[fd]
404        callback.close()
405        return 0
406
407    def _add_timeout(self, timeout: int, cb: libvirt._TimerCB, opaque: _T) -> int:
408        '''Register a callback for a timer event
409
410        :param int timeout: the timeout to monitor
411        :param cb: the callback to call when timeout has expired
412        :param opaque: user data to pass to the callback
413        :rtype: int
414        :returns: a timer value
415
416        .. seealso::
417            https://libvirt.org/html/libvirt-libvirt-event.html#virEventAddTimeoutFunc
418        '''
419        callback = TimeoutCallback(self, cb, opaque)
420        assert callback.iden not in self.callbacks
421
422        self.log.debug('add_timeout(timeout=%d, cb=..., opaque=...) = %d',
423                       timeout, callback.iden)
424        self.callbacks[callback.iden] = callback
425        callback.update(timeout=timeout)
426        self._pending_inc()
427        return callback.iden
428
429    def _update_timeout(self, timer: int, timeout: int) -> None:
430        '''Change frequency for a timer
431
432        :param int timer: the timer to modify
433        :param int timeout: the new timeout value in ms
434
435        .. seealso::
436            https://libvirt.org/html/libvirt-libvirt-event.html#virEventUpdateTimeoutFunc
437        '''
438        self.log.debug('update_timeout(timer=%d, timeout=%d)', timer, timeout)
439        callback = self.callbacks[timer]
440        assert isinstance(callback, TimeoutCallback)
441        callback.update(timeout=timeout)
442
443    def _remove_timeout(self, timer: int) -> int:
444        '''Unregister a callback for a timer
445
446        :param int timer: the timer to remove
447        :returns: -1 on error, 0 on success
448
449        .. seealso::
450            https://libvirt.org/html/libvirt-libvirt-event.html#virEventRemoveTimeoutFunc
451        '''
452        self.log.debug('remove_timeout(timer=%d)', timer)
453        try:
454            callback = self.callbacks.pop(timer)
455        except KeyError as err:
456            self.log.warning('remove_timeout(): no such timeout: %r', err.args[0])
457            return -1
458        callback.close()
459        return 0
460
461
462_current_impl = None  # type: Optional[virEventAsyncIOImpl]
463
464
465def getCurrentImpl() -> Optional[virEventAsyncIOImpl]:
466    '''Return the current implementation, or None if not yet registered'''
467    return _current_impl
468
469
470def virEventRegisterAsyncIOImpl(loop: asyncio.AbstractEventLoop = None) -> virEventAsyncIOImpl:
471    '''Arrange for libvirt's callbacks to be dispatched via asyncio event loop
472
473    The implementation object is returned, but in normal usage it can safely be
474    discarded.
475    '''
476    global _current_impl
477    _current_impl = virEventAsyncIOImpl(loop=loop).register()
478    return _current_impl
479