1from __future__ import absolute_import
2from __future__ import unicode_literals
3
4import logging
5import operator
6import sys
7from threading import Lock
8from threading import Semaphore
9from threading import Thread
10
11from docker.errors import APIError
12from docker.errors import ImageNotFound
13from six.moves import _thread as thread
14from six.moves.queue import Empty
15from six.moves.queue import Queue
16
17from compose.cli.colors import green
18from compose.cli.colors import red
19from compose.cli.signals import ShutdownException
20from compose.const import PARALLEL_LIMIT
21from compose.errors import HealthCheckFailed
22from compose.errors import NoHealthCheckConfigured
23from compose.errors import OperationFailedError
24from compose.utils import get_output_stream
25
26
27log = logging.getLogger(__name__)
28
29STOP = object()
30
31
32class GlobalLimit(object):
33    """Simple class to hold a global semaphore limiter for a project. This class
34    should be treated as a singleton that is instantiated when the project is.
35    """
36
37    global_limiter = Semaphore(PARALLEL_LIMIT)
38
39    @classmethod
40    def set_global_limit(cls, value):
41        if value is None:
42            value = PARALLEL_LIMIT
43        cls.global_limiter = Semaphore(value)
44
45
46def parallel_execute_watch(events, writer, errors, results, msg, get_name, fail_check):
47    """ Watch events from a parallel execution, update status and fill errors and results.
48        Returns exception to re-raise.
49    """
50    error_to_reraise = None
51    for obj, result, exception in events:
52        if exception is None:
53            if fail_check is not None and fail_check(obj):
54                writer.write(msg, get_name(obj), 'failed', red)
55            else:
56                writer.write(msg, get_name(obj), 'done', green)
57            results.append(result)
58        elif isinstance(exception, ImageNotFound):
59            # This is to bubble up ImageNotFound exceptions to the client so we
60            # can prompt the user if they want to rebuild.
61            errors[get_name(obj)] = exception.explanation
62            writer.write(msg, get_name(obj), 'error', red)
63            error_to_reraise = exception
64        elif isinstance(exception, APIError):
65            errors[get_name(obj)] = exception.explanation
66            writer.write(msg, get_name(obj), 'error', red)
67        elif isinstance(exception, (OperationFailedError, HealthCheckFailed, NoHealthCheckConfigured)):
68            errors[get_name(obj)] = exception.msg
69            writer.write(msg, get_name(obj), 'error', red)
70        elif isinstance(exception, UpstreamError):
71            writer.write(msg, get_name(obj), 'error', red)
72        else:
73            errors[get_name(obj)] = exception
74            error_to_reraise = exception
75    return error_to_reraise
76
77
78def parallel_execute(objects, func, get_name, msg, get_deps=None, limit=None, fail_check=None):
79    """Runs func on objects in parallel while ensuring that func is
80    ran on object only after it is ran on all its dependencies.
81
82    get_deps called on object must return a collection with its dependencies.
83    get_name called on object must return its name.
84    fail_check is an additional failure check for cases that should display as a failure
85        in the CLI logs, but don't raise an exception (such as attempting to start 0 containers)
86    """
87    objects = list(objects)
88    stream = get_output_stream(sys.stderr)
89
90    if ParallelStreamWriter.instance:
91        writer = ParallelStreamWriter.instance
92    else:
93        writer = ParallelStreamWriter(stream)
94
95    for obj in objects:
96        writer.add_object(msg, get_name(obj))
97    for obj in objects:
98        writer.write_initial(msg, get_name(obj))
99
100    events = parallel_execute_iter(objects, func, get_deps, limit)
101
102    errors = {}
103    results = []
104    error_to_reraise = parallel_execute_watch(
105        events, writer, errors, results, msg, get_name, fail_check
106    )
107
108    for obj_name, error in errors.items():
109        stream.write("\nERROR: for {}  {}\n".format(obj_name, error))
110
111    if error_to_reraise:
112        raise error_to_reraise
113
114    return results, errors
115
116
117def _no_deps(x):
118    return []
119
120
121class State(object):
122    """
123    Holds the state of a partially-complete parallel operation.
124
125    state.started:   objects being processed
126    state.finished:  objects which have been processed
127    state.failed:    objects which either failed or whose dependencies failed
128    """
129    def __init__(self, objects):
130        self.objects = objects
131
132        self.started = set()
133        self.finished = set()
134        self.failed = set()
135
136    def is_done(self):
137        return len(self.finished) + len(self.failed) >= len(self.objects)
138
139    def pending(self):
140        return set(self.objects) - self.started - self.finished - self.failed
141
142
143class NoLimit(object):
144    def __enter__(self):
145        pass
146
147    def __exit__(self, *ex):
148        pass
149
150
151def parallel_execute_iter(objects, func, get_deps, limit):
152    """
153    Runs func on objects in parallel while ensuring that func is
154    ran on object only after it is ran on all its dependencies.
155
156    Returns an iterator of tuples which look like:
157
158    # if func returned normally when run on object
159    (object, result, None)
160
161    # if func raised an exception when run on object
162    (object, None, exception)
163
164    # if func raised an exception when run on one of object's dependencies
165    (object, None, UpstreamError())
166    """
167    if get_deps is None:
168        get_deps = _no_deps
169
170    if limit is None:
171        limiter = NoLimit()
172    else:
173        limiter = Semaphore(limit)
174
175    results = Queue()
176    state = State(objects)
177
178    while True:
179        feed_queue(objects, func, get_deps, results, state, limiter)
180
181        try:
182            event = results.get(timeout=0.1)
183        except Empty:
184            continue
185        # See https://github.com/docker/compose/issues/189
186        except thread.error:
187            raise ShutdownException()
188
189        if event is STOP:
190            break
191
192        obj, _, exception = event
193        if exception is None:
194            log.debug('Finished processing: {}'.format(obj))
195            state.finished.add(obj)
196        else:
197            log.debug('Failed: {}'.format(obj))
198            state.failed.add(obj)
199
200        yield event
201
202
203def producer(obj, func, results, limiter):
204    """
205    The entry point for a producer thread which runs func on a single object.
206    Places a tuple on the results queue once func has either returned or raised.
207    """
208    with limiter, GlobalLimit.global_limiter:
209        try:
210            result = func(obj)
211            results.put((obj, result, None))
212        except Exception as e:
213            results.put((obj, None, e))
214
215
216def feed_queue(objects, func, get_deps, results, state, limiter):
217    """
218    Starts producer threads for any objects which are ready to be processed
219    (i.e. they have no dependencies which haven't been successfully processed).
220
221    Shortcuts any objects whose dependencies have failed and places an
222    (object, None, UpstreamError()) tuple on the results queue.
223    """
224    pending = state.pending()
225    log.debug('Pending: {}'.format(pending))
226
227    for obj in pending:
228        deps = get_deps(obj)
229        try:
230            if any(dep[0] in state.failed for dep in deps):
231                log.debug('{} has upstream errors - not processing'.format(obj))
232                results.put((obj, None, UpstreamError()))
233                state.failed.add(obj)
234            elif all(
235                dep not in objects or (
236                    dep in state.finished and (not ready_check or ready_check(dep))
237                ) for dep, ready_check in deps
238            ):
239                log.debug('Starting producer thread for {}'.format(obj))
240                t = Thread(target=producer, args=(obj, func, results, limiter))
241                t.daemon = True
242                t.start()
243                state.started.add(obj)
244        except (HealthCheckFailed, NoHealthCheckConfigured) as e:
245            log.debug(
246                'Healthcheck for service(s) upstream of {} failed - '
247                'not processing'.format(obj)
248            )
249            results.put((obj, None, e))
250
251    if state.is_done():
252        results.put(STOP)
253
254
255class UpstreamError(Exception):
256    pass
257
258
259class ParallelStreamWriter(object):
260    """Write out messages for operations happening in parallel.
261
262    Each operation has its own line, and ANSI code characters are used
263    to jump to the correct line, and write over the line.
264    """
265
266    noansi = False
267    lock = Lock()
268    instance = None
269
270    @classmethod
271    def set_noansi(cls, value=True):
272        cls.noansi = value
273
274    def __init__(self, stream):
275        self.stream = stream
276        self.lines = []
277        self.width = 0
278        ParallelStreamWriter.instance = self
279
280    def add_object(self, msg, obj_index):
281        if msg is None:
282            return
283        self.lines.append(msg + obj_index)
284        self.width = max(self.width, len(msg + ' ' + obj_index))
285
286    def write_initial(self, msg, obj_index):
287        if msg is None:
288            return
289        return self._write_noansi(msg, obj_index, '')
290
291    def _write_ansi(self, msg, obj_index, status):
292        self.lock.acquire()
293        position = self.lines.index(msg + obj_index)
294        diff = len(self.lines) - position
295        # move up
296        self.stream.write("%c[%dA" % (27, diff))
297        # erase
298        self.stream.write("%c[2K\r" % 27)
299        self.stream.write("{:<{width}} ... {}\r".format(msg + ' ' + obj_index,
300                          status, width=self.width))
301        # move back down
302        self.stream.write("%c[%dB" % (27, diff))
303        self.stream.flush()
304        self.lock.release()
305
306    def _write_noansi(self, msg, obj_index, status):
307        self.stream.write(
308            "{:<{width}} ... {}\r\n".format(
309                msg + ' ' + obj_index, status, width=self.width
310            )
311        )
312        self.stream.flush()
313
314    def write(self, msg, obj_index, status, color_func):
315        if msg is None:
316            return
317        if self.noansi:
318            self._write_noansi(msg, obj_index, status)
319        else:
320            self._write_ansi(msg, obj_index, color_func(status))
321
322
323def get_stream_writer():
324    instance = ParallelStreamWriter.instance
325    if instance is None:
326        raise RuntimeError('ParallelStreamWriter has not yet been instantiated')
327    return instance
328
329
330def parallel_operation(containers, operation, options, message):
331    parallel_execute(
332        containers,
333        operator.methodcaller(operation, **options),
334        operator.attrgetter('name'),
335        message,
336    )
337
338
339def parallel_remove(containers, options):
340    stopped_containers = [c for c in containers if not c.is_running]
341    parallel_operation(stopped_containers, 'remove', options, 'Removing')
342
343
344def parallel_pause(containers, options):
345    parallel_operation(containers, 'pause', options, 'Pausing')
346
347
348def parallel_unpause(containers, options):
349    parallel_operation(containers, 'unpause', options, 'Unpausing')
350
351
352def parallel_kill(containers, options):
353    parallel_operation(containers, 'kill', options, 'Killing')
354