1# -*- coding: utf-8 -*-
2
3#    Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
4#
5#    Licensed under the Apache License, Version 2.0 (the "License"); you may
6#    not use this file except in compliance with the License. You may obtain
7#    a copy of the License at
8#
9#         http://www.apache.org/licenses/LICENSE-2.0
10#
11#    Unless required by applicable law or agreed to in writing, software
12#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14#    License for the specific language governing permissions and limitations
15#    under the License.
16
17import contextlib
18import string
19import threading
20import time
21
22from oslo_utils import timeutils
23import redis
24import six
25
26from taskflow import exceptions
27from taskflow.listeners import capturing
28from taskflow.persistence.backends import impl_memory
29from taskflow import retry
30from taskflow import task
31from taskflow.types import failure
32from taskflow.utils import kazoo_utils
33from taskflow.utils import redis_utils
34
35ARGS_KEY = '__args__'
36KWARGS_KEY = '__kwargs__'
37ORDER_KEY = '__order__'
38ZK_TEST_CONFIG = {
39    'timeout': 1.0,
40    'hosts': ["localhost:2181"],
41}
42# If latches/events take longer than this to become empty/set, something is
43# usually wrong and should be debugged instead of deadlocking...
44WAIT_TIMEOUT = 300
45
46
47@contextlib.contextmanager
48def wrap_all_failures():
49    """Convert any exceptions to WrappedFailure.
50
51    When you expect several failures, it may be convenient
52    to wrap any exception with WrappedFailure in order to
53    unify error handling.
54    """
55    try:
56        yield
57    except Exception:
58        raise exceptions.WrappedFailure([failure.Failure()])
59
60
61def zookeeper_available(min_version, timeout=3):
62    client = kazoo_utils.make_client(ZK_TEST_CONFIG.copy())
63    try:
64        # NOTE(imelnikov): 3 seconds we should be enough for localhost
65        client.start(timeout=float(timeout))
66        if min_version:
67            zk_ver = client.server_version()
68            if zk_ver >= min_version:
69                return True
70            else:
71                return False
72        else:
73            return True
74    except Exception:
75        return False
76    finally:
77        kazoo_utils.finalize_client(client)
78
79
80def redis_available(min_version):
81    client = redis.StrictRedis()
82    try:
83        client.ping()
84    except Exception:
85        return False
86    else:
87        ok, redis_version = redis_utils.is_server_new_enough(client,
88                                                             min_version)
89        return ok
90
91
92class NoopRetry(retry.AlwaysRevert):
93    pass
94
95
96class NoopTask(task.Task):
97
98    def execute(self):
99        pass
100
101
102class DummyTask(task.Task):
103
104    def execute(self, context, *args, **kwargs):
105        pass
106
107
108class EmittingTask(task.Task):
109    TASK_EVENTS = (task.EVENT_UPDATE_PROGRESS, 'hi')
110
111    def execute(self, *args, **kwargs):
112        self.notifier.notify('hi',
113                             details={'sent_on': timeutils.utcnow(),
114                                      'args': args, 'kwargs': kwargs})
115
116
117class AddOneSameProvidesRequires(task.Task):
118    default_provides = 'value'
119
120    def execute(self, value):
121        return value + 1
122
123
124class AddOne(task.Task):
125    default_provides = 'result'
126
127    def execute(self, source):
128        return source + 1
129
130
131class GiveBackRevert(task.Task):
132
133    def execute(self, value):
134        return value + 1
135
136    def revert(self, *args, **kwargs):
137        result = kwargs.get('result')
138        # If this somehow fails, timeout, or other don't send back a
139        # valid result...
140        if isinstance(result, six.integer_types):
141            return result + 1
142
143
144class FakeTask(object):
145
146    def execute(self, **kwargs):
147        pass
148
149
150class LongArgNameTask(task.Task):
151
152    def execute(self, long_arg_name):
153        return long_arg_name
154
155
156if six.PY3:
157    RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception',
158                             'BaseException', 'object']
159else:
160    RUNTIME_ERROR_CLASSES = ['RuntimeError', 'StandardError', 'Exception',
161                             'BaseException', 'object']
162
163
164class ProvidesRequiresTask(task.Task):
165    def __init__(self, name, provides, requires, return_tuple=True):
166        super(ProvidesRequiresTask, self).__init__(name=name,
167                                                   provides=provides,
168                                                   requires=requires)
169        self.return_tuple = isinstance(provides, (tuple, list))
170
171    def execute(self, *args, **kwargs):
172        if self.return_tuple:
173            return tuple(range(len(self.provides)))
174        else:
175            return dict((k, k) for k in self.provides)
176
177
178# Used to format the captured values into strings (which are easier to
179# check later in tests)...
180LOOKUP_NAME_POSTFIX = {
181    capturing.CaptureListener.TASK: ('.t', 'task_name'),
182    capturing.CaptureListener.RETRY: ('.r', 'retry_name'),
183    capturing.CaptureListener.FLOW: ('.f', 'flow_name'),
184}
185
186
187class CaptureListener(capturing.CaptureListener):
188
189    @staticmethod
190    def _format_capture(kind, state, details):
191        name_postfix, name_key = LOOKUP_NAME_POSTFIX[kind]
192        name = details[name_key] + name_postfix
193        if 'result' in details:
194            name += ' %s(%s)' % (state, details['result'])
195        else:
196            name += " %s" % state
197        return name
198
199
200class MultiProgressingTask(task.Task):
201    def execute(self, progress_chunks):
202        for chunk in progress_chunks:
203            self.update_progress(chunk)
204        return len(progress_chunks)
205
206
207class ProgressingTask(task.Task):
208    def execute(self, **kwargs):
209        self.update_progress(0.0)
210        self.update_progress(1.0)
211        return 5
212
213    def revert(self, **kwargs):
214        self.update_progress(0)
215        self.update_progress(1.0)
216
217
218class FailingTask(ProgressingTask):
219    def execute(self, **kwargs):
220        self.update_progress(0)
221        self.update_progress(0.99)
222        raise RuntimeError('Woot!')
223
224
225class OptionalTask(task.Task):
226    def execute(self, a, b=5):
227        result = a * b
228        return result
229
230
231class TaskWithFailure(task.Task):
232
233    def execute(self, **kwargs):
234        raise RuntimeError('Woot!')
235
236
237class FailingTaskWithOneArg(ProgressingTask):
238    def execute(self, x, **kwargs):
239        raise RuntimeError('Woot with %s' % x)
240
241
242class NastyTask(task.Task):
243
244    def execute(self, **kwargs):
245        pass
246
247    def revert(self, **kwargs):
248        raise RuntimeError('Gotcha!')
249
250
251class NastyFailingTask(NastyTask):
252    def execute(self, **kwargs):
253        raise RuntimeError('Woot!')
254
255
256class TaskNoRequiresNoReturns(task.Task):
257
258    def execute(self, **kwargs):
259        pass
260
261    def revert(self, **kwargs):
262        pass
263
264
265class TaskOneArg(task.Task):
266
267    def execute(self, x, **kwargs):
268        pass
269
270    def revert(self, x, **kwargs):
271        pass
272
273
274class TaskMultiArg(task.Task):
275
276    def execute(self, x, y, z, **kwargs):
277        pass
278
279    def revert(self, x, y, z, **kwargs):
280        pass
281
282
283class TaskOneReturn(task.Task):
284
285    def execute(self, **kwargs):
286        return 1
287
288    def revert(self, **kwargs):
289        pass
290
291
292class TaskMultiReturn(task.Task):
293
294    def execute(self, **kwargs):
295        return 1, 3, 5
296
297    def revert(self, **kwargs):
298        pass
299
300
301class TaskOneArgOneReturn(task.Task):
302
303    def execute(self, x, **kwargs):
304        return 1
305
306    def revert(self, x, **kwargs):
307        pass
308
309
310class TaskMultiArgOneReturn(task.Task):
311
312    def execute(self, x, y, z, **kwargs):
313        return x + y + z
314
315    def revert(self, x, y, z, **kwargs):
316        pass
317
318
319class TaskMultiArgMultiReturn(task.Task):
320
321    def execute(self, x, y, z, **kwargs):
322        return 1, 3, 5
323
324    def revert(self, x, y, z, **kwargs):
325        pass
326
327
328class TaskMultiDict(task.Task):
329
330    def execute(self):
331        output = {}
332        for i, k in enumerate(sorted(self.provides)):
333            output[k] = i
334        return output
335
336
337class NeverRunningTask(task.Task):
338    def execute(self, **kwargs):
339        assert False, 'This method should not be called'
340
341    def revert(self, **kwargs):
342        assert False, 'This method should not be called'
343
344
345class TaskRevertExtraArgs(task.Task):
346    def execute(self, **kwargs):
347        raise exceptions.ExecutionFailure("We want to force a revert here")
348
349    def revert(self, revert_arg, flow_failures, result, **kwargs):
350        pass
351
352
353class SleepTask(task.Task):
354    def execute(self, duration, **kwargs):
355        time.sleep(duration)
356
357
358class EngineTestBase(object):
359    def setUp(self):
360        super(EngineTestBase, self).setUp()
361        self.backend = impl_memory.MemoryBackend(conf={})
362
363    def tearDown(self):
364        EngineTestBase.values = None
365        with contextlib.closing(self.backend) as be:
366            with contextlib.closing(be.get_connection()) as conn:
367                conn.clear_all()
368        super(EngineTestBase, self).tearDown()
369
370    def _make_engine(self, flow, **kwargs):
371        raise exceptions.NotImplementedError("_make_engine() must be"
372                                             " overridden if an engine is"
373                                             " desired")
374
375
376class FailureMatcher(object):
377    """Needed for failure objects comparison."""
378
379    def __init__(self, failure):
380        self._failure = failure
381
382    def __repr__(self):
383        return str(self._failure)
384
385    def __eq__(self, other):
386        return self._failure.matches(other)
387
388    def __ne__(self, other):
389        return not self.__eq__(other)
390
391
392class OneReturnRetry(retry.AlwaysRevert):
393
394    def execute(self, **kwargs):
395        return 1
396
397    def revert(self, **kwargs):
398        pass
399
400
401class ConditionalTask(ProgressingTask):
402
403    def execute(self, x, y):
404        super(ConditionalTask, self).execute()
405        if x != y:
406            raise RuntimeError('Woot!')
407
408
409class WaitForOneFromTask(ProgressingTask):
410
411    def __init__(self, name, wait_for, wait_states, **kwargs):
412        super(WaitForOneFromTask, self).__init__(name, **kwargs)
413        if isinstance(wait_for, six.string_types):
414            self.wait_for = [wait_for]
415        else:
416            self.wait_for = wait_for
417        if isinstance(wait_states, six.string_types):
418            self.wait_states = [wait_states]
419        else:
420            self.wait_states = wait_states
421        self.event = threading.Event()
422
423    def execute(self):
424        if not self.event.wait(WAIT_TIMEOUT):
425            raise RuntimeError('%s second timeout occurred while waiting '
426                               'for %s to change state to %s'
427                               % (WAIT_TIMEOUT, self.wait_for,
428                                  self.wait_states))
429        return super(WaitForOneFromTask, self).execute()
430
431    def callback(self, state, details):
432        name = details.get('task_name', None)
433        if name not in self.wait_for or state not in self.wait_states:
434            return
435        self.event.set()
436
437
438def make_many(amount, task_cls=DummyTask, offset=0):
439    name_pool = string.ascii_lowercase + string.ascii_uppercase
440    tasks = []
441    while amount > 0:
442        if offset >= len(name_pool):
443            raise AssertionError('Name pool size to small (%s < %s)'
444                                 % (len(name_pool), offset + 1))
445        tasks.append(task_cls(name=name_pool[offset]))
446        offset += 1
447        amount -= 1
448    return tasks
449