""" Script to autogenerate pyplot wrappers. When this script is run, the current contents of pyplot are split into generatable and non-generatable content (via the magic header :attr:`PYPLOT_MAGIC_HEADER`) and the generatable content is overwritten. Hence, the non-generatable content should be edited in the pyplot.py file itself, whereas the generatable content must be edited via templates in this file. """ # Although it is possible to dynamically generate the pyplot functions at # runtime with the proper signatures, a static pyplot.py is simpler for static # analysis tools to parse. from enum import Enum import inspect from inspect import Parameter from pathlib import Path import sys import textwrap # This line imports the installed copy of matplotlib, and not the local copy. import numpy as np from matplotlib import _api, mlab from matplotlib.axes import Axes from matplotlib.backend_bases import MouseButton from matplotlib.figure import Figure # we need to define a custom str because py310 change # In Python 3.10 the repr and str representation of Enums changed from # # str: 'ClassName.NAME' -> 'NAME' # repr: '' -> 'ClassName.NAME' # # which is more consistent with what str/repr should do, however this breaks # boilerplate which needs to get the ClassName.NAME version in all versions of # Python. Thus, we locally monkey patch our preferred str representation in # here. # # bpo-40066 # https://github.com/python/cpython/pull/22392/ def enum_str_back_compat_patch(self): return f'{type(self).__name__}.{self.name}' # only monkey patch if we have to. if str(MouseButton.LEFT) != 'MouseButton.Left': MouseButton.__str__ = enum_str_back_compat_patch # This is the magic line that must exist in pyplot, after which the boilerplate # content will be appended. PYPLOT_MAGIC_HEADER = ( "################# REMAINING CONTENT GENERATED BY boilerplate.py " "##############\n") AUTOGEN_MSG = """ # Autogenerated by boilerplate.py. Do not edit as changes will be lost.""" AXES_CMAPPABLE_METHOD_TEMPLATE = AUTOGEN_MSG + """ @_copy_docstring_and_deprecators(Axes.{called_name}) def {name}{signature}: __ret = gca().{called_name}{call} {sci_command} return __ret """ AXES_METHOD_TEMPLATE = AUTOGEN_MSG + """ @_copy_docstring_and_deprecators(Axes.{called_name}) def {name}{signature}: return gca().{called_name}{call} """ FIGURE_METHOD_TEMPLATE = AUTOGEN_MSG + """ @_copy_docstring_and_deprecators(Figure.{called_name}) def {name}{signature}: return gcf().{called_name}{call} """ CMAP_TEMPLATE = "def {name}(): set_cmap({name!r})\n" # Colormap functions. class value_formatter: """ Format function default values as needed for inspect.formatargspec. The interesting part is a hard-coded list of functions used as defaults in pyplot methods. """ def __init__(self, value): if value is mlab.detrend_none: self._repr = "mlab.detrend_none" elif value is mlab.window_hanning: self._repr = "mlab.window_hanning" elif value is np.mean: self._repr = "np.mean" elif value is _api.deprecation._deprecated_parameter: self._repr = "_api.deprecation._deprecated_parameter" elif isinstance(value, Enum): # Enum str is Class.Name whereas their repr is . self._repr = str(value) else: self._repr = repr(value) def __repr__(self): return self._repr def generate_function(name, called_fullname, template, **kwargs): """ Create a wrapper function *pyplot_name* calling *call_name*. Parameters ---------- name : str The function to be created. called_fullname : str The method to be wrapped in the format ``"Class.method"``. template : str The template to be used. The template must contain {}-style format placeholders. The following placeholders are filled in: - name: The function name. - signature: The function signature (including parentheses). - called_name: The name of the called function. - call: Parameters passed to *called_name* (including parentheses). **kwargs Additional parameters are passed to ``template.format()``. """ text_wrapper = textwrap.TextWrapper( break_long_words=False, width=70, initial_indent=' ' * 8, subsequent_indent=' ' * 8) # Get signature of wrapped function. class_name, called_name = called_fullname.split('.') class_ = {'Axes': Axes, 'Figure': Figure}[class_name] signature = inspect.signature(getattr(class_, called_name)) # Replace self argument. params = list(signature.parameters.values())[1:] signature = str(signature.replace(parameters=[ param.replace(default=value_formatter(param.default)) if param.default is not param.empty else param for param in params])) if len('def ' + name + signature) >= 80: # Move opening parenthesis before newline. signature = '(\n' + text_wrapper.fill(signature).replace('(', '', 1) # How to call the wrapped function. call = '(' + ', '.join(( # Pass "intended-as-positional" parameters positionally to avoid # forcing third-party subclasses to reproduce the parameter names. '{0}' if param.kind in [ Parameter.POSITIONAL_OR_KEYWORD] and param.default is Parameter.empty else # Only pass the data kwarg if it is actually set, to avoid forcing # third-party subclasses to support it. '**({{"data": data}} if data is not None else {{}})' # Avoid linebreaks in the middle of the expression, by using \0 as a # placeholder that will be substituted after wrapping. .replace(' ', '\0') if param.name == "data" else '{0}={0}' if param.kind in [ Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY] else '*{0}' if param.kind is Parameter.VAR_POSITIONAL else '**{0}' if param.kind is Parameter.VAR_KEYWORD else # Intentionally crash for Parameter.POSITIONAL_ONLY. None).format(param.name) for param in params) + ')' MAX_CALL_PREFIX = 18 # len(' __ret = gca().') if MAX_CALL_PREFIX + max(len(name), len(called_name)) + len(call) >= 80: call = '(\n' + text_wrapper.fill(call[1:]).replace('\0', ' ') # Bail out in case of name collision. for reserved in ('gca', 'gci', 'gcf', '__ret'): if reserved in params: raise ValueError( f'Method {called_fullname} has kwarg named {reserved}') return template.format( name=name, called_name=called_name, signature=signature, call=call, **kwargs) def boilerplate_gen(): """Generator of lines for the automated part of pyplot.""" _figure_commands = ( 'figimage', 'figtext:text', 'gca', 'gci:_gci', 'ginput', 'subplots_adjust', 'suptitle', 'waitforbuttonpress', ) # These methods are all simple wrappers of Axes methods by the same name. _axes_commands = ( 'acorr', 'angle_spectrum', 'annotate', 'arrow', 'autoscale', 'axhline', 'axhspan', 'axis', 'axline', 'axvline', 'axvspan', 'bar', 'barbs', 'barh', 'bar_label', 'boxplot', 'broken_barh', 'clabel', 'cohere', 'contour', 'contourf', 'csd', 'errorbar', 'eventplot', 'fill', 'fill_between', 'fill_betweenx', 'grid', 'hexbin', 'hist', 'stairs', 'hist2d', 'hlines', 'imshow', 'legend', 'locator_params', 'loglog', 'magnitude_spectrum', 'margins', 'minorticks_off', 'minorticks_on', 'pcolor', 'pcolormesh', 'phase_spectrum', 'pie', 'plot', 'plot_date', 'psd', 'quiver', 'quiverkey', 'scatter', 'semilogx', 'semilogy', 'specgram', 'spy', 'stackplot', 'stem', 'step', 'streamplot', 'table', 'text', 'tick_params', 'ticklabel_format', 'tricontour', 'tricontourf', 'tripcolor', 'triplot', 'violinplot', 'vlines', 'xcorr', # pyplot name : real name 'sci:_sci', 'title:set_title', 'xlabel:set_xlabel', 'ylabel:set_ylabel', 'xscale:set_xscale', 'yscale:set_yscale', ) cmappable = { 'contour': 'if __ret._A is not None: sci(__ret) # noqa', 'contourf': 'if __ret._A is not None: sci(__ret) # noqa', 'hexbin': 'sci(__ret)', 'scatter': 'sci(__ret)', 'pcolor': 'sci(__ret)', 'pcolormesh': 'sci(__ret)', 'hist2d': 'sci(__ret[-1])', 'imshow': 'sci(__ret)', 'spy': 'if isinstance(__ret, cm.ScalarMappable): sci(__ret) # noqa', 'quiver': 'sci(__ret)', 'specgram': 'sci(__ret[-1])', 'streamplot': 'sci(__ret.lines)', 'tricontour': 'if __ret._A is not None: sci(__ret) # noqa', 'tricontourf': 'if __ret._A is not None: sci(__ret) # noqa', 'tripcolor': 'sci(__ret)', } for spec in _figure_commands: if ':' in spec: name, called_name = spec.split(':') else: name = called_name = spec yield generate_function(name, f'Figure.{called_name}', FIGURE_METHOD_TEMPLATE) for spec in _axes_commands: if ':' in spec: name, called_name = spec.split(':') else: name = called_name = spec template = (AXES_CMAPPABLE_METHOD_TEMPLATE if name in cmappable else AXES_METHOD_TEMPLATE) yield generate_function(name, f'Axes.{called_name}', template, sci_command=cmappable.get(name)) yield AUTOGEN_MSG yield '\n' cmaps = ( 'autumn', 'bone', 'cool', 'copper', 'flag', 'gray', 'hot', 'hsv', 'jet', 'pink', 'prism', 'spring', 'summer', 'winter', 'magma', 'inferno', 'plasma', 'viridis', "nipy_spectral" ) # add all the colormaps (autumn, hsv, ....) for name in cmaps: yield CMAP_TEMPLATE.format(name=name) yield '\n\n' yield '_setup_pyplot_info_docstrings()' def build_pyplot(pyplot_path): pyplot_orig = pyplot_path.read_text().splitlines(keepends=True) try: pyplot_orig = pyplot_orig[:pyplot_orig.index(PYPLOT_MAGIC_HEADER) + 1] except IndexError as err: raise ValueError('The pyplot.py file *must* have the exact line: %s' % PYPLOT_MAGIC_HEADER) from err with pyplot_path.open('w') as pyplot: pyplot.writelines(pyplot_orig) pyplot.writelines(boilerplate_gen()) pyplot.write('\n') if __name__ == '__main__': # Write the matplotlib.pyplot file. if len(sys.argv) > 1: pyplot_path = Path(sys.argv[1]) else: pyplot_path = Path(__file__).parent / "../lib/matplotlib/pyplot.py" build_pyplot(pyplot_path)