1#
2#   mtr  --  a network diagnostic tool
3#   Copyright (C) 2016  Matt Kimball
4#
5#   This program is free software; you can redistribute it and/or modify
6#   it under the terms of the GNU General Public License version 2 as
7#   published by the Free Software Foundation.
8#
9#   This program is distributed in the hope that it will be useful,
10#   but WITHOUT ANY WARRANTY; without even the implied warranty of
11#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12#   GNU General Public License for more details.
13#
14#   You should have received a copy of the GNU General Public License along
15#   with this program; if not, write to the Free Software Foundation, Inc.,
16#   51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
17#
18
19'''Infrastructure for running tests which invoke mtr-packet.'''
20
21import fcntl
22import os
23import select
24import socket
25import subprocess
26import sys
27import time
28import unittest
29
30#
31#  typing is used for mypy type checking, but isn't required to run,
32#  so it's okay if we can't import it.
33#
34try:
35    # pylint: disable=locally-disabled, unused-import
36    from typing import Dict, List
37except ImportError:
38    pass
39
40
41IPV6_TEST_HOST = 'google-public-dns-a.google.com'
42
43
44class MtrPacketExecuteError(Exception):
45    "Exception raised when MtrPacketTest can't execute mtr-packet"
46    pass
47
48
49class ReadReplyTimeout(Exception):
50    'Exception raised by TestProbe.read_reply upon timeout'
51
52    pass
53
54
55class WriteCommandTimeout(Exception):
56    'Exception raised by TestProbe.write_command upon timeout'
57
58    pass
59
60
61class MtrPacketReplyParseError(Exception):
62    "Exception raised when MtrPacketReply can't parse the reply string"
63
64    pass
65
66
67class PacketListenError(Exception):
68    'Exception raised when we have unexpected results from mtr-packet-listen'
69
70    pass
71
72
73def set_nonblocking(file_descriptor):  # type: (int) -> None
74    'Put a file descriptor into non-blocking mode'
75
76    flags = fcntl.fcntl(file_descriptor, fcntl.F_GETFL)
77
78    # pylint: disable=locally-disabled, no-member
79    fcntl.fcntl(file_descriptor, fcntl.F_SETFL, flags | os.O_NONBLOCK)
80
81
82def check_for_local_ipv6():
83    '''Check for IPv6 support on the test host, to see if we should skip
84    the IPv6 tests'''
85
86    addrinfo = socket.getaddrinfo(IPV6_TEST_HOST, 1, socket.AF_INET6)
87    if len(addrinfo):
88        addr = addrinfo[0][4]
89
90    #  Create a UDP socket and check to see it can be connected to
91    #  IPV6_TEST_HOST.  (Connecting UDP requires no packets sent, just
92    #  a route present.)
93    sock = socket.socket(
94        socket.AF_INET6, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
95
96    connect_success = False
97    try:
98        sock.connect(addr)
99        connect_success = True
100    except socket.error:
101        pass
102
103    sock.close()
104
105    if not connect_success:
106        sys.stderr.write(
107            'This host has no IPv6.  Skipping IPv6 tests.\n')
108
109    return connect_success
110
111
112HAVE_IPV6 = check_for_local_ipv6()
113
114
115# pylint: disable=locally-disabled, too-few-public-methods
116class MtrPacketReply(object):
117    'A parsed reply from mtr-packet'
118
119    def __init__(self, reply):  # type: (unicode) -> None
120        self.token = 0  # type: int
121        self.command_name = None  # type: unicode
122        self.argument = {}  # type: Dict[unicode, unicode]
123
124        self.parse_reply(reply)
125
126    def parse_reply(self, reply):  # type (unicode) -> None
127        'Parses a reply string into members for the instance of this class'
128
129        tokens = reply.split()  # type List[unicode]
130
131        try:
132            self.token = int(tokens[0])
133            self.command_name = tokens[1]
134        except IndexError:
135            raise MtrPacketReplyParseError(reply)
136
137        i = 2
138        while i < len(tokens):
139            try:
140                name = tokens[i]
141                value = tokens[i + 1]
142            except IndexError:
143                raise MtrPacketReplyParseError(reply)
144
145            self.argument[name] = value
146            i += 2
147
148
149class PacketListen(object):
150    'A test process which listens for a single packet'
151
152    def __init__(self, *args):
153        self.process_args = list(args)  # type: List[unicode]
154        self.listen_process = None  # type: subprocess.Popen
155        self.attrib = None  # type: Dict[unicode, unicode]
156
157    def __enter__(self):
158        try:
159            self.listen_process = subprocess.Popen(
160                ['./mtr-packet-listen'] + self.process_args,
161                stdin=subprocess.PIPE,
162                stdout=subprocess.PIPE)
163        except OSError:
164            raise PacketListenError('unable to launch mtr-packet-listen')
165
166        status = self.listen_process.stdout.readline().decode('utf-8')
167        if status != 'status listening\n':
168            raise PacketListenError('unexpected status')
169
170        return self
171
172    def __exit__(self, exc_type, exc_value, traceback):
173        self.wait_for_exit()
174
175        self.attrib = {}
176        for line in self.listen_process.stdout.readlines():
177            tokens = line.decode('utf-8').split()
178
179            if len(tokens) >= 2:
180                name = tokens[0]
181                value = tokens[1]
182
183                self.attrib[name] = value
184
185        self.listen_process.stdin.close()
186        self.listen_process.stdout.close()
187
188    def wait_for_exit(self):
189        '''Poll the subprocess for up to ten seconds, until it exits.
190
191        We need to wait for its exit to ensure we are able to read its
192        output.'''
193
194        wait_time = 10
195        wait_step = 0.1
196
197        steps = int(wait_time / wait_step)
198
199        exit_value = None
200
201        # pylint: disable=locally-disabled, unused-variable
202        for i in range(steps):
203            exit_value = self.listen_process.poll()
204            if exit_value is not None:
205                break
206
207            time.sleep(wait_step)
208
209        if exit_value is None:
210            raise PacketListenError('mtr-packet-listen timeout')
211
212        if exit_value != 0:
213            raise PacketListenError('mtr-packet-listen unexpected error')
214
215
216class MtrPacketTest(unittest.TestCase):
217    '''Base class for tests invoking mtr-packet.
218
219    Start a new mtr-packet subprocess for each test, and kill it
220    at the conclusion of the test.
221
222    Provide methods for writing commands and reading replies.
223    '''
224
225    def __init__(self, *args):
226        self.reply_buffer = None  # type: unicode
227        self.packet_process = None  # type: subprocess.Popen
228        self.stdout_fd = None  # type: int
229
230        super(MtrPacketTest, self).__init__(*args)
231
232    def setUp(self):
233        'Set up a test case by spawning a mtr-packet process'
234
235        packet_path = os.environ.get('MTR_PACKET', './mtr-packet')
236
237        self.reply_buffer = ''
238        try:
239            self.packet_process = subprocess.Popen(
240                [packet_path],
241                stdin=subprocess.PIPE,
242                stdout=subprocess.PIPE)
243        except OSError:
244            raise MtrPacketExecuteError(packet_path)
245
246        #  Put the mtr-packet process's stdout in non-blocking mode
247        #  so that we can read from it without a timeout when
248        #  no reply is available.
249        self.stdout_fd = self.packet_process.stdout.fileno()
250        set_nonblocking(self.stdout_fd)
251
252        self.stdin_fd = self.packet_process.stdin.fileno()
253        set_nonblocking(self.stdin_fd)
254
255    def tearDown(self):
256        'After a test, kill the running mtr-packet instance'
257
258        self.packet_process.stdin.close()
259        self.packet_process.stdout.close()
260
261        try:
262            self.packet_process.kill()
263        except OSError:
264            return
265
266    def parse_reply(self, timeout=10.0):  # type: (float) -> MtrPacketReply
267        '''Read the next reply from mtr-packet and parse it into
268        an MtrPacketReply object.'''
269
270        reply_str = self.read_reply(timeout)
271
272        return MtrPacketReply(reply_str)
273
274    def read_reply(self, timeout=10.0):  # type: (float) -> unicode
275        '''Read the next reply from mtr-packet.
276
277        Attempt to read the next command reply from mtr-packet.  If no reply
278        is available withing the timeout time, raise ReadReplyTimeout
279        instead.'''
280
281        start_time = time.time()
282
283        #  Read from mtr-packet until either the timeout time has elapsed
284        #  or we read a newline character, which indicates a finished
285        #  reply.
286        while True:
287            now = time.time()
288            elapsed = now - start_time
289
290            select_time = timeout - elapsed
291            if select_time < 0:
292                select_time = 0
293
294            select.select([self.stdout_fd], [], [], select_time)
295
296            reply_bytes = None
297
298            try:
299                reply_bytes = os.read(self.stdout_fd, 1024)
300            except OSError:
301                pass
302
303            if reply_bytes:
304                self.reply_buffer += reply_bytes.decode('utf-8')
305
306            #  If we have read a newline character, we can stop waiting
307            #  for more input.
308            newline_ix = self.reply_buffer.find('\n')
309            if newline_ix != -1:
310                break
311
312            if elapsed >= timeout:
313                raise ReadReplyTimeout()
314
315        reply = self.reply_buffer[:newline_ix]
316        self.reply_buffer = self.reply_buffer[newline_ix + 1:]
317        return reply
318
319    def write_command(self, cmd, timeout=10.0):
320        # type: (unicode, float) -> None
321
322        '''Send a command string to the mtr-packet instance, timing out
323        if we are unable to write for an extended period of time.  The
324        timeout is to avoid deadlocks with the child process where both
325        the parent and the child are writing to their end of the pipe
326        and expecting the other end to be reading.'''
327
328        command_str = cmd + '\n'
329        command_bytes = command_str.encode('utf-8')
330
331        start_time = time.time()
332
333        while True:
334            now = time.time()
335            elapsed = now - start_time
336
337            select_time = timeout - elapsed
338            if select_time < 0:
339                select_time = 0
340
341            select.select([], [self.stdin_fd], [], select_time)
342
343            bytes_written = 0
344            try:
345                bytes_written = os.write(self.stdin_fd, command_bytes)
346            except OSError:
347                pass
348
349            command_bytes = command_bytes[bytes_written:]
350            if not len(command_bytes):
351                break
352
353            if elapsed >= timeout:
354                raise WriteCommandTimeout()
355
356
357def check_running_as_root():
358    'Print a warning to stderr if we are not running as root.'
359
360    # pylint: disable=locally-disabled, no-member
361    if sys.platform != 'cygwin' and os.getuid() > 0:
362        sys.stderr.write(
363            'Warning: many tests require running as root\n')
364