1import os
2import sys
3import py
4import tempfile
5
6try:
7    from io import StringIO
8except ImportError:
9    from StringIO import StringIO
10
11if sys.version_info < (3,0):
12    class TextIO(StringIO):
13        def write(self, data):
14            if not isinstance(data, unicode):
15                data = unicode(data, getattr(self, '_encoding', 'UTF-8'), 'replace')
16            StringIO.write(self, data)
17else:
18    TextIO = StringIO
19
20try:
21    from io import BytesIO
22except ImportError:
23    class BytesIO(StringIO):
24        def write(self, data):
25            if isinstance(data, unicode):
26                raise TypeError("not a byte value: %r" %(data,))
27            StringIO.write(self, data)
28
29patchsysdict = {0: 'stdin', 1: 'stdout', 2: 'stderr'}
30
31class FDCapture:
32    """ Capture IO to/from a given os-level filedescriptor. """
33
34    def __init__(self, targetfd, tmpfile=None, now=True, patchsys=False):
35        """ save targetfd descriptor, and open a new
36            temporary file there.  If no tmpfile is
37            specified a tempfile.Tempfile() will be opened
38            in text mode.
39        """
40        self.targetfd = targetfd
41        if tmpfile is None and targetfd != 0:
42            f = tempfile.TemporaryFile('wb+')
43            tmpfile = dupfile(f, encoding="UTF-8")
44            f.close()
45        self.tmpfile = tmpfile
46        self._savefd = os.dup(self.targetfd)
47        if patchsys:
48            self._oldsys = getattr(sys, patchsysdict[targetfd])
49        if now:
50            self.start()
51
52    def start(self):
53        try:
54            os.fstat(self._savefd)
55        except OSError:
56            raise ValueError("saved filedescriptor not valid, "
57                "did you call start() twice?")
58        if self.targetfd == 0 and not self.tmpfile:
59            fd = os.open(devnullpath, os.O_RDONLY)
60            os.dup2(fd, 0)
61            os.close(fd)
62            if hasattr(self, '_oldsys'):
63                setattr(sys, patchsysdict[self.targetfd], DontReadFromInput())
64        else:
65            os.dup2(self.tmpfile.fileno(), self.targetfd)
66            if hasattr(self, '_oldsys'):
67                setattr(sys, patchsysdict[self.targetfd], self.tmpfile)
68
69    def done(self):
70        """ unpatch and clean up, returns the self.tmpfile (file object)
71        """
72        os.dup2(self._savefd, self.targetfd)
73        os.close(self._savefd)
74        if self.targetfd != 0:
75            self.tmpfile.seek(0)
76        if hasattr(self, '_oldsys'):
77            setattr(sys, patchsysdict[self.targetfd], self._oldsys)
78        return self.tmpfile
79
80    def writeorg(self, data):
81        """ write a string to the original file descriptor
82        """
83        tempfp = tempfile.TemporaryFile()
84        try:
85            os.dup2(self._savefd, tempfp.fileno())
86            tempfp.write(data)
87        finally:
88            tempfp.close()
89
90
91def dupfile(f, mode=None, buffering=0, raising=False, encoding=None):
92    """ return a new open file object that's a duplicate of f
93
94        mode is duplicated if not given, 'buffering' controls
95        buffer size (defaulting to no buffering) and 'raising'
96        defines whether an exception is raised when an incompatible
97        file object is passed in (if raising is False, the file
98        object itself will be returned)
99    """
100    try:
101        fd = f.fileno()
102        mode = mode or f.mode
103    except AttributeError:
104        if raising:
105            raise
106        return f
107    newfd = os.dup(fd)
108    if sys.version_info >= (3,0):
109        if encoding is not None:
110            mode = mode.replace("b", "")
111            buffering = True
112        return os.fdopen(newfd, mode, buffering, encoding, closefd=True)
113    else:
114        f = os.fdopen(newfd, mode, buffering)
115        if encoding is not None:
116            return EncodedFile(f, encoding)
117        return f
118
119class EncodedFile(object):
120    def __init__(self, _stream, encoding):
121        self._stream = _stream
122        self.encoding = encoding
123
124    def write(self, obj):
125        if isinstance(obj, unicode):
126            obj = obj.encode(self.encoding)
127        elif isinstance(obj, str):
128            pass
129        else:
130            obj = str(obj)
131        self._stream.write(obj)
132
133    def writelines(self, linelist):
134        data = ''.join(linelist)
135        self.write(data)
136
137    def __getattr__(self, name):
138        return getattr(self._stream, name)
139
140class Capture(object):
141    def call(cls, func, *args, **kwargs):
142        """ return a (res, out, err) tuple where
143            out and err represent the output/error output
144            during function execution.
145            call the given function with args/kwargs
146            and capture output/error during its execution.
147        """
148        so = cls()
149        try:
150            res = func(*args, **kwargs)
151        finally:
152            out, err = so.reset()
153        return res, out, err
154    call = classmethod(call)
155
156    def reset(self):
157        """ reset sys.stdout/stderr and return captured output as strings. """
158        if hasattr(self, '_reset'):
159            raise ValueError("was already reset")
160        self._reset = True
161        outfile, errfile = self.done(save=False)
162        out, err = "", ""
163        if outfile and not outfile.closed:
164            out = outfile.read()
165            outfile.close()
166        if errfile and errfile != outfile and not errfile.closed:
167            err = errfile.read()
168            errfile.close()
169        return out, err
170
171    def suspend(self):
172        """ return current snapshot captures, memorize tempfiles. """
173        outerr = self.readouterr()
174        outfile, errfile = self.done()
175        return outerr
176
177
178class StdCaptureFD(Capture):
179    """ This class allows to capture writes to FD1 and FD2
180        and may connect a NULL file to FD0 (and prevent
181        reads from sys.stdin).  If any of the 0,1,2 file descriptors
182        is invalid it will not be captured.
183    """
184    def __init__(self, out=True, err=True, mixed=False,
185        in_=True, patchsys=True, now=True):
186        self._options = {
187            "out": out,
188            "err": err,
189            "mixed": mixed,
190            "in_": in_,
191            "patchsys": patchsys,
192            "now": now,
193        }
194        self._save()
195        if now:
196            self.startall()
197
198    def _save(self):
199        in_ = self._options['in_']
200        out = self._options['out']
201        err = self._options['err']
202        mixed = self._options['mixed']
203        patchsys = self._options['patchsys']
204        if in_:
205            try:
206                self.in_ = FDCapture(0, tmpfile=None, now=False,
207                    patchsys=patchsys)
208            except OSError:
209                pass
210        if out:
211            tmpfile = None
212            if hasattr(out, 'write'):
213                tmpfile = out
214            try:
215                self.out = FDCapture(1, tmpfile=tmpfile,
216                           now=False, patchsys=patchsys)
217                self._options['out'] = self.out.tmpfile
218            except OSError:
219                pass
220        if err:
221            if out and mixed:
222                tmpfile = self.out.tmpfile
223            elif hasattr(err, 'write'):
224                tmpfile = err
225            else:
226                tmpfile = None
227            try:
228                self.err = FDCapture(2, tmpfile=tmpfile,
229                           now=False, patchsys=patchsys)
230                self._options['err'] = self.err.tmpfile
231            except OSError:
232                pass
233
234    def startall(self):
235        if hasattr(self, 'in_'):
236            self.in_.start()
237        if hasattr(self, 'out'):
238            self.out.start()
239        if hasattr(self, 'err'):
240            self.err.start()
241
242    def resume(self):
243        """ resume capturing with original temp files. """
244        self.startall()
245
246    def done(self, save=True):
247        """ return (outfile, errfile) and stop capturing. """
248        outfile = errfile = None
249        if hasattr(self, 'out') and not self.out.tmpfile.closed:
250            outfile = self.out.done()
251        if hasattr(self, 'err') and not self.err.tmpfile.closed:
252            errfile = self.err.done()
253        if hasattr(self, 'in_'):
254            tmpfile = self.in_.done()
255        if save:
256            self._save()
257        return outfile, errfile
258
259    def readouterr(self):
260        """ return snapshot value of stdout/stderr capturings. """
261        if hasattr(self, "out"):
262            out = self._readsnapshot(self.out.tmpfile)
263        else:
264            out = ""
265        if hasattr(self, "err"):
266            err = self._readsnapshot(self.err.tmpfile)
267        else:
268            err = ""
269        return [out, err]
270
271    def _readsnapshot(self, f):
272        f.seek(0)
273        res = f.read()
274        enc = getattr(f, "encoding", None)
275        if enc:
276            res = py.builtin._totext(res, enc, "replace")
277        f.truncate(0)
278        f.seek(0)
279        return res
280
281
282class StdCapture(Capture):
283    """ This class allows to capture writes to sys.stdout|stderr "in-memory"
284        and will raise errors on tries to read from sys.stdin. It only
285        modifies sys.stdout|stderr|stdin attributes and does not
286        touch underlying File Descriptors (use StdCaptureFD for that).
287    """
288    def __init__(self, out=True, err=True, in_=True, mixed=False, now=True):
289        self._oldout = sys.stdout
290        self._olderr = sys.stderr
291        self._oldin  = sys.stdin
292        if out and not hasattr(out, 'file'):
293            out = TextIO()
294        self.out = out
295        if err:
296            if mixed:
297                err = out
298            elif not hasattr(err, 'write'):
299                err = TextIO()
300        self.err = err
301        self.in_ = in_
302        if now:
303            self.startall()
304
305    def startall(self):
306        if self.out:
307            sys.stdout = self.out
308        if self.err:
309            sys.stderr = self.err
310        if self.in_:
311            sys.stdin  = self.in_  = DontReadFromInput()
312
313    def done(self, save=True):
314        """ return (outfile, errfile) and stop capturing. """
315        outfile = errfile = None
316        if self.out and not self.out.closed:
317            sys.stdout = self._oldout
318            outfile = self.out
319            outfile.seek(0)
320        if self.err and not self.err.closed:
321            sys.stderr = self._olderr
322            errfile = self.err
323            errfile.seek(0)
324        if self.in_:
325            sys.stdin = self._oldin
326        return outfile, errfile
327
328    def resume(self):
329        """ resume capturing with original temp files. """
330        self.startall()
331
332    def readouterr(self):
333        """ return snapshot value of stdout/stderr capturings. """
334        out = err = ""
335        if self.out:
336            out = self.out.getvalue()
337            self.out.truncate(0)
338            self.out.seek(0)
339        if self.err:
340            err = self.err.getvalue()
341            self.err.truncate(0)
342            self.err.seek(0)
343        return out, err
344
345class DontReadFromInput:
346    """Temporary stub class.  Ideally when stdin is accessed, the
347    capturing should be turned off, with possibly all data captured
348    so far sent to the screen.  This should be configurable, though,
349    because in automated test runs it is better to crash than
350    hang indefinitely.
351    """
352    def read(self, *args):
353        raise IOError("reading from stdin while output is captured")
354    readline = read
355    readlines = read
356    __iter__ = read
357
358    def fileno(self):
359        raise ValueError("redirected Stdin is pseudofile, has no fileno()")
360    def isatty(self):
361        return False
362    def close(self):
363        pass
364
365try:
366    devnullpath = os.devnull
367except AttributeError:
368    if os.name == 'nt':
369        devnullpath = 'NUL'
370    else:
371        devnullpath = '/dev/null'
372