1# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
2# For details: https://bitbucket.org/ned/coveragepy/src/default/NOTICE.txt
3
4"""Tests for concurrency libraries."""
5
6import os
7import random
8import sys
9import threading
10import time
11
12from flaky import flaky
13
14import coverage
15from coverage import env
16from coverage.backward import import_local_file
17from coverage.files import abs_file
18
19from tests.coveragetest import CoverageTest
20
21
22# These libraries aren't always available, we'll skip tests if they aren't.
23
24try:
25    import multiprocessing
26except ImportError:         # pragma: only jython
27    multiprocessing = None
28
29try:
30    import eventlet
31except ImportError:
32    eventlet = None
33
34try:
35    import gevent
36except ImportError:
37    gevent = None
38
39try:
40    import greenlet
41except ImportError:         # pragma: only jython
42    greenlet = None
43
44
45def measurable_line(l):
46    """Is this a line of code coverage will measure?
47
48    Not blank, not a comment, and not "else"
49    """
50    l = l.strip()
51    if not l:
52        return False
53    if l.startswith('#'):
54        return False
55    if l.startswith('else:'):
56        return False
57    if env.JYTHON and l.startswith(('try:', 'except:', 'except ', 'break', 'with ')):
58        # Jython doesn't measure these statements.
59        return False                    # pragma: only jython
60    return True
61
62
63def line_count(s):
64    """How many measurable lines are in `s`?"""
65    return len(list(filter(measurable_line, s.splitlines())))
66
67
68def print_simple_annotation(code, linenos):
69    """Print the lines in `code` with X for each line number in `linenos`."""
70    for lineno, line in enumerate(code.splitlines(), start=1):
71        print(" {0} {1}".format("X" if lineno in linenos else " ", line))
72
73
74class LineCountTest(CoverageTest):
75    """Test the helpers here."""
76
77    run_in_temp_dir = False
78
79    def test_line_count(self):
80        CODE = """
81            # Hey there!
82            x = 1
83            if x:
84                print("hello")
85            else:
86                print("bye")
87
88            print("done")
89            """
90
91        self.assertEqual(line_count(CODE), 5)
92
93
94# The code common to all the concurrency models.
95SUM_RANGE_Q = """
96    # Above this will be imports defining queue and threading.
97
98    class Producer(threading.Thread):
99        def __init__(self, limit, q):
100            threading.Thread.__init__(self)
101            self.limit = limit
102            self.q = q
103
104        def run(self):
105            for i in range(self.limit):
106                self.q.put(i)
107            self.q.put(None)
108
109    class Consumer(threading.Thread):
110        def __init__(self, q, qresult):
111            threading.Thread.__init__(self)
112            self.q = q
113            self.qresult = qresult
114
115        def run(self):
116            sum = 0
117            while "no peephole".upper():
118                i = self.q.get()
119                if i is None:
120                    break
121                sum += i
122            self.qresult.put(sum)
123
124    def sum_range(limit):
125        q = queue.Queue()
126        qresult = queue.Queue()
127        c = Consumer(q, qresult)
128        p = Producer(limit, q)
129        c.start()
130        p.start()
131
132        p.join()
133        c.join()
134        return qresult.get()
135
136    # Below this will be something using sum_range.
137    """
138
139PRINT_SUM_RANGE = """
140    print(sum_range({QLIMIT}))
141    """
142
143# Import the things to use threads.
144if env.PY2:
145    THREAD = """
146    import threading
147    import Queue as queue
148    """
149else:
150    THREAD = """
151    import threading
152    import queue
153    """
154
155# Import the things to use eventlet.
156EVENTLET = """
157    import eventlet.green.threading as threading
158    import eventlet.queue as queue
159    """
160
161# Import the things to use gevent.
162GEVENT = """
163    from gevent import monkey
164    monkey.patch_thread()
165    import threading
166    import gevent.queue as queue
167    """
168
169# Uncomplicated code that doesn't use any of the concurrency stuff, to test
170# the simple case under each of the regimes.
171SIMPLE = """
172    total = 0
173    for i in range({QLIMIT}):
174        total += i
175    print(total)
176    """
177
178
179def cant_trace_msg(concurrency, the_module):
180    """What might coverage.py say about a concurrency setting and imported module?"""
181    # In the concurrency choices, "multiprocessing" doesn't count, so remove it.
182    if "multiprocessing" in concurrency:
183        parts = concurrency.split(",")
184        parts.remove("multiprocessing")
185        concurrency = ",".join(parts)
186
187    if the_module is None:
188        # We don't even have the underlying module installed, we expect
189        # coverage to alert us to this fact.
190        expected_out = (
191            "Couldn't trace with concurrency=%s, "
192            "the module isn't installed.\n" % concurrency
193        )
194    elif env.C_TRACER or concurrency == "thread" or concurrency == "":
195        expected_out = None
196    else:
197        expected_out = (
198            "Can't support concurrency=%s with PyTracer, "
199            "only threads are supported\n" % concurrency
200        )
201    return expected_out
202
203
204class ConcurrencyTest(CoverageTest):
205    """Tests of the concurrency support in coverage.py."""
206
207    QLIMIT = 1000
208
209    def try_some_code(self, code, concurrency, the_module, expected_out=None):
210        """Run some concurrency testing code and see that it was all covered.
211
212        `code` is the Python code to execute.  `concurrency` is the name of
213        the concurrency regime to test it under.  `the_module` is the imported
214        module that must be available for this to work at all. `expected_out`
215        is the text we expect the code to produce.
216
217        """
218
219        self.make_file("try_it.py", code)
220
221        cmd = "coverage run --concurrency=%s try_it.py" % concurrency
222        out = self.run_command(cmd)
223
224        expected_cant_trace = cant_trace_msg(concurrency, the_module)
225
226        if expected_cant_trace is not None:
227            self.assertEqual(out, expected_cant_trace)
228        else:
229            # We can fully measure the code if we are using the C tracer, which
230            # can support all the concurrency, or if we are using threads.
231            if expected_out is None:
232                expected_out = "%d\n" % (sum(range(self.QLIMIT)))
233            print(code)
234            self.assertEqual(out, expected_out)
235
236            # Read the coverage file and see that try_it.py has all its lines
237            # executed.
238            data = coverage.CoverageData()
239            data.read_file(".coverage")
240
241            # If the test fails, it's helpful to see this info:
242            fname = abs_file("try_it.py")
243            linenos = data.lines(fname)
244            print("{0}: {1}".format(len(linenos), linenos))
245            print_simple_annotation(code, linenos)
246
247            lines = line_count(code)
248            self.assertEqual(data.line_counts()['try_it.py'], lines)
249
250    def test_threads(self):
251        code = (THREAD + SUM_RANGE_Q + PRINT_SUM_RANGE).format(QLIMIT=self.QLIMIT)
252        self.try_some_code(code, "thread", threading)
253
254    def test_threads_simple_code(self):
255        code = SIMPLE.format(QLIMIT=self.QLIMIT)
256        self.try_some_code(code, "thread", threading)
257
258    def test_eventlet(self):
259        code = (EVENTLET + SUM_RANGE_Q + PRINT_SUM_RANGE).format(QLIMIT=self.QLIMIT)
260        self.try_some_code(code, "eventlet", eventlet)
261
262    def test_eventlet_simple_code(self):
263        code = SIMPLE.format(QLIMIT=self.QLIMIT)
264        self.try_some_code(code, "eventlet", eventlet)
265
266    def test_gevent(self):
267        code = (GEVENT + SUM_RANGE_Q + PRINT_SUM_RANGE).format(QLIMIT=self.QLIMIT)
268        self.try_some_code(code, "gevent", gevent)
269
270    def test_gevent_simple_code(self):
271        code = SIMPLE.format(QLIMIT=self.QLIMIT)
272        self.try_some_code(code, "gevent", gevent)
273
274    def test_greenlet(self):
275        GREENLET = """\
276            from greenlet import greenlet
277
278            def test1(x, y):
279                z = gr2.switch(x+y)
280                print(z)
281
282            def test2(u):
283                print(u)
284                gr1.switch(42)
285
286            gr1 = greenlet(test1)
287            gr2 = greenlet(test2)
288            gr1.switch("hello", " world")
289            """
290        self.try_some_code(GREENLET, "greenlet", greenlet, "hello world\n42\n")
291
292    def test_greenlet_simple_code(self):
293        code = SIMPLE.format(QLIMIT=self.QLIMIT)
294        self.try_some_code(code, "greenlet", greenlet)
295
296    def test_bug_330(self):
297        BUG_330 = """\
298            from weakref import WeakKeyDictionary
299            import eventlet
300
301            def do():
302                eventlet.sleep(.01)
303
304            gts = WeakKeyDictionary()
305            for _ in range(100):
306                gts[eventlet.spawn(do)] = True
307                eventlet.sleep(.005)
308
309            eventlet.sleep(.1)
310            print(len(gts))
311            """
312        self.try_some_code(BUG_330, "eventlet", eventlet, "0\n")
313
314
315SQUARE_OR_CUBE_WORK = """
316    def work(x):
317        # Use different lines in different subprocesses.
318        if x % 2:
319            y = x*x
320        else:
321            y = x*x*x
322        return y
323    """
324
325SUM_RANGE_WORK = """
326    def work(x):
327        return sum_range((x+1)*100)
328    """
329
330MULTI_CODE = """
331    # Above this will be a definition of work().
332    import multiprocessing
333    import os
334    import time
335    import sys
336
337    def process_worker_main(args):
338        # Need to pause, or the tasks go too quick, and some processes
339        # in the pool don't get any work, and then don't record data.
340        time.sleep(0.02)
341        ret = work(*args)
342        return os.getpid(), ret
343
344    if __name__ == "__main__":      # pragma: no branch
345        # This if is on a single line so we can get 100% coverage
346        # even if we have no arguments.
347        if len(sys.argv) > 1: multiprocessing.set_start_method(sys.argv[1])
348        pool = multiprocessing.Pool({NPROCS})
349        inputs = [(x,) for x in range({UPTO})]
350        outputs = pool.imap_unordered(process_worker_main, inputs)
351        pids = set()
352        total = 0
353        for pid, sq in outputs:
354            pids.add(pid)
355            total += sq
356        print("%d pids, total = %d" % (len(pids), total))
357        pool.close()
358        pool.join()
359    """
360
361
362@flaky(max_runs=10)         # Sometimes a test fails due to inherent randomness. Try one more time.
363class MultiprocessingTest(CoverageTest):
364    """Test support of the multiprocessing module."""
365
366    def setUp(self):
367        super(MultiprocessingTest, self).setUp()
368        if not multiprocessing:
369            self.skipTest("No multiprocessing in this Python")      # pragma: only jython
370
371    def try_multiprocessing_code(
372        self, code, expected_out, the_module, concurrency="multiprocessing"
373    ):
374        """Run code using multiprocessing, it should produce `expected_out`."""
375        self.make_file("multi.py", code)
376        self.make_file(".coveragerc", """\
377            [run]
378            concurrency = %s
379            source = .
380            """ % concurrency)
381
382        if env.PYVERSION >= (3, 4):
383            start_methods = ['fork', 'spawn']
384        else:
385            start_methods = ['']
386
387        for start_method in start_methods:
388            if start_method and start_method not in multiprocessing.get_all_start_methods():
389                continue
390
391            out = self.run_command("coverage run multi.py %s" % (start_method,))
392            expected_cant_trace = cant_trace_msg(concurrency, the_module)
393
394            if expected_cant_trace is not None:
395                self.assertEqual(out, expected_cant_trace)
396            else:
397                self.assertEqual(out.rstrip(), expected_out)
398
399                out = self.run_command("coverage combine")
400                self.assertEqual(out, "")
401                out = self.run_command("coverage report -m")
402
403                last_line = self.squeezed_lines(out)[-1]
404                self.assertRegex(last_line, r"multi.py \d+ 0 100%")
405
406    def test_multiprocessing(self):
407        nprocs = 3
408        upto = 30
409        code = (SQUARE_OR_CUBE_WORK + MULTI_CODE).format(NPROCS=nprocs, UPTO=upto)
410        total = sum(x*x if x%2 else x*x*x for x in range(upto))
411        expected_out = "{nprocs} pids, total = {total}".format(nprocs=nprocs, total=total)
412        self.try_multiprocessing_code(code, expected_out, threading)
413
414    def test_multiprocessing_and_gevent(self):
415        nprocs = 3
416        upto = 30
417        code = (
418            SUM_RANGE_WORK + EVENTLET + SUM_RANGE_Q + MULTI_CODE
419        ).format(NPROCS=nprocs, UPTO=upto)
420        total = sum(sum(range((x + 1) * 100)) for x in range(upto))
421        expected_out = "{nprocs} pids, total = {total}".format(nprocs=nprocs, total=total)
422        self.try_multiprocessing_code(
423            code, expected_out, eventlet, concurrency="multiprocessing,eventlet"
424        )
425
426    def try_multiprocessing_code_with_branching(self, code, expected_out):
427        """Run code using multiprocessing, it should produce `expected_out`."""
428        self.make_file("multi.py", code)
429        self.make_file("multi.rc", """\
430            [run]
431            concurrency = multiprocessing
432            branch = True
433            """)
434
435        if env.PYVERSION >= (3, 4):
436            start_methods = ['fork', 'spawn']
437        else:
438            start_methods = ['']
439
440        for start_method in start_methods:
441            if start_method and start_method not in multiprocessing.get_all_start_methods():
442                continue
443
444            out = self.run_command("coverage run --rcfile=multi.rc multi.py %s" % (start_method,))
445            self.assertEqual(out.rstrip(), expected_out)
446
447            out = self.run_command("coverage combine")
448            self.assertEqual(out, "")
449            out = self.run_command("coverage report -m")
450
451            last_line = self.squeezed_lines(out)[-1]
452            self.assertRegex(last_line, r"multi.py \d+ 0 \d+ 0 100%")
453
454    def test_multiprocessing_with_branching(self):
455        nprocs = 3
456        upto = 30
457        code = (SQUARE_OR_CUBE_WORK + MULTI_CODE).format(NPROCS=nprocs, UPTO=upto)
458        total = sum(x*x if x%2 else x*x*x for x in range(upto))
459        expected_out = "{nprocs} pids, total = {total}".format(nprocs=nprocs, total=total)
460        self.try_multiprocessing_code_with_branching(code, expected_out)
461
462
463def test_coverage_stop_in_threads():
464    has_started_coverage = []
465    has_stopped_coverage = []
466
467    def run_thread():
468        """Check that coverage is stopping properly in threads."""
469        deadline = time.time() + 5
470        ident = threading.currentThread().ident
471        if sys.gettrace() is not None:
472            has_started_coverage.append(ident)
473        while sys.gettrace() is not None:
474            # Wait for coverage to stop
475            time.sleep(0.01)
476            if time.time() > deadline:
477                return
478        has_stopped_coverage.append(ident)
479
480    cov = coverage.coverage()
481    cov.start()
482
483    t = threading.Thread(target=run_thread)
484    t.start()
485
486    time.sleep(0.1)
487    cov.stop()
488    time.sleep(0.1)
489
490    assert has_started_coverage == [t.ident]
491    assert has_stopped_coverage == [t.ident]
492    t.join()
493
494
495def test_thread_safe_save_data(tmpdir):
496    # Non-regression test for:
497    # https://bitbucket.org/ned/coveragepy/issues/581
498
499    # Create some Python modules and put them in the path
500    modules_dir = tmpdir.mkdir('test_modules')
501    module_names = ["m{0:03d}".format(i) for i in range(1000)]
502    for module_name in module_names:
503        modules_dir.join(module_name + ".py").write("def f(): pass\n")
504
505    # Shared variables for threads
506    should_run = [True]
507    imported = []
508
509    old_dir = os.getcwd()
510    os.chdir(modules_dir.strpath)
511    try:
512        # Make sure that all dummy modules can be imported.
513        for module_name in module_names:
514            import_local_file(module_name)
515
516        def random_load():
517            """Import modules randomly to stress coverage."""
518            while should_run[0]:
519                module_name = random.choice(module_names)
520                mod = import_local_file(module_name)
521                mod.f()
522                imported.append(mod)
523
524        # Spawn some threads with coverage enabled and attempt to read the
525        # results right after stopping coverage collection with the threads
526        #  still running.
527        duration = 0.01
528        for _ in range(3):
529            cov = coverage.coverage()
530            cov.start()
531
532            threads = [threading.Thread(target=random_load) for _ in range(10)]
533            should_run[0] = True
534            for t in threads:
535                t.start()
536
537            time.sleep(duration)
538
539            cov.stop()
540
541            # The following call used to crash with running background threads.
542            cov.get_data()
543
544            # Stop the threads
545            should_run[0] = False
546            for t in threads:
547                t.join()
548
549            if (not imported) and duration < 10:
550                duration *= 2
551
552    finally:
553        os.chdir(old_dir)
554        should_run[0] = False
555