1# ----------------------------------------------------------------------------
2# Copyright (c) 2013--, scikit-bio development team.
3#
4# Distributed under the terms of the Modified BSD License.
5#
6# The full license is in the file COPYING.txt, distributed with this software.
7# ----------------------------------------------------------------------------
8
9import inspect
10import os
11import sys
12
13import numpy as np
14import numpy.testing as npt
15import pandas.testing as pdt
16
17from ._decorator import experimental
18
19
20class ReallyEqualMixin:
21    """Use this for testing __eq__/__ne__.
22
23    Taken and modified from the following public domain code:
24      https://ludios.org/testing-your-eq-ne-cmp/
25
26    """
27
28    def assertReallyEqual(self, a, b):
29        # assertEqual first, because it will have a good message if the
30        # assertion fails.
31        self.assertEqual(a, b)
32        self.assertEqual(b, a)
33        self.assertTrue(a == b)
34        self.assertTrue(b == a)
35        self.assertFalse(a != b)
36        self.assertFalse(b != a)
37
38    def assertReallyNotEqual(self, a, b):
39        # assertNotEqual first, because it will have a good message if the
40        # assertion fails.
41        self.assertNotEqual(a, b)
42        self.assertNotEqual(b, a)
43        self.assertFalse(a == b)
44        self.assertFalse(b == a)
45        self.assertTrue(a != b)
46        self.assertTrue(b != a)
47
48
49@experimental(as_of="0.4.0")
50def get_data_path(fn, subfolder='data'):
51    """Return path to filename ``fn`` in the data folder.
52
53    During testing it is often necessary to load data files. This
54    function returns the full path to files in the ``data`` subfolder
55    by default.
56
57    Parameters
58    ----------
59    fn : str
60        File name.
61
62    subfolder : str, defaults to ``data``
63        Name of the subfolder that contains the data.
64
65
66    Returns
67    -------
68    str
69        Inferred absolute path to the test data for the module where
70        ``get_data_path(fn)`` is called.
71
72    Notes
73    -----
74    The requested path may not point to an existing file, as its
75    existence is not checked.
76
77    """
78    # getouterframes returns a list of tuples: the second tuple
79    # contains info about the caller, and the second element is its
80    # filename
81    callers_filename = inspect.getouterframes(inspect.currentframe())[1][1]
82    path = os.path.dirname(os.path.abspath(callers_filename))
83    data_path = os.path.join(path, subfolder, fn)
84    return data_path
85
86
87@experimental(as_of="0.4.0")
88def assert_ordination_results_equal(left, right, ignore_method_names=False,
89                                    ignore_axis_labels=False,
90                                    ignore_directionality=False,
91                                    decimal=7):
92    """Assert that ordination results objects are equal.
93
94    This is a helper function intended to be used in unit tests that need to
95    compare ``OrdinationResults`` objects.
96
97    Parameters
98    ----------
99    left, right : OrdinationResults
100        Ordination results to be compared for equality.
101    ignore_method_names : bool, optional
102        Ignore differences in `short_method_name` and `long_method_name`.
103    ignore_axis_labels : bool, optional
104        Ignore differences in axis labels (i.e., column labels).
105    ignore_directionality : bool, optional
106        Ignore differences in directionality (i.e., differences in signs) for
107        attributes `samples`, `features` and `biplot_scores`.
108
109    Raises
110    ------
111    AssertionError
112        If the two objects are not equal.
113
114    """
115    npt.assert_equal(type(left) is type(right), True)
116
117    if not ignore_method_names:
118        npt.assert_equal(left.short_method_name, right.short_method_name)
119        npt.assert_equal(left.long_method_name, right.long_method_name)
120
121    _assert_frame_equal(left.samples, right.samples,
122                        ignore_columns=ignore_axis_labels,
123                        ignore_directionality=ignore_directionality,
124                        decimal=decimal)
125
126    _assert_frame_equal(left.features, right.features,
127                        ignore_columns=ignore_axis_labels,
128                        ignore_directionality=ignore_directionality,
129                        decimal=decimal)
130    _assert_frame_equal(left.biplot_scores, right.biplot_scores,
131                        ignore_columns=ignore_axis_labels,
132                        ignore_directionality=ignore_directionality,
133                        decimal=decimal)
134
135    _assert_frame_equal(left.sample_constraints, right.sample_constraints,
136                        ignore_columns=ignore_axis_labels,
137                        ignore_directionality=ignore_directionality,
138                        decimal=decimal)
139
140    _assert_series_equal(left.eigvals, right.eigvals, ignore_axis_labels,
141                         decimal=decimal)
142
143    _assert_series_equal(left.proportion_explained, right.proportion_explained,
144                         ignore_axis_labels,
145                         decimal=decimal)
146
147
148def _assert_series_equal(left_s, right_s, ignore_index=False, decimal=7):
149    # assert_series_equal doesn't like None...
150    if left_s is None or right_s is None:
151        assert left_s is None and right_s is None
152    else:
153        npt.assert_almost_equal(left_s.values, right_s.values,
154                                decimal=decimal)
155        if not ignore_index:
156            pdt.assert_index_equal(left_s.index, right_s.index)
157
158
159def _assert_frame_equal(left_df, right_df, ignore_index=False,
160                        ignore_columns=False, ignore_directionality=False,
161                        decimal=7):
162    # assert_frame_equal doesn't like None...
163    if left_df is None or right_df is None:
164        assert left_df is None and right_df is None
165    else:
166        left_values = left_df.values
167        right_values = right_df.values
168        if ignore_directionality:
169            left_values, right_values = _normalize_signs(left_values,
170                                                         right_values)
171        npt.assert_almost_equal(left_values, right_values, decimal=decimal)
172
173        if not ignore_index:
174            pdt.assert_index_equal(left_df.index, right_df.index)
175        if not ignore_columns:
176            pdt.assert_index_equal(left_df.columns, right_df.columns)
177
178
179def _normalize_signs(arr1, arr2):
180    """Change column signs so that "column" and "-column" compare equal.
181
182    This is needed because results of eigenproblmes can have signs
183    flipped, but they're still right.
184
185    Notes
186    =====
187
188    This function tries hard to make sure that, if you find "column"
189    and "-column" almost equal, calling a function like np.allclose to
190    compare them after calling `normalize_signs` succeeds.
191
192    To do so, it distinguishes two cases for every column:
193
194    - It can be all almost equal to 0 (this includes a column of
195      zeros).
196    - Otherwise, it has a value that isn't close to 0.
197
198    In the first case, no sign needs to be flipped. I.e., for
199    |epsilon| small, np.allclose(-epsilon, 0) is true if and only if
200    np.allclose(epsilon, 0) is.
201
202    In the second case, the function finds the number in the column
203    whose absolute value is largest. Then, it compares its sign with
204    the number found in the same index, but in the other array, and
205    flips the sign of the column as needed.
206    """
207    # Let's convert everyting to floating point numbers (it's
208    # reasonable to assume that eigenvectors will already be floating
209    # point numbers). This is necessary because np.array(1) /
210    # np.array(0) != np.array(1.) / np.array(0.)
211    arr1 = np.asarray(arr1, dtype=np.float64)
212    arr2 = np.asarray(arr2, dtype=np.float64)
213
214    if arr1.shape != arr2.shape:
215        raise ValueError(
216            "Arrays must have the same shape ({0} vs {1}).".format(arr1.shape,
217                                                                   arr2.shape)
218        )
219
220    # To avoid issues around zero, we'll compare signs of the values
221    # with highest absolute value
222    max_idx = np.abs(arr1).argmax(axis=0)
223    max_arr1 = arr1[max_idx, range(arr1.shape[1])]
224    max_arr2 = arr2[max_idx, range(arr2.shape[1])]
225
226    sign_arr1 = np.sign(max_arr1)
227    sign_arr2 = np.sign(max_arr2)
228
229    # Store current warnings, and ignore division by zero (like 1. /
230    # 0.) and invalid operations (like 0. / 0.)
231    wrn = np.seterr(invalid='ignore', divide='ignore')
232    differences = sign_arr1 / sign_arr2
233    # The values in `differences` can be:
234    #    1 -> equal signs
235    #   -1 -> diff signs
236    #   Or nan (0/0), inf (nonzero/0), 0 (0/nonzero)
237    np.seterr(**wrn)
238
239    # Now let's deal with cases where `differences != \pm 1`
240    special_cases = (~np.isfinite(differences)) | (differences == 0)
241    # In any of these cases, the sign of the column doesn't matter, so
242    # let's just keep it
243    differences[special_cases] = 1
244
245    return arr1 * differences, arr2
246
247
248@experimental(as_of="0.4.0")
249def assert_data_frame_almost_equal(left, right):
250    """Raise AssertionError if ``pd.DataFrame`` objects are not "almost equal".
251
252    Wrapper of ``pd.util.testing.assert_frame_equal``. Floating point values
253    are considered "almost equal" if they are within a threshold defined by
254    ``assert_frame_equal``. This wrapper uses a number of
255    checks that are turned off by default in ``assert_frame_equal`` in order to
256    perform stricter comparisons (for example, ensuring the index and column
257    types are the same). It also does not consider empty ``pd.DataFrame``
258    objects equal if they have a different index.
259
260    Other notes:
261
262    * Index (row) and column ordering must be the same for objects to be equal.
263    * NaNs (``np.nan``) in the same locations are considered equal.
264
265    This is a helper function intended to be used in unit tests that need to
266    compare ``pd.DataFrame`` objects.
267
268    Parameters
269    ----------
270    left, right : pd.DataFrame
271        ``pd.DataFrame`` objects to compare.
272
273    Raises
274    ------
275    AssertionError
276        If `left` and `right` are not "almost equal".
277
278    See Also
279    --------
280    pandas.util.testing.assert_frame_equal
281
282    """
283    # pass all kwargs to ensure this function has consistent behavior even if
284    # `assert_frame_equal`'s defaults change
285    pdt.assert_frame_equal(left, right,
286                           check_dtype=True,
287                           check_index_type=True,
288                           check_column_type=True,
289                           check_frame_type=True,
290                           check_less_precise=False,
291                           check_names=True,
292                           by_blocks=False,
293                           check_exact=False)
294    # this check ensures that empty DataFrames with different indices do not
295    # compare equal. exact=True specifies that the type of the indices must be
296    # exactly the same
297    assert_index_equal(left.index, right.index)
298
299
300def assert_series_almost_equal(left, right):
301    # pass all kwargs to ensure this function has consistent behavior even if
302    # `assert_series_equal`'s defaults change
303    pdt.assert_series_equal(left, right,
304                            check_dtype=True,
305                            check_index_type=True,
306                            check_series_type=True,
307                            check_less_precise=False,
308                            check_names=True,
309                            check_exact=False,
310                            check_datetimelike_compat=False,
311                            obj='Series')
312    # this check ensures that empty Series with different indices do not
313    # compare equal.
314    assert_index_equal(left.index, right.index)
315
316
317def assert_index_equal(a, b):
318    pdt.assert_index_equal(a, b,
319                           exact=True,
320                           check_names=True,
321                           check_exact=True)
322
323
324def pytestrunner():
325    try:
326        import numpy
327        try:
328            # NumPy 1.14 changed repr output breaking our doctests,
329            # request the legacy 1.13 style
330            numpy.set_printoptions(legacy="1.13")
331        except TypeError:
332            # Old Numpy, output should be fine as it is :)
333            # TypeError: set_printoptions() got an unexpected
334            # keyword argument 'legacy'
335            pass
336    except ImportError:
337        numpy = None
338    try:
339        import pandas
340        # Max columns is automatically set by pandas based on terminal
341        # width, so set columns to unlimited to prevent the test suite
342        # from passing/failing based on terminal size.
343        pandas.options.display.max_columns = None
344    except ImportError:
345        pandas = None
346
347    # import here, cause outside the eggs aren't loaded
348    import pytest
349
350    args = ['--pyargs', 'skbio', '--doctest-modules', '--doctest-glob',
351            '*.pyx', '-o', '"doctest_optionflags=NORMALIZE_WHITESPACE'
352            ' IGNORE_EXCEPTION_DETAIL"'] + sys.argv[1:]
353
354    errno = pytest.main(args=args)
355    sys.exit(errno)
356