1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
5# Use of this source code is governed by a BSD-style license that can be
6# found in the LICENSE file.
7
8"""
9Tests for testing utils (psutil.tests namespace).
10"""
11
12import collections
13import contextlib
14import errno
15import os
16import socket
17import stat
18import subprocess
19
20from psutil import FREEBSD
21from psutil import NETBSD
22from psutil import POSIX
23from psutil._common import open_binary
24from psutil._common import open_text
25from psutil._common import supports_ipv6
26from psutil.tests import bind_socket
27from psutil.tests import bind_unix_socket
28from psutil.tests import call_until
29from psutil.tests import chdir
30from psutil.tests import CI_TESTING
31from psutil.tests import create_sockets
32from psutil.tests import get_free_port
33from psutil.tests import HAS_CONNECTIONS_UNIX
34from psutil.tests import is_namedtuple
35from psutil.tests import mock
36from psutil.tests import process_namespace
37from psutil.tests import PsutilTestCase
38from psutil.tests import PYTHON_EXE
39from psutil.tests import reap_children
40from psutil.tests import retry
41from psutil.tests import retry_on_failure
42from psutil.tests import safe_mkdir
43from psutil.tests import safe_rmpath
44from psutil.tests import serialrun
45from psutil.tests import system_namespace
46from psutil.tests import tcp_socketpair
47from psutil.tests import terminate
48from psutil.tests import TestMemoryLeak
49from psutil.tests import unittest
50from psutil.tests import unix_socketpair
51from psutil.tests import wait_for_file
52from psutil.tests import wait_for_pid
53import psutil
54import psutil.tests
55
56# ===================================================================
57# --- Unit tests for test utilities.
58# ===================================================================
59
60
61class TestRetryDecorator(PsutilTestCase):
62
63    @mock.patch('time.sleep')
64    def test_retry_success(self, sleep):
65        # Fail 3 times out of 5; make sure the decorated fun returns.
66
67        @retry(retries=5, interval=1, logfun=None)
68        def foo():
69            while queue:
70                queue.pop()
71                1 / 0
72            return 1
73
74        queue = list(range(3))
75        self.assertEqual(foo(), 1)
76        self.assertEqual(sleep.call_count, 3)
77
78    @mock.patch('time.sleep')
79    def test_retry_failure(self, sleep):
80        # Fail 6 times out of 5; th function is supposed to raise exc.
81        @retry(retries=5, interval=1, logfun=None)
82        def foo():
83            while queue:
84                queue.pop()
85                1 / 0
86            return 1
87
88        queue = list(range(6))
89        self.assertRaises(ZeroDivisionError, foo)
90        self.assertEqual(sleep.call_count, 5)
91
92    @mock.patch('time.sleep')
93    def test_exception_arg(self, sleep):
94        @retry(exception=ValueError, interval=1)
95        def foo():
96            raise TypeError
97
98        self.assertRaises(TypeError, foo)
99        self.assertEqual(sleep.call_count, 0)
100
101    @mock.patch('time.sleep')
102    def test_no_interval_arg(self, sleep):
103        # if interval is not specified sleep is not supposed to be called
104
105        @retry(retries=5, interval=None, logfun=None)
106        def foo():
107            1 / 0
108
109        self.assertRaises(ZeroDivisionError, foo)
110        self.assertEqual(sleep.call_count, 0)
111
112    @mock.patch('time.sleep')
113    def test_retries_arg(self, sleep):
114
115        @retry(retries=5, interval=1, logfun=None)
116        def foo():
117            1 / 0
118
119        self.assertRaises(ZeroDivisionError, foo)
120        self.assertEqual(sleep.call_count, 5)
121
122    @mock.patch('time.sleep')
123    def test_retries_and_timeout_args(self, sleep):
124        self.assertRaises(ValueError, retry, retries=5, timeout=1)
125
126
127class TestSyncTestUtils(PsutilTestCase):
128
129    def test_wait_for_pid(self):
130        wait_for_pid(os.getpid())
131        nopid = max(psutil.pids()) + 99999
132        with mock.patch('psutil.tests.retry.__iter__', return_value=iter([0])):
133            self.assertRaises(psutil.NoSuchProcess, wait_for_pid, nopid)
134
135    def test_wait_for_file(self):
136        testfn = self.get_testfn()
137        with open(testfn, 'w') as f:
138            f.write('foo')
139        wait_for_file(testfn)
140        assert not os.path.exists(testfn)
141
142    def test_wait_for_file_empty(self):
143        testfn = self.get_testfn()
144        with open(testfn, 'w'):
145            pass
146        wait_for_file(testfn, empty=True)
147        assert not os.path.exists(testfn)
148
149    def test_wait_for_file_no_file(self):
150        testfn = self.get_testfn()
151        with mock.patch('psutil.tests.retry.__iter__', return_value=iter([0])):
152            self.assertRaises(IOError, wait_for_file, testfn)
153
154    def test_wait_for_file_no_delete(self):
155        testfn = self.get_testfn()
156        with open(testfn, 'w') as f:
157            f.write('foo')
158        wait_for_file(testfn, delete=False)
159        assert os.path.exists(testfn)
160
161    def test_call_until(self):
162        ret = call_until(lambda: 1, "ret == 1")
163        self.assertEqual(ret, 1)
164
165
166class TestFSTestUtils(PsutilTestCase):
167
168    def test_open_text(self):
169        with open_text(__file__) as f:
170            self.assertEqual(f.mode, 'rt')
171
172    def test_open_binary(self):
173        with open_binary(__file__) as f:
174            self.assertEqual(f.mode, 'rb')
175
176    def test_safe_mkdir(self):
177        testfn = self.get_testfn()
178        safe_mkdir(testfn)
179        assert os.path.isdir(testfn)
180        safe_mkdir(testfn)
181        assert os.path.isdir(testfn)
182
183    def test_safe_rmpath(self):
184        # test file is removed
185        testfn = self.get_testfn()
186        open(testfn, 'w').close()
187        safe_rmpath(testfn)
188        assert not os.path.exists(testfn)
189        # test no exception if path does not exist
190        safe_rmpath(testfn)
191        # test dir is removed
192        os.mkdir(testfn)
193        safe_rmpath(testfn)
194        assert not os.path.exists(testfn)
195        # test other exceptions are raised
196        with mock.patch('psutil.tests.os.stat',
197                        side_effect=OSError(errno.EINVAL, "")) as m:
198            with self.assertRaises(OSError):
199                safe_rmpath(testfn)
200            assert m.called
201
202    def test_chdir(self):
203        testfn = self.get_testfn()
204        base = os.getcwd()
205        os.mkdir(testfn)
206        with chdir(testfn):
207            self.assertEqual(os.getcwd(), os.path.join(base, testfn))
208        self.assertEqual(os.getcwd(), base)
209
210
211class TestProcessUtils(PsutilTestCase):
212
213    def test_reap_children(self):
214        subp = self.spawn_testproc()
215        p = psutil.Process(subp.pid)
216        assert p.is_running()
217        reap_children()
218        assert not p.is_running()
219        assert not psutil.tests._pids_started
220        assert not psutil.tests._subprocesses_started
221
222    def test_spawn_children_pair(self):
223        child, grandchild = self.spawn_children_pair()
224        self.assertNotEqual(child.pid, grandchild.pid)
225        assert child.is_running()
226        assert grandchild.is_running()
227        children = psutil.Process().children()
228        self.assertEqual(children, [child])
229        children = psutil.Process().children(recursive=True)
230        self.assertEqual(len(children), 2)
231        self.assertIn(child, children)
232        self.assertIn(grandchild, children)
233        self.assertEqual(child.ppid(), os.getpid())
234        self.assertEqual(grandchild.ppid(), child.pid)
235
236        terminate(child)
237        assert not child.is_running()
238        assert grandchild.is_running()
239
240        terminate(grandchild)
241        assert not grandchild.is_running()
242
243    @unittest.skipIf(not POSIX, "POSIX only")
244    def test_spawn_zombie(self):
245        parent, zombie = self.spawn_zombie()
246        self.assertEqual(zombie.status(), psutil.STATUS_ZOMBIE)
247
248    def test_terminate(self):
249        # by subprocess.Popen
250        p = self.spawn_testproc()
251        terminate(p)
252        self.assertProcessGone(p)
253        terminate(p)
254        # by psutil.Process
255        p = psutil.Process(self.spawn_testproc().pid)
256        terminate(p)
257        self.assertProcessGone(p)
258        terminate(p)
259        # by psutil.Popen
260        cmd = [PYTHON_EXE, "-c", "import time; time.sleep(60);"]
261        p = psutil.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
262        terminate(p)
263        self.assertProcessGone(p)
264        terminate(p)
265        # by PID
266        pid = self.spawn_testproc().pid
267        terminate(pid)
268        self.assertProcessGone(p)
269        terminate(pid)
270        # zombie
271        if POSIX:
272            parent, zombie = self.spawn_zombie()
273            terminate(parent)
274            terminate(zombie)
275            self.assertProcessGone(parent)
276            self.assertProcessGone(zombie)
277
278
279class TestNetUtils(PsutilTestCase):
280
281    def bind_socket(self):
282        port = get_free_port()
283        with contextlib.closing(bind_socket(addr=('', port))) as s:
284            self.assertEqual(s.getsockname()[1], port)
285
286    @unittest.skipIf(not POSIX, "POSIX only")
287    def test_bind_unix_socket(self):
288        name = self.get_testfn()
289        sock = bind_unix_socket(name)
290        with contextlib.closing(sock):
291            self.assertEqual(sock.family, socket.AF_UNIX)
292            self.assertEqual(sock.type, socket.SOCK_STREAM)
293            self.assertEqual(sock.getsockname(), name)
294            assert os.path.exists(name)
295            assert stat.S_ISSOCK(os.stat(name).st_mode)
296        # UDP
297        name = self.get_testfn()
298        sock = bind_unix_socket(name, type=socket.SOCK_DGRAM)
299        with contextlib.closing(sock):
300            self.assertEqual(sock.type, socket.SOCK_DGRAM)
301
302    def tcp_tcp_socketpair(self):
303        addr = ("127.0.0.1", get_free_port())
304        server, client = tcp_socketpair(socket.AF_INET, addr=addr)
305        with contextlib.closing(server):
306            with contextlib.closing(client):
307                # Ensure they are connected and the positions are
308                # correct.
309                self.assertEqual(server.getsockname(), addr)
310                self.assertEqual(client.getpeername(), addr)
311                self.assertNotEqual(client.getsockname(), addr)
312
313    @unittest.skipIf(not POSIX, "POSIX only")
314    @unittest.skipIf(NETBSD or FREEBSD,
315                     "/var/run/log UNIX socket opened by default")
316    def test_unix_socketpair(self):
317        p = psutil.Process()
318        num_fds = p.num_fds()
319        assert not p.connections(kind='unix')
320        name = self.get_testfn()
321        server, client = unix_socketpair(name)
322        try:
323            assert os.path.exists(name)
324            assert stat.S_ISSOCK(os.stat(name).st_mode)
325            self.assertEqual(p.num_fds() - num_fds, 2)
326            self.assertEqual(len(p.connections(kind='unix')), 2)
327            self.assertEqual(server.getsockname(), name)
328            self.assertEqual(client.getpeername(), name)
329        finally:
330            client.close()
331            server.close()
332
333    def test_create_sockets(self):
334        with create_sockets() as socks:
335            fams = collections.defaultdict(int)
336            types = collections.defaultdict(int)
337            for s in socks:
338                fams[s.family] += 1
339                # work around http://bugs.python.org/issue30204
340                types[s.getsockopt(socket.SOL_SOCKET, socket.SO_TYPE)] += 1
341            self.assertGreaterEqual(fams[socket.AF_INET], 2)
342            if supports_ipv6():
343                self.assertGreaterEqual(fams[socket.AF_INET6], 2)
344            if POSIX and HAS_CONNECTIONS_UNIX:
345                self.assertGreaterEqual(fams[socket.AF_UNIX], 2)
346            self.assertGreaterEqual(types[socket.SOCK_STREAM], 2)
347            self.assertGreaterEqual(types[socket.SOCK_DGRAM], 2)
348
349
350@serialrun
351class TestMemLeakClass(TestMemoryLeak):
352
353    def test_times(self):
354        def fun():
355            cnt['cnt'] += 1
356        cnt = {'cnt': 0}
357        self.execute(fun, times=10, warmup_times=15)
358        self.assertEqual(cnt['cnt'], 26)
359
360    def test_param_err(self):
361        self.assertRaises(ValueError, self.execute, lambda: 0, times=0)
362        self.assertRaises(ValueError, self.execute, lambda: 0, times=-1)
363        self.assertRaises(ValueError, self.execute, lambda: 0, warmup_times=-1)
364        self.assertRaises(ValueError, self.execute, lambda: 0, tolerance=-1)
365        self.assertRaises(ValueError, self.execute, lambda: 0, retries=-1)
366
367    @retry_on_failure()
368    @unittest.skipIf(CI_TESTING, "skipped on CI")
369    def test_leak_mem(self):
370        ls = []
371
372        def fun(ls=ls):
373            ls.append("x" * 24 * 1024)
374
375        try:
376            # will consume around 3M in total
377            self.assertRaisesRegex(AssertionError, "extra-mem",
378                                   self.execute, fun, times=50)
379        finally:
380            del ls
381
382    def test_unclosed_files(self):
383        def fun():
384            f = open(__file__)
385            self.addCleanup(f.close)
386            box.append(f)
387
388        box = []
389        kind = "fd" if POSIX else "handle"
390        self.assertRaisesRegex(AssertionError, "unclosed " + kind,
391                               self.execute, fun)
392
393    def test_tolerance(self):
394        def fun():
395            ls.append("x" * 24 * 1024)
396        ls = []
397        times = 100
398        self.execute(fun, times=times, warmup_times=0,
399                     tolerance=200 * 1024 * 1024)
400        self.assertEqual(len(ls), times + 1)
401
402    def test_execute_w_exc(self):
403        def fun():
404            1 / 0
405        self.execute_w_exc(ZeroDivisionError, fun)
406        with self.assertRaises(ZeroDivisionError):
407            self.execute_w_exc(OSError, fun)
408
409        def fun():
410            pass
411        with self.assertRaises(AssertionError):
412            self.execute_w_exc(ZeroDivisionError, fun)
413
414
415class TestTestingUtils(PsutilTestCase):
416
417    def test_process_namespace(self):
418        p = psutil.Process()
419        ns = process_namespace(p)
420        ns.test()
421        fun = [x for x in ns.iter(ns.getters) if x[1] == 'ppid'][0][0]
422        self.assertEqual(fun(), p.ppid())
423
424    def test_system_namespace(self):
425        ns = system_namespace()
426        fun = [x for x in ns.iter(ns.getters) if x[1] == 'net_if_addrs'][0][0]
427        self.assertEqual(fun(), psutil.net_if_addrs())
428
429
430class TestOtherUtils(PsutilTestCase):
431
432    def test_is_namedtuple(self):
433        assert is_namedtuple(collections.namedtuple('foo', 'a b c')(1, 2, 3))
434        assert not is_namedtuple(tuple())
435
436
437if __name__ == '__main__':
438    from psutil.tests.runner import run_from_name
439    run_from_name(__file__)
440