1 /** 2 * @file core/cv/k_fold_cv_impl.hpp 3 * @author Kirill Mishchenko 4 * 5 * The implementation of k-fold cross-validation. 6 * 7 * mlpack is free software; you may redistribute it and/or modify it under the 8 * terms of the 3-clause BSD license. You should have received a copy of the 9 * 3-clause BSD license along with mlpack. If not, see 10 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 11 */ 12 #ifndef MLPACK_CORE_CV_K_FOLD_CV_IMPL_HPP 13 #define MLPACK_CORE_CV_K_FOLD_CV_IMPL_HPP 14 15 namespace mlpack { 16 namespace cv { 17 18 template<typename MLAlgorithm, 19 typename Metric, 20 typename MatType, 21 typename PredictionsType, 22 typename WeightsType> 23 KFoldCV<MLAlgorithm, 24 Metric, 25 MatType, 26 PredictionsType, KFoldCV(const size_t k,const MatType & xs,const PredictionsType & ys,const bool shuffle)27 WeightsType>::KFoldCV(const size_t k, 28 const MatType& xs, 29 const PredictionsType& ys, 30 const bool shuffle) : 31 KFoldCV(Base(), k, xs, ys, shuffle) 32 { /* Nothing left to do. */ } 33 34 template<typename MLAlgorithm, 35 typename Metric, 36 typename MatType, 37 typename PredictionsType, 38 typename WeightsType> 39 KFoldCV<MLAlgorithm, 40 Metric, 41 MatType, 42 PredictionsType, KFoldCV(const size_t k,const MatType & xs,const PredictionsType & ys,const size_t numClasses,const bool shuffle)43 WeightsType>::KFoldCV(const size_t k, 44 const MatType& xs, 45 const PredictionsType& ys, 46 const size_t numClasses, 47 const bool shuffle) : 48 KFoldCV(Base(numClasses), k, xs, ys, shuffle) 49 { /* Nothing left to do. */ } 50 51 template<typename MLAlgorithm, 52 typename Metric, 53 typename MatType, 54 typename PredictionsType, 55 typename WeightsType> 56 KFoldCV<MLAlgorithm, 57 Metric, 58 MatType, 59 PredictionsType, KFoldCV(const size_t k,const MatType & xs,const data::DatasetInfo & datasetInfo,const PredictionsType & ys,const size_t numClasses,const bool shuffle)60 WeightsType>::KFoldCV(const size_t k, 61 const MatType& xs, 62 const data::DatasetInfo& datasetInfo, 63 const PredictionsType& ys, 64 const size_t numClasses, 65 const bool shuffle) : 66 KFoldCV(Base(datasetInfo, numClasses), k, xs, ys, shuffle) 67 { /* Nothing left to do. */ } 68 69 template<typename MLAlgorithm, 70 typename Metric, 71 typename MatType, 72 typename PredictionsType, 73 typename WeightsType> 74 KFoldCV<MLAlgorithm, 75 Metric, 76 MatType, 77 PredictionsType, KFoldCV(const size_t k,const MatType & xs,const PredictionsType & ys,const WeightsType & weights,const bool shuffle)78 WeightsType>::KFoldCV(const size_t k, 79 const MatType& xs, 80 const PredictionsType& ys, 81 const WeightsType& weights, 82 const bool shuffle) : 83 KFoldCV(Base(), k, xs, ys, weights, shuffle) 84 { /* Nothing left to do. */ } 85 86 template<typename MLAlgorithm, 87 typename Metric, 88 typename MatType, 89 typename PredictionsType, 90 typename WeightsType> 91 KFoldCV<MLAlgorithm, 92 Metric, 93 MatType, 94 PredictionsType, KFoldCV(const size_t k,const MatType & xs,const PredictionsType & ys,const size_t numClasses,const WeightsType & weights,const bool shuffle)95 WeightsType>::KFoldCV(const size_t k, 96 const MatType& xs, 97 const PredictionsType& ys, 98 const size_t numClasses, 99 const WeightsType& weights, 100 const bool shuffle) : 101 KFoldCV(Base(numClasses), k, xs, ys, weights, shuffle) 102 { /* Nothing left to do. */ } 103 104 template<typename MLAlgorithm, 105 typename Metric, 106 typename MatType, 107 typename PredictionsType, 108 typename WeightsType> 109 KFoldCV<MLAlgorithm, 110 Metric, 111 MatType, 112 PredictionsType, KFoldCV(const size_t k,const MatType & xs,const data::DatasetInfo & datasetInfo,const PredictionsType & ys,const size_t numClasses,const WeightsType & weights,const bool shuffle)113 WeightsType>::KFoldCV(const size_t k, 114 const MatType& xs, 115 const data::DatasetInfo& datasetInfo, 116 const PredictionsType& ys, 117 const size_t numClasses, 118 const WeightsType& weights, 119 const bool shuffle) : 120 KFoldCV(Base(datasetInfo, numClasses), k, xs, ys, weights, shuffle) 121 { /* Nothing left to do. */ } 122 123 template<typename MLAlgorithm, 124 typename Metric, 125 typename MatType, 126 typename PredictionsType, 127 typename WeightsType> 128 KFoldCV<MLAlgorithm, 129 Metric, 130 MatType, 131 PredictionsType, KFoldCV(Base && base,const size_t k,const MatType & xs,const PredictionsType & ys,const bool shuffle)132 WeightsType>::KFoldCV(Base&& base, 133 const size_t k, 134 const MatType& xs, 135 const PredictionsType& ys, 136 const bool shuffle) : 137 base(std::move(base)), 138 k(k) 139 { 140 if (k < 2) 141 throw std::invalid_argument("KFoldCV: k should not be less than 2"); 142 143 Base::AssertDataConsistency(xs, ys); 144 145 InitKFoldCVMat(xs, this->xs); 146 InitKFoldCVMat(ys, this->ys); 147 148 // Do we need to shuffle the dataset? 149 if (shuffle) 150 Shuffle(); 151 } 152 153 template<typename MLAlgorithm, 154 typename Metric, 155 typename MatType, 156 typename PredictionsType, 157 typename WeightsType> 158 KFoldCV<MLAlgorithm, 159 Metric, 160 MatType, 161 PredictionsType, KFoldCV(Base && base,const size_t k,const MatType & xs,const PredictionsType & ys,const WeightsType & weights,const bool shuffle)162 WeightsType>::KFoldCV(Base&& base, 163 const size_t k, 164 const MatType& xs, 165 const PredictionsType& ys, 166 const WeightsType& weights, 167 const bool shuffle) : 168 base(std::move(base)), 169 k(k) 170 { 171 Base::AssertWeightsConsistency(xs, weights); 172 173 InitKFoldCVMat(xs, this->xs); 174 InitKFoldCVMat(ys, this->ys); 175 InitKFoldCVMat(weights, this->weights); 176 177 // Do we need to shuffle the dataset? 178 if (shuffle) 179 Shuffle(); 180 } 181 182 template<typename MLAlgorithm, 183 typename Metric, 184 typename MatType, 185 typename PredictionsType, 186 typename WeightsType> 187 template<typename... MLAlgorithmArgs> 188 double KFoldCV<MLAlgorithm, 189 Metric, 190 MatType, 191 PredictionsType, Evaluate(const MLAlgorithmArgs &...args)192 WeightsType>::Evaluate(const MLAlgorithmArgs&... args) 193 { 194 return TrainAndEvaluate(args...); 195 } 196 197 template<typename MLAlgorithm, 198 typename Metric, 199 typename MatType, 200 typename PredictionsType, 201 typename WeightsType> 202 MLAlgorithm& KFoldCV<MLAlgorithm, 203 Metric, 204 MatType, 205 PredictionsType, 206 WeightsType>::Model() 207 { 208 if (modelPtr == nullptr) 209 throw std::logic_error( 210 "KFoldCV::Model(): attempted to access an uninitialized model"); 211 212 return *modelPtr; 213 } 214 215 template<typename MLAlgorithm, 216 typename Metric, 217 typename MatType, 218 typename PredictionsType, 219 typename WeightsType> 220 template<typename DataType> 221 void KFoldCV<MLAlgorithm, 222 Metric, 223 MatType, 224 PredictionsType, InitKFoldCVMat(const DataType & source,DataType & destination)225 WeightsType>::InitKFoldCVMat(const DataType& source, 226 DataType& destination) 227 { 228 binSize = source.n_cols / k; 229 lastBinSize = source.n_cols - ((k - 1) * binSize); 230 231 destination = (k == 2) ? source : arma::join_rows(source, 232 source.cols(0, source.n_cols - lastBinSize - 1)); 233 } 234 235 template<typename MLAlgorithm, 236 typename Metric, 237 typename MatType, 238 typename PredictionsType, 239 typename WeightsType> 240 template<typename... MLAlgorithmArgs, bool Enabled, typename> 241 double KFoldCV<MLAlgorithm, 242 Metric, 243 MatType, 244 PredictionsType, TrainAndEvaluate(const MLAlgorithmArgs &...args)245 WeightsType>::TrainAndEvaluate(const MLAlgorithmArgs&... args) 246 { 247 arma::vec evaluations(k); 248 249 size_t numInvalidScores = 0; 250 for (size_t i = 0; i < k; ++i) 251 { 252 MLAlgorithm&& model = base.Train(GetTrainingSubset(xs, i), 253 GetTrainingSubset(ys, i), args...); 254 evaluations(i) = Metric::Evaluate(model, GetValidationSubset(xs, i), 255 GetValidationSubset(ys, i)); 256 if (std::isnan(evaluations(i)) || std::isinf(evaluations(i))) 257 { 258 ++numInvalidScores; 259 Log::Warn << "KFoldCV::TrainAndEvaluate(): fold " << i << " returned " 260 << "a score of " << evaluations(i) << "; ignoring when computing " 261 << "the average score." << std::endl; 262 } 263 if (i == k - 1) 264 modelPtr.reset(new MLAlgorithm(std::move(model))); 265 } 266 267 if (numInvalidScores == k) 268 { 269 Log::Warn << "KFoldCV::TrainAndEvaluate(): all folds returned invalid " 270 << "scores! Returning 0.0 as overall score." << std::endl; 271 return 0.0; 272 } 273 274 return arma::mean(evaluations.elem(arma::find_finite(evaluations))); 275 } 276 277 template<typename MLAlgorithm, 278 typename Metric, 279 typename MatType, 280 typename PredictionsType, 281 typename WeightsType> 282 template<typename... MLAlgorithmArgs, bool Enabled, typename, typename> 283 double KFoldCV<MLAlgorithm, 284 Metric, 285 MatType, 286 PredictionsType, TrainAndEvaluate(const MLAlgorithmArgs &...args)287 WeightsType>::TrainAndEvaluate(const MLAlgorithmArgs&... args) 288 { 289 arma::vec evaluations(k); 290 291 for (size_t i = 0; i < k; ++i) 292 { 293 MLAlgorithm&& model = (weights.n_elem > 0) ? 294 base.Train(GetTrainingSubset(xs, i), GetTrainingSubset(ys, i), 295 GetTrainingSubset(weights, i), args...) : 296 base.Train(GetTrainingSubset(xs, i), GetTrainingSubset(ys, i), 297 args...); 298 evaluations(i) = Metric::Evaluate(model, GetValidationSubset(xs, i), 299 GetValidationSubset(ys, i)); 300 if (i == k - 1) 301 modelPtr.reset(new MLAlgorithm(std::move(model))); 302 } 303 304 return arma::mean(evaluations); 305 } 306 307 template<typename MLAlgorithm, 308 typename Metric, 309 typename MatType, 310 typename PredictionsType, 311 typename WeightsType> 312 template<bool Enabled, typename> 313 void KFoldCV<MLAlgorithm, 314 Metric, 315 MatType, 316 PredictionsType, Shuffle()317 WeightsType>::Shuffle() 318 { 319 MatType xsOrig = xs.cols(0, (k - 1) * binSize + lastBinSize - 1); 320 PredictionsType ysOrig = ys.cols(0, (k - 1) * binSize + lastBinSize - 1); 321 322 // Now shuffle the data. 323 math::ShuffleData(xsOrig, ysOrig, xsOrig, ysOrig); 324 325 InitKFoldCVMat(xsOrig, xs); 326 InitKFoldCVMat(ysOrig, ys); 327 } 328 329 template<typename MLAlgorithm, 330 typename Metric, 331 typename MatType, 332 typename PredictionsType, 333 typename WeightsType> 334 template<bool Enabled, typename, typename> 335 void KFoldCV<MLAlgorithm, 336 Metric, 337 MatType, 338 PredictionsType, Shuffle()339 WeightsType>::Shuffle() 340 { 341 MatType xsOrig = xs.cols(0, (k - 1) * binSize + lastBinSize - 1); 342 PredictionsType ysOrig = ys.cols(0, (k - 1) * binSize + lastBinSize - 1); 343 WeightsType weightsOrig; 344 if (weights.n_elem > 0) 345 weightsOrig = weights.cols(0, (k - 1) * binSize + lastBinSize - 1); 346 347 // Now shuffle the data. 348 if (weights.n_elem > 0) 349 math::ShuffleData(xsOrig, ysOrig, weightsOrig, xsOrig, ysOrig, weightsOrig); 350 else 351 math::ShuffleData(xsOrig, ysOrig, xsOrig, ysOrig); 352 353 InitKFoldCVMat(xsOrig, xs); 354 InitKFoldCVMat(ysOrig, ys); 355 if (weights.n_elem > 0) 356 InitKFoldCVMat(weightsOrig, weights); 357 } 358 359 template<typename MLAlgorithm, 360 typename Metric, 361 typename MatType, 362 typename PredictionsType, 363 typename WeightsType> 364 size_t KFoldCV<MLAlgorithm, 365 Metric, 366 MatType, 367 PredictionsType, ValidationSubsetFirstCol(const size_t i)368 WeightsType>::ValidationSubsetFirstCol(const size_t i) 369 { 370 // Use as close to the beginning of the dataset as we can. 371 return (i == 0) ? binSize * (k - 1) : binSize * (i - 1); 372 } 373 374 template<typename MLAlgorithm, 375 typename Metric, 376 typename MatType, 377 typename PredictionsType, 378 typename WeightsType> 379 template<typename ElementType> 380 arma::Mat<ElementType> KFoldCV<MLAlgorithm, 381 Metric, 382 MatType, 383 PredictionsType, GetTrainingSubset(arma::Mat<ElementType> & m,const size_t i)384 WeightsType>::GetTrainingSubset( 385 arma::Mat<ElementType>& m, 386 const size_t i) 387 { 388 // If this is not the first fold, we have to handle it a little bit 389 // differently, since the last fold may contain slightly more than 'binSize' 390 // points. 391 const size_t subsetSize = (i != 0) ? lastBinSize + (k - 2) * binSize : 392 (k - 1) * binSize; 393 394 return arma::Mat<ElementType>(m.colptr(binSize * i), m.n_rows, subsetSize, 395 false, true); 396 } 397 398 template<typename MLAlgorithm, 399 typename Metric, 400 typename MatType, 401 typename PredictionsType, 402 typename WeightsType> 403 template<typename ElementType> 404 arma::Row<ElementType> KFoldCV<MLAlgorithm, 405 Metric, 406 MatType, 407 PredictionsType, GetTrainingSubset(arma::Row<ElementType> & r,const size_t i)408 WeightsType>::GetTrainingSubset( 409 arma::Row<ElementType>& r, 410 const size_t i) 411 { 412 // If this is not the first fold, we have to handle it a little bit 413 // differently, since the last fold may contain slightly more than 'binSize' 414 // points. 415 const size_t subsetSize = (i != 0) ? lastBinSize + (k - 2) * binSize : 416 (k - 1) * binSize; 417 418 return arma::Row<ElementType>(r.colptr(binSize * i), subsetSize, false, true); 419 } 420 421 template<typename MLAlgorithm, 422 typename Metric, 423 typename MatType, 424 typename PredictionsType, 425 typename WeightsType> 426 template<typename ElementType> 427 arma::Mat<ElementType> KFoldCV<MLAlgorithm, 428 Metric, 429 MatType, 430 PredictionsType, GetValidationSubset(arma::Mat<ElementType> & m,const size_t i)431 WeightsType>::GetValidationSubset( 432 arma::Mat<ElementType>& m, 433 const size_t i) 434 { 435 const size_t subsetSize = (i == 0) ? lastBinSize : binSize; 436 return arma::Mat<ElementType>(m.colptr(ValidationSubsetFirstCol(i)), m.n_rows, 437 subsetSize, false, true); 438 } 439 440 template<typename MLAlgorithm, 441 typename Metric, 442 typename MatType, 443 typename PredictionsType, 444 typename WeightsType> 445 template<typename ElementType> 446 arma::Row<ElementType> KFoldCV<MLAlgorithm, 447 Metric, 448 MatType, 449 PredictionsType, GetValidationSubset(arma::Row<ElementType> & r,const size_t i)450 WeightsType>::GetValidationSubset( 451 arma::Row<ElementType>& r, 452 const size_t i) 453 { 454 const size_t subsetSize = (i == 0) ? lastBinSize : binSize; 455 return arma::Row<ElementType>(r.colptr(ValidationSubsetFirstCol(i)), 456 subsetSize, false, true); 457 } 458 459 } // namespace cv 460 } // namespace mlpack 461 462 #endif 463