1"""
2Optimizer with cross validation score
3"""
4
5import numpy as np
6from sklearn.model_selection import KFold, ShuffleSplit
7from sklearn.metrics import r2_score
8from typing import Any, Dict, Tuple
9from .base_optimizer import BaseOptimizer
10from .optimizer import Optimizer
11from .fit_methods import fit
12from .tools import ScatterData
13from .model_selection import get_model_metrics
14
15
16validation_methods = {
17    'k-fold': KFold,
18    'shuffle-split': ShuffleSplit,
19}
20
21
22class CrossValidationEstimator(BaseOptimizer):
23    """
24    This class provides an optimizer with cross validation for solving the
25    linear :math:`\\boldsymbol{A}\\boldsymbol{x} = \\boldsymbol{y}` problem.
26    Cross-validation (CV) scores are calculated by splitting the
27    available reference data in multiple different ways.  It also produces
28    the finalized model (using the full input data) for which the CV score
29    is an estimation of its performance.
30
31    Warning
32    -------
33    Repeatedly setting up a CrossValidationEstimator and training
34    *without* changing the seed for the random number generator will yield
35    identical or correlated results, to avoid this please specify a different
36    seed when setting up multiple CrossValidationEstimator instances.
37
38    Parameters
39    ----------
40    fit_data : tuple(numpy.ndarray, numpy.ndarray)
41        the first element of the tuple represents the fit matrix `A`
42        (`N, M` array) while the second element represents the vector
43        of target values `y` (`N` array); here `N` (=rows of `A`,
44        elements of `y`) equals the number of target values and `M`
45        (=columns of `A`) equals the number of parameters
46    fit_method : str
47        method to be used for training; possible choice are
48        "ardr", "bayesian-ridge", "elasticnet", "lasso", "least-squares",
49        "omp", "rfe", "ridge", "split-bregman"
50    standardize : bool
51        if True the fit matrix and target values are standardized before fitting,
52        meaning columns in the fit matrix and th target values are rescaled to
53        have a standard deviation of 1.0.
54    validation_method : str
55        method to use for cross-validation; possible choices are
56        "shuffle-split", "k-fold"
57    n_splits : int
58        number of times the fit data set will be split for the cross-validation
59    check_condition : bool
60        if True the condition number will be checked
61        (this can be sligthly more time consuming for larger
62        matrices)
63    seed : int
64        seed for pseudo random number generator
65
66    Attributes
67    ----------
68    train_scatter_data : ScatterData
69        contains target and predicted values from each individual
70        traininig set in the cross-validation split;
71        :class:`ScatterData` is a namedtuple.
72    validation_scatter_data : ScatterData
73        contains target and predicted values from each individual
74        validation set in the cross-validation split;
75        :class:`ScatterData` is a namedtuple.
76
77    """
78
79    def __init__(self,
80                 fit_data: Tuple[np.ndarray, np.ndarray],
81                 fit_method: str = 'least-squares',
82                 standardize: bool = True,
83                 validation_method: str = 'k-fold',
84                 n_splits: int = 10,
85                 check_condition: bool = True,
86                 seed: int = 42,
87                 **kwargs) -> None:
88
89        super().__init__(fit_data, fit_method, standardize, check_condition, seed)
90
91        if validation_method not in validation_methods.keys():
92            msg = ['Validation method not available']
93            msg += ['Please choose one of the following:']
94            for key in validation_methods:
95                msg += [' * ' + key]
96            raise ValueError('\n'.join(msg))
97        self._validation_method = validation_method
98        self._n_splits = n_splits
99        self._set_kwargs(kwargs)
100
101        # data set splitting object
102        self._splitter = validation_methods[validation_method](
103            n_splits=self.n_splits, random_state=seed,
104            **self._split_kwargs)
105
106        self.train_scatter_data = None
107        self.validation_scatter_data = None
108
109        self._parameters_splits = None
110        self._rmse_train_splits = None
111        self._rmse_valid_splits = None
112        self.model_metrics = {}
113
114    def train(self) -> None:
115        """ Constructs the final model using all input data available. """
116        self._fit_results = fit(self._A, self._y, self.fit_method,
117                                self.standardize, self._check_condition,
118                                **self._fit_kwargs)
119        y_train_predicted = np.dot(self._A, self.parameters)
120        metrics = get_model_metrics(self._A, self.parameters, self._y, y_train_predicted)
121
122        # finalize metrics
123        self.model_metrics['rmse_train_final'] = metrics['rmse_train']
124        self.model_metrics['R2_train'] = metrics['R2_train']
125        self.model_metrics['AIC'] = metrics['AIC']
126        self.model_metrics['BIC'] = metrics['BIC']
127
128    def validate(self) -> None:
129        """ Runs validation. """
130        train_target, train_predicted = [], []
131        valid_target, valid_predicted = [], []
132        rmse_train_splits, rmse_valid_splits = [], []
133        parameters_splits = []
134        for train_set, test_set in self._splitter.split(self._A):
135            opt = Optimizer((self._A, self._y), self.fit_method,
136                            standardize=self.standardize,
137                            train_set=train_set,
138                            test_set=test_set,
139                            check_condition=self._check_condition,
140                            **self._fit_kwargs)
141            opt.train()
142
143            parameters_splits.append(opt.parameters)
144            rmse_train_splits.append(opt.rmse_train)
145            rmse_valid_splits.append(opt.rmse_test)
146            train_target.extend(opt.train_scatter_data.target)
147            train_predicted.extend(opt.train_scatter_data.predicted)
148            valid_target.extend(opt.test_scatter_data.target)
149            valid_predicted.extend(opt.test_scatter_data.predicted)
150
151        self._parameters_splits = np.array(parameters_splits)
152        self._rmse_train_splits = np.array(rmse_train_splits)
153        self._rmse_valid_splits = np.array(rmse_valid_splits)
154        self.train_scatter_data = ScatterData(
155            target=np.array(train_target), predicted=np.array(train_predicted))
156        self.validation_scatter_data = ScatterData(
157            target=np.array(valid_target), predicted=np.array(valid_predicted))
158
159        self.model_metrics['rmse_validation'] = np.sqrt(np.mean(self._rmse_valid_splits**2))
160        self.model_metrics['R2_validation'] = r2_score(valid_target, valid_predicted)
161
162    def _set_kwargs(self, kwargs: dict) -> None:
163        """
164        Sets up fit_kwargs and split_kwargs.
165        Different split methods need different keywords.
166        """
167        self._fit_kwargs = {}
168        self._split_kwargs = {}
169
170        if self.validation_method == 'k-fold':
171            self._split_kwargs['shuffle'] = True  # default True
172            for key, val in kwargs.items():
173                if key in ['shuffle']:
174                    self._split_kwargs[key] = val
175                else:
176                    self._fit_kwargs[key] = val
177        elif self.validation_method == 'shuffle-split':
178            for key, val in kwargs.items():
179                if key in ['test_size', 'train_size']:
180                    self._split_kwargs[key] = val
181                else:
182                    self._fit_kwargs[key] = val
183
184    @property
185    def summary(self) -> Dict[str, Any]:
186        """ comprehensive information about the optimizer """
187
188        info = super().summary
189
190        # Add class specific data
191        info['validation_method'] = self.validation_method
192        info['n_splits'] = self.n_splits
193        info['rmse_train'] = self.rmse_train
194        info['rmse_train_splits'] = self.rmse_train_splits
195        info['rmse_validation'] = self.rmse_validation
196        info['rmse_validation_splits'] = self.rmse_validation_splits
197        info['train_scatter_data'] = self.train_scatter_data
198        info['validation_scatter_data'] = self.validation_scatter_data
199
200        # add metrics
201        info = {**info, **self.model_metrics}
202
203        # add kwargs used for fitting and splitting
204        info = {**info, **self._fit_kwargs, **self._split_kwargs}
205        return info
206
207    def __repr__(self) -> str:
208        kwargs = dict()
209        kwargs['fit_method'] = self.fit_method
210        kwargs['validation_method'] = self.validation_method
211        kwargs['n_splits'] = self.n_splits
212        kwargs['seed'] = self.seed
213        kwargs = {**kwargs, **self._fit_kwargs, **self._split_kwargs}
214        return 'CrossValidationEstimator((A, y), {})'.format(
215            ', '.join('{}={}'.format(*kwarg) for kwarg in kwargs.items()))
216
217    @property
218    def validation_method(self) -> str:
219        """ validation method name """
220        return self._validation_method
221
222    @property
223    def n_splits(self) -> int:
224        """ number of splits (folds) used for cross-validation """
225        return self._n_splits
226
227    @property
228    def parameters_splits(self) -> np.ndarray:
229        """ all parameters obtained during cross-validation """
230        return self._parameters_splits
231
232    @property
233    def n_nonzero_parameters_splits(self) -> np.ndarray:
234        """ number of non-zero parameters for each split """
235        if self.parameters_splits is None:
236            return None
237        else:
238            return np.array([np.count_nonzero(p) for p in self.parameters_splits])
239
240    @property
241    def rmse_train_final(self) -> float:
242        """
243        root mean squared error when using the full set of input data
244        """
245        if 'rmse_train_final' not in self.model_metrics:
246            return None
247        return self.model_metrics['rmse_train_final']
248
249    @property
250    def rmse_train(self) -> float:
251        """
252        average root mean squared training error obtained during
253        cross-validation
254        """
255        if self._rmse_train_splits is None:
256            return None
257        return np.sqrt(np.mean(self._rmse_train_splits**2))
258
259    @property
260    def rmse_train_splits(self) -> np.ndarray:
261        """
262        root mean squared training errors obtained during
263        cross-validation
264        """
265        return self._rmse_train_splits
266
267    @property
268    def rmse_validation(self) -> float:
269        """ average root mean squared cross-validation error """
270        if self._rmse_valid_splits is None:
271            return None
272        return np.sqrt(np.mean(self._rmse_valid_splits**2))
273
274    @property
275    def rmse_validation_splits(self) -> np.ndarray:
276        """
277        root mean squared validation errors obtained during
278        cross-validation
279        """
280        return self._rmse_valid_splits
281