1# -*- coding: utf-8 -*- 2 3# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); you may 6# not use this file except in compliance with the License. You may obtain 7# a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 14# License for the specific language governing permissions and limitations 15# under the License. 16 17import contextlib 18import string 19import threading 20import time 21 22from oslo_utils import timeutils 23import redis 24import six 25 26from taskflow import exceptions 27from taskflow.listeners import capturing 28from taskflow.persistence.backends import impl_memory 29from taskflow import retry 30from taskflow import task 31from taskflow.types import failure 32from taskflow.utils import kazoo_utils 33from taskflow.utils import redis_utils 34 35ARGS_KEY = '__args__' 36KWARGS_KEY = '__kwargs__' 37ORDER_KEY = '__order__' 38ZK_TEST_CONFIG = { 39 'timeout': 1.0, 40 'hosts': ["localhost:2181"], 41} 42# If latches/events take longer than this to become empty/set, something is 43# usually wrong and should be debugged instead of deadlocking... 44WAIT_TIMEOUT = 300 45 46 47@contextlib.contextmanager 48def wrap_all_failures(): 49 """Convert any exceptions to WrappedFailure. 50 51 When you expect several failures, it may be convenient 52 to wrap any exception with WrappedFailure in order to 53 unify error handling. 54 """ 55 try: 56 yield 57 except Exception: 58 raise exceptions.WrappedFailure([failure.Failure()]) 59 60 61def zookeeper_available(min_version, timeout=3): 62 client = kazoo_utils.make_client(ZK_TEST_CONFIG.copy()) 63 try: 64 # NOTE(imelnikov): 3 seconds we should be enough for localhost 65 client.start(timeout=float(timeout)) 66 if min_version: 67 zk_ver = client.server_version() 68 if zk_ver >= min_version: 69 return True 70 else: 71 return False 72 else: 73 return True 74 except Exception: 75 return False 76 finally: 77 kazoo_utils.finalize_client(client) 78 79 80def redis_available(min_version): 81 client = redis.StrictRedis() 82 try: 83 client.ping() 84 except Exception: 85 return False 86 else: 87 ok, redis_version = redis_utils.is_server_new_enough(client, 88 min_version) 89 return ok 90 91 92class NoopRetry(retry.AlwaysRevert): 93 pass 94 95 96class NoopTask(task.Task): 97 98 def execute(self): 99 pass 100 101 102class DummyTask(task.Task): 103 104 def execute(self, context, *args, **kwargs): 105 pass 106 107 108class EmittingTask(task.Task): 109 TASK_EVENTS = (task.EVENT_UPDATE_PROGRESS, 'hi') 110 111 def execute(self, *args, **kwargs): 112 self.notifier.notify('hi', 113 details={'sent_on': timeutils.utcnow(), 114 'args': args, 'kwargs': kwargs}) 115 116 117class AddOneSameProvidesRequires(task.Task): 118 default_provides = 'value' 119 120 def execute(self, value): 121 return value + 1 122 123 124class AddOne(task.Task): 125 default_provides = 'result' 126 127 def execute(self, source): 128 return source + 1 129 130 131class GiveBackRevert(task.Task): 132 133 def execute(self, value): 134 return value + 1 135 136 def revert(self, *args, **kwargs): 137 result = kwargs.get('result') 138 # If this somehow fails, timeout, or other don't send back a 139 # valid result... 140 if isinstance(result, six.integer_types): 141 return result + 1 142 143 144class FakeTask(object): 145 146 def execute(self, **kwargs): 147 pass 148 149 150class LongArgNameTask(task.Task): 151 152 def execute(self, long_arg_name): 153 return long_arg_name 154 155 156if six.PY3: 157 RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception', 158 'BaseException', 'object'] 159else: 160 RUNTIME_ERROR_CLASSES = ['RuntimeError', 'StandardError', 'Exception', 161 'BaseException', 'object'] 162 163 164class ProvidesRequiresTask(task.Task): 165 def __init__(self, name, provides, requires, return_tuple=True): 166 super(ProvidesRequiresTask, self).__init__(name=name, 167 provides=provides, 168 requires=requires) 169 self.return_tuple = isinstance(provides, (tuple, list)) 170 171 def execute(self, *args, **kwargs): 172 if self.return_tuple: 173 return tuple(range(len(self.provides))) 174 else: 175 return dict((k, k) for k in self.provides) 176 177 178# Used to format the captured values into strings (which are easier to 179# check later in tests)... 180LOOKUP_NAME_POSTFIX = { 181 capturing.CaptureListener.TASK: ('.t', 'task_name'), 182 capturing.CaptureListener.RETRY: ('.r', 'retry_name'), 183 capturing.CaptureListener.FLOW: ('.f', 'flow_name'), 184} 185 186 187class CaptureListener(capturing.CaptureListener): 188 189 @staticmethod 190 def _format_capture(kind, state, details): 191 name_postfix, name_key = LOOKUP_NAME_POSTFIX[kind] 192 name = details[name_key] + name_postfix 193 if 'result' in details: 194 name += ' %s(%s)' % (state, details['result']) 195 else: 196 name += " %s" % state 197 return name 198 199 200class MultiProgressingTask(task.Task): 201 def execute(self, progress_chunks): 202 for chunk in progress_chunks: 203 self.update_progress(chunk) 204 return len(progress_chunks) 205 206 207class ProgressingTask(task.Task): 208 def execute(self, **kwargs): 209 self.update_progress(0.0) 210 self.update_progress(1.0) 211 return 5 212 213 def revert(self, **kwargs): 214 self.update_progress(0) 215 self.update_progress(1.0) 216 217 218class FailingTask(ProgressingTask): 219 def execute(self, **kwargs): 220 self.update_progress(0) 221 self.update_progress(0.99) 222 raise RuntimeError('Woot!') 223 224 225class OptionalTask(task.Task): 226 def execute(self, a, b=5): 227 result = a * b 228 return result 229 230 231class TaskWithFailure(task.Task): 232 233 def execute(self, **kwargs): 234 raise RuntimeError('Woot!') 235 236 237class FailingTaskWithOneArg(ProgressingTask): 238 def execute(self, x, **kwargs): 239 raise RuntimeError('Woot with %s' % x) 240 241 242class NastyTask(task.Task): 243 244 def execute(self, **kwargs): 245 pass 246 247 def revert(self, **kwargs): 248 raise RuntimeError('Gotcha!') 249 250 251class NastyFailingTask(NastyTask): 252 def execute(self, **kwargs): 253 raise RuntimeError('Woot!') 254 255 256class TaskNoRequiresNoReturns(task.Task): 257 258 def execute(self, **kwargs): 259 pass 260 261 def revert(self, **kwargs): 262 pass 263 264 265class TaskOneArg(task.Task): 266 267 def execute(self, x, **kwargs): 268 pass 269 270 def revert(self, x, **kwargs): 271 pass 272 273 274class TaskMultiArg(task.Task): 275 276 def execute(self, x, y, z, **kwargs): 277 pass 278 279 def revert(self, x, y, z, **kwargs): 280 pass 281 282 283class TaskOneReturn(task.Task): 284 285 def execute(self, **kwargs): 286 return 1 287 288 def revert(self, **kwargs): 289 pass 290 291 292class TaskMultiReturn(task.Task): 293 294 def execute(self, **kwargs): 295 return 1, 3, 5 296 297 def revert(self, **kwargs): 298 pass 299 300 301class TaskOneArgOneReturn(task.Task): 302 303 def execute(self, x, **kwargs): 304 return 1 305 306 def revert(self, x, **kwargs): 307 pass 308 309 310class TaskMultiArgOneReturn(task.Task): 311 312 def execute(self, x, y, z, **kwargs): 313 return x + y + z 314 315 def revert(self, x, y, z, **kwargs): 316 pass 317 318 319class TaskMultiArgMultiReturn(task.Task): 320 321 def execute(self, x, y, z, **kwargs): 322 return 1, 3, 5 323 324 def revert(self, x, y, z, **kwargs): 325 pass 326 327 328class TaskMultiDict(task.Task): 329 330 def execute(self): 331 output = {} 332 for i, k in enumerate(sorted(self.provides)): 333 output[k] = i 334 return output 335 336 337class NeverRunningTask(task.Task): 338 def execute(self, **kwargs): 339 assert False, 'This method should not be called' 340 341 def revert(self, **kwargs): 342 assert False, 'This method should not be called' 343 344 345class TaskRevertExtraArgs(task.Task): 346 def execute(self, **kwargs): 347 raise exceptions.ExecutionFailure("We want to force a revert here") 348 349 def revert(self, revert_arg, flow_failures, result, **kwargs): 350 pass 351 352 353class SleepTask(task.Task): 354 def execute(self, duration, **kwargs): 355 time.sleep(duration) 356 357 358class EngineTestBase(object): 359 def setUp(self): 360 super(EngineTestBase, self).setUp() 361 self.backend = impl_memory.MemoryBackend(conf={}) 362 363 def tearDown(self): 364 EngineTestBase.values = None 365 with contextlib.closing(self.backend) as be: 366 with contextlib.closing(be.get_connection()) as conn: 367 conn.clear_all() 368 super(EngineTestBase, self).tearDown() 369 370 def _make_engine(self, flow, **kwargs): 371 raise exceptions.NotImplementedError("_make_engine() must be" 372 " overridden if an engine is" 373 " desired") 374 375 376class FailureMatcher(object): 377 """Needed for failure objects comparison.""" 378 379 def __init__(self, failure): 380 self._failure = failure 381 382 def __repr__(self): 383 return str(self._failure) 384 385 def __eq__(self, other): 386 return self._failure.matches(other) 387 388 def __ne__(self, other): 389 return not self.__eq__(other) 390 391 392class OneReturnRetry(retry.AlwaysRevert): 393 394 def execute(self, **kwargs): 395 return 1 396 397 def revert(self, **kwargs): 398 pass 399 400 401class ConditionalTask(ProgressingTask): 402 403 def execute(self, x, y): 404 super(ConditionalTask, self).execute() 405 if x != y: 406 raise RuntimeError('Woot!') 407 408 409class WaitForOneFromTask(ProgressingTask): 410 411 def __init__(self, name, wait_for, wait_states, **kwargs): 412 super(WaitForOneFromTask, self).__init__(name, **kwargs) 413 if isinstance(wait_for, six.string_types): 414 self.wait_for = [wait_for] 415 else: 416 self.wait_for = wait_for 417 if isinstance(wait_states, six.string_types): 418 self.wait_states = [wait_states] 419 else: 420 self.wait_states = wait_states 421 self.event = threading.Event() 422 423 def execute(self): 424 if not self.event.wait(WAIT_TIMEOUT): 425 raise RuntimeError('%s second timeout occurred while waiting ' 426 'for %s to change state to %s' 427 % (WAIT_TIMEOUT, self.wait_for, 428 self.wait_states)) 429 return super(WaitForOneFromTask, self).execute() 430 431 def callback(self, state, details): 432 name = details.get('task_name', None) 433 if name not in self.wait_for or state not in self.wait_states: 434 return 435 self.event.set() 436 437 438def make_many(amount, task_cls=DummyTask, offset=0): 439 name_pool = string.ascii_lowercase + string.ascii_uppercase 440 tasks = [] 441 while amount > 0: 442 if offset >= len(name_pool): 443 raise AssertionError('Name pool size to small (%s < %s)' 444 % (len(name_pool), offset + 1)) 445 tasks.append(task_cls(name=name_pool[offset])) 446 offset += 1 447 amount -= 1 448 return tasks 449