1from pytest_regressions.common import perform_regression_check, import_error_message
2
3
4class DataFrameRegressionFixture:
5    """
6    Pandas DataFrame Regression fixture implementation used on dataframe_regression fixture.
7    """
8
9    DISPLAY_PRECISION = 17  # Decimal places
10    DISPLAY_WIDTH = 1000  # Max. Chars on outputs
11    DISPLAY_MAX_COLUMNS = 1000  # Max. Number of columns (see #3)
12
13    def __init__(self, datadir, original_datadir, request):
14        """
15        :type datadir: Path
16        :type original_datadir: Path
17        :type request: FixtureRequest
18        """
19        self._tolerances_dict = {}
20        self._default_tolerance = {}
21
22        self.request = request
23        self.datadir = datadir
24        self.original_datadir = original_datadir
25        self._force_regen = False
26
27        self._pandas_display_options = (
28            "display.precision",
29            DataFrameRegressionFixture.DISPLAY_PRECISION,
30            "display.width",
31            DataFrameRegressionFixture.DISPLAY_WIDTH,
32            "display.max_columns",
33            DataFrameRegressionFixture.DISPLAY_MAX_COLUMNS,
34        )
35
36    def _check_data_types(self, key, obtained_column, expected_column):
37        """
38        Check if data type of obtained and expected columns are the same. Fail if not.
39        Helper method used in _check_fn method.
40        """
41        try:
42            import numpy as np
43        except ModuleNotFoundError:
44            raise ModuleNotFoundError(import_error_message("Numpy"))
45
46        __tracebackhide__ = True
47        obtained_data_type = obtained_column.values.dtype
48        expected_data_type = expected_column.values.dtype
49        if obtained_data_type != expected_data_type:
50            # Check if both data types are comparable as numbers (float, int, short, bytes, etc...)
51            if np.issubdtype(obtained_data_type, np.number) and np.issubdtype(
52                expected_data_type, np.number
53            ):
54                return
55
56            # In case they are not, assume they are not comparable
57            error_msg = (
58                "Data type for data %s of obtained and expected are not the same.\n"
59                "Obtained: %s\n"
60                "Expected: %s\n" % (key, obtained_data_type, expected_data_type)
61            )
62            raise AssertionError(error_msg)
63
64    def _check_data_shapes(self, obtained_column, expected_column):
65        """
66        Check if obtained and expected columns have the same size.
67        Helper method used in _check_fn method.
68        """
69        __tracebackhide__ = True
70
71        obtained_data_shape = obtained_column.values.shape
72        expected_data_shape = expected_column.values.shape
73        if obtained_data_shape != expected_data_shape:
74            error_msg = (
75                "Obtained and expected data shape are not the same.\n"
76                "Obtained: %s\n"
77                "Expected: %s\n" % (obtained_data_shape, expected_data_shape)
78            )
79            raise AssertionError(error_msg)
80
81    def _check_fn(self, obtained_filename, expected_filename):
82        """
83        Check if dict contents dumped to a file match the contents in expected file.
84
85        :param str obtained_filename:
86        :param str expected_filename:
87        """
88        try:
89            import numpy as np
90        except ModuleNotFoundError:
91            raise ModuleNotFoundError(import_error_message("Numpy"))
92        try:
93            import pandas as pd
94        except ModuleNotFoundError:
95            raise ModuleNotFoundError(import_error_message("Pandas"))
96
97        __tracebackhide__ = True
98
99        obtained_data = pd.read_csv(str(obtained_filename))
100        expected_data = pd.read_csv(str(expected_filename))
101
102        comparison_tables_dict = {}
103        for k in obtained_data.keys():
104            obtained_column = obtained_data[k]
105            expected_column = expected_data.get(k)
106
107            if expected_column is None:
108                error_msg = f"Could not find key '{k}' in the expected results.\n"
109                error_msg += "Keys in the obtained data table: ["
110                for k in obtained_data.keys():
111                    error_msg += f"'{k}', "
112                error_msg += "]\n"
113                error_msg += "Keys in the expected data table: ["
114                for k in expected_data.keys():
115                    error_msg += f"'{k}', "
116                error_msg += "]\n"
117                error_msg += "To update values, use --force-regen option.\n\n"
118                raise AssertionError(error_msg)
119
120            tolerance_args = self._tolerances_dict.get(k, self._default_tolerance)
121
122            self._check_data_types(k, obtained_column, expected_column)
123            self._check_data_shapes(obtained_column, expected_column)
124
125            data_type = obtained_column.values.dtype
126            if data_type in [float, np.float, np.float16, np.float32, np.float64]:
127                not_close_mask = ~np.isclose(
128                    obtained_column.values,
129                    expected_column.values,
130                    equal_nan=True,
131                    **tolerance_args,
132                )
133            else:
134                not_close_mask = obtained_column.values != expected_column.values
135
136            if np.any(not_close_mask):
137                diff_ids = np.where(not_close_mask)[0]
138                diff_obtained_data = obtained_column[diff_ids]
139                diff_expected_data = expected_column[diff_ids]
140                if data_type == np.bool:
141                    diffs = np.logical_xor(obtained_column, expected_column)[diff_ids]
142                else:
143                    diffs = np.abs(obtained_column - expected_column)[diff_ids]
144
145                comparison_table = pd.concat(
146                    [diff_obtained_data, diff_expected_data, diffs], axis=1
147                )
148                comparison_table.columns = [f"obtained_{k}", f"expected_{k}", "diff"]
149                comparison_tables_dict[k] = comparison_table
150
151        if len(comparison_tables_dict) > 0:
152            error_msg = "Values are not sufficiently close.\n"
153            error_msg += "To update values, use --force-regen option.\n\n"
154            for k, comparison_table in comparison_tables_dict.items():
155                error_msg += f"{k}:\n{comparison_table}\n\n"
156            raise AssertionError(error_msg)
157
158    def _dump_fn(self, data_object, filename):
159        """
160        Dump dict contents to the given filename
161
162        :param pd.DataFrame data_object:
163        :param str filename:
164        """
165        data_object.to_csv(
166            str(filename),
167            float_format=f"%.{DataFrameRegressionFixture.DISPLAY_PRECISION}g",
168        )
169
170    def check(
171        self,
172        data_frame,
173        basename=None,
174        fullpath=None,
175        tolerances=None,
176        default_tolerance=None,
177    ):
178        """
179        Checks the given pandas dataframe against a previously recorded version, or generate a new file.
180
181        Example::
182
183            data_frame = pandas.DataFrame.from_dict({
184                'U_gas': U[0][positions],
185                'U_liquid': U[1][positions],
186                'gas_vol_frac [-]': vol_frac[0][positions],
187                'liquid_vol_frac [-]': vol_frac[1][positions],
188                'P': Pa_to_bar(P)[positions],
189            })
190            dataframe_regression.check(data_frame)
191
192        :param pandas.DataFrame data_frame: pandas DataFrame containing data for regression check.
193
194        :param str basename: basename of the file to test/record. If not given the name
195            of the test is used.
196
197        :param str fullpath: complete path to use as a reference file. This option
198            will ignore embed_data completely, being useful if a reference file is located
199            in the session data dir for example.
200
201        :param dict tolerances: dict mapping keys from the data_dict to tolerance settings for the
202            given data. Example::
203
204                tolerances={'U': Tolerance(atol=1e-2)}
205
206        :param dict default_tolerance: dict mapping the default tolerance for the current check
207            call. Example::
208
209                default_tolerance=dict(atol=1e-7, rtol=1e-18).
210
211            If not provided, will use defaults from numpy's ``isclose`` function.
212
213        ``basename`` and ``fullpath`` are exclusive.
214        """
215        try:
216            import pandas as pd
217        except ModuleNotFoundError:
218            raise ModuleNotFoundError(import_error_message("Pandas"))
219
220        import functools
221
222        __tracebackhide__ = True
223
224        assert type(data_frame) is pd.DataFrame, (
225            "Only pandas DataFrames are supported on on dataframe_regression fixture.\n"
226            "Object with type '%s' was given." % (str(type(data_frame)),)
227        )
228
229        for column in data_frame.columns:
230            array = data_frame[column]
231            # Skip assertion if an array of strings
232            if (array.dtype == "O") and (type(array[0]) is str):
233                continue
234            # Rejected: timedelta, datetime, objects, zero-terminated bytes, unicode strings and raw data
235            assert array.dtype not in ["m", "M", "O", "S", "a", "U", "V"], (
236                "Only numeric data is supported on dataframe_regression fixture.\n"
237                "Array with type '%s' was given.\n" % (str(array.dtype),)
238            )
239
240        if tolerances is None:
241            tolerances = {}
242        self._tolerances_dict = tolerances
243
244        if default_tolerance is None:
245            default_tolerance = {}
246        self._default_tolerance = default_tolerance
247
248        dump_fn = functools.partial(self._dump_fn, data_frame)
249
250        with pd.option_context(*self._pandas_display_options):
251            perform_regression_check(
252                datadir=self.datadir,
253                original_datadir=self.original_datadir,
254                request=self.request,
255                check_fn=self._check_fn,
256                dump_fn=dump_fn,
257                extension=".csv",
258                basename=basename,
259                fullpath=fullpath,
260                force_regen=self._force_regen,
261            )
262