1""" 2Optimizer with cross validation score 3""" 4 5import numpy as np 6from sklearn.model_selection import KFold, ShuffleSplit 7from sklearn.metrics import r2_score 8from typing import Any, Dict, Tuple 9from .base_optimizer import BaseOptimizer 10from .optimizer import Optimizer 11from .fit_methods import fit 12from .tools import ScatterData 13from .model_selection import get_model_metrics 14 15 16validation_methods = { 17 'k-fold': KFold, 18 'shuffle-split': ShuffleSplit, 19} 20 21 22class CrossValidationEstimator(BaseOptimizer): 23 """ 24 This class provides an optimizer with cross validation for solving the 25 linear :math:`\\boldsymbol{A}\\boldsymbol{x} = \\boldsymbol{y}` problem. 26 Cross-validation (CV) scores are calculated by splitting the 27 available reference data in multiple different ways. It also produces 28 the finalized model (using the full input data) for which the CV score 29 is an estimation of its performance. 30 31 Warning 32 ------- 33 Repeatedly setting up a CrossValidationEstimator and training 34 *without* changing the seed for the random number generator will yield 35 identical or correlated results, to avoid this please specify a different 36 seed when setting up multiple CrossValidationEstimator instances. 37 38 Parameters 39 ---------- 40 fit_data : tuple(numpy.ndarray, numpy.ndarray) 41 the first element of the tuple represents the fit matrix `A` 42 (`N, M` array) while the second element represents the vector 43 of target values `y` (`N` array); here `N` (=rows of `A`, 44 elements of `y`) equals the number of target values and `M` 45 (=columns of `A`) equals the number of parameters 46 fit_method : str 47 method to be used for training; possible choice are 48 "ardr", "bayesian-ridge", "elasticnet", "lasso", "least-squares", 49 "omp", "rfe", "ridge", "split-bregman" 50 standardize : bool 51 if True the fit matrix and target values are standardized before fitting, 52 meaning columns in the fit matrix and th target values are rescaled to 53 have a standard deviation of 1.0. 54 validation_method : str 55 method to use for cross-validation; possible choices are 56 "shuffle-split", "k-fold" 57 n_splits : int 58 number of times the fit data set will be split for the cross-validation 59 check_condition : bool 60 if True the condition number will be checked 61 (this can be sligthly more time consuming for larger 62 matrices) 63 seed : int 64 seed for pseudo random number generator 65 66 Attributes 67 ---------- 68 train_scatter_data : ScatterData 69 contains target and predicted values from each individual 70 traininig set in the cross-validation split; 71 :class:`ScatterData` is a namedtuple. 72 validation_scatter_data : ScatterData 73 contains target and predicted values from each individual 74 validation set in the cross-validation split; 75 :class:`ScatterData` is a namedtuple. 76 77 """ 78 79 def __init__(self, 80 fit_data: Tuple[np.ndarray, np.ndarray], 81 fit_method: str = 'least-squares', 82 standardize: bool = True, 83 validation_method: str = 'k-fold', 84 n_splits: int = 10, 85 check_condition: bool = True, 86 seed: int = 42, 87 **kwargs) -> None: 88 89 super().__init__(fit_data, fit_method, standardize, check_condition, seed) 90 91 if validation_method not in validation_methods.keys(): 92 msg = ['Validation method not available'] 93 msg += ['Please choose one of the following:'] 94 for key in validation_methods: 95 msg += [' * ' + key] 96 raise ValueError('\n'.join(msg)) 97 self._validation_method = validation_method 98 self._n_splits = n_splits 99 self._set_kwargs(kwargs) 100 101 # data set splitting object 102 self._splitter = validation_methods[validation_method]( 103 n_splits=self.n_splits, random_state=seed, 104 **self._split_kwargs) 105 106 self.train_scatter_data = None 107 self.validation_scatter_data = None 108 109 self._parameters_splits = None 110 self._rmse_train_splits = None 111 self._rmse_valid_splits = None 112 self.model_metrics = {} 113 114 def train(self) -> None: 115 """ Constructs the final model using all input data available. """ 116 self._fit_results = fit(self._A, self._y, self.fit_method, 117 self.standardize, self._check_condition, 118 **self._fit_kwargs) 119 y_train_predicted = np.dot(self._A, self.parameters) 120 metrics = get_model_metrics(self._A, self.parameters, self._y, y_train_predicted) 121 122 # finalize metrics 123 self.model_metrics['rmse_train_final'] = metrics['rmse_train'] 124 self.model_metrics['R2_train'] = metrics['R2_train'] 125 self.model_metrics['AIC'] = metrics['AIC'] 126 self.model_metrics['BIC'] = metrics['BIC'] 127 128 def validate(self) -> None: 129 """ Runs validation. """ 130 train_target, train_predicted = [], [] 131 valid_target, valid_predicted = [], [] 132 rmse_train_splits, rmse_valid_splits = [], [] 133 parameters_splits = [] 134 for train_set, test_set in self._splitter.split(self._A): 135 opt = Optimizer((self._A, self._y), self.fit_method, 136 standardize=self.standardize, 137 train_set=train_set, 138 test_set=test_set, 139 check_condition=self._check_condition, 140 **self._fit_kwargs) 141 opt.train() 142 143 parameters_splits.append(opt.parameters) 144 rmse_train_splits.append(opt.rmse_train) 145 rmse_valid_splits.append(opt.rmse_test) 146 train_target.extend(opt.train_scatter_data.target) 147 train_predicted.extend(opt.train_scatter_data.predicted) 148 valid_target.extend(opt.test_scatter_data.target) 149 valid_predicted.extend(opt.test_scatter_data.predicted) 150 151 self._parameters_splits = np.array(parameters_splits) 152 self._rmse_train_splits = np.array(rmse_train_splits) 153 self._rmse_valid_splits = np.array(rmse_valid_splits) 154 self.train_scatter_data = ScatterData( 155 target=np.array(train_target), predicted=np.array(train_predicted)) 156 self.validation_scatter_data = ScatterData( 157 target=np.array(valid_target), predicted=np.array(valid_predicted)) 158 159 self.model_metrics['rmse_validation'] = np.sqrt(np.mean(self._rmse_valid_splits**2)) 160 self.model_metrics['R2_validation'] = r2_score(valid_target, valid_predicted) 161 162 def _set_kwargs(self, kwargs: dict) -> None: 163 """ 164 Sets up fit_kwargs and split_kwargs. 165 Different split methods need different keywords. 166 """ 167 self._fit_kwargs = {} 168 self._split_kwargs = {} 169 170 if self.validation_method == 'k-fold': 171 self._split_kwargs['shuffle'] = True # default True 172 for key, val in kwargs.items(): 173 if key in ['shuffle']: 174 self._split_kwargs[key] = val 175 else: 176 self._fit_kwargs[key] = val 177 elif self.validation_method == 'shuffle-split': 178 for key, val in kwargs.items(): 179 if key in ['test_size', 'train_size']: 180 self._split_kwargs[key] = val 181 else: 182 self._fit_kwargs[key] = val 183 184 @property 185 def summary(self) -> Dict[str, Any]: 186 """ comprehensive information about the optimizer """ 187 188 info = super().summary 189 190 # Add class specific data 191 info['validation_method'] = self.validation_method 192 info['n_splits'] = self.n_splits 193 info['rmse_train'] = self.rmse_train 194 info['rmse_train_splits'] = self.rmse_train_splits 195 info['rmse_validation'] = self.rmse_validation 196 info['rmse_validation_splits'] = self.rmse_validation_splits 197 info['train_scatter_data'] = self.train_scatter_data 198 info['validation_scatter_data'] = self.validation_scatter_data 199 200 # add metrics 201 info = {**info, **self.model_metrics} 202 203 # add kwargs used for fitting and splitting 204 info = {**info, **self._fit_kwargs, **self._split_kwargs} 205 return info 206 207 def __repr__(self) -> str: 208 kwargs = dict() 209 kwargs['fit_method'] = self.fit_method 210 kwargs['validation_method'] = self.validation_method 211 kwargs['n_splits'] = self.n_splits 212 kwargs['seed'] = self.seed 213 kwargs = {**kwargs, **self._fit_kwargs, **self._split_kwargs} 214 return 'CrossValidationEstimator((A, y), {})'.format( 215 ', '.join('{}={}'.format(*kwarg) for kwarg in kwargs.items())) 216 217 @property 218 def validation_method(self) -> str: 219 """ validation method name """ 220 return self._validation_method 221 222 @property 223 def n_splits(self) -> int: 224 """ number of splits (folds) used for cross-validation """ 225 return self._n_splits 226 227 @property 228 def parameters_splits(self) -> np.ndarray: 229 """ all parameters obtained during cross-validation """ 230 return self._parameters_splits 231 232 @property 233 def n_nonzero_parameters_splits(self) -> np.ndarray: 234 """ number of non-zero parameters for each split """ 235 if self.parameters_splits is None: 236 return None 237 else: 238 return np.array([np.count_nonzero(p) for p in self.parameters_splits]) 239 240 @property 241 def rmse_train_final(self) -> float: 242 """ 243 root mean squared error when using the full set of input data 244 """ 245 if 'rmse_train_final' not in self.model_metrics: 246 return None 247 return self.model_metrics['rmse_train_final'] 248 249 @property 250 def rmse_train(self) -> float: 251 """ 252 average root mean squared training error obtained during 253 cross-validation 254 """ 255 if self._rmse_train_splits is None: 256 return None 257 return np.sqrt(np.mean(self._rmse_train_splits**2)) 258 259 @property 260 def rmse_train_splits(self) -> np.ndarray: 261 """ 262 root mean squared training errors obtained during 263 cross-validation 264 """ 265 return self._rmse_train_splits 266 267 @property 268 def rmse_validation(self) -> float: 269 """ average root mean squared cross-validation error """ 270 if self._rmse_valid_splits is None: 271 return None 272 return np.sqrt(np.mean(self._rmse_valid_splits**2)) 273 274 @property 275 def rmse_validation_splits(self) -> np.ndarray: 276 """ 277 root mean squared validation errors obtained during 278 cross-validation 279 """ 280 return self._rmse_valid_splits 281