1#
2# Copyright (c) ZeroC, Inc. All rights reserved.
3#
4
5import atexit
6import os
7import re
8import signal
9import string
10import subprocess
11import sys
12import sys
13import threading
14import time
15import traceback
16import types
17
18__all__ = ["Expect", "EOF", "TIMEOUT" ]
19
20win32 = (sys.platform == "win32")
21if win32:
22    import ctypes
23
24class EOF:
25    """Raised when EOF is read from a child.
26    """
27    def __init__(self, value):
28        self.value = value
29    def __str__(self):
30        return str(self.value)
31
32class TIMEOUT(Exception):
33    """Raised when a read time exceeds the timeout.
34    """
35    def __init__(self, value):
36        self.value = value
37    def __str__(self):
38        return str(self.value)
39
40def getStringIO():
41    if sys.version_info[0] == 2:
42        import StringIO
43        return StringIO.StringIO()
44    else:
45        import io
46        return io.StringIO()
47
48def escape(s, escapeNewlines = True):
49    if s == TIMEOUT:
50        return "<TIMEOUT>"
51    o = getStringIO()
52    for c in s:
53        if c == '\\':
54            o.write('\\\\')
55        elif c == '\'':
56            o.write("\\'")
57        elif c == '\"':
58            o.write('\\"')
59        elif c == '\b':
60            o.write('\\b')
61        elif c == '\f':
62            o.write('\\f')
63        elif c == '\n':
64            if escapeNewlines:
65                o.write('\\n')
66            else:
67                o.write('\n')
68        elif c == '\r':
69            o.write('\\r')
70        elif c == '\t':
71            o.write('\\t')
72        else:
73            if c in string.printable:
74                o.write(c)
75            else:
76                o.write('\\%03o' % ord(c))
77    return o.getvalue()
78
79def taskkill(args):
80    p = subprocess.Popen("taskkill {0}".format(args), shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
81    out = p.stdout.read().decode('UTF-8').strip()
82    #print(out)
83    p.wait()
84    p.stdout.close()
85
86def killProcess(p):
87    if win32:
88        taskkill("/F /T /PID {0}".format(p.pid))
89    else:
90        os.kill(p.pid, signal.SIGKILL)
91
92def terminateProcess(p, hasInterruptSupport=True):
93    if win32:
94        #
95        # Signals under windows are all turned into CTRL_BREAK_EVENT, except with Java since
96        # CTRL_BREAK_EVENT generates a stack trace. We don't use taskkill here because it
97        # doesn't work with CLI processes (it sends a WM_CLOSE event).
98        #
99        if hasInterruptSupport:
100            try:
101                ctypes.windll.kernel32.GenerateConsoleCtrlEvent(1, p.pid) # 1 is CTRL_BREAK_EVENT
102            except NameError:
103                taskkill("/F /T /PID {0}".format(p.pid))
104                pass
105            except:
106                traceback.print_exc(file=sys.stdout)
107                taskkill("/F /T /PID {0}".format(p.pid))
108        else:
109            taskkill("/F /T /PID {0}".format(p.pid))
110    else:
111        os.kill(p.pid, signal.SIGINT)
112
113class reader(threading.Thread):
114    def __init__(self, desc, p, logfile):
115        self.desc = desc
116        self.buf = getStringIO()
117        self.cv = threading.Condition()
118        self.p = p
119        self._trace = False
120        self._tbuf = getStringIO()
121        self._tracesuppress = None
122        self.logfile = logfile
123        self.watchDog = None
124        self._finish = False
125        threading.Thread.__init__(self)
126
127    def setWatchDog(self, watchDog):
128        self.watchDog = watchDog
129
130    def run(self):
131        try:
132            while True:
133                c = self.p.stdout.read(1)
134                if not c:
135                    self.cv.acquire()
136                    self.trace(None)
137                    self._finish = True # We have finished processing output
138                    self.cv.notify()
139                    self.cv.release()
140                    break
141                if c == '\r': continue
142
143                self.cv.acquire()
144                try:
145                    # Depending on Python version and platform, the value c could be a
146                    # string or a bytes object.
147                    if type(c) != str:
148                        c = c.decode()
149                    self.trace(c)
150                    if self.watchDog is not None:
151                        self.watchDog.reset()
152                    self.buf.write(c)
153                    self.cv.notify()
154                finally:
155                    self.cv.release()
156        except IOError as e:
157            print(e)
158
159    def trace(self, c):
160        if self._trace:
161            if self._tracesuppress:
162                if not c is None:
163                    self._tbuf.write(c)
164                if c == '\n' or c is None:
165                    content = self._tbuf.getvalue()
166                    suppress = False
167                    for p in self._tracesuppress:
168                        if isinstance(p, types.LambdaType) or isinstance(p, types.FunctionType):
169                            content = p(content)
170                        elif p.search(content):
171                            suppress = True
172                            break
173                    if not suppress:
174                        sys.stdout.write(content)
175                    self._tbuf.truncate(0)
176                    self._tbuf.seek(0)
177            elif not c is None:
178                sys.stdout.write(c)
179                sys.stdout.flush()
180
181    def enabletrace(self, suppress = None):
182        self.cv.acquire()
183        try:
184            if not self._trace:
185                self._trace = True
186                self._tracesuppress = suppress
187                for c in self.buf.getvalue():
188                    self.trace(c)
189        finally:
190            self.cv.release()
191
192    def getbuf(self):
193        self.cv.acquire()
194        try:
195            buf = self.buf.getvalue()
196        finally:
197            self.cv.release()
198        return buf
199
200    def match(self, pattern, timeout, matchall = False):
201        # pattern is a list of string, regexp duples.
202
203        if timeout is not None:
204            end = time.time() + timeout
205        start = time.time()
206
207        # Trace the match
208        if self.logfile:
209            if timeout is None:
210                tdesc = "<infinite>"
211            else:
212                tdesc = "%.2fs" % timeout
213            p = [ escape(s) for (s, r) in pattern ]
214            pdesc = getStringIO()
215            if len(p) == 1:
216                pdesc.write(escape(p[0]))
217            else:
218                pdesc.write('[');
219                for pat in p:
220                    if pat != p[0]:
221                        pdesc.write(',');
222                    pdesc.write(escape(pat))
223                pdesc.write(']');
224            self.logfile.write('%s: expect: "%s" timeout: %s\n' % (self.desc, pdesc.getvalue(), tdesc))
225            self.logfile.flush()
226
227        maxend = None
228        self.cv.acquire()
229        try:
230            try: # This second try/except block is necessary because of python 2.3
231                while True:
232                    buf = self.buf.getvalue()
233
234                    # Try to match on the current buffer.
235                    olen = len(pattern)
236                    for index, p in enumerate(pattern):
237                        s, regexp = p
238                        if s == TIMEOUT:
239                            continue
240                        if not buf:
241                            #
242                            # Don't try to match on an empty buffer, http://bugs.python.org/issue17998
243                            #
244                            break
245                        m = regexp.search(buf)
246                        if m is not None:
247                            before = buf[:m.start()]
248                            matched = buf[m.start():m.end()]
249                            after = buf[m.end():]
250
251                            if maxend is None or m.end() > maxend:
252                                maxend = m.end()
253
254                            # Trace the match
255                            if self.logfile:
256                                if len(pattern) > 1:
257                                    self.logfile.write('%s: match found in %.2fs.\npattern: "%s"\nbuffer: "%s||%s||%s"\n'%
258                                                       (self.desc, time.time() - start, escape(s), escape(before),
259                                                        escape(matched), escape(after)))
260                                else:
261                                    self.logfile.write('%s: match found in %.2fs.\nbuffer: "%s||%s||%s"\n' %
262                                                       (self.desc, time.time() - start, escape(before), escape(matched),
263                                                        escape(after)))
264
265                            if matchall:
266                                del pattern[index]
267                                # If all patterns have been found then
268                                # truncate the buffer to the longest match,
269                                # and then return.
270                                if len(pattern) == 0:
271                                    self.buf.truncate(0)
272                                    self.buf.seek(0)
273                                    self.buf.write(buf[maxend:])
274                                    return buf
275                                break
276
277                            # Consume matched portion of the buffer.
278                            self.buf.truncate(0)
279                            self.buf.seek(0)
280                            self.buf.write(after)
281
282                            return buf, before, after, m, index
283
284                    # If a single match was found then the match.
285                    if len(pattern) != olen:
286                        continue
287
288                    # If no match and we have finished processing output raise a TIMEOUT
289                    if self._finish:
290                      raise  TIMEOUT ('timeout exceeded in match\npattern: "%s"\nbuffer: "%s"\n' %
291                                           (escape(s), escape(buf, False)))
292
293                    if timeout is None:
294                        self.cv.wait()
295                    else:
296                        self.cv.wait(end - time.time())
297                        if time.time() >= end:
298                            # Log the failure
299                            if self.logfile:
300                                self.logfile.write('%s: match failed.\npattern: "%s"\nbuffer: "%s"\n"' %
301                                                   (self.desc, escape(s), escape(buf)))
302                                self.logfile.flush()
303                            raise TIMEOUT ('timeout exceeded in match\npattern: "%s"\nbuffer: "%s"\n' %
304                                           (escape(s), escape(buf, False)))
305            except TIMEOUT as e:
306                if (TIMEOUT, None) in pattern:
307                    return buf, buf, TIMEOUT, None, pattern.index((TIMEOUT, None))
308                raise e
309        finally:
310            self.cv.release()
311
312def splitCommand(command_line):
313    arg_list = []
314    arg = ''
315
316    state_basic = 0
317    state_esc = 1
318    state_singlequote = 2
319    state_doublequote = 3
320    state_whitespace = 4
321    state = state_basic
322    pre_esc_state = state_basic
323
324    for c in command_line:
325        if state != state_esc and c == '\\':
326            pre_esc_state = state
327            state = state_esc
328        elif state == state_basic or state == state_whitespace:
329            if c == r"'":
330                state = state_singlequote
331            elif c == r'"':
332                state = state_doublequote
333            elif c.isspace():
334                if state == state_whitespace:
335                    None
336                else:
337                    arg_list.append(arg)
338                    arg = ''
339                    state = state_whitespace
340            else:
341                arg = arg + c
342                state = state_basic
343        elif state == state_esc:
344            arg = arg + c
345            state = pre_esc_state
346        elif state == state_singlequote:
347            if c == r"'":
348                state = state_basic
349            else:
350                arg = arg + c
351        elif state == state_doublequote:
352            if c == r'"':
353                state = state_basic
354            else:
355                arg = arg + c
356
357    if arg != '':
358        arg_list.append(arg)
359
360    return arg_list
361
362processes = {}
363
364def cleanup():
365    for key in processes.copy():
366        try:
367            killProcess(processes[key])
368        except:
369            pass
370    processes.clear()
371
372class Expect (object):
373    def __init__(self, command, startReader=True, timeout=30, logfile=None, mapping=None, desc=None, cwd=None, env=None,
374                 preexec_fn=None):
375        self.buf = "" # The part before the match
376        self.before = "" # The part before the match
377        self.after = "" # The part after the match
378        self.matchindex = 0 # the index of the matched pattern
379        self.match = None # The last match
380        self.mapping = mapping # The mapping of the test.
381        self.exitstatus = None # The exitstatus, either -signal or, if positive, the exit code.
382        self.killed = None # If killed, the signal that was sent.
383        self.desc = desc
384        self.logfile = logfile
385        self.timeout = timeout
386        self.p = None
387
388        if self.logfile:
389            self.logfile.write('spawn: "%s"\n' % command)
390            self.logfile.flush()
391
392        if win32:
393            # Don't rely on win32api
394            # import win32process
395            # creationflags = win32process.CREATE_NEW_PROCESS_GROUP)
396            #
397            # universal_newlines = True is necessary for Python 3 on Windows
398            #
399            # We can't use shell=True because terminate() wouldn't
400            # work. This means the PATH isn't searched for the
401            # command.
402            #
403            CREATE_NEW_PROCESS_GROUP = 512
404            self.p = subprocess.Popen(command, env=env, cwd=cwd, shell=False, bufsize=0, stdin=subprocess.PIPE,
405                                      stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
406                                      creationflags = CREATE_NEW_PROCESS_GROUP, universal_newlines=True)
407        else:
408            self.p = subprocess.Popen(splitCommand(command), env=env, cwd=cwd, shell=False, bufsize=0,
409                                      stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
410                                      preexec_fn=preexec_fn)
411        global processes
412        processes[self.p.pid] = self.p
413
414        self.r = reader(desc, self.p, logfile)
415
416        # The thread is marked as a daemon thread. This is done so that if
417        # an expect script runs off the end of main without kill/wait on each
418        # spawned process the script will not hang trying to join with the
419        # reader thread.
420        self.r.setDaemon(True)
421
422        if startReader:
423            self.startReader()
424
425    def __str__(self):
426        return "{0} pid={1}".format(self.desc, "<none>" if self.p is None else self.p.pid)
427
428    def startReader(self, watchDog = None):
429        if watchDog is not None:
430            self.r.setWatchDog(watchDog)
431        self.r.start()
432
433    def expect(self, pattern, timeout = 60):
434        """pattern is either a string, or a list of string regexp patterns.
435
436           timeout == None expect can block indefinitely.
437
438           timeout == -1 then the default is used.
439        """
440        if timeout == -1:
441            timeout = self.timeout
442
443        if type(pattern) != list:
444            pattern = [ pattern ]
445        def compile(s):
446            if type(s) == str:
447                return re.compile(s, re.S)
448            return None
449        pattern = [ ( p, compile(p) ) for p in pattern ]
450        try:
451            self.buf, self.before, self.after, self.match, self.matchindex = self.r.match(pattern, timeout)
452        except TIMEOUT as e:
453            self.buf = ""
454            self.before = ""
455            self.after = ""
456            self.match = None
457            self.matchindex = 0
458            raise e
459        return self.matchindex
460
461    def expectall(self, pattern, timeout = 60):
462        """pattern is a list of string regexp patterns.
463
464           timeout == None expect can block indefinitely.
465
466           timeout == -1 then the default is used.
467        """
468        if timeout == -1:
469            timeout = self.timeout
470
471        pattern = [ ( p, re.compile(p, re.S) ) for p in pattern ]
472        try:
473            self.buf = self.r.match(pattern, timeout, matchall = True)
474            self.before = ""
475            self.after = ""
476            self.matchindex = 0
477            self.match = None
478        except TIMEOUT as e:
479            self.buf = ""
480            self.before = ""
481            self.after = ""
482            self.matchindex = 0
483            self.match = None
484            raise e
485
486    def sendline(self, data):
487        """send data to the application.
488        """
489        if self.logfile:
490            self.logfile.write('%s: sendline: "%s"\n' % (self.desc, escape(data)))
491            self.logfile.flush()
492        data = data + "\n"
493        if win32 or sys.version_info[0] == 2:
494            self.p.stdin.write(data)
495        else:
496            self.p.stdin.write(data.encode("utf-8"))
497
498    def wait(self, timeout = None):
499        """Wait for the application to terminate for up to timeout seconds, or
500           raises a TIMEOUT exception. If timeout is None, the wait is
501           indefinite.
502
503           The exit status is returned. A negative exit status means
504           the application was killed by a signal.
505        """
506        if self.p is None:
507            return self.exitstatus
508
509        # Unfortunately, with the subprocess module there is no
510        # better method of doing a timed wait.
511        if timeout is not None:
512            end = time.time() + timeout
513            while time.time() < end and self.p and self.p.poll() is None:
514                time.sleep(0.1)
515            if self.p and self.p.poll() is None:
516                raise TIMEOUT ('timed wait exceeded timeout')
517        elif win32:
518            # We poll on Windows or otherwise KeyboardInterrupt isn't delivered
519            while self.p.poll() is None:
520                time.sleep(0.5)
521
522        if self.p is None:
523            return self.exitstatus
524
525        self.exitstatus = self.p.wait()
526
527        # A Windows application killed with CTRL_BREAK. Fudge the exit status.
528        if win32 and self.exitstatus != 0 and self.killed is not None:
529            self.exitstatus = -self.killed
530        global processes
531        if self.p.pid in processes:
532            del processes[self.p.pid]
533        self.p = None
534        self.r.join()
535        # Simulate a match on EOF
536        self.buf = self.r.getbuf()
537        self.before = self.buf
538        self.after = ""
539        #
540        # Without this we get warnings when runing with python_d on Windows
541        #
542        # ResourceWarning: unclosed file <_io.TextIOWrapper name=3 encoding='cp1252'>
543        #
544        self.r.p.stdout.close()
545        self.r.p.stdin.close()
546        self.r = None
547
548        return self.exitstatus
549
550    def terminate(self):
551        """Terminate the process."""
552
553        if self.p is None:
554            return
555
556        def kill():
557            ex = None
558            while True:
559                try:
560                    if not self.p:
561                        return
562                    killProcess(self.p)
563                    self.wait()
564                except KeyboardInterrupt as e:
565                    ex = e
566                    raise
567                except e:
568                    ex = e
569            if ex:
570                print(ex)
571                raise ex
572
573        try:
574            self.wait(timeout = 0.5)
575            return
576        except KeyboardInterrupt:
577            kill()
578            raise
579        except TIMEOUT:
580            pass
581
582        try:
583            terminateProcess(self.p, self.hasInterruptSupport())
584        except KeyboardInterrupt:
585            kill()
586            raise
587        except:
588            traceback.print_exc(file=sys.stdout)
589
590        # If the break does not terminate the process within 5
591        # seconds, then kill the process.
592        try:
593            self.wait(timeout = 5)
594            return
595        except KeyboardInterrupt:
596            kill()
597            raise
598        except TIMEOUT:
599            kill()
600
601    def kill(self, sig):
602        """Send the signal to the process."""
603        self.killed = sig # Save the sent signal.
604        if win32:
605            terminateProcess(self.p, self.hasInterruptSupport())
606        else:
607            os.kill(self.p.pid, sig)
608
609    def trace(self, suppress = None):
610        self.r.enabletrace(suppress)
611
612    def waitSuccess(self, exitstatus = 0, timeout = None):
613        """Wait for the process to terminate for up to timeout seconds, and
614           validate the exit status is as expected."""
615
616        def test(result, expected):
617            if not win32 and result == -2: # Interrupted by Ctrl-C, simulate KeyboardInterrupt
618                raise KeyboardInterrupt()
619            if expected != result:
620                raise RuntimeError("unexpected exit status: expected: %d, got %d\n" % (expected, result))
621
622        self.wait(timeout)
623        if self.mapping in ["java", "java-compat"]:
624            if self.killed is not None:
625                if win32:
626                    test(self.exitstatus, -self.killed)
627                else:
628                    if self.killed == signal.SIGINT:
629                        test(130, self.exitstatus)
630                    else:
631                        test(self.exitstatus, exitstatus)
632            else:
633                test(self.exitstatus, exitstatus)
634        else:
635            test(self.exitstatus, exitstatus)
636
637    def getOutput(self):
638        return self.buf if self.p is None else self.r.getbuf()
639
640    def hasInterruptSupport(self):
641        """Return True if the application gracefully terminated, False otherwise."""
642        if win32 and (self.mapping == "java" or self.mapping == "java-compat"):
643            return False
644        return True
645