1from pytest_regressions.common import perform_regression_check, import_error_message
2from pytest_regressions.dataframe_regression import DataFrameRegressionFixture
3
4
5class NumericRegressionFixture(DataFrameRegressionFixture):
6    """
7    Numeric Data Regression fixture implementation used on num_regression fixture.
8    """
9
10    def check(
11        self,
12        data_dict,
13        basename=None,
14        fullpath=None,
15        tolerances=None,
16        default_tolerance=None,
17        data_index=None,
18        fill_different_shape_with_nan=True,
19    ):
20        """
21        Checks the given dict against a previously recorded version, or generate a new file.
22        The dict must map from user-defined keys to 1d numpy arrays or array-like values.
23
24        Example::
25
26            num_regression.check({
27                'U_gas': U[0][positions],
28                'U_liquid': U[1][positions],
29                'gas_vol_frac [-]': vol_frac[0][positions],
30                'liquid_vol_frac [-]': vol_frac[1][positions],
31                'P': Pa_to_bar(P)[positions],
32            })
33
34        :param dict data_dict: dict mapping keys to numpy arrays, or objects that can be
35            coerced to 1d numpy arrays with a numeric dtype (e.g. list, tuple, etc).
36
37        :param str basename: basename of the file to test/record. If not given the name
38            of the test is used.
39
40        :param str fullpath: complete path to use as a reference file. This option
41            will ignore embed_data completely, being useful if a reference file is located
42            in the session data dir for example.
43
44        :param dict tolerances: dict mapping keys from the data_dict to tolerance settings for the
45            given data. Example::
46
47                tolerances={'U': Tolerance(atol=1e-2)}
48
49        :param dict default_tolerance: dict mapping the default tolerance for the current check
50            call. Example::
51
52                default_tolerance=dict(atol=1e-7, rtol=1e-18).
53
54            If not provided, will use defaults from numpy's ``isclose`` function.
55
56        :param list data_index: If set, will override the indexes shown in the outputs. Default
57            is panda's default, which is ``range(0, len(data))``.
58
59        :param bool fill_different_shape_with_nan: If set, all the data provided in the data_dict
60            that has size lower than the bigger size will be filled with ``np.NaN``, in order to save
61            the data in a CSV file.
62
63        ``basename`` and ``fullpath`` are exclusive.
64        """
65
66        try:
67            import numpy as np
68        except ModuleNotFoundError:
69            raise ModuleNotFoundError(import_error_message("Numpy"))
70        try:
71            import pandas as pd
72        except ModuleNotFoundError:
73            raise ModuleNotFoundError(import_error_message("Pandas"))
74
75        __tracebackhide__ = True
76
77        for k, obj in data_dict.items():
78            if not isinstance(obj, np.ndarray):
79                arr = np.atleast_1d(np.asarray(obj))
80                if np.issubdtype(arr.dtype, np.number):
81                    data_dict[k] = arr
82
83        data_shapes = []
84        for obj in data_dict.values():
85            assert type(obj) in [
86                np.ndarray
87            ], "Only objects that can be coerced to numpy arrays are valid for numeric_data_regression fixture.\n"
88            shape = obj.shape
89
90            assert len(shape) == 1, (
91                "Only 1D arrays are supported on num_data_regression fixture.\n"
92                "Array with shape %s was given.\n" % (shape,)
93            )
94            data_shapes.append(shape[0])
95
96        data_shapes = np.array(data_shapes)
97        if not np.all(data_shapes == data_shapes[0]):
98            if not fill_different_shape_with_nan:
99                assert (
100                    False
101                ), "Data dict with different array lengths will not be accepted. Try setting fill_different_shape_with_nan=True."
102            elif len(data_dict) > 1 and not all(
103                np.issubdtype(a.dtype, np.floating) for a in data_dict.values()
104            ):
105                raise TypeError(
106                    "Checking multiple arrays with different shapes are not supported for non-float arrays"
107                )
108            else:
109                max_size = max(data_shapes)
110                for k, obj in data_dict.items():
111                    new_data = np.empty(shape=(max_size,), dtype=obj.dtype)
112                    new_data[: len(obj)] = obj
113                    new_data[len(obj) :] = np.nan
114                    data_dict[k] = new_data
115
116        data_frame = pd.DataFrame(data_dict, index=data_index)
117
118        DataFrameRegressionFixture.check(
119            self, data_frame, basename, fullpath, tolerances, default_tolerance
120        )
121