1""" 2Script to autogenerate pyplot wrappers. 3 4When this script is run, the current contents of pyplot are 5split into generatable and non-generatable content (via the magic header 6:attr:`PYPLOT_MAGIC_HEADER`) and the generatable content is overwritten. 7Hence, the non-generatable content should be edited in the pyplot.py file 8itself, whereas the generatable content must be edited via templates in 9this file. 10""" 11 12# Although it is possible to dynamically generate the pyplot functions at 13# runtime with the proper signatures, a static pyplot.py is simpler for static 14# analysis tools to parse. 15 16from enum import Enum 17import inspect 18from inspect import Parameter 19from pathlib import Path 20import sys 21import textwrap 22 23# This line imports the installed copy of matplotlib, and not the local copy. 24import numpy as np 25from matplotlib import _api, mlab 26from matplotlib.axes import Axes 27from matplotlib.backend_bases import MouseButton 28from matplotlib.figure import Figure 29 30 31# we need to define a custom str because py310 change 32# In Python 3.10 the repr and str representation of Enums changed from 33# 34# str: 'ClassName.NAME' -> 'NAME' 35# repr: '<ClassName.NAME: value>' -> 'ClassName.NAME' 36# 37# which is more consistent with what str/repr should do, however this breaks 38# boilerplate which needs to get the ClassName.NAME version in all versions of 39# Python. Thus, we locally monkey patch our preferred str representation in 40# here. 41# 42# bpo-40066 43# https://github.com/python/cpython/pull/22392/ 44def enum_str_back_compat_patch(self): 45 return f'{type(self).__name__}.{self.name}' 46 47# only monkey patch if we have to. 48if str(MouseButton.LEFT) != 'MouseButton.Left': 49 MouseButton.__str__ = enum_str_back_compat_patch 50 51 52# This is the magic line that must exist in pyplot, after which the boilerplate 53# content will be appended. 54PYPLOT_MAGIC_HEADER = ( 55 "################# REMAINING CONTENT GENERATED BY boilerplate.py " 56 "##############\n") 57 58AUTOGEN_MSG = """ 59 60# Autogenerated by boilerplate.py. Do not edit as changes will be lost.""" 61 62AXES_CMAPPABLE_METHOD_TEMPLATE = AUTOGEN_MSG + """ 63@_copy_docstring_and_deprecators(Axes.{called_name}) 64def {name}{signature}: 65 __ret = gca().{called_name}{call} 66 {sci_command} 67 return __ret 68""" 69 70AXES_METHOD_TEMPLATE = AUTOGEN_MSG + """ 71@_copy_docstring_and_deprecators(Axes.{called_name}) 72def {name}{signature}: 73 return gca().{called_name}{call} 74""" 75 76FIGURE_METHOD_TEMPLATE = AUTOGEN_MSG + """ 77@_copy_docstring_and_deprecators(Figure.{called_name}) 78def {name}{signature}: 79 return gcf().{called_name}{call} 80""" 81 82CMAP_TEMPLATE = "def {name}(): set_cmap({name!r})\n" # Colormap functions. 83 84 85class value_formatter: 86 """ 87 Format function default values as needed for inspect.formatargspec. 88 The interesting part is a hard-coded list of functions used 89 as defaults in pyplot methods. 90 """ 91 92 def __init__(self, value): 93 if value is mlab.detrend_none: 94 self._repr = "mlab.detrend_none" 95 elif value is mlab.window_hanning: 96 self._repr = "mlab.window_hanning" 97 elif value is np.mean: 98 self._repr = "np.mean" 99 elif value is _api.deprecation._deprecated_parameter: 100 self._repr = "_api.deprecation._deprecated_parameter" 101 elif isinstance(value, Enum): 102 # Enum str is Class.Name whereas their repr is <Class.Name: value>. 103 self._repr = str(value) 104 else: 105 self._repr = repr(value) 106 107 def __repr__(self): 108 return self._repr 109 110 111def generate_function(name, called_fullname, template, **kwargs): 112 """ 113 Create a wrapper function *pyplot_name* calling *call_name*. 114 115 Parameters 116 ---------- 117 name : str 118 The function to be created. 119 called_fullname : str 120 The method to be wrapped in the format ``"Class.method"``. 121 template : str 122 The template to be used. The template must contain {}-style format 123 placeholders. The following placeholders are filled in: 124 125 - name: The function name. 126 - signature: The function signature (including parentheses). 127 - called_name: The name of the called function. 128 - call: Parameters passed to *called_name* (including parentheses). 129 130 **kwargs 131 Additional parameters are passed to ``template.format()``. 132 """ 133 text_wrapper = textwrap.TextWrapper( 134 break_long_words=False, width=70, 135 initial_indent=' ' * 8, subsequent_indent=' ' * 8) 136 137 # Get signature of wrapped function. 138 class_name, called_name = called_fullname.split('.') 139 class_ = {'Axes': Axes, 'Figure': Figure}[class_name] 140 141 signature = inspect.signature(getattr(class_, called_name)) 142 # Replace self argument. 143 params = list(signature.parameters.values())[1:] 144 signature = str(signature.replace(parameters=[ 145 param.replace(default=value_formatter(param.default)) 146 if param.default is not param.empty else param 147 for param in params])) 148 if len('def ' + name + signature) >= 80: 149 # Move opening parenthesis before newline. 150 signature = '(\n' + text_wrapper.fill(signature).replace('(', '', 1) 151 # How to call the wrapped function. 152 call = '(' + ', '.join(( 153 # Pass "intended-as-positional" parameters positionally to avoid 154 # forcing third-party subclasses to reproduce the parameter names. 155 '{0}' 156 if param.kind in [ 157 Parameter.POSITIONAL_OR_KEYWORD] 158 and param.default is Parameter.empty else 159 # Only pass the data kwarg if it is actually set, to avoid forcing 160 # third-party subclasses to support it. 161 '**({{"data": data}} if data is not None else {{}})' 162 # Avoid linebreaks in the middle of the expression, by using \0 as a 163 # placeholder that will be substituted after wrapping. 164 .replace(' ', '\0') 165 if param.name == "data" else 166 '{0}={0}' 167 if param.kind in [ 168 Parameter.POSITIONAL_OR_KEYWORD, 169 Parameter.KEYWORD_ONLY] else 170 '*{0}' 171 if param.kind is Parameter.VAR_POSITIONAL else 172 '**{0}' 173 if param.kind is Parameter.VAR_KEYWORD else 174 # Intentionally crash for Parameter.POSITIONAL_ONLY. 175 None).format(param.name) 176 for param in params) + ')' 177 MAX_CALL_PREFIX = 18 # len(' __ret = gca().') 178 if MAX_CALL_PREFIX + max(len(name), len(called_name)) + len(call) >= 80: 179 call = '(\n' + text_wrapper.fill(call[1:]).replace('\0', ' ') 180 # Bail out in case of name collision. 181 for reserved in ('gca', 'gci', 'gcf', '__ret'): 182 if reserved in params: 183 raise ValueError( 184 f'Method {called_fullname} has kwarg named {reserved}') 185 186 return template.format( 187 name=name, 188 called_name=called_name, 189 signature=signature, 190 call=call, 191 **kwargs) 192 193 194def boilerplate_gen(): 195 """Generator of lines for the automated part of pyplot.""" 196 197 _figure_commands = ( 198 'figimage', 199 'figtext:text', 200 'gca', 201 'gci:_gci', 202 'ginput', 203 'subplots_adjust', 204 'suptitle', 205 'waitforbuttonpress', 206 ) 207 208 # These methods are all simple wrappers of Axes methods by the same name. 209 _axes_commands = ( 210 'acorr', 211 'angle_spectrum', 212 'annotate', 213 'arrow', 214 'autoscale', 215 'axhline', 216 'axhspan', 217 'axis', 218 'axline', 219 'axvline', 220 'axvspan', 221 'bar', 222 'barbs', 223 'barh', 224 'bar_label', 225 'boxplot', 226 'broken_barh', 227 'clabel', 228 'cohere', 229 'contour', 230 'contourf', 231 'csd', 232 'errorbar', 233 'eventplot', 234 'fill', 235 'fill_between', 236 'fill_betweenx', 237 'grid', 238 'hexbin', 239 'hist', 240 'stairs', 241 'hist2d', 242 'hlines', 243 'imshow', 244 'legend', 245 'locator_params', 246 'loglog', 247 'magnitude_spectrum', 248 'margins', 249 'minorticks_off', 250 'minorticks_on', 251 'pcolor', 252 'pcolormesh', 253 'phase_spectrum', 254 'pie', 255 'plot', 256 'plot_date', 257 'psd', 258 'quiver', 259 'quiverkey', 260 'scatter', 261 'semilogx', 262 'semilogy', 263 'specgram', 264 'spy', 265 'stackplot', 266 'stem', 267 'step', 268 'streamplot', 269 'table', 270 'text', 271 'tick_params', 272 'ticklabel_format', 273 'tricontour', 274 'tricontourf', 275 'tripcolor', 276 'triplot', 277 'violinplot', 278 'vlines', 279 'xcorr', 280 # pyplot name : real name 281 'sci:_sci', 282 'title:set_title', 283 'xlabel:set_xlabel', 284 'ylabel:set_ylabel', 285 'xscale:set_xscale', 286 'yscale:set_yscale', 287 ) 288 289 cmappable = { 290 'contour': 'if __ret._A is not None: sci(__ret) # noqa', 291 'contourf': 'if __ret._A is not None: sci(__ret) # noqa', 292 'hexbin': 'sci(__ret)', 293 'scatter': 'sci(__ret)', 294 'pcolor': 'sci(__ret)', 295 'pcolormesh': 'sci(__ret)', 296 'hist2d': 'sci(__ret[-1])', 297 'imshow': 'sci(__ret)', 298 'spy': 'if isinstance(__ret, cm.ScalarMappable): sci(__ret) # noqa', 299 'quiver': 'sci(__ret)', 300 'specgram': 'sci(__ret[-1])', 301 'streamplot': 'sci(__ret.lines)', 302 'tricontour': 'if __ret._A is not None: sci(__ret) # noqa', 303 'tricontourf': 'if __ret._A is not None: sci(__ret) # noqa', 304 'tripcolor': 'sci(__ret)', 305 } 306 307 for spec in _figure_commands: 308 if ':' in spec: 309 name, called_name = spec.split(':') 310 else: 311 name = called_name = spec 312 yield generate_function(name, f'Figure.{called_name}', 313 FIGURE_METHOD_TEMPLATE) 314 315 for spec in _axes_commands: 316 if ':' in spec: 317 name, called_name = spec.split(':') 318 else: 319 name = called_name = spec 320 321 template = (AXES_CMAPPABLE_METHOD_TEMPLATE if name in cmappable else 322 AXES_METHOD_TEMPLATE) 323 yield generate_function(name, f'Axes.{called_name}', template, 324 sci_command=cmappable.get(name)) 325 326 yield AUTOGEN_MSG 327 yield '\n' 328 cmaps = ( 329 'autumn', 330 'bone', 331 'cool', 332 'copper', 333 'flag', 334 'gray', 335 'hot', 336 'hsv', 337 'jet', 338 'pink', 339 'prism', 340 'spring', 341 'summer', 342 'winter', 343 'magma', 344 'inferno', 345 'plasma', 346 'viridis', 347 "nipy_spectral" 348 ) 349 # add all the colormaps (autumn, hsv, ....) 350 for name in cmaps: 351 yield CMAP_TEMPLATE.format(name=name) 352 353 yield '\n\n' 354 yield '_setup_pyplot_info_docstrings()' 355 356 357def build_pyplot(pyplot_path): 358 pyplot_orig = pyplot_path.read_text().splitlines(keepends=True) 359 try: 360 pyplot_orig = pyplot_orig[:pyplot_orig.index(PYPLOT_MAGIC_HEADER) + 1] 361 except IndexError as err: 362 raise ValueError('The pyplot.py file *must* have the exact line: %s' 363 % PYPLOT_MAGIC_HEADER) from err 364 365 with pyplot_path.open('w') as pyplot: 366 pyplot.writelines(pyplot_orig) 367 pyplot.writelines(boilerplate_gen()) 368 pyplot.write('\n') 369 370 371if __name__ == '__main__': 372 # Write the matplotlib.pyplot file. 373 if len(sys.argv) > 1: 374 pyplot_path = Path(sys.argv[1]) 375 else: 376 pyplot_path = Path(__file__).parent / "../lib/matplotlib/pyplot.py" 377 build_pyplot(pyplot_path) 378