1"""Correlation plot functions."""
2from statsmodels.compat.pandas import deprecate_kwarg
3
4import calendar
5import warnings
6
7import numpy as np
8import pandas as pd
9
10from statsmodels.graphics import utils
11from statsmodels.tsa.stattools import acf, pacf
12
13
14def _prepare_data_corr_plot(x, lags, zero):
15    zero = bool(zero)
16    irregular = False if zero else True
17    if lags is None:
18        # GH 4663 - use a sensible default value
19        nobs = x.shape[0]
20        lim = min(int(np.ceil(10 * np.log10(nobs))), nobs - 1)
21        lags = np.arange(not zero, lim + 1)
22    elif np.isscalar(lags):
23        lags = np.arange(not zero, int(lags) + 1)  # +1 for zero lag
24    else:
25        irregular = True
26        lags = np.asanyarray(lags).astype(int)
27    nlags = lags.max(0)
28
29    return lags, nlags, irregular
30
31
32def _plot_corr(
33    ax,
34    title,
35    acf_x,
36    confint,
37    lags,
38    irregular,
39    use_vlines,
40    vlines_kwargs,
41    auto_ylims=False,
42    **kwargs,
43):
44    if irregular:
45        acf_x = acf_x[lags]
46        if confint is not None:
47            confint = confint[lags]
48
49    if use_vlines:
50        ax.vlines(lags, [0], acf_x, **vlines_kwargs)
51        ax.axhline(**kwargs)
52
53    kwargs.setdefault("marker", "o")
54    kwargs.setdefault("markersize", 5)
55    if "ls" not in kwargs:
56        # gh-2369
57        kwargs.setdefault("linestyle", "None")
58    ax.margins(0.05)
59    ax.plot(lags, acf_x, **kwargs)
60    ax.set_title(title)
61
62    ax.set_ylim(-1, 1)
63    if auto_ylims:
64        ax.set_ylim(
65            1.25 * np.minimum(min(acf_x), min(confint[:, 0] - acf_x)),
66            1.25 * np.maximum(max(acf_x), max(confint[:, 1] - acf_x)),
67        )
68
69    if confint is not None:
70        if lags[0] == 0:
71            lags = lags[1:]
72            confint = confint[1:]
73            acf_x = acf_x[1:]
74        lags = lags.astype(float)
75        lags[0] -= 0.5
76        lags[-1] += 0.5
77        ax.fill_between(
78            lags, confint[:, 0] - acf_x, confint[:, 1] - acf_x, alpha=0.25
79        )
80
81
82@deprecate_kwarg("unbiased", "adjusted")
83def plot_acf(
84    x,
85    ax=None,
86    lags=None,
87    *,
88    alpha=0.05,
89    use_vlines=True,
90    adjusted=False,
91    fft=False,
92    missing="none",
93    title="Autocorrelation",
94    zero=True,
95    auto_ylims=False,
96    bartlett_confint=True,
97    vlines_kwargs=None,
98    **kwargs,
99):
100    """
101    Plot the autocorrelation function
102
103    Plots lags on the horizontal and the correlations on vertical axis.
104
105    Parameters
106    ----------
107    x : array_like
108        Array of time-series values
109    ax : AxesSubplot, optional
110        If given, this subplot is used to plot in instead of a new figure being
111        created.
112    lags : {int, array_like}, optional
113        An int or array of lag values, used on horizontal axis. Uses
114        np.arange(lags) when lags is an int.  If not provided,
115        ``lags=np.arange(len(corr))`` is used.
116    alpha : scalar, optional
117        If a number is given, the confidence intervals for the given level are
118        returned. For instance if alpha=.05, 95 % confidence intervals are
119        returned where the standard deviation is computed according to
120        Bartlett's formula. If None, no confidence intervals are plotted.
121    use_vlines : bool, optional
122        If True, vertical lines and markers are plotted.
123        If False, only markers are plotted.  The default marker is 'o'; it can
124        be overridden with a ``marker`` kwarg.
125    adjusted : bool
126        If True, then denominators for autocovariance are n-k, otherwise n
127    fft : bool, optional
128        If True, computes the ACF via FFT.
129    missing : str, optional
130        A string in ['none', 'raise', 'conservative', 'drop'] specifying how
131        the NaNs are to be treated.
132    title : str, optional
133        Title to place on plot.  Default is 'Autocorrelation'
134    zero : bool, optional
135        Flag indicating whether to include the 0-lag autocorrelation.
136        Default is True.
137    auto_ylims : bool, optional
138        If True, adjusts automatically the y-axis limits to ACF values.
139    bartlett_confint : bool, default True
140        Confidence intervals for ACF values are generally placed at 2
141        standard errors around r_k. The formula used for standard error
142        depends upon the situation. If the autocorrelations are being used
143        to test for randomness of residuals as part of the ARIMA routine,
144        the standard errors are determined assuming the residuals are white
145        noise. The approximate formula for any lag is that standard error
146        of each r_k = 1/sqrt(N). See section 9.4 of [1] for more details on
147        the 1/sqrt(N) result. For more elementary discussion, see section
148        5.3.2 in [2].
149        For the ACF of raw data, the standard error at a lag k is
150        found as if the right model was an MA(k-1). This allows the
151        possible interpretation that if all autocorrelations past a
152        certain lag are within the limits, the model might be an MA of
153        order defined by the last significant autocorrelation. In this
154        case, a moving average model is assumed for the data and the
155        standard errors for the confidence intervals should be
156        generated using Bartlett's formula. For more details on
157        Bartlett formula result, see section 7.2 in [1].
158    vlines_kwargs : dict, optional
159        Optional dictionary of keyword arguments that are passed to vlines.
160    **kwargs : kwargs, optional
161        Optional keyword arguments that are directly passed on to the
162        Matplotlib ``plot`` and ``axhline`` functions.
163
164    Returns
165    -------
166    Figure
167        If `ax` is None, the created figure.  Otherwise the figure to which
168        `ax` is connected.
169
170    See Also
171    --------
172    matplotlib.pyplot.xcorr
173    matplotlib.pyplot.acorr
174
175    Notes
176    -----
177    Adapted from matplotlib's `xcorr`.
178
179    Data are plotted as ``plot(lags, corr, **kwargs)``
180
181    kwargs is used to pass matplotlib optional arguments to both the line
182    tracing the autocorrelations and for the horizontal line at 0. These
183    options must be valid for a Line2D object.
184
185    vlines_kwargs is used to pass additional optional arguments to the
186    vertical lines connecting each autocorrelation to the axis.  These options
187    must be valid for a LineCollection object.
188
189    References
190    ----------
191    [1] Brockwell and Davis, 1987. Time Series Theory and Methods
192    [2] Brockwell and Davis, 2010. Introduction to Time Series and
193    Forecasting, 2nd edition.
194
195    Examples
196    --------
197    >>> import pandas as pd
198    >>> import matplotlib.pyplot as plt
199    >>> import statsmodels.api as sm
200
201    >>> dta = sm.datasets.sunspots.load_pandas().data
202    >>> dta.index = pd.Index(sm.tsa.datetools.dates_from_range('1700', '2008'))
203    >>> del dta["YEAR"]
204    >>> sm.graphics.tsa.plot_acf(dta.values.squeeze(), lags=40)
205    >>> plt.show()
206
207    .. plot:: plots/graphics_tsa_plot_acf.py
208    """
209    fig, ax = utils.create_mpl_ax(ax)
210
211    lags, nlags, irregular = _prepare_data_corr_plot(x, lags, zero)
212    vlines_kwargs = {} if vlines_kwargs is None else vlines_kwargs
213
214    confint = None
215    # acf has different return type based on alpha
216    acf_x = acf(
217        x,
218        nlags=nlags,
219        alpha=alpha,
220        fft=fft,
221        bartlett_confint=bartlett_confint,
222        adjusted=adjusted,
223        missing=missing,
224    )
225    if alpha is not None:
226        acf_x, confint = acf_x[:2]
227
228    _plot_corr(
229        ax,
230        title,
231        acf_x,
232        confint,
233        lags,
234        irregular,
235        use_vlines,
236        vlines_kwargs,
237        auto_ylims=auto_ylims,
238        **kwargs,
239    )
240
241    return fig
242
243
244def plot_pacf(
245    x,
246    ax=None,
247    lags=None,
248    alpha=0.05,
249    method=None,
250    use_vlines=True,
251    title="Partial Autocorrelation",
252    zero=True,
253    vlines_kwargs=None,
254    **kwargs,
255):
256    """
257    Plot the partial autocorrelation function
258
259    Parameters
260    ----------
261    x : array_like
262        Array of time-series values
263    ax : AxesSubplot, optional
264        If given, this subplot is used to plot in instead of a new figure being
265        created.
266    lags : {int, array_like}, optional
267        An int or array of lag values, used on horizontal axis. Uses
268        np.arange(lags) when lags is an int.  If not provided,
269        ``lags=np.arange(len(corr))`` is used.
270    alpha : float, optional
271        If a number is given, the confidence intervals for the given level are
272        returned. For instance if alpha=.05, 95 % confidence intervals are
273        returned where the standard deviation is computed according to
274        1/sqrt(len(x))
275    method : str
276        Specifies which method for the calculations to use:
277
278        - "ywm" or "ywmle" : Yule-Walker without adjustment. Default.
279        - "yw" or "ywadjusted" : Yule-Walker with sample-size adjustment in
280          denominator for acovf. Default.
281        - "ols" : regression of time series on lags of it and on constant.
282        - "ols-inefficient" : regression of time series on lags using a single
283          common sample to estimate all pacf coefficients.
284        - "ols-adjusted" : regression of time series on lags with a bias
285          adjustment.
286        - "ld" or "ldadjusted" : Levinson-Durbin recursion with bias
287          correction.
288        - "ldb" or "ldbiased" : Levinson-Durbin recursion without bias
289          correction.
290
291    use_vlines : bool, optional
292        If True, vertical lines and markers are plotted.
293        If False, only markers are plotted.  The default marker is 'o'; it can
294        be overridden with a ``marker`` kwarg.
295    title : str, optional
296        Title to place on plot.  Default is 'Partial Autocorrelation'
297    zero : bool, optional
298        Flag indicating whether to include the 0-lag autocorrelation.
299        Default is True.
300    vlines_kwargs : dict, optional
301        Optional dictionary of keyword arguments that are passed to vlines.
302    **kwargs : kwargs, optional
303        Optional keyword arguments that are directly passed on to the
304        Matplotlib ``plot`` and ``axhline`` functions.
305
306    Returns
307    -------
308    Figure
309        If `ax` is None, the created figure.  Otherwise the figure to which
310        `ax` is connected.
311
312    See Also
313    --------
314    matplotlib.pyplot.xcorr
315    matplotlib.pyplot.acorr
316
317    Notes
318    -----
319    Plots lags on the horizontal and the correlations on vertical axis.
320    Adapted from matplotlib's `xcorr`.
321
322    Data are plotted as ``plot(lags, corr, **kwargs)``
323
324    kwargs is used to pass matplotlib optional arguments to both the line
325    tracing the autocorrelations and for the horizontal line at 0. These
326    options must be valid for a Line2D object.
327
328    vlines_kwargs is used to pass additional optional arguments to the
329    vertical lines connecting each autocorrelation to the axis.  These options
330    must be valid for a LineCollection object.
331
332    Examples
333    --------
334    >>> import pandas as pd
335    >>> import matplotlib.pyplot as plt
336    >>> import statsmodels.api as sm
337
338    >>> dta = sm.datasets.sunspots.load_pandas().data
339    >>> dta.index = pd.Index(sm.tsa.datetools.dates_from_range('1700', '2008'))
340    >>> del dta["YEAR"]
341    >>> sm.graphics.tsa.plot_pacf(dta.values.squeeze(), lags=40, method="ywm")
342    >>> plt.show()
343
344    .. plot:: plots/graphics_tsa_plot_pacf.py
345    """
346    if method is None:
347        method = "yw"
348        warnings.warn(
349            "The default method 'yw' can produce PACF values outside of "
350            "the [-1,1] interval. After 0.13, the default will change to"
351            "unadjusted Yule-Walker ('ywm'). You can use this method now "
352            "by setting method='ywm'.",
353            FutureWarning,
354        )
355    fig, ax = utils.create_mpl_ax(ax)
356    vlines_kwargs = {} if vlines_kwargs is None else vlines_kwargs
357    lags, nlags, irregular = _prepare_data_corr_plot(x, lags, zero)
358
359    confint = None
360    if alpha is None:
361        acf_x = pacf(x, nlags=nlags, alpha=alpha, method=method)
362    else:
363        acf_x, confint = pacf(x, nlags=nlags, alpha=alpha, method=method)
364
365    _plot_corr(
366        ax,
367        title,
368        acf_x,
369        confint,
370        lags,
371        irregular,
372        use_vlines,
373        vlines_kwargs,
374        **kwargs,
375    )
376
377    return fig
378
379
380def seasonal_plot(grouped_x, xticklabels, ylabel=None, ax=None):
381    """
382    Consider using one of month_plot or quarter_plot unless you need
383    irregular plotting.
384
385    Parameters
386    ----------
387    grouped_x : iterable of DataFrames
388        Should be a GroupBy object (or similar pair of group_names and groups
389        as DataFrames) with a DatetimeIndex or PeriodIndex
390    xticklabels : list of str
391        List of season labels, one for each group.
392    ylabel : str
393        Lable for y axis
394    ax : AxesSubplot, optional
395        If given, this subplot is used to plot in instead of a new figure being
396        created.
397    """
398    fig, ax = utils.create_mpl_ax(ax)
399    start = 0
400    ticks = []
401    for season, df in grouped_x:
402        df = df.copy()  # or sort balks for series. may be better way
403        df.sort_index()
404        nobs = len(df)
405        x_plot = np.arange(start, start + nobs)
406        ticks.append(x_plot.mean())
407        ax.plot(x_plot, df.values, "k")
408        ax.hlines(
409            df.values.mean(), x_plot[0], x_plot[-1], colors="r", linewidth=3
410        )
411        start += nobs
412
413    ax.set_xticks(ticks)
414    ax.set_xticklabels(xticklabels)
415    ax.set_ylabel(ylabel)
416    ax.margins(0.1, 0.05)
417    return fig
418
419
420def month_plot(x, dates=None, ylabel=None, ax=None):
421    """
422    Seasonal plot of monthly data.
423
424    Parameters
425    ----------
426    x : array_like
427        Seasonal data to plot. If dates is None, x must be a pandas object
428        with a PeriodIndex or DatetimeIndex with a monthly frequency.
429    dates : array_like, optional
430        If `x` is not a pandas object, then dates must be supplied.
431    ylabel : str, optional
432        The label for the y-axis. Will attempt to use the `name` attribute
433        of the Series.
434    ax : Axes, optional
435        Existing axes instance.
436
437    Returns
438    -------
439    Figure
440       If `ax` is provided, the Figure instance attached to `ax`. Otherwise
441       a new Figure instance.
442
443    Examples
444    --------
445    >>> import statsmodels.api as sm
446    >>> import pandas as pd
447
448    >>> dta = sm.datasets.elnino.load_pandas().data
449    >>> dta['YEAR'] = dta.YEAR.astype(int).astype(str)
450    >>> dta = dta.set_index('YEAR').T.unstack()
451    >>> dates = pd.to_datetime(list(map(lambda x: '-'.join(x) + '-1',
452    ...                                 dta.index.values)))
453    >>> dta.index = pd.DatetimeIndex(dates, freq='MS')
454    >>> fig = sm.graphics.tsa.month_plot(dta)
455
456    .. plot:: plots/graphics_tsa_month_plot.py
457    """
458
459    if dates is None:
460        from statsmodels.tools.data import _check_period_index
461
462        _check_period_index(x, freq="M")
463    else:
464        x = pd.Series(x, index=pd.PeriodIndex(dates, freq="M"))
465
466    # there's no zero month
467    xticklabels = list(calendar.month_abbr)[1:]
468    return seasonal_plot(
469        x.groupby(lambda y: y.month), xticklabels, ylabel=ylabel, ax=ax
470    )
471
472
473def quarter_plot(x, dates=None, ylabel=None, ax=None):
474    """
475    Seasonal plot of quarterly data
476
477    Parameters
478    ----------
479    x : array_like
480        Seasonal data to plot. If dates is None, x must be a pandas object
481        with a PeriodIndex or DatetimeIndex with a monthly frequency.
482    dates : array_like, optional
483        If `x` is not a pandas object, then dates must be supplied.
484    ylabel : str, optional
485        The label for the y-axis. Will attempt to use the `name` attribute
486        of the Series.
487    ax : matplotlib.axes, optional
488        Existing axes instance.
489
490    Returns
491    -------
492    Figure
493       If `ax` is provided, the Figure instance attached to `ax`. Otherwise
494       a new Figure instance.
495
496    Examples
497    --------
498    >>> import statsmodels.api as sm
499    >>> import pandas as pd
500
501    >>> dta = sm.datasets.elnino.load_pandas().data
502    >>> dta['YEAR'] = dta.YEAR.astype(int).astype(str)
503    >>> dta = dta.set_index('YEAR').T.unstack()
504    >>> dates = pd.to_datetime(list(map(lambda x: '-'.join(x) + '-1',
505    ...                                 dta.index.values)))
506    >>> dta.index = dates.to_period('Q')
507    >>> fig = sm.graphics.tsa.quarter_plot(dta)
508
509    .. plot:: plots/graphics_tsa_quarter_plot.py
510    """
511
512    if dates is None:
513        from statsmodels.tools.data import _check_period_index
514
515        _check_period_index(x, freq="Q")
516    else:
517        x = pd.Series(x, index=pd.PeriodIndex(dates, freq="Q"))
518
519    xticklabels = ["q1", "q2", "q3", "q4"]
520    return seasonal_plot(
521        x.groupby(lambda y: y.quarter), xticklabels, ylabel=ylabel, ax=ax
522    )
523
524
525def plot_predict(
526    result,
527    start=None,
528    end=None,
529    dynamic=False,
530    alpha=0.05,
531    ax=None,
532    **predict_kwargs,
533):
534    """
535
536    Parameters
537    ----------
538    result : Result
539        Any model result supporting ``get_prediction``.
540    start : int, str, or datetime, optional
541        Zero-indexed observation number at which to start forecasting,
542        i.e., the first forecast is start. Can also be a date string to
543        parse or a datetime type. Default is the the zeroth observation.
544    end : int, str, or datetime, optional
545        Zero-indexed observation number at which to end forecasting, i.e.,
546        the last forecast is end. Can also be a date string to
547        parse or a datetime type. However, if the dates index does not
548        have a fixed frequency, end must be an integer index if you
549        want out of sample prediction. Default is the last observation in
550        the sample.
551    dynamic : bool, int, str, or datetime, optional
552        Integer offset relative to `start` at which to begin dynamic
553        prediction. Can also be an absolute date string to parse or a
554        datetime type (these are not interpreted as offsets).
555        Prior to this observation, true endogenous values will be used for
556        prediction; starting with this observation and continuing through
557        the end of prediction, forecasted endogenous values will be used
558        instead.
559    alpha : {float, None}
560        The tail probability not covered by the confidence interval. Must
561        be in (0, 1). Confidence interval is constructed assuming normally
562        distributed shocks. If None, figure will not show the confidence
563        interval.
564    ax : AxesSubplot
565        matplotlib Axes instance to use
566    **predict_kwargs
567        Any additional keyword arguments to pass to ``result.get_prediction``.
568
569    Returns
570    -------
571    Figure
572        matplotlib Figure containing the prediction plot
573    """
574    from statsmodels.graphics.utils import _import_mpl, create_mpl_ax
575
576    _ = _import_mpl()
577    fig, ax = create_mpl_ax(ax)
578    from statsmodels.tsa.base.prediction import PredictionResults
579
580    # use predict so you set dates
581    pred: PredictionResults = result.get_prediction(
582        start=start, end=end, dynamic=dynamic, **predict_kwargs
583    )
584    mean = pred.predicted_mean
585    if isinstance(mean, (pd.Series, pd.DataFrame)):
586        x = mean.index
587        mean.plot(ax=ax, label="forecast")
588    else:
589        x = np.arange(mean.shape[0])
590        ax.plot(x, mean)
591
592    if alpha is not None:
593        label = f"{1-alpha:.0%} confidence interval"
594        ci = pred.conf_int(alpha)
595        conf_int = np.asarray(ci)
596
597        ax.fill_between(
598            x,
599            conf_int[:, 0],
600            conf_int[:, 1],
601            color="gray",
602            alpha=0.5,
603            label=label,
604        )
605
606    ax.legend(loc="best")
607
608    return fig
609