1from functools import wraps
2
3from statsmodels.tools.data import _is_using_pandas
4from statsmodels.tsa.tsatools import freq_to_period
5
6
7def _get_pandas_wrapper(X, trim_head=None, trim_tail=None, names=None):
8    index = X.index
9    #TODO: allow use index labels
10    if trim_head is None and trim_tail is None:
11        index = index
12    elif trim_tail is None:
13        index = index[trim_head:]
14    elif trim_head is None:
15        index = index[:-trim_tail]
16    else:
17        index = index[trim_head:-trim_tail]
18    if hasattr(X, "columns"):
19        if names is None:
20            names = X.columns
21        return lambda x : X.__class__(x, index=index, columns=names)
22    else:
23        if names is None:
24            names = X.name
25        return lambda x : X.__class__(x, index=index, name=names)
26
27
28def pandas_wrapper(func, trim_head=None, trim_tail=None, names=None, *args,
29                   **kwargs):
30    @wraps(func)
31    def new_func(X, *args, **kwargs):
32        # quick pass-through for do nothing case
33        if not _is_using_pandas(X, None):
34            return func(X, *args, **kwargs)
35
36        wrapper_func = _get_pandas_wrapper(X, trim_head, trim_tail,
37                                           names)
38        ret = func(X, *args, **kwargs)
39        ret = wrapper_func(ret)
40        return ret
41
42    return new_func
43
44
45def pandas_wrapper_bunch(func, trim_head=None, trim_tail=None,
46                         names=None, *args, **kwargs):
47    @wraps(func)
48    def new_func(X, *args, **kwargs):
49        # quick pass-through for do nothing case
50        if not _is_using_pandas(X, None):
51            return func(X, *args, **kwargs)
52
53        wrapper_func = _get_pandas_wrapper(X, trim_head, trim_tail,
54                                           names)
55        ret = func(X, *args, **kwargs)
56        ret = wrapper_func(ret)
57        return ret
58
59    return new_func
60
61
62def pandas_wrapper_predict(func, trim_head=None, trim_tail=None,
63                           columns=None, *args, **kwargs):
64    raise NotImplementedError
65
66
67def pandas_wrapper_freq(func, trim_head=None, trim_tail=None,
68                        freq_kw='freq', columns=None, *args, **kwargs):
69    """
70    Return a new function that catches the incoming X, checks if it's pandas,
71    calls the functions as is. Then wraps the results in the incoming index.
72
73    Deals with frequencies. Expects that the function returns a tuple,
74    a Bunch object, or a pandas-object.
75    """
76
77    @wraps(func)
78    def new_func(X, *args, **kwargs):
79        # quick pass-through for do nothing case
80        if not _is_using_pandas(X, None):
81            return func(X, *args, **kwargs)
82
83        wrapper_func = _get_pandas_wrapper(X, trim_head, trim_tail,
84                                           columns)
85        index = X.index
86        freq = index.inferred_freq
87        kwargs.update({freq_kw : freq_to_period(freq)})
88        ret = func(X, *args, **kwargs)
89        ret = wrapper_func(ret)
90        return ret
91
92    return new_func
93