1#!/usr/bin/env python
2#######################################################################
3# Copyright (C) 2008-2020 by Carnegie Mellon University.
4#
5# @OPENSOURCE_LICENSE_START@
6# See license information in ../../../LICENSE.txt
7# @OPENSOURCE_LICENSE_END@
8#
9#######################################################################
10
11#######################################################################
12# $SiLK: sendrcv_tests.py ef14e54179be 2020-04-14 21:57:45Z mthomas $
13#######################################################################
14
15import sys
16import re
17import os
18import os.path
19import tempfile
20import random
21import stat
22import fcntl
23import re
24import shutil
25import itertools
26import optparse
27import time
28import signal
29import select
30import subprocess
31import socket
32import json
33import traceback
34import struct
35import datetime
36
37conv = json
38
39srcdir = os.environ.get("srcdir")
40if srcdir:
41    sys.path.insert(0, os.path.join(srcdir, "tests"))
42from gencerts import reset_all_certs_and_keys, generate_signed_cert, generate_ca_cert, PASSWORD
43from config_vars import config_vars
44from daemon_test import get_ephemeral_port, Dirobject
45
46os.environ["RWSENDER_TLS_PASSWORD"] = PASSWORD
47os.environ["RWRECEIVER_TLS_PASSWORD"] = PASSWORD
48
49try:
50    import hashlib
51    md5_new = hashlib.md5
52    sha1_new = hashlib.sha1
53except ImportError:
54    import md5
55    md5_new = md5.new
56    import sha
57    sha1_new = sha.new
58
59
60global_int     = 0
61
62KILL_DELAY     = 20
63CHUNKSIZE      = 2048
64OVERWRITE      = False
65LOG_LEVEL      = "info"
66LOG_OUTPUT     = []
67
68# do not remove the directory after the test
69NO_REMOVE      = False
70FILE_LIST_FILE = None
71
72# tests to run (in the order in which to run them) if no tests are
73# specified on the command line
74ALL_TESTS = ['testConnectOnlyIPv4Addr', 'testConnectOnlyHostname',
75             'testConnectOnlyTLS', 'testConnectOnlyIPv6Addr',
76             'testSendRcvStopReceiverServer', 'testSendRcvStopSenderServer',
77             'testSendRcvStopReceiverClient', 'testSendRcvStopSenderClient',
78             'testSendRcvKillReceiverServer', 'testSendRcvKillSenderServer',
79             'testSendRcvKillReceiverClient', 'testSendRcvKillSenderClient',
80             'testSendRcvStopReceiverServerTLS',
81             'testSendRcvStopSenderServerTLS',
82             'testSendRcvStopReceiverClientTLS',
83             'testSendRcvStopSenderClientTLS',
84             'testSendRcvKillReceiverServerTLS',
85             'testSendRcvKillSenderServerTLS',
86             'testSendRcvKillReceiverClientTLS',
87             'testSendRcvKillSenderClientTLS',
88             'testMultiple', 'testMultipleTLS',
89             'testFilter', 'testPostCommand']
90
91rfiles = None
92
93TIMEOUT_FACTOR = 1.0
94
95
96if sys.version_info[0] >= 3:
97    coding = {"encoding": "ascii"}
98else:
99    coding = {}
100
101
102class TriggerError(Exception):
103    pass
104
105class FileTransferError(Exception):
106    pass
107
108def global_log(name, msg, timestamp=True):
109    if not name:
110        name = sys.argv[0]
111    out = name + ": "
112    if timestamp:
113        out += datetime.datetime.now().strftime("%b %d %H:%M:%S ")
114    out += msg
115    if out[-1] != "\n":
116        out += "\n"
117    for dest in LOG_OUTPUT:
118        dest.write(out)
119        dest.flush()
120
121def t(dur):
122    return int(dur * TIMEOUT_FACTOR)
123
124def setup():
125    global rfiles
126    rfiles = []
127    top_tests = os.path.dirname(FILE_LIST_FILE)
128    file_list = open(FILE_LIST_FILE, "r", 1)
129    for line in file_list:
130        (path, size, md5) = line.split()
131        path = os.path.join(top_tests, path)
132        # empty string is placeholder for SHA1 digest
133        rfiles.append((path, (int(size), "", md5)))
134
135def teardown():
136    pass
137
138def tls_supported():
139    return eval(config_vars.get('SK_ENABLE_GNUTLS', None))
140
141def ipv6_supported():
142    return eval(config_vars.get('SK_ENABLE_INET6_NETWORKING', None))
143
144trigger_id = 0
145
146def trigger(*specs, **kwd):
147    global trigger_id
148    trigger_id += 1
149    pid = kwd.get('pid', True)
150    first = kwd.get('first', False)
151    pipes = {}
152    count = len(specs)
153    retval = []
154    for daemon, timeout, match in specs:
155        daemon.dump(('trigger',
156                     {'pid': pid, 'id': trigger_id,
157                      'match': match, 'timeout': t(timeout)}))
158        pipes[daemon.pipe] = (daemon, timeout, match, len(retval))
159        retval.append(None)
160    while count:
161        (readers, writers, x) = select.select(pipes.keys(), [], [])
162        for reader in readers:
163            data = pipes[reader]
164            try:
165                rv = data[0].load()
166            except EOFError:
167                raise TriggerError(data[:3])
168            if rv[0] != 'trigger':
169                # Trigger failed
170                raise TriggerError(data[:3])
171            if rv[1] != trigger_id:
172                continue
173            if rv[2] == False:
174                raise TriggerError(data[:3])
175            count -= 1
176            retval[data[3]] = rv[3]
177            del pipes[reader]
178            if first:
179                return retval
180    return retval
181
182def check_started(clients, servers, tls):
183    if tls:
184        tcp_tls = r"\(TCP, TLS\)"
185        timeout = 120
186    else:
187        tcp_tls = r"\(TCP\)"
188        timeout = 10
189    watches = itertools.chain(
190        ((c, timeout, r"Attempting to connect to \S+ %s" % tcp_tls)
191         for c in clients),
192        ((s, timeout, r"Bound to \S+ for listening %s" % tcp_tls)
193         for s in servers))
194    trigger(*watches)
195
196def check_connected(clients, servers, timeout=25):
197    if sys.version_info[0] < 3:
198        zipper = itertools.izip_longest
199    else:
200        zipper = itertools.zip_longest
201    for c,s in zipper(clients, servers):
202        watches = []
203        if c is not None:
204            watches = watches + [
205                (s, timeout, "Connected to remote %s" % c.name) for s in servers
206            ]
207        if s is not None:
208            watches = watches + [
209                (c, timeout, "Connected to remote %s" % s.name) for c in clients
210            ]
211        trigger(*watches)
212
213def create_random_file(suffix="", prefix="random", dir=None, size=(0, 0)):
214    (handle, path) = tempfile.mkstemp(suffix, prefix, dir)
215    f = os.fdopen(handle, "w")
216    numbytes = random.randint(size[0], size[1])
217    totalbytes = numbytes
218    #checksum_sha = sha1_new()
219    checksum_md5 = md5_new()
220    while numbytes:
221        length = min(numbytes, CHUNKSIZE)
222        try:
223            bytes = os.urandom(length)
224        except NotImplementedError:
225            bytes = ''.join(chr(random.getrandbits(8))
226                            for x in range(0, length))
227        f.write(bytes)
228        #checksum_sha.update(bytes)
229        checksum_md5.update(bytes)
230        numbytes -= length
231    f.close()
232    # empty string is placeholder for SHA1 digest
233    return (path, (totalbytes, "", checksum_md5.hexdigest()))
234
235def checksum_file(path):
236    f = open(path, 'rb')
237    #checksum_sha = sha1_new()
238    checksum_md5 = md5_new()
239    size = os.fstat(f.fileno())[stat.ST_SIZE]
240    data = f.read(CHUNKSIZE)
241    while data:
242        #checksum_sha.update(data)
243        checksum_md5.update(data)
244        data = f.read(CHUNKSIZE)
245    f.close()
246    # empty string is placeholder for SHA1 digest
247    return (size, "", checksum_md5.hexdigest())
248
249
250sconv = "L"
251slen  = struct.calcsize(sconv)
252
253
254class Daemon(Dirobject):
255
256    def __init__(self, name=None, log_level="info", prog_env=None,
257                 verbose=True, **kwds):
258        global global_int
259        Dirobject.__init__(self, **kwds)
260        if name:
261            self.name = name
262        else:
263            self.name = type(self).__name__ + str(global_int)
264            global_int += 1
265        self.process = None
266        self.logdata = []
267        self.log_level = log_level
268        self._prog_env = prog_env
269        self.daemon = True
270        self.verbose = verbose
271        self.pipe = None
272        self.pid = None
273        self.trigger = None
274        self.timeout = None
275        self.channels = []
276        self.pending_line = None
277
278    def printv(self, *args):
279        if self.pid is None or self.pid != 0:
280            me = "parent"
281        else:
282            me = "child"
283        if self.verbose:
284            global_log(self.name, me + ("[%s]:" % os.getpid()) +
285                       ' '.join(str(x) for x in args))
286
287    def get_executable(self):
288        return os.environ.get(self._prog_env,
289                              os.path.join(".", self.exe_name))
290
291    def get_args(self):
292        args = self.get_executable().split()
293        args.extend([ '--no-daemon',
294                      '--log-dest', 'stderr',
295                      '--log-level', self.log_level])
296        return args
297
298    def init(self):
299        pass
300
301    def log_verbose(self, msg):
302        if self.verbose:
303            global_log(self.name, msg)
304
305    def _handle_log(self, fd):
306        # 'fd' is the log of the process, which supports non-blocking
307        # reads.  read from the log until there is thing left to read,
308        # so that select() on 'fd' will work correctly
309        got_line = False
310        try:
311            # when no-more-data, following returns empty string in
312            # Python3, but throws IOError [EAGAIN] in Python2
313            for line in fd:
314                line = str(line, **coding)
315                # handle reading a partial line last time and this time
316                if self.pending_line:
317                    line = self.pending_line + line
318                    self.pending_line = None
319                if line[-1] != "\n":
320                    self.pending_line = line
321                    break
322                got_line = True
323                self.logdata.append(line)
324                global_log(self.name, line, timestamp=False)
325                if self.trigger:
326                    match = self.trigger['re'].search(line)
327                    if match:
328                        self.log_verbose(
329                            'Trigger fired for "%s"' % self.trigger['match'])
330                        self.dump(('trigger', self.trigger['id'], True, line))
331                        self.timeout = None
332                        self.trigger = None
333        except IOError:
334            pass
335        if not got_line:
336            return False
337        return True
338
339    def _handle_parent(self):
340        retval = False
341        request = self.load()
342        # self.log_verbose("Handling %s" % request)
343        if request[0] == 'stop':
344            self._stop()
345        if request[0] == 'start':
346            if self.process is not None:
347                self.process.poll()
348                if self.process.returncode is None:
349                    raise RuntimeError()
350            self._start()
351        elif request[0] == 'kill':
352            self._kill()
353        elif request[0] == 'end':
354            self._end()
355            self.dump(("stopped", self.process.returncode))
356        elif request[0] == 'exit':
357            try:
358                self.pipe.close()
359            except:
360                pass
361            os._exit(0)
362        elif request[0] == 'trigger':
363            # search previously captured output from the application
364            self.trigger = request[1]
365            if self.trigger['pid']:
366                regexp = re.compile(r"\[%s\].*%s" %
367                                    (self.process.pid, self.trigger['match']))
368            else:
369                regexp = re.compile(self.trigger['match'])
370            for line in self.logdata:
371                match = regexp.search(line)
372                if match:
373                    self.log_verbose(
374                        'Trigger fired for "%s"' % self.trigger['match'])
375                    self.dump(('trigger', self.trigger['id'], True, line))
376                    self.trigger = None
377                    return retval
378            self.trigger['re'] = regexp
379            self.timeout = time.time() + self.trigger['timeout']
380        return retval
381
382    def _child(self):
383        while self.channels:
384            # channels contains a socket to the parent and to the
385            # stderr of the application's process
386            (readers, writers, x) = select.select(self.channels, [], [], 1)
387            if self.process is not None and self.process.stderr in readers:
388                rv = self._handle_log(self.process.stderr)
389                if not rv:
390                    self.channels.remove(self.process.stderr)
391            if self.pipe in readers:
392                if self._handle_parent():
393                    try:
394                        self.pipe.close()
395                    except:
396                        pass
397                    self.channels.remove(self.pipe)
398            if self.timeout is not None and time.time() > self.timeout:
399                self.log_verbose(
400                    'Trigger timed out after %s seconds: "%s"' %
401                    (self.trigger['timeout'], self.trigger['match']))
402                self.dump(('trigger', self.trigger['id'], False))
403                self.timeout = None
404                self.trigger = None
405
406    def expect(self, cmd):
407        try:
408            while True:
409                rv = self.load()
410                # self.log_verbose("Retrieved %s" % rv)
411                if cmd == rv[0]:
412                    break
413            return rv
414        except EOFError:
415            if cmd in ['stop', 'kill', 'end']:
416                return ('stopped', None)
417            raise
418
419    def start(self):
420        if self.pid is not None:
421            self.dump(('start',))
422            self.expect('start')
423            return None
424        pipes = socket.socketpair()
425        self.pid = os.fork()
426        if self.pid != 0:
427            pipes[1].close()
428            self.pipe = pipes[0]
429            self.dump(('start',))
430            self.expect('start')
431            return None
432        try:
433            pipes[0].close()
434            self.pipe = pipes[1]
435            self.channels = [self.pipe]
436            self._child()
437        except:
438            traceback.print_exc()
439            self._kill()
440        finally:
441            try:
442                self.pipe.close()
443            except:
444                pass
445            os._exit(0)
446
447    def _start(self):
448        if self.process is not None and self.process.stderr in self.channels:
449            self.channels.remove(self.process.stderr)
450        # work around issue #11459 in python 3.1.[0-3], 3.2.0 (where a
451        # process is line buffered despite bufsize=0) by making the
452        # buffer large, making the stream non-blocking, and getting
453        # everything available from the stream when we read
454        self.process = subprocess.Popen(self.get_args(), bufsize = -1,
455                                        stderr=subprocess.PIPE)
456        fcntl.fcntl(self.process.stderr, fcntl.F_SETFL,
457                    (os.O_NONBLOCK
458                     | fcntl.fcntl(self.process.stderr, fcntl.F_GETFL)))
459        self.channels.append(self.process.stderr)
460        self.dump(('start',))
461
462    def dump(self, arg):
463        value = conv.dumps(arg).encode('ascii')
464        data = struct.pack(sconv, len(value)) + value
465        try:
466            self.pipe.sendall(data)
467        except IOError:
468            pass
469
470    def load(self):
471        rv = self.pipe.recv(slen)
472        if len(rv) != slen:
473            raise RuntimeError
474        (length,) = struct.unpack(sconv, rv)
475        value = b""
476        while len(value) != length:
477            value += self.pipe.recv(length)
478        retval = conv.loads(value.decode('ascii'))
479        return retval
480
481    def kill(self):
482        if self.pid is None:
483            return None
484        self.dump(('kill',))
485        self.expect('kill')
486
487    def stop(self):
488        if self.pid is None:
489            return None
490        self.dump(('stop',))
491        self.expect('stop')
492
493    def end(self):
494        if self.pid is None:
495            return None
496        self.dump(('end',))
497        rv = self.expect('stopped')[1]
498        if rv is not None:
499            if rv >= 0:
500                self.log_verbose("Exited with status %s" % rv)
501            else:
502                self.log_verbose("Exited with signal %s" % (-rv))
503        return rv
504
505    def exit(self):
506        self.end()
507        if self.pid is None:
508            return None
509        self.dump(('exit',))
510        try:
511            os.waitpid(self.pid, 0)
512        except OSError:
513            pass
514
515    def _kill(self):
516        if self.process is not None and self.process.returncode is None:
517            try:
518                self.log_verbose("Sending SIGKILL")
519                os.kill(self.process.pid, signal.SIGKILL)
520            except OSError:
521                pass
522        self.dump(('kill',))
523
524    def _stop(self):
525        if self.process is not None and self.process.returncode is None:
526            try:
527                self.log_verbose("Sending SIGTERM")
528                os.kill(self.process.pid, signal.SIGTERM)
529            except OSError:
530                pass
531        self.dump(('stop',))
532
533    def _end(self):
534        target = time.time() + KILL_DELAY
535        self.process.poll()
536        while self.process.returncode is None and time.time() < target:
537            self.process.poll()
538            time.sleep(1)
539        if self.process.returncode is not None:
540            while self._handle_log(self.process.stderr):
541                pass
542            return True
543        self._kill()
544        self.process.poll()
545        while self.process.returncode is None:
546            self.process.poll()
547            time.sleep(1)
548        while self._handle_log(self.process.stderr):
549            pass
550        return False
551
552
553class Sndrcv_base(Daemon):
554
555    def __init__(self, name=None, **kwds):
556        Daemon.__init__(self, name, **kwds)
557        self.mode = "client"
558        self.listen = None
559        self.port = None
560        self.clients = list()
561        self.servers = list()
562        self.ca_cert = None
563        self.ca_key = None
564        self.cert = None
565
566    def create_cert(self):
567        self.cert = generate_signed_cert(self.basedir,
568                                         (self.ca_key, self.ca_cert),
569                                         os.path.join(self.basedir, "key.pem"),
570                                         os.path.join(self.basedir, "key.p12"))
571
572    def set_ca(self, ca_key, ca_cert):
573        self.ca_key = ca_key
574        self.ca_cert = ca_cert
575
576    def init(self):
577        Daemon.init(self)
578        if self.ca_cert:
579            self.dirs.append("cert")
580        self.create_dirs()
581        if self.ca_cert:
582            self.create_cert()
583
584    def get_args(self):
585        args = Daemon.get_args(self)
586        args += ['--mode', self.mode,
587                 '--identifier', self.name]
588        if self.ca_cert:
589            args += ['--tls-ca', os.path.abspath(self.ca_cert),
590                     '--tls-pkcs12', os.path.abspath(self.cert),
591                     '--tls-priority=NORMAL',
592                     #'--tls-priority=SECURE128:+SECURE192:-VERS-ALL:+VERS-TLS1.2',
593                     #'--tls-security=ultra',
594                     '--tls-debug-level=0']
595        if self.mode == "server":
596            if self.listen is not None:
597                args += ['--server-port', "%s:%s" % (self.listen, self.port)]
598            else:
599                args += ['--server-port', str(self.port)]
600            for client in self.clients:
601                 args += ['--client-ident', client]
602        else:
603            for (ident, addr, port) in self.servers:
604                args += ['--server-address',
605                         ':'.join((ident, addr, str(port)))]
606        return args
607
608    def _check_file(self, dir, finfo):
609        (path, (size, ck_sha, ck_md5)) = finfo
610        path = os.path.join(self.dirname[dir], os.path.basename(path))
611        if not os.path.exists(path):
612            return ("Does not exist", path)
613        (nsize, ck2_sha, ck2_md5) = checksum_file(path)
614        if nsize != size:
615            return ("Size mismatch (%s != %s)" % (size, nsize), path)
616        if ck2_sha != ck_sha:
617            return ("SHA mismatch (%s != %s)" % (ck_sha, ck2_sha), path)
618        if ck2_md5 != ck_md5:
619            return ("MD5 mismatch (%s != %s)" % (ck_md5, ck2_md5), path)
620        return (None, path)
621
622
623class Rwsender(Sndrcv_base):
624
625    def __init__(self, name=None, polling_interval=5, filters=[],
626                 overwrite=None, log_level=None, **kwds):
627        if log_level is None:
628            log_level = LOG_LEVEL
629        if overwrite is None:
630            overwrite = OVERWRITE
631        Sndrcv_base.__init__(self, name, overwrite=overwrite,
632                             log_level=log_level, prog_env="RWSENDER", **kwds)
633        self.exe_name = "rwsender"
634        self.filters = filters
635        self.polling_interval = polling_interval
636        self.dirs = ["in", "proc", "error"]
637
638    def get_args(self):
639        args = Sndrcv_base.get_args(self)
640        args += ['--incoming-directory', os.path.abspath(self.dirname["in"]),
641                 '--processing-directory',
642                 os.path.abspath(self.dirname["proc"]),
643                 '--error-directory', os.path.abspath(self.dirname["error"]),
644                 '--polling-interval', str(self.polling_interval)]
645        for ident, regexp in self.filters:
646            args.extend(["--filter", ident + ':' + regexp])
647        return args
648
649    def send_random_file(self, suffix="", prefix="random", size=(0, 0)):
650        return create_random_file(suffix = suffix, prefix = prefix,
651                                  dir = self.dirname["in"], size = size)
652
653    def send_files(self, files):
654        for f, data in files:
655            shutil.copy(f, self.dirname["in"])
656
657    def check_error(self, data):
658        return self._check_file("error", data)
659
660
661class Rwreceiver(Sndrcv_base):
662
663    def __init__(self, name=None, post_command=None,
664                 overwrite=None, log_level=None, **kwds):
665        if log_level is None:
666            log_level = LOG_LEVEL
667        if overwrite is None:
668            overwrite = OVERWRITE
669        Sndrcv_base.__init__(self, name, overwrite=overwrite,
670                             log_level=log_level, prog_env="RWRECEIVER", **kwds)
671        self.exe_name = "rwreceiver"
672        self.dirs = ["dest"]
673        self.post_command = post_command
674
675    def get_args(self):
676        args = Sndrcv_base.get_args(self)
677        args += ['--destination-directory',
678                 os.path.abspath(self.dirname["dest"])]
679        if self.post_command:
680            args += ['--post-command', self.post_command]
681        return args
682
683    def check_sent(self, data):
684        return self._check_file("dest", data)
685
686class System(Dirobject):
687
688    def __init__(self):
689        Dirobject.__init__(self)
690        self.create_dirs()
691        self.client_type = None
692        self.server_type = None
693        self.clients = set()
694        self.servers = set()
695        self.ca_cert = None
696        self.ca_key = None
697
698    def create_ca_cert(self):
699        self.ca_key, self.ca_cert = generate_ca_cert(
700            self.basedir, os.path.join(self.basedir, 'ca_cert.pem'))
701
702    def connect(self, clients, servers, tls=False, hostname=None):
703        if tls:
704            self.create_ca_cert()
705        if hostname is None:
706            hostname = os.environ.get("SK_TESTS_SENDRCV_HOSTNAME")
707            if hostname is None:
708                hostname = "localhost"
709        if isinstance(clients, Sndrcv_base):
710            clients = [clients]
711        if isinstance(servers, Sndrcv_base):
712            servers = [servers]
713        for server in servers:
714            server.listen = hostname
715        for client in clients:
716            for server in servers:
717                self._connect(client, server, tls, hostname)
718
719    def _connect(self, client, server, tls, hostname):
720        if not isinstance(client, Sndrcv_base):
721            raise ValueError("Can only connect rwsenders and rwreceivers")
722        if not self.client_type:
723            if isinstance(client, Rwsender):
724                self.client_type = Rwsender
725                self.server_type = Rwreceiver
726            else:
727                self.client_type = Rwreceiver
728                self.server_type = Rwsender
729        if not isinstance(client, self.client_type):
730            raise ValueError("Client must be of type %s" %
731                               self.client_type.__name__)
732        if not isinstance(server, self.server_type):
733            raise ValueError("Server must be of type %s" %
734                               self.server_type.__name__)
735        client.mode = "client"
736        server.mode = "server"
737
738        if server.port is None:
739            server.port = get_ephemeral_port()
740
741        client.servers.append((server.name, hostname, server.port))
742        server.clients.append(client.name)
743
744        self.clients.add(client)
745        self.servers.add(server)
746
747        if tls:
748            client.set_ca(self.ca_key, self.ca_cert)
749            server.set_ca(self.ca_key, self.ca_cert)
750
751    def _forall(self, call, which, *args, **kwds):
752        if which == "clients":
753            it = self.clients
754        elif which == "servers":
755            it = self.servers
756        else:
757            it = itertools.chain(self.clients, self.servers)
758        return [getattr(x, call)(*args, **kwds) for x in it]
759
760    def start(self, which = None):
761        self._forall("init", which)
762        self._forall("start", which)
763
764    def end(self, which = None, noremove=False):
765        self._forall("exit", which)
766        if not noremove:
767            self.remove_basedir()
768
769    def stop(self, which = None):
770        self._forall("stop", which)
771
772
773def _rename_pkcs12(x):
774    if x == "--tls-pkcs12":
775        return "--tls-cert"
776    return x
777
778# Like Rwsender but uses customized TLS keys+certificates
779class RwsenderCert(Rwsender):
780    def __init__(self, name, ca_cert, key, cert, **kwds):
781        Rwsender.__init__(self, name, **kwds)
782        self.ca_cert = ca_cert
783        self.key = key
784        self.cert = cert
785
786    def create_cert(self):
787        pass
788
789    def set_ca(self, x, y):
790        pass
791
792    def get_args(self):
793        args = Rwsender.get_args(self)
794        if self.key:
795            args += ['--tls-key', os.path.abspath(self.key)]
796            return map(_rename_pkcs12, args)
797        return args
798
799# Like Rwreceiver but uses customized TLS keys+certificates
800class RwreceiverCert(Rwreceiver):
801    def __init__(self, name, ca_cert, key, cert, **kwds):
802        Rwreceiver.__init__(self, name, **kwds)
803        self.ca_cert = ca_cert
804        self.key = key
805        self.cert = cert
806
807    def create_cert(self):
808        pass
809
810    def set_ca(self, x, y):
811        pass
812
813    def get_args(self):
814        args = Rwreceiver.get_args(self)
815        if self.key:
816            args += ['--tls-key', os.path.abspath(self.key)]
817            return map(_rename_pkcs12, args)
818        return args
819
820# Like System but uses customized TLS keys+certificates
821class SystemCert(System):
822    def __init__(self):
823        System.__init__(self)
824
825    def create_ca_cert(self):
826        pass
827
828
829#def Sender(**kwds):
830#    return Rwsender(overwrite=OVERWRITE, log_level=LOG_LEVEL, **kwds)
831#
832#def Receiver(**kwds):
833#    return Rwreceiver(overwrite=OVERWRITE, log_level=LOG_LEVEL, **kwds)
834
835def _testConnectAndClose(tls=False, hostname="localhost"):
836    if tls and not tls_supported():
837        return None
838    reset_all_certs_and_keys()
839    s1 = Rwsender()
840    r1 = Rwreceiver()
841    sy = System()
842    try:
843        sy.connect(s1, r1, tls=tls, hostname=hostname)
844        sy.start()
845        check_started([s1], [r1], tls=tls)
846        check_connected([r1], [s1])
847        sy.stop()
848        trigger((s1, 20, "Finished shutting down"),
849                (r1, 20, "Finished shutting down"))
850        trigger((s1, 25, "Stopped logging"),
851                (r1, 25, "Stopped logging"))
852    except:
853        traceback.print_exc()
854        sy.stop()
855        raise
856    finally:
857        sy.end(noremove=NO_REMOVE)
858
859def testConnectOnlyIPv4Addr():
860    """
861    Test to see if we can start a sender/receiver pair, that they
862    connect, and that they shut down properly.
863    """
864    _testConnectAndClose(hostname="127.0.0.1")
865
866def testConnectOnlyHostname():
867    """
868    Test to see if we can start a sender/receiver pair, that they
869    connect, and that they shut down properly.
870    """
871    _testConnectAndClose()
872
873def testConnectOnlyTLS():
874    """
875    Test to see if we can start a sender/receiver pair using TLS,
876    that they connect, and that they shut down properly.
877    """
878    _testConnectAndClose(tls=True)
879
880def testConnectOnlyIPv6Addr():
881    """
882    Test to see if we can start a sender/receiver pair, that they
883    connect, and that they shut down properly.
884    """
885    if not ipv6_supported():
886        return None
887    _testConnectAndClose(hostname="[::1]")
888
889def _testSendRcv(tls=False,
890                 sender_client=True,
891                 stop_sender=False,
892                 kill=False):
893    if tls and not tls_supported():
894        return None
895    global rfiles
896    reset_all_certs_and_keys()
897    s1 = Rwsender()
898    r1 = Rwreceiver()
899    if stop_sender:
900        if kill:
901            stop = s1.kill
902        else:
903            stop = s1.stop
904        end = s1.end
905        start = s1.start
906        stopped = s1
907    else:
908        if kill:
909            stop = r1.kill
910        else:
911            stop = r1.stop
912        end = r1.end
913        start = r1.start
914        stopped = r1
915    s1.create_dirs()
916    s1.send_files(rfiles)
917    sy = System()
918    try:
919        if sender_client:
920            cli = s1
921            srv = r1
922        else:
923            cli = r1
924            srv = s1
925        sy.connect(cli, srv, tls=tls)
926        sy.start()
927        check_started([cli], [srv], tls=tls)
928        check_connected([cli], [srv], 75)
929        trigger((s1, 40, "Succeeded sending .* to %s" % r1.name))
930        stop()
931        if not kill:
932            trigger((stopped, 25, "Stopped logging"))
933        end()
934        start()
935        if stopped == cli:
936            check_started([cli], [], tls=tls)
937        else:
938            check_started([], [srv], tls=tls)
939        check_connected([cli], [srv], 75)
940        try:
941            for path, data in rfiles:
942                base = os.path.basename(path)
943                data = {"name": re.escape(base),
944                        "rname": r1.name, "sname": s1.name}
945                trigger((s1, 40,
946                         ("Succeeded sending .*/%(name)s to %(rname)s|"
947                          "Remote side %(rname)s rejected .*/%(name)s")
948                         % data),
949                        (r1, 40,
950                         "Finished receiving from %(sname)s: %(name)s" % data),
951                        pid=False, first=True)
952        except TriggerError:
953            pass
954        for f in rfiles:
955            (error, path) = r1.check_sent(f)
956            if error:
957                global_log(False, ("Error receiving %s: %s" %
958                                   (os.path.basename(f[0]), error)))
959                raise FileTransferError()
960        sy.stop()
961        trigger((s1, 25, "Stopped logging"),
962                (r1, 25, "Stopped logging"))
963    except KeyboardInterrupt:
964        global_log(False, "%s: Interrupted by C-c", os.getpid())
965        traceback.print_exc()
966        sy.stop()
967        raise
968    except:
969        traceback.print_exc()
970        sy.stop()
971        raise
972    finally:
973        sy.end(noremove=NO_REMOVE)
974
975
976def testSendRcvStopReceiverServer():
977    """
978    Test a sender/receiver connection, with receiver as server,
979    sending files.  Midway the connection is terminated by
980    stopping the sender.  The connection is restarted and resumed.
981    """
982    _testSendRcv(sender_client=True, stop_sender=False, kill=False)
983
984def testSendRcvStopSenderServer():
985    """
986    Test a sender/receiver connection, with sender as server,
987    sending files.  Midway the connection is terminated by
988    stopping the sender.  The connection is restarted and resumed.
989    """
990    _testSendRcv(sender_client=False, stop_sender=True, kill=False)
991
992def testSendRcvStopReceiverClient():
993    """
994    Test a sender/receiver connection, with sender as server,
995    sending files.  Midway the connection is terminated by
996    stopping the receiver.  The connection is restarted and resumed.
997    """
998    _testSendRcv(sender_client=False, stop_sender=False, kill=False)
999
1000def testSendRcvStopSenderClient():
1001    """
1002    Test a sender/receiver connection, with receiver as server,
1003    sending files.  Midway the connection is terminated by
1004    stopping the sender.  The connection is restarted and resumed.
1005    """
1006    _testSendRcv(sender_client=True, stop_sender=True, kill=False)
1007
1008def testSendRcvKillReceiverServer():
1009    """
1010    Test a sender/receiver connection, with receiver as server,
1011    sending files.  Midway the connection is terminated by
1012    killing the sender.  The connection is restarted and resumed.
1013    """
1014    _testSendRcv(sender_client=True, stop_sender=False, kill=True)
1015
1016def testSendRcvKillSenderServer():
1017    """
1018    Test a sender/receiver connection, with sender as server,
1019    sending files.  Midway the connection is terminated by
1020    killing the sender.  The connection is restarted and resumed.
1021    """
1022    _testSendRcv(sender_client=False, stop_sender=True, kill=True)
1023
1024def testSendRcvKillReceiverClient():
1025    """
1026    Test a sender/receiver connection, with sender as server,
1027    sending files.  Midway the connection is terminated by
1028    killing the receiver.  The connection is restarted and resumed.
1029    """
1030    _testSendRcv(sender_client=False, stop_sender=False, kill=True)
1031
1032def testSendRcvKillSenderClient():
1033    """
1034    Test a sender/receiver connection, with receiver as server,
1035    sending files.  Midway the connection is terminated by
1036    killing the sender.  The connection is restarted and resumed.
1037    """
1038    _testSendRcv(sender_client=True, stop_sender=True, kill=True)
1039
1040
1041def testSendRcvStopReceiverServerTLS():
1042    """
1043    Test a sender/receiver connection, with receiver as server,
1044    sending files.  Midway the connection is terminated by
1045    stopping the sender.  The connection is restarted and resumed.
1046    """
1047    _testSendRcv(sender_client=True, stop_sender=False, kill=False, tls=True)
1048
1049def testSendRcvStopSenderServerTLS():
1050    """
1051    Test a sender/receiver connection, with sender as server,
1052    sending files.  Midway the connection is terminated by
1053    stopping the sender.  The connection is restarted and resumed.
1054    """
1055    _testSendRcv(sender_client=False, stop_sender=True, kill=False, tls=True)
1056
1057def testSendRcvStopReceiverClientTLS():
1058    """
1059    Test a sender/receiver connection, with sender as server,
1060    sending files.  Midway the connection is terminated by
1061    stopping the receiver.  The connection is restarted and resumed.
1062    """
1063    _testSendRcv(sender_client=False, stop_sender=False, kill=False, tls=True)
1064
1065def testSendRcvStopSenderClientTLS():
1066    """
1067    Test a sender/receiver connection, with receiver as server,
1068    sending files.  Midway the connection is terminated by
1069    stopping the sender.  The connection is restarted and resumed.
1070    """
1071    _testSendRcv(sender_client=True, stop_sender=True, kill=False, tls=True)
1072
1073def testSendRcvKillReceiverServerTLS():
1074    """
1075    Test a sender/receiver connection, with receiver as server,
1076    sending files.  Midway the connection is terminated by
1077    killing the sender.  The connection is restarted and resumed.
1078    """
1079    _testSendRcv(sender_client=True, stop_sender=False, kill=True, tls=True)
1080
1081def testSendRcvKillSenderServerTLS():
1082    """
1083    Test a sender/receiver connection, with sender as server,
1084    sending files.  Midway the connection is terminated by
1085    killing the sender.  The connection is restarted and resumed.
1086    """
1087    _testSendRcv(sender_client=False, stop_sender=True, kill=True, tls=True)
1088
1089def testSendRcvKillReceiverClientTLS():
1090    """
1091    Test a sender/receiver connection, with sender as server,
1092    sending files.  Midway the connection is terminated by
1093    killing the receiver.  The connection is restarted and resumed.
1094    """
1095    _testSendRcv(sender_client=False, stop_sender=False, kill=True, tls=True)
1096
1097def testSendRcvKillSenderClientTLS():
1098    """
1099    Test a sender/receiver connection, with receiver as server,
1100    sending files.  Midway the connection is terminated by
1101    killing the sender.  The connection is restarted and resumed.
1102    """
1103    _testSendRcv(sender_client=True, stop_sender=True, kill=True, tls=True)
1104
1105
1106def _testMultiple(tls=False):
1107    global rfiles
1108    if tls and not tls_supported():
1109        return None
1110    reset_all_certs_and_keys()
1111    s1 = Rwsender()
1112    s2 = Rwsender()
1113    r1 = Rwreceiver()
1114    r2 = Rwreceiver()
1115    sy = System()
1116    try:
1117        sy.connect([r1, r2], [s1, s2], tls=tls)
1118        sy.start()
1119        check_started([r1, r2], [s1, s2], tls=tls)
1120        check_connected([r1, r2], [s1, s2], timeout=70)
1121
1122        filea = rfiles[0]
1123        fileb = rfiles[1]
1124        s1.send_files([filea])
1125        s2.send_files([fileb])
1126        params = {"filea": re.escape(os.path.basename(filea[0])),
1127                  "fileb": re.escape(os.path.basename(fileb[0])),
1128                  "rnamec": r1.name, "rnamed": r2.name}
1129
1130        trigger((s1, 40,
1131                 "Succeeded sending .*/%(filea)s to %(rnamec)s" % params),
1132                (s2, 40,
1133                 "Succeeded sending .*/%(fileb)s to %(rnamec)s" % params))
1134        trigger((s1, 40,
1135                 "Succeeded sending .*/%(filea)s to %(rnamed)s" % params),
1136                (s2, 40,
1137                 "Succeeded sending .*/%(fileb)s to %(rnamed)s" % params))
1138        for f in [filea, fileb]:
1139            for r in [r1, r2]:
1140                (error, path) = r.check_sent(f)
1141                if error:
1142                    global_log(False, ("Error receiving %s: %s" %
1143                                       (os.path.basename(f[0]), error)))
1144                    raise FileTransferError()
1145        sy.stop()
1146        trigger((s1, 25, "Stopped logging"),
1147                (r1, 25, "Stopped logging"),
1148                (s2, 25, "Stopped logging"),
1149                (r2, 25, "Stopped logging"))
1150    except:
1151        traceback.print_exc()
1152        sy.stop()
1153        raise
1154    finally:
1155        sy.end(noremove=NO_REMOVE)
1156
1157def testMultiple():
1158    """
1159    Test two senders connected to two receivers.  Each sender
1160    sends a file to both receivers.
1161    """
1162    _testMultiple()
1163
1164def testMultipleTLS():
1165    """
1166    Test two senders connected to two receivers via TLS.  Each
1167    sender sends a file to both receivers.
1168    """
1169    _testMultiple(tls=True)
1170
1171
1172def _testFilter(tls=False):
1173    global rfiles
1174    if tls and not tls_supported():
1175        return None
1176    reset_all_certs_and_keys()
1177    r1 = Rwreceiver()
1178    r2 = Rwreceiver()
1179    s1 = Rwsender(filters=[(r1.name, "[a-g]$"), (r2.name, "[d-j]$")])
1180    sy = System()
1181    try:
1182        sy.connect([r1, r2], s1, tls=tls)
1183        sy.start()
1184        check_started([r1, r2], [s1], tls=tls)
1185        check_connected([r1, r2], [s1], timeout=70)
1186
1187        s1.send_files(rfiles)
1188        cfiles = [x for x in rfiles if 'a' <= x[0][-1] <= 'g']
1189        dfiles = [x for x in rfiles if 'd' <= x[0][-1] <= 'j']
1190        for (f, data) in cfiles:
1191            trigger((s1, 25,
1192                     "Succeeded sending .*/%(file)s to %(name)s"
1193                     % {"file": re.escape(os.path.basename(f)),
1194                        "name" : r1.name}))
1195        for (f, data) in dfiles:
1196            trigger((s1, 25,
1197                     "Succeeded sending .*/%(file)s to %(name)s"
1198                     % {"file": re.escape(os.path.basename(f)),
1199                        "name" : r2.name}))
1200        for f in cfiles:
1201            (error, path) = r1.check_sent(f)
1202            if error:
1203                global_log(False, ("Error receiving %s: %s" %
1204                                   (os.path.basename(f[0]), error)))
1205                raise FileTransferError()
1206        for f in dfiles:
1207            (error, path) = r2.check_sent(f)
1208            if error:
1209                global_log(False, ("Error receiving %s: %s" %
1210                                   (os.path.basename(f[0]), error)))
1211        cset = set(cfiles)
1212        dset = set(dfiles)
1213        for f in cset - dset:
1214            (error, path) = r2.check_sent(f)
1215            if not error:
1216                global_log(False, ("Unexpectedly received file %s" %
1217                                   os.path.basename(f[0])))
1218                raise FileTransferError()
1219        for f in dset - cset:
1220            (error, path) = r1.check_sent(f)
1221            if not error:
1222                global_log(False, ("Unexpectedly received file %s" %
1223                                   os.path.basename(f[0])))
1224                raise FileTransferError()
1225        sy.stop()
1226        trigger((s1, 25, "Stopped logging"),
1227                (r1, 25, "Stopped logging"),
1228                (r2, 25, "Stopped logging"))
1229    except:
1230        traceback.print_exc()
1231        sy.stop()
1232        raise
1233    finally:
1234        sy.end(noremove=NO_REMOVE)
1235
1236def testFilter():
1237    """
1238    Test filtering with a sender and two receivers.  Using
1239    filters, some files get sent to receiver A, some to receiver
1240    B, and some to both.
1241    """
1242    _testFilter()
1243
1244
1245def testPostCommand():
1246    """
1247    Test the post command option.
1248    """
1249    global rfiles
1250    if srcdir:
1251        cmddir = os.path.join(srcdir, "tests")
1252    else:
1253        cmddir = os.path.join(".", "tests")
1254    command = os.path.join(cmddir, "post-command.sh")
1255    post_command = command + " %I %s"
1256    if not os.access(command, os.X_OK):
1257        sys.exit(77)
1258    s1 = Rwsender()
1259    r1 = Rwreceiver(post_command=post_command)
1260    s1.create_dirs()
1261    s1.send_files(rfiles)
1262    sy = System()
1263    try:
1264        sy.connect(s1, r1)
1265        sy.start()
1266        check_connected([s1], [r1], timeout=70)
1267        for path, data in rfiles:
1268            trigger((r1, 40,
1269                     ("Post command: Ident: %(sname)s  "
1270                      "Filename: .*/%(file)s") %
1271                     {"file": re.escape(os.path.basename(path)),
1272                      "sname": s1.name}), pid=False)
1273        sy.stop()
1274        trigger((s1, 25, "Stopped logging"),
1275                (r1, 25, "Stopped logging"))
1276    except:
1277        traceback.print_exc()
1278        sy.stop()
1279        raise
1280    finally:
1281        sy.end(noremove=NO_REMOVE)
1282
1283
1284def _testFailedConnection(ca_cert, key, cert, hostname="127.0.0.1"):
1285    if not tls_supported():
1286        return None
1287    ca_cert = os.path.join(srcdir, "tests", ca_cert)
1288    key = os.path.join(srcdir, "tests", key)
1289    cert = os.path.join(srcdir, "tests", cert)
1290    s1 = RwsenderCert(None, ca_cert, key, cert)
1291    r1 = RwreceiverCert(None, ca_cert, key, cert)
1292    sy = SystemCert()
1293    try:
1294        sy.connect(s1, r1, tls=True, hostname=hostname)
1295        sy.start()
1296        check_started([s1], [r1], tls=True)
1297        trigger((s1, 50, "Attempt to connect to %s failed" % r1.name),
1298                (r1, 50, "Unable to initialize connection with"))
1299        sy.stop()
1300        trigger((s1, 20, "Finished shutting down"),
1301                (r1, 20, "Finished shutting down"))
1302        trigger((s1, 25, "Stopped logging"),
1303                (r1, 25, "Stopped logging"))
1304    except:
1305        traceback.print_exc()
1306        sy.stop()
1307        raise
1308    finally:
1309        sy.end(noremove=NO_REMOVE)
1310
1311def testExpiredAuthorityTLS():
1312    """
1313    Test to see if we can start a sender/receiver pair, that they
1314    connect, and that they shut down properly.
1315    """
1316    _testFailedConnection("ca-expired-cert.pem", "signed-expired-ca-key.pem",
1317                          "signed-expired-ca-cert.pem")
1318
1319def testExpiredCertificateTLS():
1320    """
1321    Test to see if we can start a sender/receiver pair, that they
1322    connect, and that they shut down properly.
1323    """
1324    _testFailedConnection("ca_cert_key8.pem", "expired-key.pem",
1325                          "expired-cert.pem")
1326
1327def testMismatchedCertsTLS():
1328    hostname = "127.0.0.1"
1329    if not tls_supported():
1330        return None
1331    s1 = RwsenderCert(
1332        name=None, ca_cert=os.path.join(srcdir, "tests", "ca_cert_key8.pem"),
1333        key=None,
1334        cert=os.path.join(srcdir, "tests", "cert-key5-ca_cert_key8.p12"))
1335    r1 = RwreceiverCert(
1336        name=None, ca_cert=os.path.join(srcdir, "tests", "other-ca-cert.pem"),
1337        key=os.path.join(srcdir, "tests", "other-key.pem"),
1338        cert=os.path.join(srcdir, "tests", "other-cert.pem"))
1339    sy = SystemCert()
1340    try:
1341        sy.connect(r1, s1, tls=True, hostname=hostname)
1342        sy.start()
1343        check_started([r1], [s1], tls=True)
1344        trigger((r1, 50, "Attempt to connect to %s failed" % s1.name),
1345                (s1, 50, "Unable to initialize connection with"))
1346        sy.stop()
1347        trigger((s1, 20, "Finished shutting down"),
1348                (r1, 20, "Finished shutting down"))
1349        trigger((s1, 25, "Stopped logging"),
1350                (r1, 25, "Stopped logging"))
1351    except:
1352        traceback.print_exc()
1353        sy.stop()
1354        raise
1355    finally:
1356        sy.end(noremove=NO_REMOVE)
1357
1358def testOtherCertsTLS():
1359    hostname = "127.0.0.1"
1360    if not tls_supported():
1361        return None
1362    s1 = RwsenderCert(
1363        name=None, ca_cert=os.path.join(srcdir, "tests", "other-ca-cert.pem"),
1364        key=os.path.join(srcdir, "tests", "other-key.pem"),
1365        cert=os.path.join(srcdir, "tests", "other-cert.pem"))
1366    r1 = RwreceiverCert(
1367        name=None, ca_cert=os.path.join(srcdir, "tests", "other-ca-cert.pem"),
1368        key=os.path.join(srcdir, "tests", "other-key.pem"),
1369        cert=os.path.join(srcdir, "tests", "other-cert.pem"))
1370    sy = SystemCert()
1371    try:
1372        sy.connect(r1, s1, tls=True, hostname=hostname)
1373        sy.start()
1374        check_started([r1], [s1], tls=True)
1375        check_connected([r1], [s1])
1376        sy.stop()
1377        trigger((s1, 20, "Finished shutting down"),
1378                (r1, 20, "Finished shutting down"))
1379        trigger((s1, 25, "Stopped logging"),
1380                (r1, 25, "Stopped logging"))
1381    except:
1382        traceback.print_exc()
1383        sy.stop()
1384        raise
1385    finally:
1386        sy.end(noremove=NO_REMOVE)
1387
1388
1389if __name__ == '__main__':
1390    parser = optparse.OptionParser()
1391    parser.add_option("--verbose", action="store_true", dest="verbose",
1392                      default=False)
1393    parser.add_option("--overwrite-dirs", action="store_true",
1394                      dest="overwrite", default=False)
1395    parser.add_option("--save-output", action="store_true", dest="save_output",
1396                      default=False)
1397    parser.add_option("--log-level", action="store", type="string",
1398                      dest="log_level", default="info")
1399    parser.add_option("--log-output-to", action="store", type="string",
1400                      dest="log_output", default=None)
1401    parser.add_option("--file-list-file", action="store", type="string",
1402                      dest="file_list_file", default=None)
1403    parser.add_option("--print-test-names", action="store_true",
1404                      dest="print_test_names", default=False)
1405    parser.add_option("--timeout-factor", action="store", type="float",
1406                      dest="timeout_factor", default = 1.0)
1407    (options, args) = parser.parse_args()
1408
1409    if options.print_test_names:
1410        print_test_names()
1411        sys.exit()
1412
1413    if not options.file_list_file:
1414        sys.exit("The --file-list-file switch is required when running tests")
1415
1416    FILE_LIST_FILE = options.file_list_file
1417    OVERWRITE = options.overwrite
1418    LOG_LEVEL = options.log_level
1419    NO_REMOVE = options.save_output
1420
1421    (fd, path) = tempfile.mkstemp(".log", "sendrcv-", None)
1422    LOG_OUTPUT.append(os.fdopen(fd, "a"))
1423    if options.verbose:
1424        LOG_OUTPUT.append(sys.stdout)
1425    if options.log_output:
1426        LOG_OUTPUT.append(open(options.log_output, "a"))
1427
1428    TIMEOUT_FACTOR = options.timeout_factor
1429
1430    if not args:
1431        args = ALL_TESTS
1432
1433    setup()
1434    retval = 1
1435
1436    try:
1437        for x in args:
1438            locals()[x]()
1439    except SystemExit:
1440        raise
1441    finally:
1442        teardown()
1443