1# ---------------------------------------------------------------------------- 2# Copyright (c) 2013--, scikit-bio development team. 3# 4# Distributed under the terms of the Modified BSD License. 5# 6# The full license is in the file COPYING.txt, distributed with this software. 7# ---------------------------------------------------------------------------- 8 9import inspect 10import os 11import sys 12 13import numpy as np 14import numpy.testing as npt 15import pandas.testing as pdt 16 17from ._decorator import experimental 18 19 20class ReallyEqualMixin: 21 """Use this for testing __eq__/__ne__. 22 23 Taken and modified from the following public domain code: 24 https://ludios.org/testing-your-eq-ne-cmp/ 25 26 """ 27 28 def assertReallyEqual(self, a, b): 29 # assertEqual first, because it will have a good message if the 30 # assertion fails. 31 self.assertEqual(a, b) 32 self.assertEqual(b, a) 33 self.assertTrue(a == b) 34 self.assertTrue(b == a) 35 self.assertFalse(a != b) 36 self.assertFalse(b != a) 37 38 def assertReallyNotEqual(self, a, b): 39 # assertNotEqual first, because it will have a good message if the 40 # assertion fails. 41 self.assertNotEqual(a, b) 42 self.assertNotEqual(b, a) 43 self.assertFalse(a == b) 44 self.assertFalse(b == a) 45 self.assertTrue(a != b) 46 self.assertTrue(b != a) 47 48 49@experimental(as_of="0.4.0") 50def get_data_path(fn, subfolder='data'): 51 """Return path to filename ``fn`` in the data folder. 52 53 During testing it is often necessary to load data files. This 54 function returns the full path to files in the ``data`` subfolder 55 by default. 56 57 Parameters 58 ---------- 59 fn : str 60 File name. 61 62 subfolder : str, defaults to ``data`` 63 Name of the subfolder that contains the data. 64 65 66 Returns 67 ------- 68 str 69 Inferred absolute path to the test data for the module where 70 ``get_data_path(fn)`` is called. 71 72 Notes 73 ----- 74 The requested path may not point to an existing file, as its 75 existence is not checked. 76 77 """ 78 # getouterframes returns a list of tuples: the second tuple 79 # contains info about the caller, and the second element is its 80 # filename 81 callers_filename = inspect.getouterframes(inspect.currentframe())[1][1] 82 path = os.path.dirname(os.path.abspath(callers_filename)) 83 data_path = os.path.join(path, subfolder, fn) 84 return data_path 85 86 87@experimental(as_of="0.4.0") 88def assert_ordination_results_equal(left, right, ignore_method_names=False, 89 ignore_axis_labels=False, 90 ignore_directionality=False, 91 decimal=7): 92 """Assert that ordination results objects are equal. 93 94 This is a helper function intended to be used in unit tests that need to 95 compare ``OrdinationResults`` objects. 96 97 Parameters 98 ---------- 99 left, right : OrdinationResults 100 Ordination results to be compared for equality. 101 ignore_method_names : bool, optional 102 Ignore differences in `short_method_name` and `long_method_name`. 103 ignore_axis_labels : bool, optional 104 Ignore differences in axis labels (i.e., column labels). 105 ignore_directionality : bool, optional 106 Ignore differences in directionality (i.e., differences in signs) for 107 attributes `samples`, `features` and `biplot_scores`. 108 109 Raises 110 ------ 111 AssertionError 112 If the two objects are not equal. 113 114 """ 115 npt.assert_equal(type(left) is type(right), True) 116 117 if not ignore_method_names: 118 npt.assert_equal(left.short_method_name, right.short_method_name) 119 npt.assert_equal(left.long_method_name, right.long_method_name) 120 121 _assert_frame_equal(left.samples, right.samples, 122 ignore_columns=ignore_axis_labels, 123 ignore_directionality=ignore_directionality, 124 decimal=decimal) 125 126 _assert_frame_equal(left.features, right.features, 127 ignore_columns=ignore_axis_labels, 128 ignore_directionality=ignore_directionality, 129 decimal=decimal) 130 _assert_frame_equal(left.biplot_scores, right.biplot_scores, 131 ignore_columns=ignore_axis_labels, 132 ignore_directionality=ignore_directionality, 133 decimal=decimal) 134 135 _assert_frame_equal(left.sample_constraints, right.sample_constraints, 136 ignore_columns=ignore_axis_labels, 137 ignore_directionality=ignore_directionality, 138 decimal=decimal) 139 140 _assert_series_equal(left.eigvals, right.eigvals, ignore_axis_labels, 141 decimal=decimal) 142 143 _assert_series_equal(left.proportion_explained, right.proportion_explained, 144 ignore_axis_labels, 145 decimal=decimal) 146 147 148def _assert_series_equal(left_s, right_s, ignore_index=False, decimal=7): 149 # assert_series_equal doesn't like None... 150 if left_s is None or right_s is None: 151 assert left_s is None and right_s is None 152 else: 153 npt.assert_almost_equal(left_s.values, right_s.values, 154 decimal=decimal) 155 if not ignore_index: 156 pdt.assert_index_equal(left_s.index, right_s.index) 157 158 159def _assert_frame_equal(left_df, right_df, ignore_index=False, 160 ignore_columns=False, ignore_directionality=False, 161 decimal=7): 162 # assert_frame_equal doesn't like None... 163 if left_df is None or right_df is None: 164 assert left_df is None and right_df is None 165 else: 166 left_values = left_df.values 167 right_values = right_df.values 168 if ignore_directionality: 169 left_values, right_values = _normalize_signs(left_values, 170 right_values) 171 npt.assert_almost_equal(left_values, right_values, decimal=decimal) 172 173 if not ignore_index: 174 pdt.assert_index_equal(left_df.index, right_df.index) 175 if not ignore_columns: 176 pdt.assert_index_equal(left_df.columns, right_df.columns) 177 178 179def _normalize_signs(arr1, arr2): 180 """Change column signs so that "column" and "-column" compare equal. 181 182 This is needed because results of eigenproblmes can have signs 183 flipped, but they're still right. 184 185 Notes 186 ===== 187 188 This function tries hard to make sure that, if you find "column" 189 and "-column" almost equal, calling a function like np.allclose to 190 compare them after calling `normalize_signs` succeeds. 191 192 To do so, it distinguishes two cases for every column: 193 194 - It can be all almost equal to 0 (this includes a column of 195 zeros). 196 - Otherwise, it has a value that isn't close to 0. 197 198 In the first case, no sign needs to be flipped. I.e., for 199 |epsilon| small, np.allclose(-epsilon, 0) is true if and only if 200 np.allclose(epsilon, 0) is. 201 202 In the second case, the function finds the number in the column 203 whose absolute value is largest. Then, it compares its sign with 204 the number found in the same index, but in the other array, and 205 flips the sign of the column as needed. 206 """ 207 # Let's convert everyting to floating point numbers (it's 208 # reasonable to assume that eigenvectors will already be floating 209 # point numbers). This is necessary because np.array(1) / 210 # np.array(0) != np.array(1.) / np.array(0.) 211 arr1 = np.asarray(arr1, dtype=np.float64) 212 arr2 = np.asarray(arr2, dtype=np.float64) 213 214 if arr1.shape != arr2.shape: 215 raise ValueError( 216 "Arrays must have the same shape ({0} vs {1}).".format(arr1.shape, 217 arr2.shape) 218 ) 219 220 # To avoid issues around zero, we'll compare signs of the values 221 # with highest absolute value 222 max_idx = np.abs(arr1).argmax(axis=0) 223 max_arr1 = arr1[max_idx, range(arr1.shape[1])] 224 max_arr2 = arr2[max_idx, range(arr2.shape[1])] 225 226 sign_arr1 = np.sign(max_arr1) 227 sign_arr2 = np.sign(max_arr2) 228 229 # Store current warnings, and ignore division by zero (like 1. / 230 # 0.) and invalid operations (like 0. / 0.) 231 wrn = np.seterr(invalid='ignore', divide='ignore') 232 differences = sign_arr1 / sign_arr2 233 # The values in `differences` can be: 234 # 1 -> equal signs 235 # -1 -> diff signs 236 # Or nan (0/0), inf (nonzero/0), 0 (0/nonzero) 237 np.seterr(**wrn) 238 239 # Now let's deal with cases where `differences != \pm 1` 240 special_cases = (~np.isfinite(differences)) | (differences == 0) 241 # In any of these cases, the sign of the column doesn't matter, so 242 # let's just keep it 243 differences[special_cases] = 1 244 245 return arr1 * differences, arr2 246 247 248@experimental(as_of="0.4.0") 249def assert_data_frame_almost_equal(left, right): 250 """Raise AssertionError if ``pd.DataFrame`` objects are not "almost equal". 251 252 Wrapper of ``pd.util.testing.assert_frame_equal``. Floating point values 253 are considered "almost equal" if they are within a threshold defined by 254 ``assert_frame_equal``. This wrapper uses a number of 255 checks that are turned off by default in ``assert_frame_equal`` in order to 256 perform stricter comparisons (for example, ensuring the index and column 257 types are the same). It also does not consider empty ``pd.DataFrame`` 258 objects equal if they have a different index. 259 260 Other notes: 261 262 * Index (row) and column ordering must be the same for objects to be equal. 263 * NaNs (``np.nan``) in the same locations are considered equal. 264 265 This is a helper function intended to be used in unit tests that need to 266 compare ``pd.DataFrame`` objects. 267 268 Parameters 269 ---------- 270 left, right : pd.DataFrame 271 ``pd.DataFrame`` objects to compare. 272 273 Raises 274 ------ 275 AssertionError 276 If `left` and `right` are not "almost equal". 277 278 See Also 279 -------- 280 pandas.util.testing.assert_frame_equal 281 282 """ 283 # pass all kwargs to ensure this function has consistent behavior even if 284 # `assert_frame_equal`'s defaults change 285 pdt.assert_frame_equal(left, right, 286 check_dtype=True, 287 check_index_type=True, 288 check_column_type=True, 289 check_frame_type=True, 290 check_less_precise=False, 291 check_names=True, 292 by_blocks=False, 293 check_exact=False) 294 # this check ensures that empty DataFrames with different indices do not 295 # compare equal. exact=True specifies that the type of the indices must be 296 # exactly the same 297 assert_index_equal(left.index, right.index) 298 299 300def assert_series_almost_equal(left, right): 301 # pass all kwargs to ensure this function has consistent behavior even if 302 # `assert_series_equal`'s defaults change 303 pdt.assert_series_equal(left, right, 304 check_dtype=True, 305 check_index_type=True, 306 check_series_type=True, 307 check_less_precise=False, 308 check_names=True, 309 check_exact=False, 310 check_datetimelike_compat=False, 311 obj='Series') 312 # this check ensures that empty Series with different indices do not 313 # compare equal. 314 assert_index_equal(left.index, right.index) 315 316 317def assert_index_equal(a, b): 318 pdt.assert_index_equal(a, b, 319 exact=True, 320 check_names=True, 321 check_exact=True) 322 323 324def pytestrunner(): 325 try: 326 import numpy 327 try: 328 # NumPy 1.14 changed repr output breaking our doctests, 329 # request the legacy 1.13 style 330 numpy.set_printoptions(legacy="1.13") 331 except TypeError: 332 # Old Numpy, output should be fine as it is :) 333 # TypeError: set_printoptions() got an unexpected 334 # keyword argument 'legacy' 335 pass 336 except ImportError: 337 numpy = None 338 try: 339 import pandas 340 # Max columns is automatically set by pandas based on terminal 341 # width, so set columns to unlimited to prevent the test suite 342 # from passing/failing based on terminal size. 343 pandas.options.display.max_columns = None 344 except ImportError: 345 pandas = None 346 347 # import here, cause outside the eggs aren't loaded 348 import pytest 349 350 args = ['--pyargs', 'skbio', '--doctest-modules', '--doctest-glob', 351 '*.pyx', '-o', '"doctest_optionflags=NORMALIZE_WHITESPACE' 352 ' IGNORE_EXCEPTION_DETAIL"'] + sys.argv[1:] 353 354 errno = pytest.main(args=args) 355 sys.exit(errno) 356