1"""
2Script to autogenerate pyplot wrappers.
3
4When this script is run, the current contents of pyplot are
5split into generatable and non-generatable content (via the magic header
6:attr:`PYPLOT_MAGIC_HEADER`) and the generatable content is overwritten.
7Hence, the non-generatable content should be edited in the pyplot.py file
8itself, whereas the generatable content must be edited via templates in
9this file.
10"""
11
12# Although it is possible to dynamically generate the pyplot functions at
13# runtime with the proper signatures, a static pyplot.py is simpler for static
14# analysis tools to parse.
15
16from enum import Enum
17import inspect
18from inspect import Parameter
19from pathlib import Path
20import sys
21import textwrap
22
23# This line imports the installed copy of matplotlib, and not the local copy.
24import numpy as np
25from matplotlib import _api, mlab
26from matplotlib.axes import Axes
27from matplotlib.backend_bases import MouseButton
28from matplotlib.figure import Figure
29
30
31# we need to define a custom str because py310 change
32# In Python 3.10 the repr and str representation of Enums changed from
33#
34#  str: 'ClassName.NAME' -> 'NAME'
35#  repr: '<ClassName.NAME: value>' -> 'ClassName.NAME'
36#
37# which is more consistent with what str/repr should do, however this breaks
38# boilerplate which needs to get the ClassName.NAME version in all versions of
39# Python. Thus, we locally monkey patch our preferred str representation in
40# here.
41#
42# bpo-40066
43# https://github.com/python/cpython/pull/22392/
44def enum_str_back_compat_patch(self):
45    return f'{type(self).__name__}.{self.name}'
46
47# only monkey patch if we have to.
48if str(MouseButton.LEFT) != 'MouseButton.Left':
49    MouseButton.__str__ = enum_str_back_compat_patch
50
51
52# This is the magic line that must exist in pyplot, after which the boilerplate
53# content will be appended.
54PYPLOT_MAGIC_HEADER = (
55    "################# REMAINING CONTENT GENERATED BY boilerplate.py "
56    "##############\n")
57
58AUTOGEN_MSG = """
59
60# Autogenerated by boilerplate.py.  Do not edit as changes will be lost."""
61
62AXES_CMAPPABLE_METHOD_TEMPLATE = AUTOGEN_MSG + """
63@_copy_docstring_and_deprecators(Axes.{called_name})
64def {name}{signature}:
65    __ret = gca().{called_name}{call}
66    {sci_command}
67    return __ret
68"""
69
70AXES_METHOD_TEMPLATE = AUTOGEN_MSG + """
71@_copy_docstring_and_deprecators(Axes.{called_name})
72def {name}{signature}:
73    return gca().{called_name}{call}
74"""
75
76FIGURE_METHOD_TEMPLATE = AUTOGEN_MSG + """
77@_copy_docstring_and_deprecators(Figure.{called_name})
78def {name}{signature}:
79    return gcf().{called_name}{call}
80"""
81
82CMAP_TEMPLATE = "def {name}(): set_cmap({name!r})\n"  # Colormap functions.
83
84
85class value_formatter:
86    """
87    Format function default values as needed for inspect.formatargspec.
88    The interesting part is a hard-coded list of functions used
89    as defaults in pyplot methods.
90    """
91
92    def __init__(self, value):
93        if value is mlab.detrend_none:
94            self._repr = "mlab.detrend_none"
95        elif value is mlab.window_hanning:
96            self._repr = "mlab.window_hanning"
97        elif value is np.mean:
98            self._repr = "np.mean"
99        elif value is _api.deprecation._deprecated_parameter:
100            self._repr = "_api.deprecation._deprecated_parameter"
101        elif isinstance(value, Enum):
102            # Enum str is Class.Name whereas their repr is <Class.Name: value>.
103            self._repr = str(value)
104        else:
105            self._repr = repr(value)
106
107    def __repr__(self):
108        return self._repr
109
110
111def generate_function(name, called_fullname, template, **kwargs):
112    """
113    Create a wrapper function *pyplot_name* calling *call_name*.
114
115    Parameters
116    ----------
117    name : str
118        The function to be created.
119    called_fullname : str
120        The method to be wrapped in the format ``"Class.method"``.
121    template : str
122        The template to be used. The template must contain {}-style format
123        placeholders. The following placeholders are filled in:
124
125        - name: The function name.
126        - signature: The function signature (including parentheses).
127        - called_name: The name of the called function.
128        - call: Parameters passed to *called_name* (including parentheses).
129
130    **kwargs
131        Additional parameters are passed to ``template.format()``.
132    """
133    text_wrapper = textwrap.TextWrapper(
134        break_long_words=False, width=70,
135        initial_indent=' ' * 8, subsequent_indent=' ' * 8)
136
137    # Get signature of wrapped function.
138    class_name, called_name = called_fullname.split('.')
139    class_ = {'Axes': Axes, 'Figure': Figure}[class_name]
140
141    signature = inspect.signature(getattr(class_, called_name))
142    # Replace self argument.
143    params = list(signature.parameters.values())[1:]
144    signature = str(signature.replace(parameters=[
145        param.replace(default=value_formatter(param.default))
146        if param.default is not param.empty else param
147        for param in params]))
148    if len('def ' + name + signature) >= 80:
149        # Move opening parenthesis before newline.
150        signature = '(\n' + text_wrapper.fill(signature).replace('(', '', 1)
151    # How to call the wrapped function.
152    call = '(' + ', '.join((
153           # Pass "intended-as-positional" parameters positionally to avoid
154           # forcing third-party subclasses to reproduce the parameter names.
155           '{0}'
156           if param.kind in [
157               Parameter.POSITIONAL_OR_KEYWORD]
158              and param.default is Parameter.empty else
159           # Only pass the data kwarg if it is actually set, to avoid forcing
160           # third-party subclasses to support it.
161           '**({{"data": data}} if data is not None else {{}})'
162           # Avoid linebreaks in the middle of the expression, by using \0 as a
163           # placeholder that will be substituted after wrapping.
164           .replace(' ', '\0')
165           if param.name == "data" else
166           '{0}={0}'
167           if param.kind in [
168               Parameter.POSITIONAL_OR_KEYWORD,
169               Parameter.KEYWORD_ONLY] else
170           '*{0}'
171           if param.kind is Parameter.VAR_POSITIONAL else
172           '**{0}'
173           if param.kind is Parameter.VAR_KEYWORD else
174           # Intentionally crash for Parameter.POSITIONAL_ONLY.
175           None).format(param.name)
176       for param in params) + ')'
177    MAX_CALL_PREFIX = 18  # len('    __ret = gca().')
178    if MAX_CALL_PREFIX + max(len(name), len(called_name)) + len(call) >= 80:
179        call = '(\n' + text_wrapper.fill(call[1:]).replace('\0', ' ')
180    # Bail out in case of name collision.
181    for reserved in ('gca', 'gci', 'gcf', '__ret'):
182        if reserved in params:
183            raise ValueError(
184                f'Method {called_fullname} has kwarg named {reserved}')
185
186    return template.format(
187        name=name,
188        called_name=called_name,
189        signature=signature,
190        call=call,
191        **kwargs)
192
193
194def boilerplate_gen():
195    """Generator of lines for the automated part of pyplot."""
196
197    _figure_commands = (
198        'figimage',
199        'figtext:text',
200        'gca',
201        'gci:_gci',
202        'ginput',
203        'subplots_adjust',
204        'suptitle',
205        'waitforbuttonpress',
206    )
207
208    # These methods are all simple wrappers of Axes methods by the same name.
209    _axes_commands = (
210        'acorr',
211        'angle_spectrum',
212        'annotate',
213        'arrow',
214        'autoscale',
215        'axhline',
216        'axhspan',
217        'axis',
218        'axline',
219        'axvline',
220        'axvspan',
221        'bar',
222        'barbs',
223        'barh',
224        'bar_label',
225        'boxplot',
226        'broken_barh',
227        'clabel',
228        'cohere',
229        'contour',
230        'contourf',
231        'csd',
232        'errorbar',
233        'eventplot',
234        'fill',
235        'fill_between',
236        'fill_betweenx',
237        'grid',
238        'hexbin',
239        'hist',
240        'stairs',
241        'hist2d',
242        'hlines',
243        'imshow',
244        'legend',
245        'locator_params',
246        'loglog',
247        'magnitude_spectrum',
248        'margins',
249        'minorticks_off',
250        'minorticks_on',
251        'pcolor',
252        'pcolormesh',
253        'phase_spectrum',
254        'pie',
255        'plot',
256        'plot_date',
257        'psd',
258        'quiver',
259        'quiverkey',
260        'scatter',
261        'semilogx',
262        'semilogy',
263        'specgram',
264        'spy',
265        'stackplot',
266        'stem',
267        'step',
268        'streamplot',
269        'table',
270        'text',
271        'tick_params',
272        'ticklabel_format',
273        'tricontour',
274        'tricontourf',
275        'tripcolor',
276        'triplot',
277        'violinplot',
278        'vlines',
279        'xcorr',
280        # pyplot name : real name
281        'sci:_sci',
282        'title:set_title',
283        'xlabel:set_xlabel',
284        'ylabel:set_ylabel',
285        'xscale:set_xscale',
286        'yscale:set_yscale',
287    )
288
289    cmappable = {
290        'contour': 'if __ret._A is not None: sci(__ret)  # noqa',
291        'contourf': 'if __ret._A is not None: sci(__ret)  # noqa',
292        'hexbin': 'sci(__ret)',
293        'scatter': 'sci(__ret)',
294        'pcolor': 'sci(__ret)',
295        'pcolormesh': 'sci(__ret)',
296        'hist2d': 'sci(__ret[-1])',
297        'imshow': 'sci(__ret)',
298        'spy': 'if isinstance(__ret, cm.ScalarMappable): sci(__ret)  # noqa',
299        'quiver': 'sci(__ret)',
300        'specgram': 'sci(__ret[-1])',
301        'streamplot': 'sci(__ret.lines)',
302        'tricontour': 'if __ret._A is not None: sci(__ret)  # noqa',
303        'tricontourf': 'if __ret._A is not None: sci(__ret)  # noqa',
304        'tripcolor': 'sci(__ret)',
305    }
306
307    for spec in _figure_commands:
308        if ':' in spec:
309            name, called_name = spec.split(':')
310        else:
311            name = called_name = spec
312        yield generate_function(name, f'Figure.{called_name}',
313                                FIGURE_METHOD_TEMPLATE)
314
315    for spec in _axes_commands:
316        if ':' in spec:
317            name, called_name = spec.split(':')
318        else:
319            name = called_name = spec
320
321        template = (AXES_CMAPPABLE_METHOD_TEMPLATE if name in cmappable else
322                    AXES_METHOD_TEMPLATE)
323        yield generate_function(name, f'Axes.{called_name}', template,
324                                sci_command=cmappable.get(name))
325
326    yield AUTOGEN_MSG
327    yield '\n'
328    cmaps = (
329        'autumn',
330        'bone',
331        'cool',
332        'copper',
333        'flag',
334        'gray',
335        'hot',
336        'hsv',
337        'jet',
338        'pink',
339        'prism',
340        'spring',
341        'summer',
342        'winter',
343        'magma',
344        'inferno',
345        'plasma',
346        'viridis',
347        "nipy_spectral"
348    )
349    # add all the colormaps (autumn, hsv, ....)
350    for name in cmaps:
351        yield CMAP_TEMPLATE.format(name=name)
352
353    yield '\n\n'
354    yield '_setup_pyplot_info_docstrings()'
355
356
357def build_pyplot(pyplot_path):
358    pyplot_orig = pyplot_path.read_text().splitlines(keepends=True)
359    try:
360        pyplot_orig = pyplot_orig[:pyplot_orig.index(PYPLOT_MAGIC_HEADER) + 1]
361    except IndexError as err:
362        raise ValueError('The pyplot.py file *must* have the exact line: %s'
363                         % PYPLOT_MAGIC_HEADER) from err
364
365    with pyplot_path.open('w') as pyplot:
366        pyplot.writelines(pyplot_orig)
367        pyplot.writelines(boilerplate_gen())
368        pyplot.write('\n')
369
370
371if __name__ == '__main__':
372    # Write the matplotlib.pyplot file.
373    if len(sys.argv) > 1:
374        pyplot_path = Path(sys.argv[1])
375    else:
376        pyplot_path = Path(__file__).parent / "../lib/matplotlib/pyplot.py"
377    build_pyplot(pyplot_path)
378