1# -*- coding: utf-8 -*-
2
3#    Copyright (C) 2015 Yahoo! Inc. All Rights Reserved.
4#
5#    Licensed under the Apache License, Version 2.0 (the "License"); you may
6#    not use this file except in compliance with the License. You may obtain
7#    a copy of the License at
8#
9#         http://www.apache.org/licenses/LICENSE-2.0
10#
11#    Unless required by applicable law or agreed to in writing, software
12#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14#    License for the specific language governing permissions and limitations
15#    under the License.
16
17try:
18    from contextlib import ExitStack
19except ImportError:
20    from contextlib2 import ExitStack
21
22import collections
23import contextlib
24
25from concurrent import futures
26from concurrent.futures import _base
27import six
28
29import futurist
30from futurist import _utils
31
32try:
33    from eventlet.green import threading as greenthreading
34except ImportError:
35    greenthreading = None
36
37
38#: Named tuple returned from ``wait_for*`` calls.
39DoneAndNotDoneFutures = collections.namedtuple(
40    'DoneAndNotDoneFutures', 'done not_done')
41
42_DONE_STATES = frozenset([
43    _base.CANCELLED_AND_NOTIFIED,
44    _base.FINISHED,
45])
46
47
48@contextlib.contextmanager
49def _acquire_and_release_futures(fs):
50    # Do this to ensure that we always get the futures in the same order (aka
51    # always acquire the conditions in the same order, no matter what; a way
52    # to avoid dead-lock).
53    fs = sorted(fs, key=id)
54    with ExitStack() as stack:
55        for fut in fs:
56            stack.enter_context(fut._condition)
57        yield
58
59
60def _ensure_eventlet(func):
61    """Decorator that verifies we have the needed eventlet components."""
62
63    @six.wraps(func)
64    def wrapper(*args, **kwargs):
65        if not _utils.EVENTLET_AVAILABLE or greenthreading is None:
66            raise RuntimeError('Eventlet is needed to wait on green futures')
67        return func(*args, **kwargs)
68
69    return wrapper
70
71
72def _wait_for(fs, no_green_return_when, on_all_green_cb,
73              caller_name, timeout=None):
74    green_fs = sum(1 for f in fs if isinstance(f, futurist.GreenFuture))
75    if not green_fs:
76        done, not_done = futures.wait(fs, timeout=timeout,
77                                      return_when=no_green_return_when)
78        return DoneAndNotDoneFutures(done, not_done)
79    else:
80        non_green_fs = len(fs) - green_fs
81        if non_green_fs:
82            raise RuntimeError("Can not wait on %s green futures and %s"
83                               " non-green futures in the same"
84                               " `%s` call" % (green_fs, non_green_fs,
85                                               caller_name))
86        else:
87            return on_all_green_cb(fs, timeout=timeout)
88
89
90def wait_for_all(fs, timeout=None):
91    """Wait for all of the futures to complete.
92
93    Works correctly with both green and non-green futures (but not both
94    together, since this can't be guaranteed to avoid dead-lock due to how
95    the waiting implementations are different when green threads are being
96    used).
97
98    Returns pair (done futures, not done futures).
99    """
100    return _wait_for(fs, futures.ALL_COMPLETED, _wait_for_all_green,
101                     'wait_for_all', timeout=timeout)
102
103
104def wait_for_any(fs, timeout=None):
105    """Wait for one (**any**) of the futures to complete.
106
107    Works correctly with both green and non-green futures (but not both
108    together, since this can't be guaranteed to avoid dead-lock due to how
109    the waiting implementations are different when green threads are being
110    used).
111
112    Returns pair (done futures, not done futures).
113    """
114    return _wait_for(fs, futures.FIRST_COMPLETED, _wait_for_any_green,
115                     'wait_for_any', timeout=timeout)
116
117
118class _AllGreenWaiter(object):
119    """Provides the event that ``_wait_for_all_green`` blocks on."""
120
121    def __init__(self, pending):
122        self.event = greenthreading.Event()
123        self.lock = greenthreading.Lock()
124        self.pending = pending
125
126    def _decrement_pending(self):
127        with self.lock:
128            self.pending -= 1
129            if self.pending <= 0:
130                self.event.set()
131
132    def add_result(self, future):
133        self._decrement_pending()
134
135    def add_exception(self, future):
136        self._decrement_pending()
137
138    def add_cancelled(self, future):
139        self._decrement_pending()
140
141
142class _AnyGreenWaiter(object):
143    """Provides the event that ``_wait_for_any_green`` blocks on."""
144
145    def __init__(self):
146        self.event = greenthreading.Event()
147
148    def add_result(self, future):
149        self.event.set()
150
151    def add_exception(self, future):
152        self.event.set()
153
154    def add_cancelled(self, future):
155        self.event.set()
156
157
158def _partition_futures(fs):
159    done = set()
160    not_done = set()
161    for f in fs:
162        if f._state in _DONE_STATES:
163            done.add(f)
164        else:
165            not_done.add(f)
166    return done, not_done
167
168
169def _create_and_install_waiters(fs, waiter_cls, *args, **kwargs):
170    waiter = waiter_cls(*args, **kwargs)
171    for f in fs:
172        f._waiters.append(waiter)
173    return waiter
174
175
176@_ensure_eventlet
177def _wait_for_all_green(fs, timeout=None):
178    if not fs:
179        return DoneAndNotDoneFutures(set(), set())
180
181    with _acquire_and_release_futures(fs):
182        done, not_done = _partition_futures(fs)
183        if len(done) == len(fs):
184            return DoneAndNotDoneFutures(done, not_done)
185        waiter = _create_and_install_waiters(not_done,
186                                             _AllGreenWaiter,
187                                             len(not_done))
188    waiter.event.wait(timeout)
189    for f in not_done:
190        with f._condition:
191            f._waiters.remove(waiter)
192
193    with _acquire_and_release_futures(fs):
194        done, not_done = _partition_futures(fs)
195        return DoneAndNotDoneFutures(done, not_done)
196
197
198@_ensure_eventlet
199def _wait_for_any_green(fs, timeout=None):
200    if not fs:
201        return DoneAndNotDoneFutures(set(), set())
202
203    with _acquire_and_release_futures(fs):
204        done, not_done = _partition_futures(fs)
205        if done:
206            return DoneAndNotDoneFutures(done, not_done)
207        waiter = _create_and_install_waiters(fs, _AnyGreenWaiter)
208
209    waiter.event.wait(timeout)
210    for f in fs:
211        with f._condition:
212            f._waiters.remove(waiter)
213
214    with _acquire_and_release_futures(fs):
215        done, not_done = _partition_futures(fs)
216        return DoneAndNotDoneFutures(done, not_done)
217