1import asyncio
2import errno
3
4from pexpect import EOF
5
6@asyncio.coroutine
7def expect_async(expecter, timeout=None):
8    # First process data that was previously read - if it maches, we don't need
9    # async stuff.
10    previously_read = expecter.spawn.buffer
11    expecter.spawn._buffer = expecter.spawn.buffer_type()
12    expecter.spawn._before = expecter.spawn.buffer_type()
13    idx = expecter.new_data(previously_read)
14    if idx is not None:
15        return idx
16    if not expecter.spawn.async_pw_transport:
17        pw = PatternWaiter()
18        pw.set_expecter(expecter)
19        transport, pw = yield from asyncio.get_event_loop()\
20            .connect_read_pipe(lambda: pw, expecter.spawn)
21        expecter.spawn.async_pw_transport = pw, transport
22    else:
23        pw, transport = expecter.spawn.async_pw_transport
24        pw.set_expecter(expecter)
25        transport.resume_reading()
26    try:
27        return (yield from asyncio.wait_for(pw.fut, timeout))
28    except asyncio.TimeoutError as e:
29        transport.pause_reading()
30        return expecter.timeout(e)
31
32
33class PatternWaiter(asyncio.Protocol):
34    transport = None
35
36    def set_expecter(self, expecter):
37        self.expecter = expecter
38        self.fut = asyncio.Future()
39
40    def found(self, result):
41        if not self.fut.done():
42            self.fut.set_result(result)
43            self.transport.pause_reading()
44
45    def error(self, exc):
46        if not self.fut.done():
47            self.fut.set_exception(exc)
48            self.transport.pause_reading()
49
50    def connection_made(self, transport):
51        self.transport = transport
52
53    def data_received(self, data):
54        spawn = self.expecter.spawn
55        s = spawn._decoder.decode(data)
56        spawn._log(s, 'read')
57
58        if self.fut.done():
59            spawn._buffer.write(s)
60            return
61
62        try:
63            index = self.expecter.new_data(s)
64            if index is not None:
65                # Found a match
66                self.found(index)
67        except Exception as e:
68            self.expecter.errored()
69            self.error(e)
70
71    def eof_received(self):
72        # N.B. If this gets called, async will close the pipe (the spawn object)
73        # for us
74        try:
75            self.expecter.spawn.flag_eof = True
76            index = self.expecter.eof()
77        except EOF as e:
78            self.error(e)
79        else:
80            self.found(index)
81
82    def connection_lost(self, exc):
83        if isinstance(exc, OSError) and exc.errno == errno.EIO:
84            # We may get here without eof_received being called, e.g on Linux
85            self.eof_received()
86        elif exc is not None:
87            self.error(exc)
88