1__all__ = 'create_subprocess_exec', 'create_subprocess_shell'
2
3import subprocess
4import warnings
5
6from . import events
7from . import protocols
8from . import streams
9from . import tasks
10from .log import logger
11
12
13PIPE = subprocess.PIPE
14STDOUT = subprocess.STDOUT
15DEVNULL = subprocess.DEVNULL
16
17
18class SubprocessStreamProtocol(streams.FlowControlMixin,
19                               protocols.SubprocessProtocol):
20    """Like StreamReaderProtocol, but for a subprocess."""
21
22    def __init__(self, limit, loop):
23        super().__init__(loop=loop)
24        self._limit = limit
25        self.stdin = self.stdout = self.stderr = None
26        self._transport = None
27        self._process_exited = False
28        self._pipe_fds = []
29        self._stdin_closed = self._loop.create_future()
30
31    def __repr__(self):
32        info = [self.__class__.__name__]
33        if self.stdin is not None:
34            info.append(f'stdin={self.stdin!r}')
35        if self.stdout is not None:
36            info.append(f'stdout={self.stdout!r}')
37        if self.stderr is not None:
38            info.append(f'stderr={self.stderr!r}')
39        return '<{}>'.format(' '.join(info))
40
41    def connection_made(self, transport):
42        self._transport = transport
43
44        stdout_transport = transport.get_pipe_transport(1)
45        if stdout_transport is not None:
46            self.stdout = streams.StreamReader(limit=self._limit,
47                                               loop=self._loop)
48            self.stdout.set_transport(stdout_transport)
49            self._pipe_fds.append(1)
50
51        stderr_transport = transport.get_pipe_transport(2)
52        if stderr_transport is not None:
53            self.stderr = streams.StreamReader(limit=self._limit,
54                                               loop=self._loop)
55            self.stderr.set_transport(stderr_transport)
56            self._pipe_fds.append(2)
57
58        stdin_transport = transport.get_pipe_transport(0)
59        if stdin_transport is not None:
60            self.stdin = streams.StreamWriter(stdin_transport,
61                                              protocol=self,
62                                              reader=None,
63                                              loop=self._loop)
64
65    def pipe_data_received(self, fd, data):
66        if fd == 1:
67            reader = self.stdout
68        elif fd == 2:
69            reader = self.stderr
70        else:
71            reader = None
72        if reader is not None:
73            reader.feed_data(data)
74
75    def pipe_connection_lost(self, fd, exc):
76        if fd == 0:
77            pipe = self.stdin
78            if pipe is not None:
79                pipe.close()
80            self.connection_lost(exc)
81            if exc is None:
82                self._stdin_closed.set_result(None)
83            else:
84                self._stdin_closed.set_exception(exc)
85            return
86        if fd == 1:
87            reader = self.stdout
88        elif fd == 2:
89            reader = self.stderr
90        else:
91            reader = None
92        if reader is not None:
93            if exc is None:
94                reader.feed_eof()
95            else:
96                reader.set_exception(exc)
97
98        if fd in self._pipe_fds:
99            self._pipe_fds.remove(fd)
100        self._maybe_close_transport()
101
102    def process_exited(self):
103        self._process_exited = True
104        self._maybe_close_transport()
105
106    def _maybe_close_transport(self):
107        if len(self._pipe_fds) == 0 and self._process_exited:
108            self._transport.close()
109            self._transport = None
110
111    def _get_close_waiter(self, stream):
112        if stream is self.stdin:
113            return self._stdin_closed
114
115
116class Process:
117    def __init__(self, transport, protocol, loop):
118        self._transport = transport
119        self._protocol = protocol
120        self._loop = loop
121        self.stdin = protocol.stdin
122        self.stdout = protocol.stdout
123        self.stderr = protocol.stderr
124        self.pid = transport.get_pid()
125
126    def __repr__(self):
127        return f'<{self.__class__.__name__} {self.pid}>'
128
129    @property
130    def returncode(self):
131        return self._transport.get_returncode()
132
133    async def wait(self):
134        """Wait until the process exit and return the process return code."""
135        return await self._transport._wait()
136
137    def send_signal(self, signal):
138        self._transport.send_signal(signal)
139
140    def terminate(self):
141        self._transport.terminate()
142
143    def kill(self):
144        self._transport.kill()
145
146    async def _feed_stdin(self, input):
147        debug = self._loop.get_debug()
148        self.stdin.write(input)
149        if debug:
150            logger.debug(
151                '%r communicate: feed stdin (%s bytes)', self, len(input))
152        try:
153            await self.stdin.drain()
154        except (BrokenPipeError, ConnectionResetError) as exc:
155            # communicate() ignores BrokenPipeError and ConnectionResetError
156            if debug:
157                logger.debug('%r communicate: stdin got %r', self, exc)
158
159        if debug:
160            logger.debug('%r communicate: close stdin', self)
161        self.stdin.close()
162
163    async def _noop(self):
164        return None
165
166    async def _read_stream(self, fd):
167        transport = self._transport.get_pipe_transport(fd)
168        if fd == 2:
169            stream = self.stderr
170        else:
171            assert fd == 1
172            stream = self.stdout
173        if self._loop.get_debug():
174            name = 'stdout' if fd == 1 else 'stderr'
175            logger.debug('%r communicate: read %s', self, name)
176        output = await stream.read()
177        if self._loop.get_debug():
178            name = 'stdout' if fd == 1 else 'stderr'
179            logger.debug('%r communicate: close %s', self, name)
180        transport.close()
181        return output
182
183    async def communicate(self, input=None):
184        if input is not None:
185            stdin = self._feed_stdin(input)
186        else:
187            stdin = self._noop()
188        if self.stdout is not None:
189            stdout = self._read_stream(1)
190        else:
191            stdout = self._noop()
192        if self.stderr is not None:
193            stderr = self._read_stream(2)
194        else:
195            stderr = self._noop()
196        stdin, stdout, stderr = await tasks.gather(stdin, stdout, stderr,
197                                                   loop=self._loop)
198        await self.wait()
199        return (stdout, stderr)
200
201
202async def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None,
203                                  loop=None, limit=streams._DEFAULT_LIMIT,
204                                  **kwds):
205    if loop is None:
206        loop = events.get_event_loop()
207    else:
208        warnings.warn("The loop argument is deprecated since Python 3.8 "
209                      "and scheduled for removal in Python 3.10.",
210                      DeprecationWarning,
211                      stacklevel=2
212        )
213
214    protocol_factory = lambda: SubprocessStreamProtocol(limit=limit,
215                                                        loop=loop)
216    transport, protocol = await loop.subprocess_shell(
217        protocol_factory,
218        cmd, stdin=stdin, stdout=stdout,
219        stderr=stderr, **kwds)
220    return Process(transport, protocol, loop)
221
222
223async def create_subprocess_exec(program, *args, stdin=None, stdout=None,
224                                 stderr=None, loop=None,
225                                 limit=streams._DEFAULT_LIMIT, **kwds):
226    if loop is None:
227        loop = events.get_event_loop()
228    else:
229        warnings.warn("The loop argument is deprecated since Python 3.8 "
230                      "and scheduled for removal in Python 3.10.",
231                      DeprecationWarning,
232                      stacklevel=2
233        )
234    protocol_factory = lambda: SubprocessStreamProtocol(limit=limit,
235                                                        loop=loop)
236    transport, protocol = await loop.subprocess_exec(
237        protocol_factory,
238        program, *args,
239        stdin=stdin, stdout=stdout,
240        stderr=stderr, **kwds)
241    return Process(transport, protocol, loop)
242