1import datetime
2import inspect
3from numbers import Integral
4
5import pandas as pd
6from pandas.api.types import is_datetime64_any_dtype
7from pandas.core.window import Rolling as pd_Rolling
8
9from ..base import tokenize
10from ..highlevelgraph import HighLevelGraph
11from ..utils import M, derived_from, funcname, has_keyword
12from . import methods
13from .core import _emulate
14from .utils import make_meta
15
16
17def overlap_chunk(
18    func, prev_part, current_part, next_part, before, after, args, kwargs
19):
20
21    msg = (
22        "Partition size is less than overlapping "
23        "window size. Try using ``df.repartition`` "
24        "to increase the partition size."
25    )
26
27    if prev_part is not None and isinstance(before, Integral):
28        if prev_part.shape[0] != before:
29            raise NotImplementedError(msg)
30
31    if next_part is not None and isinstance(after, Integral):
32        if next_part.shape[0] != after:
33            raise NotImplementedError(msg)
34
35    parts = [p for p in (prev_part, current_part, next_part) if p is not None]
36    combined = methods.concat(parts)
37    out = func(combined, *args, **kwargs)
38    if prev_part is None:
39        before = None
40    if isinstance(before, datetime.timedelta):
41        before = len(prev_part)
42
43    expansion = None
44    if combined.shape[0] != 0:
45        expansion = out.shape[0] // combined.shape[0]
46    if before and expansion:
47        before *= expansion
48    if next_part is None:
49        return out.iloc[before:]
50    if isinstance(after, datetime.timedelta):
51        after = len(next_part)
52    if after and expansion:
53        after *= expansion
54    return out.iloc[before:-after]
55
56
57def map_overlap(func, df, before, after, *args, **kwargs):
58    """Apply a function to each partition, sharing rows with adjacent partitions.
59
60    Parameters
61    ----------
62    func : function
63        Function applied to each partition.
64    df : dd.DataFrame, dd.Series
65    before : int or timedelta
66        The rows to prepend to partition ``i`` from the end of
67        partition ``i - 1``.
68    after : int or timedelta
69        The rows to append to partition ``i`` from the beginning
70        of partition ``i + 1``.
71    args, kwargs :
72        Arguments and keywords to pass to the function. The partition will
73        be the first argument, and these will be passed *after*.
74
75    See Also
76    --------
77    dd.DataFrame.map_overlap
78    """
79    if isinstance(before, datetime.timedelta) or isinstance(after, datetime.timedelta):
80        if not is_datetime64_any_dtype(df.index._meta_nonempty.inferred_type):
81            raise TypeError(
82                "Must have a `DatetimeIndex` when using string offset "
83                "for `before` and `after`"
84            )
85    else:
86        if not (
87            isinstance(before, Integral)
88            and before >= 0
89            and isinstance(after, Integral)
90            and after >= 0
91        ):
92            raise ValueError("before and after must be positive integers")
93
94    if "token" in kwargs:
95        func_name = kwargs.pop("token")
96        token = tokenize(df, before, after, *args, **kwargs)
97    else:
98        func_name = "overlap-" + funcname(func)
99        token = tokenize(func, df, before, after, *args, **kwargs)
100
101    if "meta" in kwargs:
102        meta = kwargs.pop("meta")
103    else:
104        meta = _emulate(func, df, *args, **kwargs)
105    meta = make_meta(meta, index=df._meta.index, parent_meta=df._meta)
106
107    name = f"{func_name}-{token}"
108    name_a = "overlap-prepend-" + tokenize(df, before)
109    name_b = "overlap-append-" + tokenize(df, after)
110    df_name = df._name
111
112    dsk = {}
113
114    timedelta_partition_message = (
115        "Partition size is less than specified window. "
116        "Try using ``df.repartition`` to increase the partition size"
117    )
118
119    if before and isinstance(before, Integral):
120
121        prevs = [None]
122        for i in range(df.npartitions - 1):
123            key = (name_a, i)
124            dsk[key] = (M.tail, (df_name, i), before)
125            prevs.append(key)
126
127    elif isinstance(before, datetime.timedelta):
128        # Assumes monotonic (increasing?) index
129        divs = pd.Series(df.divisions)
130        deltas = divs.diff().iloc[1:-1]
131
132        # In the first case window-size is larger than at least one partition, thus it is
133        # necessary to calculate how many partitions must be used for each rolling task.
134        # Otherwise, these calculations can be skipped (faster)
135
136        if (before > deltas).any():
137            pt_z = divs[0]
138            prevs = [None]
139            for i in range(df.npartitions - 1):
140                # Select all indexes of relevant partitions between the current partition and
141                # the partition with the highest division outside the rolling window (before)
142                pt_i = divs[i + 1]
143
144                # lower-bound the search to the first division
145                lb = max(pt_i - before, pt_z)
146
147                first, j = divs[i], i
148                while first > lb and j > 0:
149                    first = first - deltas[j]
150                    j = j - 1
151
152                key = (name_a, i)
153                dsk[key] = (
154                    _tail_timedelta,
155                    [(df_name, k) for k in range(j, i + 1)],
156                    (df_name, i + 1),
157                    before,
158                )
159                prevs.append(key)
160
161        else:
162            prevs = [None]
163            for i in range(df.npartitions - 1):
164                key = (name_a, i)
165                dsk[key] = (
166                    _tail_timedelta,
167                    [(df_name, i)],
168                    (df_name, i + 1),
169                    before,
170                )
171                prevs.append(key)
172    else:
173        prevs = [None] * df.npartitions
174
175    if after and isinstance(after, Integral):
176        nexts = []
177        for i in range(1, df.npartitions):
178            key = (name_b, i)
179            dsk[key] = (M.head, (df_name, i), after)
180            nexts.append(key)
181        nexts.append(None)
182    elif isinstance(after, datetime.timedelta):
183        # TODO: Do we have a use-case for this? Pandas doesn't allow negative rolling windows
184        deltas = pd.Series(df.divisions).diff().iloc[1:-1]
185        if (after > deltas).any():
186            raise ValueError(timedelta_partition_message)
187
188        nexts = []
189        for i in range(1, df.npartitions):
190            key = (name_b, i)
191            dsk[key] = (_head_timedelta, (df_name, i - 0), (df_name, i), after)
192            nexts.append(key)
193        nexts.append(None)
194    else:
195        nexts = [None] * df.npartitions
196
197    for i, (prev, current, next) in enumerate(zip(prevs, df.__dask_keys__(), nexts)):
198        dsk[(name, i)] = (
199            overlap_chunk,
200            func,
201            prev,
202            current,
203            next,
204            before,
205            after,
206            args,
207            kwargs,
208        )
209
210    graph = HighLevelGraph.from_collections(name, dsk, dependencies=[df])
211    return df._constructor(graph, name, meta, df.divisions)
212
213
214def _head_timedelta(current, next_, after):
215    """Return rows of ``next_`` whose index is before the last
216    observation in ``current`` + ``after``.
217
218    Parameters
219    ----------
220    current : DataFrame
221    next_ : DataFrame
222    after : timedelta
223
224    Returns
225    -------
226    overlapped : DataFrame
227    """
228    return next_[next_.index < (current.index.max() + after)]
229
230
231def _tail_timedelta(prevs, current, before):
232    """Return the concatenated rows of each dataframe in ``prevs`` whose
233    index is after the first observation in ``current`` - ``before``.
234
235    Parameters
236    ----------
237    current : DataFrame
238    prevs : list of DataFrame objects
239    before : timedelta
240
241    Returns
242    -------
243    overlapped : DataFrame
244    """
245    selected = methods.concat(
246        [prev[prev.index > (current.index.min() - before)] for prev in prevs]
247    )
248    return selected
249
250
251class Rolling:
252    """Provides rolling window calculations."""
253
254    def __init__(
255        self, obj, window=None, min_periods=None, center=False, win_type=None, axis=0
256    ):
257        self.obj = obj  # dataframe or series
258        self.window = window
259        self.min_periods = min_periods
260        self.center = center
261        self.axis = axis
262        self.win_type = win_type
263        # Allow pandas to raise if appropriate
264        obj._meta.rolling(**self._rolling_kwargs())
265        # Using .rolling(window='2s'), pandas will convert the
266        # offset str to a window in nanoseconds. But pandas doesn't
267        # accept the integer window with win_type='freq', so we store
268        # that information here.
269        # See https://github.com/pandas-dev/pandas/issues/15969
270        self._win_type = None if isinstance(self.window, int) else "freq"
271
272    def _rolling_kwargs(self):
273        return {
274            "window": self.window,
275            "min_periods": self.min_periods,
276            "center": self.center,
277            "win_type": self.win_type,
278            "axis": self.axis,
279        }
280
281    @property
282    def _has_single_partition(self):
283        """
284        Indicator for whether the object has a single partition (True)
285        or multiple (False).
286        """
287        return (
288            self.axis in (1, "columns")
289            or (isinstance(self.window, Integral) and self.window <= 1)
290            or self.obj.npartitions == 1
291        )
292
293    @staticmethod
294    def pandas_rolling_method(df, rolling_kwargs, name, *args, **kwargs):
295        rolling = df.rolling(**rolling_kwargs)
296        return getattr(rolling, name)(*args, **kwargs)
297
298    def _call_method(self, method_name, *args, **kwargs):
299        rolling_kwargs = self._rolling_kwargs()
300        meta = self.pandas_rolling_method(
301            self.obj._meta_nonempty, rolling_kwargs, method_name, *args, **kwargs
302        )
303
304        if self._has_single_partition:
305            # There's no overlap just use map_partitions
306            return self.obj.map_partitions(
307                self.pandas_rolling_method,
308                rolling_kwargs,
309                method_name,
310                *args,
311                token=method_name,
312                meta=meta,
313                **kwargs,
314            )
315        # Convert window to overlap
316        if self.center:
317            before = self.window // 2
318            after = self.window - before - 1
319        elif self._win_type == "freq":
320            before = pd.Timedelta(self.window)
321            after = 0
322        else:
323            before = self.window - 1
324            after = 0
325        return map_overlap(
326            self.pandas_rolling_method,
327            self.obj,
328            before,
329            after,
330            rolling_kwargs,
331            method_name,
332            *args,
333            token=method_name,
334            meta=meta,
335            **kwargs,
336        )
337
338    @derived_from(pd_Rolling)
339    def count(self):
340        return self._call_method("count")
341
342    @derived_from(pd_Rolling)
343    def cov(self):
344        return self._call_method("cov")
345
346    @derived_from(pd_Rolling)
347    def sum(self):
348        return self._call_method("sum")
349
350    @derived_from(pd_Rolling)
351    def mean(self):
352        return self._call_method("mean")
353
354    @derived_from(pd_Rolling)
355    def median(self):
356        return self._call_method("median")
357
358    @derived_from(pd_Rolling)
359    def min(self):
360        return self._call_method("min")
361
362    @derived_from(pd_Rolling)
363    def max(self):
364        return self._call_method("max")
365
366    @derived_from(pd_Rolling)
367    def std(self, ddof=1):
368        return self._call_method("std", ddof=1)
369
370    @derived_from(pd_Rolling)
371    def var(self, ddof=1):
372        return self._call_method("var", ddof=1)
373
374    @derived_from(pd_Rolling)
375    def skew(self):
376        return self._call_method("skew")
377
378    @derived_from(pd_Rolling)
379    def kurt(self):
380        return self._call_method("kurt")
381
382    @derived_from(pd_Rolling)
383    def quantile(self, quantile):
384        return self._call_method("quantile", quantile)
385
386    @derived_from(pd_Rolling)
387    def apply(
388        self,
389        func,
390        raw=None,
391        engine="cython",
392        engine_kwargs=None,
393        args=None,
394        kwargs=None,
395    ):
396        compat_kwargs = {}
397        kwargs = kwargs or {}
398        args = args or ()
399        meta = self.obj._meta.rolling(0)
400        if has_keyword(meta.apply, "engine"):
401            # PANDAS_GT_100
402            compat_kwargs = dict(engine=engine, engine_kwargs=engine_kwargs)
403        if raw is None:
404            # PANDAS_GT_100: The default changed from None to False
405            raw = inspect.signature(meta.apply).parameters["raw"]
406
407        return self._call_method(
408            "apply", func, raw=raw, args=args, kwargs=kwargs, **compat_kwargs
409        )
410
411    @derived_from(pd_Rolling)
412    def aggregate(self, func, args=(), kwargs={}, **kwds):
413        return self._call_method("agg", func, args=args, kwargs=kwargs, **kwds)
414
415    agg = aggregate
416
417    def __repr__(self):
418        def order(item):
419            k, v = item
420            _order = {
421                "window": 0,
422                "min_periods": 1,
423                "center": 2,
424                "win_type": 3,
425                "axis": 4,
426            }
427            return _order[k]
428
429        rolling_kwargs = self._rolling_kwargs()
430        rolling_kwargs["window"] = self.window
431        rolling_kwargs["win_type"] = self._win_type
432        return "Rolling [{}]".format(
433            ",".join(
434                f"{k}={v}"
435                for k, v in sorted(rolling_kwargs.items(), key=order)
436                if v is not None
437            )
438        )
439
440
441class RollingGroupby(Rolling):
442    def __init__(
443        self,
444        groupby,
445        window=None,
446        min_periods=None,
447        center=False,
448        win_type=None,
449        axis=0,
450    ):
451        self._groupby_kwargs = groupby._groupby_kwargs
452        self._groupby_slice = groupby._slice
453
454        obj = groupby.obj
455        if self._groupby_slice is not None:
456            if isinstance(self._groupby_slice, str):
457                sliced_plus = [self._groupby_slice]
458            else:
459                sliced_plus = list(self._groupby_slice)
460            if isinstance(groupby.index, str):
461                sliced_plus.append(groupby.index)
462            else:
463                sliced_plus.extend(groupby.index)
464            obj = obj[sliced_plus]
465
466        super().__init__(
467            obj,
468            window=window,
469            min_periods=min_periods,
470            center=center,
471            win_type=win_type,
472            axis=axis,
473        )
474
475    @staticmethod
476    def pandas_rolling_method(
477        df,
478        rolling_kwargs,
479        name,
480        *args,
481        groupby_kwargs=None,
482        groupby_slice=None,
483        **kwargs,
484    ):
485        groupby = df.groupby(**groupby_kwargs)
486        if groupby_slice:
487            groupby = groupby[groupby_slice]
488        rolling = groupby.rolling(**rolling_kwargs)
489        return getattr(rolling, name)(*args, **kwargs).sort_index(level=-1)
490
491    def _call_method(self, method_name, *args, **kwargs):
492        return super()._call_method(
493            method_name,
494            *args,
495            groupby_kwargs=self._groupby_kwargs,
496            groupby_slice=self._groupby_slice,
497            **kwargs,
498        )
499