1"""
2Scales define the distribution of data values on an axis, e.g. a log scaling.
3
4They are attached to an `~.axis.Axis` and hold a `.Transform`, which is
5responsible for the actual data transformation.
6
7See also `.axes.Axes.set_xscale` and the scales examples in the documentation.
8"""
9
10import inspect
11import textwrap
12
13import numpy as np
14from numpy import ma
15
16import matplotlib as mpl
17from matplotlib import _api, docstring
18from matplotlib.ticker import (
19    NullFormatter, ScalarFormatter, LogFormatterSciNotation, LogitFormatter,
20    NullLocator, LogLocator, AutoLocator, AutoMinorLocator,
21    SymmetricalLogLocator, LogitLocator)
22from matplotlib.transforms import Transform, IdentityTransform
23
24
25class ScaleBase:
26    """
27    The base class for all scales.
28
29    Scales are separable transformations, working on a single dimension.
30
31    Any subclasses will want to override:
32
33    - :attr:`name`
34    - :meth:`get_transform`
35    - :meth:`set_default_locators_and_formatters`
36
37    And optionally:
38
39    - :meth:`limit_range_for_scale`
40
41    """
42
43    def __init__(self, axis):
44        r"""
45        Construct a new scale.
46
47        Notes
48        -----
49        The following note is for scale implementors.
50
51        For back-compatibility reasons, scales take an `~matplotlib.axis.Axis`
52        object as first argument.  However, this argument should not
53        be used: a single scale object should be usable by multiple
54        `~matplotlib.axis.Axis`\es at the same time.
55        """
56
57    def get_transform(self):
58        """
59        Return the :class:`~matplotlib.transforms.Transform` object
60        associated with this scale.
61        """
62        raise NotImplementedError()
63
64    def set_default_locators_and_formatters(self, axis):
65        """
66        Set the locators and formatters of *axis* to instances suitable for
67        this scale.
68        """
69        raise NotImplementedError()
70
71    def limit_range_for_scale(self, vmin, vmax, minpos):
72        """
73        Return the range *vmin*, *vmax*, restricted to the
74        domain supported by this scale (if any).
75
76        *minpos* should be the minimum positive value in the data.
77        This is used by log scales to determine a minimum value.
78        """
79        return vmin, vmax
80
81
82class LinearScale(ScaleBase):
83    """
84    The default linear scale.
85    """
86
87    name = 'linear'
88
89    def __init__(self, axis):
90        # This method is present only to prevent inheritance of the base class'
91        # constructor docstring, which would otherwise end up interpolated into
92        # the docstring of Axis.set_scale.
93        """
94        """
95
96    def set_default_locators_and_formatters(self, axis):
97        # docstring inherited
98        axis.set_major_locator(AutoLocator())
99        axis.set_major_formatter(ScalarFormatter())
100        axis.set_minor_formatter(NullFormatter())
101        # update the minor locator for x and y axis based on rcParams
102        if (axis.axis_name == 'x' and mpl.rcParams['xtick.minor.visible'] or
103                axis.axis_name == 'y' and mpl.rcParams['ytick.minor.visible']):
104            axis.set_minor_locator(AutoMinorLocator())
105        else:
106            axis.set_minor_locator(NullLocator())
107
108    def get_transform(self):
109        """
110        Return the transform for linear scaling, which is just the
111        `~matplotlib.transforms.IdentityTransform`.
112        """
113        return IdentityTransform()
114
115
116class FuncTransform(Transform):
117    """
118    A simple transform that takes and arbitrary function for the
119    forward and inverse transform.
120    """
121
122    input_dims = output_dims = 1
123
124    def __init__(self, forward, inverse):
125        """
126        Parameters
127        ----------
128        forward : callable
129            The forward function for the transform.  This function must have
130            an inverse and, for best behavior, be monotonic.
131            It must have the signature::
132
133               def forward(values: array-like) -> array-like
134
135        inverse : callable
136            The inverse of the forward function.  Signature as ``forward``.
137        """
138        super().__init__()
139        if callable(forward) and callable(inverse):
140            self._forward = forward
141            self._inverse = inverse
142        else:
143            raise ValueError('arguments to FuncTransform must be functions')
144
145    def transform_non_affine(self, values):
146        return self._forward(values)
147
148    def inverted(self):
149        return FuncTransform(self._inverse, self._forward)
150
151
152class FuncScale(ScaleBase):
153    """
154    Provide an arbitrary scale with user-supplied function for the axis.
155    """
156
157    name = 'function'
158
159    def __init__(self, axis, functions):
160        """
161        Parameters
162        ----------
163        axis : `~matplotlib.axis.Axis`
164            The axis for the scale.
165        functions : (callable, callable)
166            two-tuple of the forward and inverse functions for the scale.
167            The forward function must be monotonic.
168
169            Both functions must have the signature::
170
171               def forward(values: array-like) -> array-like
172        """
173        forward, inverse = functions
174        transform = FuncTransform(forward, inverse)
175        self._transform = transform
176
177    def get_transform(self):
178        """Return the `.FuncTransform` associated with this scale."""
179        return self._transform
180
181    def set_default_locators_and_formatters(self, axis):
182        # docstring inherited
183        axis.set_major_locator(AutoLocator())
184        axis.set_major_formatter(ScalarFormatter())
185        axis.set_minor_formatter(NullFormatter())
186        # update the minor locator for x and y axis based on rcParams
187        if (axis.axis_name == 'x' and mpl.rcParams['xtick.minor.visible'] or
188                axis.axis_name == 'y' and mpl.rcParams['ytick.minor.visible']):
189            axis.set_minor_locator(AutoMinorLocator())
190        else:
191            axis.set_minor_locator(NullLocator())
192
193
194class LogTransform(Transform):
195    input_dims = output_dims = 1
196
197    @_api.rename_parameter("3.3", "nonpos", "nonpositive")
198    def __init__(self, base, nonpositive='clip'):
199        super().__init__()
200        if base <= 0 or base == 1:
201            raise ValueError('The log base cannot be <= 0 or == 1')
202        self.base = base
203        self._clip = _api.check_getitem(
204            {"clip": True, "mask": False}, nonpositive=nonpositive)
205
206    def __str__(self):
207        return "{}(base={}, nonpositive={!r})".format(
208            type(self).__name__, self.base, "clip" if self._clip else "mask")
209
210    def transform_non_affine(self, a):
211        # Ignore invalid values due to nans being passed to the transform.
212        with np.errstate(divide="ignore", invalid="ignore"):
213            log = {np.e: np.log, 2: np.log2, 10: np.log10}.get(self.base)
214            if log:  # If possible, do everything in a single call to NumPy.
215                out = log(a)
216            else:
217                out = np.log(a)
218                out /= np.log(self.base)
219            if self._clip:
220                # SVG spec says that conforming viewers must support values up
221                # to 3.4e38 (C float); however experiments suggest that
222                # Inkscape (which uses cairo for rendering) runs into cairo's
223                # 24-bit limit (which is apparently shared by Agg).
224                # Ghostscript (used for pdf rendering appears to overflow even
225                # earlier, with the max value around 2 ** 15 for the tests to
226                # pass. On the other hand, in practice, we want to clip beyond
227                #     np.log10(np.nextafter(0, 1)) ~ -323
228                # so 1000 seems safe.
229                out[a <= 0] = -1000
230        return out
231
232    def inverted(self):
233        return InvertedLogTransform(self.base)
234
235
236class InvertedLogTransform(Transform):
237    input_dims = output_dims = 1
238
239    def __init__(self, base):
240        super().__init__()
241        self.base = base
242
243    def __str__(self):
244        return "{}(base={})".format(type(self).__name__, self.base)
245
246    def transform_non_affine(self, a):
247        return ma.power(self.base, a)
248
249    def inverted(self):
250        return LogTransform(self.base)
251
252
253class LogScale(ScaleBase):
254    """
255    A standard logarithmic scale.  Care is taken to only plot positive values.
256    """
257    name = 'log'
258
259    @_api.deprecated("3.3", alternative="scale.LogTransform")
260    @property
261    def LogTransform(self):
262        return LogTransform
263
264    @_api.deprecated("3.3", alternative="scale.InvertedLogTransform")
265    @property
266    def InvertedLogTransform(self):
267        return InvertedLogTransform
268
269    def __init__(self, axis, **kwargs):
270        """
271        Parameters
272        ----------
273        axis : `~matplotlib.axis.Axis`
274            The axis for the scale.
275        base : float, default: 10
276            The base of the logarithm.
277        nonpositive : {'clip', 'mask'}, default: 'clip'
278            Determines the behavior for non-positive values. They can either
279            be masked as invalid, or clipped to a very small positive number.
280        subs : sequence of int, default: None
281            Where to place the subticks between each major tick.  For example,
282            in a log10 scale, ``[2, 3, 4, 5, 6, 7, 8, 9]`` will place 8
283            logarithmically spaced minor ticks between each major tick.
284        """
285        # After the deprecation, the whole (outer) __init__ can be replaced by
286        # def __init__(self, axis, *, base=10, subs=None, nonpositive="clip")
287        # The following is to emit the right warnings depending on the axis
288        # used, as the *old* kwarg names depended on the axis.
289        axis_name = getattr(axis, "axis_name", "x")
290        @_api.rename_parameter("3.3", f"base{axis_name}", "base")
291        @_api.rename_parameter("3.3", f"subs{axis_name}", "subs")
292        @_api.rename_parameter("3.3", f"nonpos{axis_name}", "nonpositive")
293        def __init__(*, base=10, subs=None, nonpositive="clip"):
294            return base, subs, nonpositive
295
296        base, subs, nonpositive = __init__(**kwargs)
297        self._transform = LogTransform(base, nonpositive)
298        self.subs = subs
299
300    base = property(lambda self: self._transform.base)
301
302    def set_default_locators_and_formatters(self, axis):
303        # docstring inherited
304        axis.set_major_locator(LogLocator(self.base))
305        axis.set_major_formatter(LogFormatterSciNotation(self.base))
306        axis.set_minor_locator(LogLocator(self.base, self.subs))
307        axis.set_minor_formatter(
308            LogFormatterSciNotation(self.base,
309                                    labelOnlyBase=(self.subs is not None)))
310
311    def get_transform(self):
312        """Return the `.LogTransform` associated with this scale."""
313        return self._transform
314
315    def limit_range_for_scale(self, vmin, vmax, minpos):
316        """Limit the domain to positive values."""
317        if not np.isfinite(minpos):
318            minpos = 1e-300  # Should rarely (if ever) have a visible effect.
319
320        return (minpos if vmin <= 0 else vmin,
321                minpos if vmax <= 0 else vmax)
322
323
324class FuncScaleLog(LogScale):
325    """
326    Provide an arbitrary scale with user-supplied function for the axis and
327    then put on a logarithmic axes.
328    """
329
330    name = 'functionlog'
331
332    def __init__(self, axis, functions, base=10):
333        """
334        Parameters
335        ----------
336        axis : `matplotlib.axis.Axis`
337            The axis for the scale.
338        functions : (callable, callable)
339            two-tuple of the forward and inverse functions for the scale.
340            The forward function must be monotonic.
341
342            Both functions must have the signature::
343
344                def forward(values: array-like) -> array-like
345
346        base : float, default: 10
347            Logarithmic base of the scale.
348        """
349        forward, inverse = functions
350        self.subs = None
351        self._transform = FuncTransform(forward, inverse) + LogTransform(base)
352
353    @property
354    def base(self):
355        return self._transform._b.base  # Base of the LogTransform.
356
357    def get_transform(self):
358        """Return the `.Transform` associated with this scale."""
359        return self._transform
360
361
362class SymmetricalLogTransform(Transform):
363    input_dims = output_dims = 1
364
365    def __init__(self, base, linthresh, linscale):
366        super().__init__()
367        if base <= 1.0:
368            raise ValueError("'base' must be larger than 1")
369        if linthresh <= 0.0:
370            raise ValueError("'linthresh' must be positive")
371        if linscale <= 0.0:
372            raise ValueError("'linscale' must be positive")
373        self.base = base
374        self.linthresh = linthresh
375        self.linscale = linscale
376        self._linscale_adj = (linscale / (1.0 - self.base ** -1))
377        self._log_base = np.log(base)
378
379    def transform_non_affine(self, a):
380        abs_a = np.abs(a)
381        with np.errstate(divide="ignore", invalid="ignore"):
382            out = np.sign(a) * self.linthresh * (
383                self._linscale_adj +
384                np.log(abs_a / self.linthresh) / self._log_base)
385            inside = abs_a <= self.linthresh
386        out[inside] = a[inside] * self._linscale_adj
387        return out
388
389    def inverted(self):
390        return InvertedSymmetricalLogTransform(self.base, self.linthresh,
391                                               self.linscale)
392
393
394class InvertedSymmetricalLogTransform(Transform):
395    input_dims = output_dims = 1
396
397    def __init__(self, base, linthresh, linscale):
398        super().__init__()
399        symlog = SymmetricalLogTransform(base, linthresh, linscale)
400        self.base = base
401        self.linthresh = linthresh
402        self.invlinthresh = symlog.transform(linthresh)
403        self.linscale = linscale
404        self._linscale_adj = (linscale / (1.0 - self.base ** -1))
405
406    def transform_non_affine(self, a):
407        abs_a = np.abs(a)
408        with np.errstate(divide="ignore", invalid="ignore"):
409            out = np.sign(a) * self.linthresh * (
410                np.power(self.base,
411                         abs_a / self.linthresh - self._linscale_adj))
412            inside = abs_a <= self.invlinthresh
413        out[inside] = a[inside] / self._linscale_adj
414        return out
415
416    def inverted(self):
417        return SymmetricalLogTransform(self.base,
418                                       self.linthresh, self.linscale)
419
420
421class SymmetricalLogScale(ScaleBase):
422    """
423    The symmetrical logarithmic scale is logarithmic in both the
424    positive and negative directions from the origin.
425
426    Since the values close to zero tend toward infinity, there is a
427    need to have a range around zero that is linear.  The parameter
428    *linthresh* allows the user to specify the size of this range
429    (-*linthresh*, *linthresh*).
430
431    Parameters
432    ----------
433    base : float, default: 10
434        The base of the logarithm.
435
436    linthresh : float, default: 2
437        Defines the range ``(-x, x)``, within which the plot is linear.
438        This avoids having the plot go to infinity around zero.
439
440    subs : sequence of int
441        Where to place the subticks between each major tick.
442        For example, in a log10 scale: ``[2, 3, 4, 5, 6, 7, 8, 9]`` will place
443        8 logarithmically spaced minor ticks between each major tick.
444
445    linscale : float, optional
446        This allows the linear range ``(-linthresh, linthresh)`` to be
447        stretched relative to the logarithmic range. Its value is the number of
448        decades to use for each half of the linear range. For example, when
449        *linscale* == 1.0 (the default), the space used for the positive and
450        negative halves of the linear range will be equal to one decade in
451        the logarithmic range.
452    """
453    name = 'symlog'
454
455    @_api.deprecated("3.3", alternative="scale.SymmetricalLogTransform")
456    @property
457    def SymmetricalLogTransform(self):
458        return SymmetricalLogTransform
459
460    @_api.deprecated(
461        "3.3", alternative="scale.InvertedSymmetricalLogTransform")
462    @property
463    def InvertedSymmetricalLogTransform(self):
464        return InvertedSymmetricalLogTransform
465
466    def __init__(self, axis, **kwargs):
467        axis_name = getattr(axis, "axis_name", "x")
468        # See explanation in LogScale.__init__.
469        @_api.rename_parameter("3.3", f"base{axis_name}", "base")
470        @_api.rename_parameter("3.3", f"linthresh{axis_name}", "linthresh")
471        @_api.rename_parameter("3.3", f"subs{axis_name}", "subs")
472        @_api.rename_parameter("3.3", f"linscale{axis_name}", "linscale")
473        def __init__(*, base=10, linthresh=2, subs=None, linscale=1):
474            return base, linthresh, subs, linscale
475
476        base, linthresh, subs, linscale = __init__(**kwargs)
477        self._transform = SymmetricalLogTransform(base, linthresh, linscale)
478        self.subs = subs
479
480    base = property(lambda self: self._transform.base)
481    linthresh = property(lambda self: self._transform.linthresh)
482    linscale = property(lambda self: self._transform.linscale)
483
484    def set_default_locators_and_formatters(self, axis):
485        # docstring inherited
486        axis.set_major_locator(SymmetricalLogLocator(self.get_transform()))
487        axis.set_major_formatter(LogFormatterSciNotation(self.base))
488        axis.set_minor_locator(SymmetricalLogLocator(self.get_transform(),
489                                                     self.subs))
490        axis.set_minor_formatter(NullFormatter())
491
492    def get_transform(self):
493        """Return the `.SymmetricalLogTransform` associated with this scale."""
494        return self._transform
495
496
497class LogitTransform(Transform):
498    input_dims = output_dims = 1
499
500    @_api.rename_parameter("3.3", "nonpos", "nonpositive")
501    def __init__(self, nonpositive='mask'):
502        super().__init__()
503        _api.check_in_list(['mask', 'clip'], nonpositive=nonpositive)
504        self._nonpositive = nonpositive
505        self._clip = {"clip": True, "mask": False}[nonpositive]
506
507    def transform_non_affine(self, a):
508        """logit transform (base 10), masked or clipped"""
509        with np.errstate(divide="ignore", invalid="ignore"):
510            out = np.log10(a / (1 - a))
511        if self._clip:  # See LogTransform for choice of clip value.
512            out[a <= 0] = -1000
513            out[1 <= a] = 1000
514        return out
515
516    def inverted(self):
517        return LogisticTransform(self._nonpositive)
518
519    def __str__(self):
520        return "{}({!r})".format(type(self).__name__, self._nonpositive)
521
522
523class LogisticTransform(Transform):
524    input_dims = output_dims = 1
525
526    @_api.rename_parameter("3.3", "nonpos", "nonpositive")
527    def __init__(self, nonpositive='mask'):
528        super().__init__()
529        self._nonpositive = nonpositive
530
531    def transform_non_affine(self, a):
532        """logistic transform (base 10)"""
533        return 1.0 / (1 + 10**(-a))
534
535    def inverted(self):
536        return LogitTransform(self._nonpositive)
537
538    def __str__(self):
539        return "{}({!r})".format(type(self).__name__, self._nonpositive)
540
541
542class LogitScale(ScaleBase):
543    """
544    Logit scale for data between zero and one, both excluded.
545
546    This scale is similar to a log scale close to zero and to one, and almost
547    linear around 0.5. It maps the interval ]0, 1[ onto ]-infty, +infty[.
548    """
549    name = 'logit'
550
551    @_api.rename_parameter("3.3", "nonpos", "nonpositive")
552    def __init__(self, axis, nonpositive='mask', *,
553                 one_half=r"\frac{1}{2}", use_overline=False):
554        r"""
555        Parameters
556        ----------
557        axis : `matplotlib.axis.Axis`
558            Currently unused.
559        nonpositive : {'mask', 'clip'}
560            Determines the behavior for values beyond the open interval ]0, 1[.
561            They can either be masked as invalid, or clipped to a number very
562            close to 0 or 1.
563        use_overline : bool, default: False
564            Indicate the usage of survival notation (\overline{x}) in place of
565            standard notation (1-x) for probability close to one.
566        one_half : str, default: r"\frac{1}{2}"
567            The string used for ticks formatter to represent 1/2.
568        """
569        self._transform = LogitTransform(nonpositive)
570        self._use_overline = use_overline
571        self._one_half = one_half
572
573    def get_transform(self):
574        """Return the `.LogitTransform` associated with this scale."""
575        return self._transform
576
577    def set_default_locators_and_formatters(self, axis):
578        # docstring inherited
579        # ..., 0.01, 0.1, 0.5, 0.9, 0.99, ...
580        axis.set_major_locator(LogitLocator())
581        axis.set_major_formatter(
582            LogitFormatter(
583                one_half=self._one_half,
584                use_overline=self._use_overline
585            )
586        )
587        axis.set_minor_locator(LogitLocator(minor=True))
588        axis.set_minor_formatter(
589            LogitFormatter(
590                minor=True,
591                one_half=self._one_half,
592                use_overline=self._use_overline
593            )
594        )
595
596    def limit_range_for_scale(self, vmin, vmax, minpos):
597        """
598        Limit the domain to values between 0 and 1 (excluded).
599        """
600        if not np.isfinite(minpos):
601            minpos = 1e-7  # Should rarely (if ever) have a visible effect.
602        return (minpos if vmin <= 0 else vmin,
603                1 - minpos if vmax >= 1 else vmax)
604
605
606_scale_mapping = {
607    'linear': LinearScale,
608    'log':    LogScale,
609    'symlog': SymmetricalLogScale,
610    'logit':  LogitScale,
611    'function': FuncScale,
612    'functionlog': FuncScaleLog,
613    }
614
615
616def get_scale_names():
617    """Return the names of the available scales."""
618    return sorted(_scale_mapping)
619
620
621def scale_factory(scale, axis, **kwargs):
622    """
623    Return a scale class by name.
624
625    Parameters
626    ----------
627    scale : {%(names)s}
628    axis : `matplotlib.axis.Axis`
629    """
630    scale = scale.lower()
631    _api.check_in_list(_scale_mapping, scale=scale)
632    return _scale_mapping[scale](axis, **kwargs)
633
634
635if scale_factory.__doc__:
636    scale_factory.__doc__ = scale_factory.__doc__ % {
637        "names": ", ".join(map(repr, get_scale_names()))}
638
639
640def register_scale(scale_class):
641    """
642    Register a new kind of scale.
643
644    Parameters
645    ----------
646    scale_class : subclass of `ScaleBase`
647        The scale to register.
648    """
649    _scale_mapping[scale_class.name] = scale_class
650
651
652def _get_scale_docs():
653    """
654    Helper function for generating docstrings related to scales.
655    """
656    docs = []
657    for name, scale_class in _scale_mapping.items():
658        docs.extend([
659            f"    {name!r}",
660            "",
661            textwrap.indent(inspect.getdoc(scale_class.__init__), " " * 8),
662            ""
663        ])
664    return "\n".join(docs)
665
666
667docstring.interpd.update(
668    scale_type='{%s}' % ', '.join([repr(x) for x in get_scale_names()]),
669    scale_docs=_get_scale_docs().rstrip(),
670    )
671