1 /**
2  * @file core/cv/k_fold_cv_impl.hpp
3  * @author Kirill Mishchenko
4  *
5  * The implementation of k-fold cross-validation.
6  *
7  * mlpack is free software; you may redistribute it and/or modify it under the
8  * terms of the 3-clause BSD license.  You should have received a copy of the
9  * 3-clause BSD license along with mlpack.  If not, see
10  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11  */
12 #ifndef MLPACK_CORE_CV_K_FOLD_CV_IMPL_HPP
13 #define MLPACK_CORE_CV_K_FOLD_CV_IMPL_HPP
14 
15 namespace mlpack {
16 namespace cv {
17 
18 template<typename MLAlgorithm,
19          typename Metric,
20          typename MatType,
21          typename PredictionsType,
22          typename WeightsType>
23 KFoldCV<MLAlgorithm,
24         Metric,
25         MatType,
26         PredictionsType,
KFoldCV(const size_t k,const MatType & xs,const PredictionsType & ys,const bool shuffle)27         WeightsType>::KFoldCV(const size_t k,
28                               const MatType& xs,
29                               const PredictionsType& ys,
30                               const bool shuffle) :
31     KFoldCV(Base(), k, xs, ys, shuffle)
32 { /* Nothing left to do. */ }
33 
34 template<typename MLAlgorithm,
35          typename Metric,
36          typename MatType,
37          typename PredictionsType,
38          typename WeightsType>
39 KFoldCV<MLAlgorithm,
40         Metric,
41         MatType,
42         PredictionsType,
KFoldCV(const size_t k,const MatType & xs,const PredictionsType & ys,const size_t numClasses,const bool shuffle)43         WeightsType>::KFoldCV(const size_t k,
44                               const MatType& xs,
45                               const PredictionsType& ys,
46                               const size_t numClasses,
47                               const bool shuffle) :
48     KFoldCV(Base(numClasses), k, xs, ys, shuffle)
49 { /* Nothing left to do. */ }
50 
51 template<typename MLAlgorithm,
52          typename Metric,
53          typename MatType,
54          typename PredictionsType,
55          typename WeightsType>
56 KFoldCV<MLAlgorithm,
57         Metric,
58         MatType,
59         PredictionsType,
KFoldCV(const size_t k,const MatType & xs,const data::DatasetInfo & datasetInfo,const PredictionsType & ys,const size_t numClasses,const bool shuffle)60         WeightsType>::KFoldCV(const size_t k,
61                               const MatType& xs,
62                               const data::DatasetInfo& datasetInfo,
63                               const PredictionsType& ys,
64                               const size_t numClasses,
65                               const bool shuffle) :
66     KFoldCV(Base(datasetInfo, numClasses), k, xs, ys, shuffle)
67 { /* Nothing left to do. */ }
68 
69 template<typename MLAlgorithm,
70          typename Metric,
71          typename MatType,
72          typename PredictionsType,
73          typename WeightsType>
74 KFoldCV<MLAlgorithm,
75         Metric,
76         MatType,
77         PredictionsType,
KFoldCV(const size_t k,const MatType & xs,const PredictionsType & ys,const WeightsType & weights,const bool shuffle)78         WeightsType>::KFoldCV(const size_t k,
79                               const MatType& xs,
80                               const PredictionsType& ys,
81                               const WeightsType& weights,
82                               const bool shuffle) :
83     KFoldCV(Base(), k, xs, ys, weights, shuffle)
84 { /* Nothing left to do. */ }
85 
86 template<typename MLAlgorithm,
87          typename Metric,
88          typename MatType,
89          typename PredictionsType,
90          typename WeightsType>
91 KFoldCV<MLAlgorithm,
92         Metric,
93         MatType,
94         PredictionsType,
KFoldCV(const size_t k,const MatType & xs,const PredictionsType & ys,const size_t numClasses,const WeightsType & weights,const bool shuffle)95         WeightsType>::KFoldCV(const size_t k,
96                               const MatType& xs,
97                               const PredictionsType& ys,
98                               const size_t numClasses,
99                               const WeightsType& weights,
100                               const bool shuffle) :
101     KFoldCV(Base(numClasses), k, xs, ys, weights, shuffle)
102 { /* Nothing left to do. */ }
103 
104 template<typename MLAlgorithm,
105          typename Metric,
106          typename MatType,
107          typename PredictionsType,
108          typename WeightsType>
109 KFoldCV<MLAlgorithm,
110         Metric,
111         MatType,
112         PredictionsType,
KFoldCV(const size_t k,const MatType & xs,const data::DatasetInfo & datasetInfo,const PredictionsType & ys,const size_t numClasses,const WeightsType & weights,const bool shuffle)113         WeightsType>::KFoldCV(const size_t k,
114                               const MatType& xs,
115                               const data::DatasetInfo& datasetInfo,
116                               const PredictionsType& ys,
117                               const size_t numClasses,
118                               const WeightsType& weights,
119                               const bool shuffle) :
120     KFoldCV(Base(datasetInfo, numClasses), k, xs, ys, weights, shuffle)
121 { /* Nothing left to do. */ }
122 
123 template<typename MLAlgorithm,
124          typename Metric,
125          typename MatType,
126          typename PredictionsType,
127          typename WeightsType>
128 KFoldCV<MLAlgorithm,
129         Metric,
130         MatType,
131         PredictionsType,
KFoldCV(Base && base,const size_t k,const MatType & xs,const PredictionsType & ys,const bool shuffle)132         WeightsType>::KFoldCV(Base&& base,
133                               const size_t k,
134                               const MatType& xs,
135                               const PredictionsType& ys,
136                               const bool shuffle) :
137     base(std::move(base)),
138     k(k)
139 {
140   if (k < 2)
141     throw std::invalid_argument("KFoldCV: k should not be less than 2");
142 
143   Base::AssertDataConsistency(xs, ys);
144 
145   InitKFoldCVMat(xs, this->xs);
146   InitKFoldCVMat(ys, this->ys);
147 
148   // Do we need to shuffle the dataset?
149   if (shuffle)
150     Shuffle();
151 }
152 
153 template<typename MLAlgorithm,
154          typename Metric,
155          typename MatType,
156          typename PredictionsType,
157          typename WeightsType>
158 KFoldCV<MLAlgorithm,
159         Metric,
160         MatType,
161         PredictionsType,
KFoldCV(Base && base,const size_t k,const MatType & xs,const PredictionsType & ys,const WeightsType & weights,const bool shuffle)162         WeightsType>::KFoldCV(Base&& base,
163                               const size_t k,
164                               const MatType& xs,
165                               const PredictionsType& ys,
166                               const WeightsType& weights,
167                               const bool shuffle) :
168     base(std::move(base)),
169     k(k)
170 {
171   Base::AssertWeightsConsistency(xs, weights);
172 
173   InitKFoldCVMat(xs, this->xs);
174   InitKFoldCVMat(ys, this->ys);
175   InitKFoldCVMat(weights, this->weights);
176 
177   // Do we need to shuffle the dataset?
178   if (shuffle)
179     Shuffle();
180 }
181 
182 template<typename MLAlgorithm,
183          typename Metric,
184          typename MatType,
185          typename PredictionsType,
186          typename WeightsType>
187 template<typename... MLAlgorithmArgs>
188 double KFoldCV<MLAlgorithm,
189                Metric,
190                MatType,
191                PredictionsType,
Evaluate(const MLAlgorithmArgs &...args)192                WeightsType>::Evaluate(const MLAlgorithmArgs&... args)
193 {
194   return TrainAndEvaluate(args...);
195 }
196 
197 template<typename MLAlgorithm,
198          typename Metric,
199          typename MatType,
200          typename PredictionsType,
201          typename WeightsType>
202 MLAlgorithm& KFoldCV<MLAlgorithm,
203                      Metric,
204                      MatType,
205                      PredictionsType,
206                      WeightsType>::Model()
207 {
208   if (modelPtr == nullptr)
209     throw std::logic_error(
210         "KFoldCV::Model(): attempted to access an uninitialized model");
211 
212   return *modelPtr;
213 }
214 
215 template<typename MLAlgorithm,
216          typename Metric,
217          typename MatType,
218          typename PredictionsType,
219          typename WeightsType>
220 template<typename DataType>
221 void KFoldCV<MLAlgorithm,
222              Metric,
223              MatType,
224              PredictionsType,
InitKFoldCVMat(const DataType & source,DataType & destination)225              WeightsType>::InitKFoldCVMat(const DataType& source,
226                                           DataType& destination)
227 {
228   binSize = source.n_cols / k;
229   lastBinSize = source.n_cols - ((k - 1) * binSize);
230 
231   destination = (k == 2) ? source : arma::join_rows(source,
232       source.cols(0, source.n_cols - lastBinSize - 1));
233 }
234 
235 template<typename MLAlgorithm,
236          typename Metric,
237          typename MatType,
238          typename PredictionsType,
239          typename WeightsType>
240 template<typename... MLAlgorithmArgs, bool Enabled, typename>
241 double KFoldCV<MLAlgorithm,
242                 Metric,
243                 MatType,
244                 PredictionsType,
TrainAndEvaluate(const MLAlgorithmArgs &...args)245                 WeightsType>::TrainAndEvaluate(const MLAlgorithmArgs&... args)
246 {
247   arma::vec evaluations(k);
248 
249   size_t numInvalidScores = 0;
250   for (size_t i = 0; i < k; ++i)
251   {
252     MLAlgorithm&& model  = base.Train(GetTrainingSubset(xs, i),
253         GetTrainingSubset(ys, i), args...);
254     evaluations(i) = Metric::Evaluate(model, GetValidationSubset(xs, i),
255         GetValidationSubset(ys, i));
256     if (std::isnan(evaluations(i)) || std::isinf(evaluations(i)))
257     {
258       ++numInvalidScores;
259       Log::Warn << "KFoldCV::TrainAndEvaluate(): fold " << i << " returned "
260           << "a score of " << evaluations(i) << "; ignoring when computing "
261           << "the average score." << std::endl;
262     }
263     if (i == k - 1)
264       modelPtr.reset(new MLAlgorithm(std::move(model)));
265   }
266 
267   if (numInvalidScores == k)
268   {
269     Log::Warn << "KFoldCV::TrainAndEvaluate(): all folds returned invalid "
270         << "scores!  Returning 0.0 as overall score." << std::endl;
271     return 0.0;
272   }
273 
274   return arma::mean(evaluations.elem(arma::find_finite(evaluations)));
275 }
276 
277 template<typename MLAlgorithm,
278          typename Metric,
279          typename MatType,
280          typename PredictionsType,
281          typename WeightsType>
282 template<typename... MLAlgorithmArgs, bool Enabled, typename, typename>
283 double KFoldCV<MLAlgorithm,
284                 Metric,
285                 MatType,
286                 PredictionsType,
TrainAndEvaluate(const MLAlgorithmArgs &...args)287                 WeightsType>::TrainAndEvaluate(const MLAlgorithmArgs&... args)
288 {
289   arma::vec evaluations(k);
290 
291   for (size_t i = 0; i < k; ++i)
292   {
293     MLAlgorithm&& model = (weights.n_elem > 0) ?
294         base.Train(GetTrainingSubset(xs, i), GetTrainingSubset(ys, i),
295             GetTrainingSubset(weights, i), args...) :
296         base.Train(GetTrainingSubset(xs, i), GetTrainingSubset(ys, i),
297             args...);
298     evaluations(i) = Metric::Evaluate(model, GetValidationSubset(xs, i),
299         GetValidationSubset(ys, i));
300     if (i == k - 1)
301       modelPtr.reset(new MLAlgorithm(std::move(model)));
302   }
303 
304   return arma::mean(evaluations);
305 }
306 
307 template<typename MLAlgorithm,
308          typename Metric,
309          typename MatType,
310          typename PredictionsType,
311          typename WeightsType>
312 template<bool Enabled, typename>
313 void KFoldCV<MLAlgorithm,
314              Metric,
315              MatType,
316              PredictionsType,
Shuffle()317              WeightsType>::Shuffle()
318 {
319   MatType xsOrig = xs.cols(0, (k - 1) * binSize + lastBinSize - 1);
320   PredictionsType ysOrig = ys.cols(0, (k - 1) * binSize + lastBinSize - 1);
321 
322   // Now shuffle the data.
323   math::ShuffleData(xsOrig, ysOrig, xsOrig, ysOrig);
324 
325   InitKFoldCVMat(xsOrig, xs);
326   InitKFoldCVMat(ysOrig, ys);
327 }
328 
329 template<typename MLAlgorithm,
330          typename Metric,
331          typename MatType,
332          typename PredictionsType,
333          typename WeightsType>
334 template<bool Enabled, typename, typename>
335 void KFoldCV<MLAlgorithm,
336              Metric,
337              MatType,
338              PredictionsType,
Shuffle()339              WeightsType>::Shuffle()
340 {
341   MatType xsOrig = xs.cols(0, (k - 1) * binSize + lastBinSize - 1);
342   PredictionsType ysOrig = ys.cols(0, (k - 1) * binSize + lastBinSize - 1);
343   WeightsType weightsOrig;
344   if (weights.n_elem > 0)
345     weightsOrig = weights.cols(0, (k - 1) * binSize + lastBinSize - 1);
346 
347   // Now shuffle the data.
348   if (weights.n_elem > 0)
349     math::ShuffleData(xsOrig, ysOrig, weightsOrig, xsOrig, ysOrig, weightsOrig);
350   else
351     math::ShuffleData(xsOrig, ysOrig, xsOrig, ysOrig);
352 
353   InitKFoldCVMat(xsOrig, xs);
354   InitKFoldCVMat(ysOrig, ys);
355   if (weights.n_elem > 0)
356     InitKFoldCVMat(weightsOrig, weights);
357 }
358 
359 template<typename MLAlgorithm,
360          typename Metric,
361          typename MatType,
362          typename PredictionsType,
363          typename WeightsType>
364 size_t KFoldCV<MLAlgorithm,
365                Metric,
366                MatType,
367                PredictionsType,
ValidationSubsetFirstCol(const size_t i)368                WeightsType>::ValidationSubsetFirstCol(const size_t i)
369 {
370   // Use as close to the beginning of the dataset as we can.
371   return (i == 0) ? binSize * (k - 1) : binSize * (i - 1);
372 }
373 
374 template<typename MLAlgorithm,
375          typename Metric,
376          typename MatType,
377          typename PredictionsType,
378          typename WeightsType>
379 template<typename ElementType>
380 arma::Mat<ElementType> KFoldCV<MLAlgorithm,
381                                Metric,
382                                MatType,
383                                PredictionsType,
GetTrainingSubset(arma::Mat<ElementType> & m,const size_t i)384                                WeightsType>::GetTrainingSubset(
385     arma::Mat<ElementType>& m,
386     const size_t i)
387 {
388   // If this is not the first fold, we have to handle it a little bit
389   // differently, since the last fold may contain slightly more than 'binSize'
390   // points.
391   const size_t subsetSize = (i != 0) ? lastBinSize + (k - 2) * binSize :
392       (k - 1) * binSize;
393 
394   return arma::Mat<ElementType>(m.colptr(binSize * i), m.n_rows, subsetSize,
395       false, true);
396 }
397 
398 template<typename MLAlgorithm,
399          typename Metric,
400          typename MatType,
401          typename PredictionsType,
402          typename WeightsType>
403 template<typename ElementType>
404 arma::Row<ElementType> KFoldCV<MLAlgorithm,
405                                Metric,
406                                MatType,
407                                PredictionsType,
GetTrainingSubset(arma::Row<ElementType> & r,const size_t i)408                                WeightsType>::GetTrainingSubset(
409     arma::Row<ElementType>& r,
410     const size_t i)
411 {
412   // If this is not the first fold, we have to handle it a little bit
413   // differently, since the last fold may contain slightly more than 'binSize'
414   // points.
415   const size_t subsetSize = (i != 0) ? lastBinSize + (k - 2) * binSize :
416       (k - 1) * binSize;
417 
418   return arma::Row<ElementType>(r.colptr(binSize * i), subsetSize, false, true);
419 }
420 
421 template<typename MLAlgorithm,
422          typename Metric,
423          typename MatType,
424          typename PredictionsType,
425          typename WeightsType>
426 template<typename ElementType>
427 arma::Mat<ElementType> KFoldCV<MLAlgorithm,
428                                Metric,
429                                MatType,
430                                PredictionsType,
GetValidationSubset(arma::Mat<ElementType> & m,const size_t i)431                                WeightsType>::GetValidationSubset(
432     arma::Mat<ElementType>& m,
433     const size_t i)
434 {
435   const size_t subsetSize = (i == 0) ? lastBinSize : binSize;
436   return arma::Mat<ElementType>(m.colptr(ValidationSubsetFirstCol(i)), m.n_rows,
437       subsetSize, false, true);
438 }
439 
440 template<typename MLAlgorithm,
441          typename Metric,
442          typename MatType,
443          typename PredictionsType,
444          typename WeightsType>
445 template<typename ElementType>
446 arma::Row<ElementType> KFoldCV<MLAlgorithm,
447                                Metric,
448                                MatType,
449                                PredictionsType,
GetValidationSubset(arma::Row<ElementType> & r,const size_t i)450                                WeightsType>::GetValidationSubset(
451     arma::Row<ElementType>& r,
452     const size_t i)
453 {
454   const size_t subsetSize = (i == 0) ? lastBinSize : binSize;
455   return arma::Row<ElementType>(r.colptr(ValidationSubsetFirstCol(i)),
456       subsetSize, false, true);
457 }
458 
459 } // namespace cv
460 } // namespace mlpack
461 
462 #endif
463