1import numpy as np 2import pytest 3 4from pandas import DataFrame, Series 5import pandas._testing as tm 6 7 8@pytest.mark.parametrize("name", ["var", "vol", "mean"]) 9def test_ewma_series(series, name): 10 series_result = getattr(series.ewm(com=10), name)() 11 assert isinstance(series_result, Series) 12 13 14@pytest.mark.parametrize("name", ["var", "vol", "mean"]) 15def test_ewma_frame(frame, name): 16 frame_result = getattr(frame.ewm(com=10), name)() 17 assert isinstance(frame_result, DataFrame) 18 19 20def test_ewma_adjust(): 21 vals = Series(np.zeros(1000)) 22 vals[5] = 1 23 result = vals.ewm(span=100, adjust=False).mean().sum() 24 assert np.abs(result - 1) < 1e-2 25 26 27@pytest.mark.parametrize("adjust", [True, False]) 28@pytest.mark.parametrize("ignore_na", [True, False]) 29def test_ewma_cases(adjust, ignore_na): 30 # try adjust/ignore_na args matrix 31 32 s = Series([1.0, 2.0, 4.0, 8.0]) 33 34 if adjust: 35 expected = Series([1.0, 1.6, 2.736842, 4.923077]) 36 else: 37 expected = Series([1.0, 1.333333, 2.222222, 4.148148]) 38 39 result = s.ewm(com=2.0, adjust=adjust, ignore_na=ignore_na).mean() 40 tm.assert_series_equal(result, expected) 41 42 43def test_ewma_nan_handling(): 44 s = Series([1.0] + [np.nan] * 5 + [1.0]) 45 result = s.ewm(com=5).mean() 46 tm.assert_series_equal(result, Series([1.0] * len(s))) 47 48 s = Series([np.nan] * 2 + [1.0] + [np.nan] * 2 + [1.0]) 49 result = s.ewm(com=5).mean() 50 tm.assert_series_equal(result, Series([np.nan] * 2 + [1.0] * 4)) 51 52 53@pytest.mark.parametrize( 54 "s, adjust, ignore_na, w", 55 [ 56 ( 57 Series([np.nan, 1.0, 101.0]), 58 True, 59 False, 60 [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0], 61 ), 62 ( 63 Series([np.nan, 1.0, 101.0]), 64 True, 65 True, 66 [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0], 67 ), 68 ( 69 Series([np.nan, 1.0, 101.0]), 70 False, 71 False, 72 [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))], 73 ), 74 ( 75 Series([np.nan, 1.0, 101.0]), 76 False, 77 True, 78 [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))], 79 ), 80 ( 81 Series([1.0, np.nan, 101.0]), 82 True, 83 False, 84 [(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, 1.0], 85 ), 86 ( 87 Series([1.0, np.nan, 101.0]), 88 True, 89 True, 90 [(1.0 - (1.0 / (1.0 + 2.0))), np.nan, 1.0], 91 ), 92 ( 93 Series([1.0, np.nan, 101.0]), 94 False, 95 False, 96 [(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, (1.0 / (1.0 + 2.0))], 97 ), 98 ( 99 Series([1.0, np.nan, 101.0]), 100 False, 101 True, 102 [(1.0 - (1.0 / (1.0 + 2.0))), np.nan, (1.0 / (1.0 + 2.0))], 103 ), 104 ( 105 Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]), 106 True, 107 False, 108 [np.nan, (1.0 - (1.0 / (1.0 + 2.0))) ** 3, np.nan, np.nan, 1.0, np.nan], 109 ), 110 ( 111 Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]), 112 True, 113 True, 114 [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), np.nan, np.nan, 1.0, np.nan], 115 ), 116 ( 117 Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]), 118 False, 119 False, 120 [ 121 np.nan, 122 (1.0 - (1.0 / (1.0 + 2.0))) ** 3, 123 np.nan, 124 np.nan, 125 (1.0 / (1.0 + 2.0)), 126 np.nan, 127 ], 128 ), 129 ( 130 Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]), 131 False, 132 True, 133 [ 134 np.nan, 135 (1.0 - (1.0 / (1.0 + 2.0))), 136 np.nan, 137 np.nan, 138 (1.0 / (1.0 + 2.0)), 139 np.nan, 140 ], 141 ), 142 ( 143 Series([1.0, np.nan, 101.0, 50.0]), 144 True, 145 False, 146 [ 147 (1.0 - (1.0 / (1.0 + 2.0))) ** 3, 148 np.nan, 149 (1.0 - (1.0 / (1.0 + 2.0))), 150 1.0, 151 ], 152 ), 153 ( 154 Series([1.0, np.nan, 101.0, 50.0]), 155 True, 156 True, 157 [ 158 (1.0 - (1.0 / (1.0 + 2.0))) ** 2, 159 np.nan, 160 (1.0 - (1.0 / (1.0 + 2.0))), 161 1.0, 162 ], 163 ), 164 ( 165 Series([1.0, np.nan, 101.0, 50.0]), 166 False, 167 False, 168 [ 169 (1.0 - (1.0 / (1.0 + 2.0))) ** 3, 170 np.nan, 171 (1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)), 172 (1.0 / (1.0 + 2.0)) 173 * ((1.0 - (1.0 / (1.0 + 2.0))) ** 2 + (1.0 / (1.0 + 2.0))), 174 ], 175 ), 176 ( 177 Series([1.0, np.nan, 101.0, 50.0]), 178 False, 179 True, 180 [ 181 (1.0 - (1.0 / (1.0 + 2.0))) ** 2, 182 np.nan, 183 (1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)), 184 (1.0 / (1.0 + 2.0)), 185 ], 186 ), 187 ], 188) 189def test_ewma_nan_handling_cases(s, adjust, ignore_na, w): 190 # GH 7603 191 expected = (s.multiply(w).cumsum() / Series(w).cumsum()).fillna(method="ffill") 192 result = s.ewm(com=2.0, adjust=adjust, ignore_na=ignore_na).mean() 193 194 tm.assert_series_equal(result, expected) 195 if ignore_na is False: 196 # check that ignore_na defaults to False 197 result = s.ewm(com=2.0, adjust=adjust).mean() 198 tm.assert_series_equal(result, expected) 199 200 201def test_ewma_span_com_args(series): 202 A = series.ewm(com=9.5).mean() 203 B = series.ewm(span=20).mean() 204 tm.assert_almost_equal(A, B) 205 msg = "comass, span, halflife, and alpha are mutually exclusive" 206 with pytest.raises(ValueError, match=msg): 207 series.ewm(com=9.5, span=20) 208 209 msg = "Must pass one of comass, span, halflife, or alpha" 210 with pytest.raises(ValueError, match=msg): 211 series.ewm().mean() 212 213 214def test_ewma_halflife_arg(series): 215 A = series.ewm(com=13.932726172912965).mean() 216 B = series.ewm(halflife=10.0).mean() 217 tm.assert_almost_equal(A, B) 218 msg = "comass, span, halflife, and alpha are mutually exclusive" 219 with pytest.raises(ValueError, match=msg): 220 series.ewm(span=20, halflife=50) 221 with pytest.raises(ValueError): 222 series.ewm(com=9.5, halflife=50) 223 with pytest.raises(ValueError): 224 series.ewm(com=9.5, span=20, halflife=50) 225 with pytest.raises(ValueError): 226 series.ewm() 227 228 229def test_ewm_alpha(): 230 # GH 10789 231 arr = np.random.randn(100) 232 locs = np.arange(20, 40) 233 arr[locs] = np.NaN 234 235 s = Series(arr) 236 a = s.ewm(alpha=0.61722699889169674).mean() 237 b = s.ewm(com=0.62014947789973052).mean() 238 c = s.ewm(span=2.240298955799461).mean() 239 d = s.ewm(halflife=0.721792864318).mean() 240 tm.assert_series_equal(a, b) 241 tm.assert_series_equal(a, c) 242 tm.assert_series_equal(a, d) 243 244 245def test_ewm_alpha_arg(series): 246 # GH 10789 247 s = series 248 msg = "Must pass one of comass, span, halflife, or alpha" 249 with pytest.raises(ValueError, match=msg): 250 s.ewm() 251 252 msg = "comass, span, halflife, and alpha are mutually exclusive" 253 with pytest.raises(ValueError, match=msg): 254 s.ewm(com=10.0, alpha=0.5) 255 with pytest.raises(ValueError, match=msg): 256 s.ewm(span=10.0, alpha=0.5) 257 with pytest.raises(ValueError, match=msg): 258 s.ewm(halflife=10.0, alpha=0.5) 259 260 261def test_ewm_domain_checks(): 262 # GH 12492 263 arr = np.random.randn(100) 264 locs = np.arange(20, 40) 265 arr[locs] = np.NaN 266 267 s = Series(arr) 268 msg = "comass must satisfy: comass >= 0" 269 with pytest.raises(ValueError, match=msg): 270 s.ewm(com=-0.1) 271 s.ewm(com=0.0) 272 s.ewm(com=0.1) 273 274 msg = "span must satisfy: span >= 1" 275 with pytest.raises(ValueError, match=msg): 276 s.ewm(span=-0.1) 277 with pytest.raises(ValueError, match=msg): 278 s.ewm(span=0.0) 279 with pytest.raises(ValueError, match=msg): 280 s.ewm(span=0.9) 281 s.ewm(span=1.0) 282 s.ewm(span=1.1) 283 284 msg = "halflife must satisfy: halflife > 0" 285 with pytest.raises(ValueError, match=msg): 286 s.ewm(halflife=-0.1) 287 with pytest.raises(ValueError, match=msg): 288 s.ewm(halflife=0.0) 289 s.ewm(halflife=0.1) 290 291 msg = "alpha must satisfy: 0 < alpha <= 1" 292 with pytest.raises(ValueError, match=msg): 293 s.ewm(alpha=-0.1) 294 with pytest.raises(ValueError, match=msg): 295 s.ewm(alpha=0.0) 296 s.ewm(alpha=0.1) 297 s.ewm(alpha=1.0) 298 with pytest.raises(ValueError, match=msg): 299 s.ewm(alpha=1.1) 300 301 302@pytest.mark.parametrize("method", ["mean", "vol", "var"]) 303def test_ew_empty_series(method): 304 vals = Series([], dtype=np.float64) 305 306 ewm = vals.ewm(3) 307 result = getattr(ewm, method)() 308 tm.assert_almost_equal(result, vals) 309 310 311@pytest.mark.parametrize("min_periods", [0, 1]) 312@pytest.mark.parametrize("name", ["mean", "var", "vol"]) 313def test_ew_min_periods(min_periods, name): 314 # excluding NaNs correctly 315 arr = np.random.randn(50) 316 arr[:10] = np.NaN 317 arr[-10:] = np.NaN 318 s = Series(arr) 319 320 # check min_periods 321 # GH 7898 322 result = getattr(s.ewm(com=50, min_periods=2), name)() 323 assert result[:11].isna().all() 324 assert not result[11:].isna().any() 325 326 result = getattr(s.ewm(com=50, min_periods=min_periods), name)() 327 if name == "mean": 328 assert result[:10].isna().all() 329 assert not result[10:].isna().any() 330 else: 331 # ewm.std, ewm.vol, ewm.var (with bias=False) require at least 332 # two values 333 assert result[:11].isna().all() 334 assert not result[11:].isna().any() 335 336 # check series of length 0 337 result = getattr(Series(dtype=object).ewm(com=50, min_periods=min_periods), name)() 338 tm.assert_series_equal(result, Series(dtype="float64")) 339 340 # check series of length 1 341 result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)() 342 if name == "mean": 343 tm.assert_series_equal(result, Series([1.0])) 344 else: 345 # ewm.std, ewm.vol, ewm.var with bias=False require at least 346 # two values 347 tm.assert_series_equal(result, Series([np.NaN])) 348 349 # pass in ints 350 result2 = getattr(Series(np.arange(50)).ewm(span=10), name)() 351 assert result2.dtype == np.float_ 352