1import numpy as np
2import pytest
3
4from pandas import DataFrame, Series
5import pandas._testing as tm
6
7
8@pytest.mark.parametrize("name", ["var", "vol", "mean"])
9def test_ewma_series(series, name):
10    series_result = getattr(series.ewm(com=10), name)()
11    assert isinstance(series_result, Series)
12
13
14@pytest.mark.parametrize("name", ["var", "vol", "mean"])
15def test_ewma_frame(frame, name):
16    frame_result = getattr(frame.ewm(com=10), name)()
17    assert isinstance(frame_result, DataFrame)
18
19
20def test_ewma_adjust():
21    vals = Series(np.zeros(1000))
22    vals[5] = 1
23    result = vals.ewm(span=100, adjust=False).mean().sum()
24    assert np.abs(result - 1) < 1e-2
25
26
27@pytest.mark.parametrize("adjust", [True, False])
28@pytest.mark.parametrize("ignore_na", [True, False])
29def test_ewma_cases(adjust, ignore_na):
30    # try adjust/ignore_na args matrix
31
32    s = Series([1.0, 2.0, 4.0, 8.0])
33
34    if adjust:
35        expected = Series([1.0, 1.6, 2.736842, 4.923077])
36    else:
37        expected = Series([1.0, 1.333333, 2.222222, 4.148148])
38
39    result = s.ewm(com=2.0, adjust=adjust, ignore_na=ignore_na).mean()
40    tm.assert_series_equal(result, expected)
41
42
43def test_ewma_nan_handling():
44    s = Series([1.0] + [np.nan] * 5 + [1.0])
45    result = s.ewm(com=5).mean()
46    tm.assert_series_equal(result, Series([1.0] * len(s)))
47
48    s = Series([np.nan] * 2 + [1.0] + [np.nan] * 2 + [1.0])
49    result = s.ewm(com=5).mean()
50    tm.assert_series_equal(result, Series([np.nan] * 2 + [1.0] * 4))
51
52
53@pytest.mark.parametrize(
54    "s, adjust, ignore_na, w",
55    [
56        (
57            Series([np.nan, 1.0, 101.0]),
58            True,
59            False,
60            [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0],
61        ),
62        (
63            Series([np.nan, 1.0, 101.0]),
64            True,
65            True,
66            [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0],
67        ),
68        (
69            Series([np.nan, 1.0, 101.0]),
70            False,
71            False,
72            [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))],
73        ),
74        (
75            Series([np.nan, 1.0, 101.0]),
76            False,
77            True,
78            [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))],
79        ),
80        (
81            Series([1.0, np.nan, 101.0]),
82            True,
83            False,
84            [(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, 1.0],
85        ),
86        (
87            Series([1.0, np.nan, 101.0]),
88            True,
89            True,
90            [(1.0 - (1.0 / (1.0 + 2.0))), np.nan, 1.0],
91        ),
92        (
93            Series([1.0, np.nan, 101.0]),
94            False,
95            False,
96            [(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, (1.0 / (1.0 + 2.0))],
97        ),
98        (
99            Series([1.0, np.nan, 101.0]),
100            False,
101            True,
102            [(1.0 - (1.0 / (1.0 + 2.0))), np.nan, (1.0 / (1.0 + 2.0))],
103        ),
104        (
105            Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
106            True,
107            False,
108            [np.nan, (1.0 - (1.0 / (1.0 + 2.0))) ** 3, np.nan, np.nan, 1.0, np.nan],
109        ),
110        (
111            Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
112            True,
113            True,
114            [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), np.nan, np.nan, 1.0, np.nan],
115        ),
116        (
117            Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
118            False,
119            False,
120            [
121                np.nan,
122                (1.0 - (1.0 / (1.0 + 2.0))) ** 3,
123                np.nan,
124                np.nan,
125                (1.0 / (1.0 + 2.0)),
126                np.nan,
127            ],
128        ),
129        (
130            Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
131            False,
132            True,
133            [
134                np.nan,
135                (1.0 - (1.0 / (1.0 + 2.0))),
136                np.nan,
137                np.nan,
138                (1.0 / (1.0 + 2.0)),
139                np.nan,
140            ],
141        ),
142        (
143            Series([1.0, np.nan, 101.0, 50.0]),
144            True,
145            False,
146            [
147                (1.0 - (1.0 / (1.0 + 2.0))) ** 3,
148                np.nan,
149                (1.0 - (1.0 / (1.0 + 2.0))),
150                1.0,
151            ],
152        ),
153        (
154            Series([1.0, np.nan, 101.0, 50.0]),
155            True,
156            True,
157            [
158                (1.0 - (1.0 / (1.0 + 2.0))) ** 2,
159                np.nan,
160                (1.0 - (1.0 / (1.0 + 2.0))),
161                1.0,
162            ],
163        ),
164        (
165            Series([1.0, np.nan, 101.0, 50.0]),
166            False,
167            False,
168            [
169                (1.0 - (1.0 / (1.0 + 2.0))) ** 3,
170                np.nan,
171                (1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)),
172                (1.0 / (1.0 + 2.0))
173                * ((1.0 - (1.0 / (1.0 + 2.0))) ** 2 + (1.0 / (1.0 + 2.0))),
174            ],
175        ),
176        (
177            Series([1.0, np.nan, 101.0, 50.0]),
178            False,
179            True,
180            [
181                (1.0 - (1.0 / (1.0 + 2.0))) ** 2,
182                np.nan,
183                (1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)),
184                (1.0 / (1.0 + 2.0)),
185            ],
186        ),
187    ],
188)
189def test_ewma_nan_handling_cases(s, adjust, ignore_na, w):
190    # GH 7603
191    expected = (s.multiply(w).cumsum() / Series(w).cumsum()).fillna(method="ffill")
192    result = s.ewm(com=2.0, adjust=adjust, ignore_na=ignore_na).mean()
193
194    tm.assert_series_equal(result, expected)
195    if ignore_na is False:
196        # check that ignore_na defaults to False
197        result = s.ewm(com=2.0, adjust=adjust).mean()
198        tm.assert_series_equal(result, expected)
199
200
201def test_ewma_span_com_args(series):
202    A = series.ewm(com=9.5).mean()
203    B = series.ewm(span=20).mean()
204    tm.assert_almost_equal(A, B)
205    msg = "comass, span, halflife, and alpha are mutually exclusive"
206    with pytest.raises(ValueError, match=msg):
207        series.ewm(com=9.5, span=20)
208
209    msg = "Must pass one of comass, span, halflife, or alpha"
210    with pytest.raises(ValueError, match=msg):
211        series.ewm().mean()
212
213
214def test_ewma_halflife_arg(series):
215    A = series.ewm(com=13.932726172912965).mean()
216    B = series.ewm(halflife=10.0).mean()
217    tm.assert_almost_equal(A, B)
218    msg = "comass, span, halflife, and alpha are mutually exclusive"
219    with pytest.raises(ValueError, match=msg):
220        series.ewm(span=20, halflife=50)
221    with pytest.raises(ValueError):
222        series.ewm(com=9.5, halflife=50)
223    with pytest.raises(ValueError):
224        series.ewm(com=9.5, span=20, halflife=50)
225    with pytest.raises(ValueError):
226        series.ewm()
227
228
229def test_ewm_alpha():
230    # GH 10789
231    arr = np.random.randn(100)
232    locs = np.arange(20, 40)
233    arr[locs] = np.NaN
234
235    s = Series(arr)
236    a = s.ewm(alpha=0.61722699889169674).mean()
237    b = s.ewm(com=0.62014947789973052).mean()
238    c = s.ewm(span=2.240298955799461).mean()
239    d = s.ewm(halflife=0.721792864318).mean()
240    tm.assert_series_equal(a, b)
241    tm.assert_series_equal(a, c)
242    tm.assert_series_equal(a, d)
243
244
245def test_ewm_alpha_arg(series):
246    # GH 10789
247    s = series
248    msg = "Must pass one of comass, span, halflife, or alpha"
249    with pytest.raises(ValueError, match=msg):
250        s.ewm()
251
252    msg = "comass, span, halflife, and alpha are mutually exclusive"
253    with pytest.raises(ValueError, match=msg):
254        s.ewm(com=10.0, alpha=0.5)
255    with pytest.raises(ValueError, match=msg):
256        s.ewm(span=10.0, alpha=0.5)
257    with pytest.raises(ValueError, match=msg):
258        s.ewm(halflife=10.0, alpha=0.5)
259
260
261def test_ewm_domain_checks():
262    # GH 12492
263    arr = np.random.randn(100)
264    locs = np.arange(20, 40)
265    arr[locs] = np.NaN
266
267    s = Series(arr)
268    msg = "comass must satisfy: comass >= 0"
269    with pytest.raises(ValueError, match=msg):
270        s.ewm(com=-0.1)
271    s.ewm(com=0.0)
272    s.ewm(com=0.1)
273
274    msg = "span must satisfy: span >= 1"
275    with pytest.raises(ValueError, match=msg):
276        s.ewm(span=-0.1)
277    with pytest.raises(ValueError, match=msg):
278        s.ewm(span=0.0)
279    with pytest.raises(ValueError, match=msg):
280        s.ewm(span=0.9)
281    s.ewm(span=1.0)
282    s.ewm(span=1.1)
283
284    msg = "halflife must satisfy: halflife > 0"
285    with pytest.raises(ValueError, match=msg):
286        s.ewm(halflife=-0.1)
287    with pytest.raises(ValueError, match=msg):
288        s.ewm(halflife=0.0)
289    s.ewm(halflife=0.1)
290
291    msg = "alpha must satisfy: 0 < alpha <= 1"
292    with pytest.raises(ValueError, match=msg):
293        s.ewm(alpha=-0.1)
294    with pytest.raises(ValueError, match=msg):
295        s.ewm(alpha=0.0)
296    s.ewm(alpha=0.1)
297    s.ewm(alpha=1.0)
298    with pytest.raises(ValueError, match=msg):
299        s.ewm(alpha=1.1)
300
301
302@pytest.mark.parametrize("method", ["mean", "vol", "var"])
303def test_ew_empty_series(method):
304    vals = Series([], dtype=np.float64)
305
306    ewm = vals.ewm(3)
307    result = getattr(ewm, method)()
308    tm.assert_almost_equal(result, vals)
309
310
311@pytest.mark.parametrize("min_periods", [0, 1])
312@pytest.mark.parametrize("name", ["mean", "var", "vol"])
313def test_ew_min_periods(min_periods, name):
314    # excluding NaNs correctly
315    arr = np.random.randn(50)
316    arr[:10] = np.NaN
317    arr[-10:] = np.NaN
318    s = Series(arr)
319
320    # check min_periods
321    # GH 7898
322    result = getattr(s.ewm(com=50, min_periods=2), name)()
323    assert result[:11].isna().all()
324    assert not result[11:].isna().any()
325
326    result = getattr(s.ewm(com=50, min_periods=min_periods), name)()
327    if name == "mean":
328        assert result[:10].isna().all()
329        assert not result[10:].isna().any()
330    else:
331        # ewm.std, ewm.vol, ewm.var (with bias=False) require at least
332        # two values
333        assert result[:11].isna().all()
334        assert not result[11:].isna().any()
335
336    # check series of length 0
337    result = getattr(Series(dtype=object).ewm(com=50, min_periods=min_periods), name)()
338    tm.assert_series_equal(result, Series(dtype="float64"))
339
340    # check series of length 1
341    result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)()
342    if name == "mean":
343        tm.assert_series_equal(result, Series([1.0]))
344    else:
345        # ewm.std, ewm.vol, ewm.var with bias=False require at least
346        # two values
347        tm.assert_series_equal(result, Series([np.NaN]))
348
349    # pass in ints
350    result2 = getattr(Series(np.arange(50)).ewm(span=10), name)()
351    assert result2.dtype == np.float_
352