1# encoding: utf-8
2"""IO capturing utilities."""
3
4# Copyright (c) IPython Development Team.
5# Distributed under the terms of the Modified BSD License.
6
7
8import sys
9from io import StringIO
10
11#-----------------------------------------------------------------------------
12# Classes and functions
13#-----------------------------------------------------------------------------
14
15
16class RichOutput(object):
17    def __init__(self, data=None, metadata=None, transient=None, update=False):
18        self.data = data or {}
19        self.metadata = metadata or {}
20        self.transient = transient or {}
21        self.update = update
22
23    def display(self):
24        from IPython.display import publish_display_data
25        publish_display_data(data=self.data, metadata=self.metadata,
26                             transient=self.transient, update=self.update)
27
28    def _repr_mime_(self, mime):
29        if mime not in self.data:
30            return
31        data = self.data[mime]
32        if mime in self.metadata:
33            return data, self.metadata[mime]
34        else:
35            return data
36
37    def _repr_mimebundle_(self, include=None, exclude=None):
38        return self.data, self.metadata
39
40    def _repr_html_(self):
41        return self._repr_mime_("text/html")
42
43    def _repr_latex_(self):
44        return self._repr_mime_("text/latex")
45
46    def _repr_json_(self):
47        return self._repr_mime_("application/json")
48
49    def _repr_javascript_(self):
50        return self._repr_mime_("application/javascript")
51
52    def _repr_png_(self):
53        return self._repr_mime_("image/png")
54
55    def _repr_jpeg_(self):
56        return self._repr_mime_("image/jpeg")
57
58    def _repr_svg_(self):
59        return self._repr_mime_("image/svg+xml")
60
61
62class CapturedIO(object):
63    """Simple object for containing captured stdout/err and rich display StringIO objects
64
65    Each instance `c` has three attributes:
66
67    - ``c.stdout`` : standard output as a string
68    - ``c.stderr`` : standard error as a string
69    - ``c.outputs``: a list of rich display outputs
70
71    Additionally, there's a ``c.show()`` method which will print all of the
72    above in the same order, and can be invoked simply via ``c()``.
73    """
74
75    def __init__(self, stdout, stderr, outputs=None):
76        self._stdout = stdout
77        self._stderr = stderr
78        if outputs is None:
79            outputs = []
80        self._outputs = outputs
81
82    def __str__(self):
83        return self.stdout
84
85    @property
86    def stdout(self):
87        "Captured standard output"
88        if not self._stdout:
89            return ''
90        return self._stdout.getvalue()
91
92    @property
93    def stderr(self):
94        "Captured standard error"
95        if not self._stderr:
96            return ''
97        return self._stderr.getvalue()
98
99    @property
100    def outputs(self):
101        """A list of the captured rich display outputs, if any.
102
103        If you have a CapturedIO object ``c``, these can be displayed in IPython
104        using::
105
106            from IPython.display import display
107            for o in c.outputs:
108                display(o)
109        """
110        return [ RichOutput(**kargs) for kargs in self._outputs ]
111
112    def show(self):
113        """write my output to sys.stdout/err as appropriate"""
114        sys.stdout.write(self.stdout)
115        sys.stderr.write(self.stderr)
116        sys.stdout.flush()
117        sys.stderr.flush()
118        for kargs in self._outputs:
119            RichOutput(**kargs).display()
120
121    __call__ = show
122
123
124class capture_output(object):
125    """context manager for capturing stdout/err"""
126    stdout = True
127    stderr = True
128    display = True
129
130    def __init__(self, stdout=True, stderr=True, display=True):
131        self.stdout = stdout
132        self.stderr = stderr
133        self.display = display
134        self.shell = None
135
136    def __enter__(self):
137        from IPython.core.getipython import get_ipython
138        from IPython.core.displaypub import CapturingDisplayPublisher
139        from IPython.core.displayhook import CapturingDisplayHook
140
141        self.sys_stdout = sys.stdout
142        self.sys_stderr = sys.stderr
143
144        if self.display:
145            self.shell = get_ipython()
146            if self.shell is None:
147                self.save_display_pub = None
148                self.display = False
149
150        stdout = stderr = outputs = None
151        if self.stdout:
152            stdout = sys.stdout = StringIO()
153        if self.stderr:
154            stderr = sys.stderr = StringIO()
155        if self.display:
156            self.save_display_pub = self.shell.display_pub
157            self.shell.display_pub = CapturingDisplayPublisher()
158            outputs = self.shell.display_pub.outputs
159            self.save_display_hook = sys.displayhook
160            sys.displayhook = CapturingDisplayHook(shell=self.shell,
161                                                   outputs=outputs)
162
163        return CapturedIO(stdout, stderr, outputs)
164
165    def __exit__(self, exc_type, exc_value, traceback):
166        sys.stdout = self.sys_stdout
167        sys.stderr = self.sys_stderr
168        if self.display and self.shell:
169            self.shell.display_pub = self.save_display_pub
170            sys.displayhook = self.save_display_hook
171