1from pytest_regressions.common import perform_regression_check, import_error_message 2 3 4class DataFrameRegressionFixture: 5 """ 6 Pandas DataFrame Regression fixture implementation used on dataframe_regression fixture. 7 """ 8 9 DISPLAY_PRECISION = 17 # Decimal places 10 DISPLAY_WIDTH = 1000 # Max. Chars on outputs 11 DISPLAY_MAX_COLUMNS = 1000 # Max. Number of columns (see #3) 12 13 def __init__(self, datadir, original_datadir, request): 14 """ 15 :type datadir: Path 16 :type original_datadir: Path 17 :type request: FixtureRequest 18 """ 19 self._tolerances_dict = {} 20 self._default_tolerance = {} 21 22 self.request = request 23 self.datadir = datadir 24 self.original_datadir = original_datadir 25 self._force_regen = False 26 27 self._pandas_display_options = ( 28 "display.precision", 29 DataFrameRegressionFixture.DISPLAY_PRECISION, 30 "display.width", 31 DataFrameRegressionFixture.DISPLAY_WIDTH, 32 "display.max_columns", 33 DataFrameRegressionFixture.DISPLAY_MAX_COLUMNS, 34 ) 35 36 def _check_data_types(self, key, obtained_column, expected_column): 37 """ 38 Check if data type of obtained and expected columns are the same. Fail if not. 39 Helper method used in _check_fn method. 40 """ 41 try: 42 import numpy as np 43 except ModuleNotFoundError: 44 raise ModuleNotFoundError(import_error_message("Numpy")) 45 46 __tracebackhide__ = True 47 obtained_data_type = obtained_column.values.dtype 48 expected_data_type = expected_column.values.dtype 49 if obtained_data_type != expected_data_type: 50 # Check if both data types are comparable as numbers (float, int, short, bytes, etc...) 51 if np.issubdtype(obtained_data_type, np.number) and np.issubdtype( 52 expected_data_type, np.number 53 ): 54 return 55 56 # In case they are not, assume they are not comparable 57 error_msg = ( 58 "Data type for data %s of obtained and expected are not the same.\n" 59 "Obtained: %s\n" 60 "Expected: %s\n" % (key, obtained_data_type, expected_data_type) 61 ) 62 raise AssertionError(error_msg) 63 64 def _check_data_shapes(self, obtained_column, expected_column): 65 """ 66 Check if obtained and expected columns have the same size. 67 Helper method used in _check_fn method. 68 """ 69 __tracebackhide__ = True 70 71 obtained_data_shape = obtained_column.values.shape 72 expected_data_shape = expected_column.values.shape 73 if obtained_data_shape != expected_data_shape: 74 error_msg = ( 75 "Obtained and expected data shape are not the same.\n" 76 "Obtained: %s\n" 77 "Expected: %s\n" % (obtained_data_shape, expected_data_shape) 78 ) 79 raise AssertionError(error_msg) 80 81 def _check_fn(self, obtained_filename, expected_filename): 82 """ 83 Check if dict contents dumped to a file match the contents in expected file. 84 85 :param str obtained_filename: 86 :param str expected_filename: 87 """ 88 try: 89 import numpy as np 90 except ModuleNotFoundError: 91 raise ModuleNotFoundError(import_error_message("Numpy")) 92 try: 93 import pandas as pd 94 except ModuleNotFoundError: 95 raise ModuleNotFoundError(import_error_message("Pandas")) 96 97 __tracebackhide__ = True 98 99 obtained_data = pd.read_csv(str(obtained_filename)) 100 expected_data = pd.read_csv(str(expected_filename)) 101 102 comparison_tables_dict = {} 103 for k in obtained_data.keys(): 104 obtained_column = obtained_data[k] 105 expected_column = expected_data.get(k) 106 107 if expected_column is None: 108 error_msg = f"Could not find key '{k}' in the expected results.\n" 109 error_msg += "Keys in the obtained data table: [" 110 for k in obtained_data.keys(): 111 error_msg += f"'{k}', " 112 error_msg += "]\n" 113 error_msg += "Keys in the expected data table: [" 114 for k in expected_data.keys(): 115 error_msg += f"'{k}', " 116 error_msg += "]\n" 117 error_msg += "To update values, use --force-regen option.\n\n" 118 raise AssertionError(error_msg) 119 120 tolerance_args = self._tolerances_dict.get(k, self._default_tolerance) 121 122 self._check_data_types(k, obtained_column, expected_column) 123 self._check_data_shapes(obtained_column, expected_column) 124 125 data_type = obtained_column.values.dtype 126 if data_type in [float, np.float, np.float16, np.float32, np.float64]: 127 not_close_mask = ~np.isclose( 128 obtained_column.values, 129 expected_column.values, 130 equal_nan=True, 131 **tolerance_args, 132 ) 133 else: 134 not_close_mask = obtained_column.values != expected_column.values 135 136 if np.any(not_close_mask): 137 diff_ids = np.where(not_close_mask)[0] 138 diff_obtained_data = obtained_column[diff_ids] 139 diff_expected_data = expected_column[diff_ids] 140 if data_type == np.bool: 141 diffs = np.logical_xor(obtained_column, expected_column)[diff_ids] 142 else: 143 diffs = np.abs(obtained_column - expected_column)[diff_ids] 144 145 comparison_table = pd.concat( 146 [diff_obtained_data, diff_expected_data, diffs], axis=1 147 ) 148 comparison_table.columns = [f"obtained_{k}", f"expected_{k}", "diff"] 149 comparison_tables_dict[k] = comparison_table 150 151 if len(comparison_tables_dict) > 0: 152 error_msg = "Values are not sufficiently close.\n" 153 error_msg += "To update values, use --force-regen option.\n\n" 154 for k, comparison_table in comparison_tables_dict.items(): 155 error_msg += f"{k}:\n{comparison_table}\n\n" 156 raise AssertionError(error_msg) 157 158 def _dump_fn(self, data_object, filename): 159 """ 160 Dump dict contents to the given filename 161 162 :param pd.DataFrame data_object: 163 :param str filename: 164 """ 165 data_object.to_csv( 166 str(filename), 167 float_format=f"%.{DataFrameRegressionFixture.DISPLAY_PRECISION}g", 168 ) 169 170 def check( 171 self, 172 data_frame, 173 basename=None, 174 fullpath=None, 175 tolerances=None, 176 default_tolerance=None, 177 ): 178 """ 179 Checks the given pandas dataframe against a previously recorded version, or generate a new file. 180 181 Example:: 182 183 data_frame = pandas.DataFrame.from_dict({ 184 'U_gas': U[0][positions], 185 'U_liquid': U[1][positions], 186 'gas_vol_frac [-]': vol_frac[0][positions], 187 'liquid_vol_frac [-]': vol_frac[1][positions], 188 'P': Pa_to_bar(P)[positions], 189 }) 190 dataframe_regression.check(data_frame) 191 192 :param pandas.DataFrame data_frame: pandas DataFrame containing data for regression check. 193 194 :param str basename: basename of the file to test/record. If not given the name 195 of the test is used. 196 197 :param str fullpath: complete path to use as a reference file. This option 198 will ignore embed_data completely, being useful if a reference file is located 199 in the session data dir for example. 200 201 :param dict tolerances: dict mapping keys from the data_dict to tolerance settings for the 202 given data. Example:: 203 204 tolerances={'U': Tolerance(atol=1e-2)} 205 206 :param dict default_tolerance: dict mapping the default tolerance for the current check 207 call. Example:: 208 209 default_tolerance=dict(atol=1e-7, rtol=1e-18). 210 211 If not provided, will use defaults from numpy's ``isclose`` function. 212 213 ``basename`` and ``fullpath`` are exclusive. 214 """ 215 try: 216 import pandas as pd 217 except ModuleNotFoundError: 218 raise ModuleNotFoundError(import_error_message("Pandas")) 219 220 import functools 221 222 __tracebackhide__ = True 223 224 assert type(data_frame) is pd.DataFrame, ( 225 "Only pandas DataFrames are supported on on dataframe_regression fixture.\n" 226 "Object with type '%s' was given." % (str(type(data_frame)),) 227 ) 228 229 for column in data_frame.columns: 230 array = data_frame[column] 231 # Skip assertion if an array of strings 232 if (array.dtype == "O") and (type(array[0]) is str): 233 continue 234 # Rejected: timedelta, datetime, objects, zero-terminated bytes, unicode strings and raw data 235 assert array.dtype not in ["m", "M", "O", "S", "a", "U", "V"], ( 236 "Only numeric data is supported on dataframe_regression fixture.\n" 237 "Array with type '%s' was given.\n" % (str(array.dtype),) 238 ) 239 240 if tolerances is None: 241 tolerances = {} 242 self._tolerances_dict = tolerances 243 244 if default_tolerance is None: 245 default_tolerance = {} 246 self._default_tolerance = default_tolerance 247 248 dump_fn = functools.partial(self._dump_fn, data_frame) 249 250 with pd.option_context(*self._pandas_display_options): 251 perform_regression_check( 252 datadir=self.datadir, 253 original_datadir=self.original_datadir, 254 request=self.request, 255 check_fn=self._check_fn, 256 dump_fn=dump_fn, 257 extension=".csv", 258 basename=basename, 259 fullpath=fullpath, 260 force_regen=self._force_regen, 261 ) 262