2# Copyright (c) ZeroC, Inc. All rights reserved.
5import atexit
6import os
7import re
8import signal
9import string
10import subprocess
11import sys
12import sys
13import threading
14import time
15import traceback
16import types
18__all__ = ["Expect", "EOF", "TIMEOUT" ]
20win32 = (sys.platform == "win32")
21if win32:
22    import ctypes
24class EOF:
25    """Raised when EOF is read from a child.
26    """
27    def __init__(self, value):
28        self.value = value
29    def __str__(self):
30        return str(self.value)
32class TIMEOUT(Exception):
33    """Raised when a read time exceeds the timeout.
34    """
35    def __init__(self, value):
36        self.value = value
37    def __str__(self):
38        return str(self.value)
40def getStringIO():
41    if sys.version_info[0] == 2:
42        import StringIO
43        return StringIO.StringIO()
44    else:
45        import io
46        return io.StringIO()
48def escape(s, escapeNewlines = True):
49    if s == TIMEOUT:
50        return "<TIMEOUT>"
51    o = getStringIO()
52    for c in s:
53        if c == '\\':
54            o.write('\\\\')
55        elif c == '\'':
56            o.write("\\'")
57        elif c == '\"':
58            o.write('\\"')
59        elif c == '\b':
60            o.write('\\b')
61        elif c == '\f':
62            o.write('\\f')
63        elif c == '\n':
64            if escapeNewlines:
65                o.write('\\n')
66            else:
67                o.write('\n')
68        elif c == '\r':
69            o.write('\\r')
70        elif c == '\t':
71            o.write('\\t')
72        else:
73            if c in string.printable:
74                o.write(c)
75            else:
76                o.write('\\%03o' % ord(c))
77    return o.getvalue()
79def taskkill(args):
80    p = subprocess.Popen("taskkill {0}".format(args), shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
81    out = p.stdout.read().decode('UTF-8').strip()
82    #print(out)
83    p.wait()
84    p.stdout.close()
86def killProcess(p):
87    if win32:
88        taskkill("/F /T /PID {0}".format(p.pid))
89    else:
90        os.kill(p.pid, signal.SIGKILL)
92def terminateProcess(p, hasInterruptSupport=True):
93    if win32:
94        #
95        # Signals under windows are all turned into CTRL_BREAK_EVENT, except with Java since
96        # CTRL_BREAK_EVENT generates a stack trace. We don't use taskkill here because it
97        # doesn't work with CLI processes (it sends a WM_CLOSE event).
98        #
99        if hasInterruptSupport:
100            try:
101                ctypes.windll.kernel32.GenerateConsoleCtrlEvent(1, p.pid) # 1 is CTRL_BREAK_EVENT
102            except NameError:
103                taskkill("/F /T /PID {0}".format(p.pid))
104                pass
105            except:
106                traceback.print_exc(file=sys.stdout)
107                taskkill("/F /T /PID {0}".format(p.pid))
108        else:
109            taskkill("/F /T /PID {0}".format(p.pid))
110    else:
111        os.kill(p.pid, signal.SIGINT)
113class reader(threading.Thread):
114    def __init__(self, desc, p, logfile):
115        self.desc = desc
116        self.buf = getStringIO()
117        self.cv = threading.Condition()
118        self.p = p
119        self._trace = False
120        self._tbuf = getStringIO()
121        self._tracesuppress = None
122        self.logfile = logfile
123        self.watchDog = None
124        self._finish = False
125        threading.Thread.__init__(self)
127    def setWatchDog(self, watchDog):
128        self.watchDog = watchDog
130    def run(self):
131        try:
132            while True:
133                c = self.p.stdout.read(1)
134                if not c:
135                    self.cv.acquire()
136                    self.trace(None)
137                    self._finish = True # We have finished processing output
138                    self.cv.notify()
139                    self.cv.release()
140                    break
141                if c == '\r': continue
143                self.cv.acquire()
144                try:
145                    # Depending on Python version and platform, the value c could be a
146                    # string or a bytes object.
147                    if type(c) != str:
148                        c = c.decode()
149                    self.trace(c)
150                    if self.watchDog is not None:
151                        self.watchDog.reset()
152                    self.buf.write(c)
153                    self.cv.notify()
154                finally:
155                    self.cv.release()
156        except IOError as e:
157            print(e)
159    def trace(self, c):
160        if self._trace:
161            if self._tracesuppress:
162                if not c is None:
163                    self._tbuf.write(c)
164                if c == '\n' or c is None:
165                    content = self._tbuf.getvalue()
166                    suppress = False
167                    for p in self._tracesuppress:
168                        if isinstance(p, types.LambdaType) or isinstance(p, types.FunctionType):
169                            content = p(content)
170                        elif p.search(content):
171                            suppress = True
172                            break
173                    if not suppress:
174                        sys.stdout.write(content)
175                    self._tbuf.truncate(0)
176                    self._tbuf.seek(0)
177            elif not c is None:
178                sys.stdout.write(c)
179                sys.stdout.flush()
181    def enabletrace(self, suppress = None):
182        self.cv.acquire()
183        try:
184            if not self._trace:
185                self._trace = True
186                self._tracesuppress = suppress
187                for c in self.buf.getvalue():
188                    self.trace(c)
189        finally:
190            self.cv.release()
192    def getbuf(self):
193        self.cv.acquire()
194        try:
195            buf = self.buf.getvalue()
196        finally:
197            self.cv.release()
198        return buf
200    def match(self, pattern, timeout, matchall = False):
201        # pattern is a list of string, regexp duples.
203        if timeout is not None:
204            end = time.time() + timeout
205        start = time.time()
207        # Trace the match
208        if self.logfile:
209            if timeout is None:
210                tdesc = "<infinite>"
211            else:
212                tdesc = "%.2fs" % timeout
213            p = [ escape(s) for (s, r) in pattern ]
214            pdesc = getStringIO()
215            if len(p) == 1:
216                pdesc.write(escape(p[0]))
217            else:
218                pdesc.write('[');
219                for pat in p:
220                    if pat != p[0]:
221                        pdesc.write(',');
222                    pdesc.write(escape(pat))
223                pdesc.write(']');
224            self.logfile.write('%s: expect: "%s" timeout: %s\n' % (self.desc, pdesc.getvalue(), tdesc))
225            self.logfile.flush()
227        maxend = None
228        self.cv.acquire()
229        try:
230            try: # This second try/except block is necessary because of python 2.3
231                while True:
232                    buf = self.buf.getvalue()
234                    # Try to match on the current buffer.
235                    olen = len(pattern)
236                    for index, p in enumerate(pattern):
237                        s, regexp = p
238                        if s == TIMEOUT:
239                            continue
240                        if not buf:
241                            #
242                            # Don't try to match on an empty buffer, http://bugs.python.org/issue17998
243                            #
244                            break
245                        m = regexp.search(buf)
246                        if m is not None:
247                            before = buf[:m.start()]
248                            matched = buf[m.start():m.end()]
249                            after = buf[m.end():]
251                            if maxend is None or m.end() > maxend:
252                                maxend = m.end()
254                            # Trace the match
255                            if self.logfile:
256                                if len(pattern) > 1:
257                                    self.logfile.write('%s: match found in %.2fs.\npattern: "%s"\nbuffer: "%s||%s||%s"\n'%
258                                                       (self.desc, time.time() - start, escape(s), escape(before),
259                                                        escape(matched), escape(after)))
260                                else:
261                                    self.logfile.write('%s: match found in %.2fs.\nbuffer: "%s||%s||%s"\n' %
262                                                       (self.desc, time.time() - start, escape(before), escape(matched),
263                                                        escape(after)))
265                            if matchall:
266                                del pattern[index]
267                                # If all patterns have been found then
268                                # truncate the buffer to the longest match,
269                                # and then return.
270                                if len(pattern) == 0:
271                                    self.buf.truncate(0)
272                                    self.buf.seek(0)
273                                    self.buf.write(buf[maxend:])
274                                    return buf
275                                break
277                            # Consume matched portion of the buffer.
278                            self.buf.truncate(0)
279                            self.buf.seek(0)
280                            self.buf.write(after)
282                            return buf, before, after, m, index
284                    # If a single match was found then the match.
285                    if len(pattern) != olen:
286                        continue
288                    # If no match and we have finished processing output raise a TIMEOUT
289                    if self._finish:
290                      raise  TIMEOUT ('timeout exceeded in match\npattern: "%s"\nbuffer: "%s"\n' %
291                                           (escape(s), escape(buf, False)))
293                    if timeout is None:
294                        self.cv.wait()
295                    else:
296                        self.cv.wait(end - time.time())
297                        if time.time() >= end:
298                            # Log the failure
299                            if self.logfile:
300                                self.logfile.write('%s: match failed.\npattern: "%s"\nbuffer: "%s"\n"' %
301                                                   (self.desc, escape(s), escape(buf)))
302                                self.logfile.flush()
303                            raise TIMEOUT ('timeout exceeded in match\npattern: "%s"\nbuffer: "%s"\n' %
304                                           (escape(s), escape(buf, False)))
305            except TIMEOUT as e:
306                if (TIMEOUT, None) in pattern:
307                    return buf, buf, TIMEOUT, None, pattern.index((TIMEOUT, None))
308                raise e
309        finally:
310            self.cv.release()
312def splitCommand(command_line):
313    arg_list = []
314    arg = ''
316    state_basic = 0
317    state_esc = 1
318    state_singlequote = 2
319    state_doublequote = 3
320    state_whitespace = 4
321    state = state_basic
322    pre_esc_state = state_basic
324    for c in command_line:
325        if state != state_esc and c == '\\':
326            pre_esc_state = state
327            state = state_esc
328        elif state == state_basic or state == state_whitespace:
329            if c == r"'":
330                state = state_singlequote
331            elif c == r'"':
332                state = state_doublequote
333            elif c.isspace():
334                if state == state_whitespace:
335                    None
336                else:
337                    arg_list.append(arg)
338                    arg = ''
339                    state = state_whitespace
340            else:
341                arg = arg + c
342                state = state_basic
343        elif state == state_esc:
344            arg = arg + c
345            state = pre_esc_state
346        elif state == state_singlequote:
347            if c == r"'":
348                state = state_basic
349            else:
350                arg = arg + c
351        elif state == state_doublequote:
352            if c == r'"':
353                state = state_basic
354            else:
355                arg = arg + c
357    if arg != '':
358        arg_list.append(arg)
360    return arg_list
362processes = {}
364def cleanup():
365    for key in processes.copy():
366        try:
367            killProcess(processes[key])
368        except:
369            pass
370    processes.clear()
372class Expect (object):
373    def __init__(self, command, startReader=True, timeout=30, logfile=None, mapping=None, desc=None, cwd=None, env=None,
374                 preexec_fn=None):
375        self.buf = "" # The part before the match
376        self.before = "" # The part before the match
377        self.after = "" # The part after the match
378        self.matchindex = 0 # the index of the matched pattern
379        self.match = None # The last match
380        self.mapping = mapping # The mapping of the test.
381        self.exitstatus = None # The exitstatus, either -signal or, if positive, the exit code.
382        self.killed = None # If killed, the signal that was sent.
383        self.desc = desc
384        self.logfile = logfile
385        self.timeout = timeout
386        self.p = None
388        if self.logfile:
389            self.logfile.write('spawn: "%s"\n' % command)
390            self.logfile.flush()
392        if win32:
393            # Don't rely on win32api
394            # import win32process
395            # creationflags = win32process.CREATE_NEW_PROCESS_GROUP)
396            #
397            # universal_newlines = True is necessary for Python 3 on Windows
398            #
399            # We can't use shell=True because terminate() wouldn't
400            # work. This means the PATH isn't searched for the
401            # command.
402            #
403            CREATE_NEW_PROCESS_GROUP = 512
404            self.p = subprocess.Popen(command, env=env, cwd=cwd, shell=False, bufsize=0, stdin=subprocess.PIPE,
405                                      stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
406                                      creationflags = CREATE_NEW_PROCESS_GROUP, universal_newlines=True)
407        else:
408            self.p = subprocess.Popen(splitCommand(command), env=env, cwd=cwd, shell=False, bufsize=0,
409                                      stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
410                                      preexec_fn=preexec_fn)
411        global processes
412        processes[self.p.pid] = self.p
414        self.r = reader(desc, self.p, logfile)
416        # The thread is marked as a daemon thread. This is done so that if
417        # an expect script runs off the end of main without kill/wait on each
418        # spawned process the script will not hang trying to join with the
419        # reader thread.
420        self.r.setDaemon(True)
422        if startReader:
423            self.startReader()
425    def __str__(self):
426        return "{0} pid={1}".format(self.desc, "<none>" if self.p is None else self.p.pid)
428    def startReader(self, watchDog = None):
429        if watchDog is not None:
430            self.r.setWatchDog(watchDog)
431        self.r.start()
433    def expect(self, pattern, timeout = 60):
434        """pattern is either a string, or a list of string regexp patterns.
436           timeout == None expect can block indefinitely.
438           timeout == -1 then the default is used.
439        """
440        if timeout == -1:
441            timeout = self.timeout
443        if type(pattern) != list:
444            pattern = [ pattern ]
445        def compile(s):
446            if type(s) == str:
447                return re.compile(s, re.S)
448            return None
449        pattern = [ ( p, compile(p) ) for p in pattern ]
450        try:
451            self.buf, self.before, self.after, self.match, self.matchindex = self.r.match(pattern, timeout)
452        except TIMEOUT as e:
453            self.buf = ""
454            self.before = ""
455            self.after = ""
456            self.match = None
457            self.matchindex = 0
458            raise e
459        return self.matchindex
461    def expectall(self, pattern, timeout = 60):
462        """pattern is a list of string regexp patterns.
464           timeout == None expect can block indefinitely.
466           timeout == -1 then the default is used.
467        """
468        if timeout == -1:
469            timeout = self.timeout
471        pattern = [ ( p, re.compile(p, re.S) ) for p in pattern ]
472        try:
473            self.buf = self.r.match(pattern, timeout, matchall = True)
474            self.before = ""
475            self.after = ""
476            self.matchindex = 0
477            self.match = None
478        except TIMEOUT as e:
479            self.buf = ""
480            self.before = ""
481            self.after = ""
482            self.matchindex = 0
483            self.match = None
484            raise e
486    def sendline(self, data):
487        """send data to the application.
488        """
489        if self.logfile:
490            self.logfile.write('%s: sendline: "%s"\n' % (self.desc, escape(data)))
491            self.logfile.flush()
492        data = data + "\n"
493        if win32 or sys.version_info[0] == 2:
494            self.p.stdin.write(data)
495        else:
496            self.p.stdin.write(data.encode("utf-8"))
498    def wait(self, timeout = None):
499        """Wait for the application to terminate for up to timeout seconds, or
500           raises a TIMEOUT exception. If timeout is None, the wait is
501           indefinite.
503           The exit status is returned. A negative exit status means
504           the application was killed by a signal.
505        """
506        if self.p is None:
507            return self.exitstatus
509        # Unfortunately, with the subprocess module there is no
510        # better method of doing a timed wait.
511        if timeout is not None:
512            end = time.time() + timeout
513            while time.time() < end and self.p and self.p.poll() is None:
514                time.sleep(0.1)
515            if self.p and self.p.poll() is None:
516                raise TIMEOUT ('timed wait exceeded timeout')
517        elif win32:
518            # We poll on Windows or otherwise KeyboardInterrupt isn't delivered
519            while self.p.poll() is None:
520                time.sleep(0.5)
522        if self.p is None:
523            return self.exitstatus
525        self.exitstatus = self.p.wait()
527        # A Windows application killed with CTRL_BREAK. Fudge the exit status.
528        if win32 and self.exitstatus != 0 and self.killed is not None:
529            self.exitstatus = -self.killed
530        global processes
531        if self.p.pid in processes:
532            del processes[self.p.pid]
533        self.p = None
534        self.r.join()
535        # Simulate a match on EOF
536        self.buf = self.r.getbuf()
537        self.before = self.buf
538        self.after = ""
539        #
540        # Without this we get warnings when runing with python_d on Windows
541        #
542        # ResourceWarning: unclosed file <_io.TextIOWrapper name=3 encoding='cp1252'>
543        #
544        self.r.p.stdout.close()
545        self.r.p.stdin.close()
546        self.r = None
548        return self.exitstatus
550    def terminate(self):
551        """Terminate the process."""
553        if self.p is None:
554            return
556        def kill():
557            ex = None
558            while True:
559                try:
560                    if not self.p:
561                        return
562                    killProcess(self.p)
563                    self.wait()
564                except KeyboardInterrupt as e:
565                    ex = e
566                    raise
567                except e:
568                    ex = e
569            if ex:
570                print(ex)
571                raise ex
573        try:
574            self.wait(timeout = 0.5)
575            return
576        except KeyboardInterrupt:
577            kill()
578            raise
579        except TIMEOUT:
580            pass
582        try:
583            terminateProcess(self.p, self.hasInterruptSupport())
584        except KeyboardInterrupt:
585            kill()
586            raise
587        except:
588            traceback.print_exc(file=sys.stdout)
590        # If the break does not terminate the process within 5
591        # seconds, then kill the process.
592        try:
593            self.wait(timeout = 5)
594            return
595        except KeyboardInterrupt:
596            kill()
597            raise
598        except TIMEOUT:
599            kill()
601    def kill(self, sig):
602        """Send the signal to the process."""
603        self.killed = sig # Save the sent signal.
604        if win32:
605            terminateProcess(self.p, self.hasInterruptSupport())
606        else:
607            os.kill(self.p.pid, sig)
609    def trace(self, suppress = None):
610        self.r.enabletrace(suppress)
612    def waitSuccess(self, exitstatus = 0, timeout = None):
613        """Wait for the process to terminate for up to timeout seconds, and
614           validate the exit status is as expected."""
616        def test(result, expected):
617            if not win32 and result == -2: # Interrupted by Ctrl-C, simulate KeyboardInterrupt
618                raise KeyboardInterrupt()
619            if expected != result:
620                raise RuntimeError("unexpected exit status: expected: %d, got %d\n" % (expected, result))
622        self.wait(timeout)
623        if self.mapping in ["java", "java-compat"]:
624            if self.killed is not None:
625                if win32:
626                    test(self.exitstatus, -self.killed)
627                else:
628                    if self.killed == signal.SIGINT:
629                        test(130, self.exitstatus)
630                    else:
631                        test(self.exitstatus, exitstatus)
632            else:
633                test(self.exitstatus, exitstatus)
634        else:
635            test(self.exitstatus, exitstatus)
637    def getOutput(self):
638        return self.buf if self.p is None else self.r.getbuf()
640    def hasInterruptSupport(self):
641        """Return True if the application gracefully terminated, False otherwise."""
642        if win32 and (self.mapping == "java" or self.mapping == "java-compat"):
643            return False
644        return True