1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
20 //
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
24 //
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "precomp.hpp"
42 #include <ctype.h>
43 #include <algorithm>
44 #include <iterator>
45
46 #include <opencv2/core/utils/logger.hpp>
47
48 namespace cv { namespace ml {
49
50 static const float MISSED_VAL = TrainData::missingValue();
51 static const int VAR_MISSED = VAR_ORDERED;
52
~TrainData()53 TrainData::~TrainData() {}
54
getSubVector(const Mat & vec,const Mat & idx)55 Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
56 {
57 if (!(vec.cols == 1 || vec.rows == 1))
58 CV_LOG_WARNING(NULL, "'getSubVector(const Mat& vec, const Mat& idx)' call with non-1D input is deprecated. It is not designed to work with 2D matrixes (especially with 'cv::ml::COL_SAMPLE' layout).");
59 return getSubMatrix(vec, idx, vec.rows == 1 ? cv::ml::COL_SAMPLE : cv::ml::ROW_SAMPLE);
60 }
61
62 template<typename T>
getSubMatrixImpl(const Mat & m,const Mat & idx,int layout)63 Mat getSubMatrixImpl(const Mat& m, const Mat& idx, int layout)
64 {
65 int nidx = idx.checkVector(1, CV_32S);
66 int dims = m.cols, nsamples = m.rows;
67
68 Mat subm;
69 if (layout == COL_SAMPLE)
70 {
71 std::swap(dims, nsamples);
72 subm.create(dims, nidx, m.type());
73 }
74 else
75 {
76 subm.create(nidx, dims, m.type());
77 }
78
79 for (int i = 0; i < nidx; i++)
80 {
81 int k = idx.at<int>(i); CV_CheckGE(k, 0, "Bad idx"); CV_CheckLT(k, nsamples, "Bad idx or layout");
82 if (dims == 1)
83 {
84 subm.at<T>(i) = m.at<T>(k); // at() has "transparent" access for 1D col-based / row-based vectors.
85 }
86 else if (layout == COL_SAMPLE)
87 {
88 for (int j = 0; j < dims; j++)
89 subm.at<T>(j, i) = m.at<T>(j, k);
90 }
91 else
92 {
93 for (int j = 0; j < dims; j++)
94 subm.at<T>(i, j) = m.at<T>(k, j);
95 }
96 }
97 return subm;
98 }
99
getSubMatrix(const Mat & m,const Mat & idx,int layout)100 Mat TrainData::getSubMatrix(const Mat& m, const Mat& idx, int layout)
101 {
102 if (idx.empty())
103 return m;
104 int type = m.type();
105 CV_CheckType(type, type == CV_32S || type == CV_32F || type == CV_64F, "");
106 if (type == CV_32S || type == CV_32F) // 32-bit
107 return getSubMatrixImpl<int>(m, idx, layout);
108 if (type == CV_64F) // 64-bit
109 return getSubMatrixImpl<double>(m, idx, layout);
110 CV_Error(Error::StsInternal, "");
111 }
112
113
114 class TrainDataImpl CV_FINAL : public TrainData
115 {
116 public:
117 typedef std::map<String, int> MapType;
118
TrainDataImpl()119 TrainDataImpl()
120 {
121 file = 0;
122 clear();
123 }
124
~TrainDataImpl()125 virtual ~TrainDataImpl() { closeFile(); }
126
getLayout() const127 int getLayout() const CV_OVERRIDE { return layout; }
getNSamples() const128 int getNSamples() const CV_OVERRIDE
129 {
130 return !sampleIdx.empty() ? (int)sampleIdx.total() :
131 layout == ROW_SAMPLE ? samples.rows : samples.cols;
132 }
getNTrainSamples() const133 int getNTrainSamples() const CV_OVERRIDE
134 {
135 return !trainSampleIdx.empty() ? (int)trainSampleIdx.total() : getNSamples();
136 }
getNTestSamples() const137 int getNTestSamples() const CV_OVERRIDE
138 {
139 return !testSampleIdx.empty() ? (int)testSampleIdx.total() : 0;
140 }
getNVars() const141 int getNVars() const CV_OVERRIDE
142 {
143 return !varIdx.empty() ? (int)varIdx.total() : getNAllVars();
144 }
getNAllVars() const145 int getNAllVars() const CV_OVERRIDE
146 {
147 return layout == ROW_SAMPLE ? samples.cols : samples.rows;
148 }
149
getTestSamples() const150 Mat getTestSamples() const CV_OVERRIDE
151 {
152 Mat idx = getTestSampleIdx();
153 return idx.empty() ? Mat() : getSubMatrix(samples, idx, getLayout());
154 }
155
getSamples() const156 Mat getSamples() const CV_OVERRIDE { return samples; }
getResponses() const157 Mat getResponses() const CV_OVERRIDE { return responses; }
getMissing() const158 Mat getMissing() const CV_OVERRIDE { return missing; }
getVarIdx() const159 Mat getVarIdx() const CV_OVERRIDE { return varIdx; }
getVarType() const160 Mat getVarType() const CV_OVERRIDE { return varType; }
getResponseType() const161 int getResponseType() const CV_OVERRIDE
162 {
163 return classLabels.empty() ? VAR_ORDERED : VAR_CATEGORICAL;
164 }
getTrainSampleIdx() const165 Mat getTrainSampleIdx() const CV_OVERRIDE { return !trainSampleIdx.empty() ? trainSampleIdx : sampleIdx; }
getTestSampleIdx() const166 Mat getTestSampleIdx() const CV_OVERRIDE { return testSampleIdx; }
getSampleWeights() const167 Mat getSampleWeights() const CV_OVERRIDE
168 {
169 return sampleWeights;
170 }
getTrainSampleWeights() const171 Mat getTrainSampleWeights() const CV_OVERRIDE
172 {
173 return getSubVector(sampleWeights, getTrainSampleIdx()); // 1D-vector
174 }
getTestSampleWeights() const175 Mat getTestSampleWeights() const CV_OVERRIDE
176 {
177 Mat idx = getTestSampleIdx();
178 return idx.empty() ? Mat() : getSubVector(sampleWeights, idx); // 1D-vector
179 }
getTrainResponses() const180 Mat getTrainResponses() const CV_OVERRIDE
181 {
182 return getSubMatrix(responses, getTrainSampleIdx(), cv::ml::ROW_SAMPLE); // col-based responses are transposed in setData()
183 }
getTrainNormCatResponses() const184 Mat getTrainNormCatResponses() const CV_OVERRIDE
185 {
186 return getSubMatrix(normCatResponses, getTrainSampleIdx(), cv::ml::ROW_SAMPLE); // like 'responses'
187 }
getTestResponses() const188 Mat getTestResponses() const CV_OVERRIDE
189 {
190 Mat idx = getTestSampleIdx();
191 return idx.empty() ? Mat() : getSubMatrix(responses, idx, cv::ml::ROW_SAMPLE); // col-based responses are transposed in setData()
192 }
getTestNormCatResponses() const193 Mat getTestNormCatResponses() const CV_OVERRIDE
194 {
195 Mat idx = getTestSampleIdx();
196 return idx.empty() ? Mat() : getSubMatrix(normCatResponses, idx, cv::ml::ROW_SAMPLE); // like 'responses'
197 }
getNormCatResponses() const198 Mat getNormCatResponses() const CV_OVERRIDE { return normCatResponses; }
getClassLabels() const199 Mat getClassLabels() const CV_OVERRIDE { return classLabels; }
getClassCounters() const200 Mat getClassCounters() const { return classCounters; }
getCatCount(int vi) const201 int getCatCount(int vi) const CV_OVERRIDE
202 {
203 int n = (int)catOfs.total();
204 CV_Assert( 0 <= vi && vi < n );
205 Vec2i ofs = catOfs.at<Vec2i>(vi);
206 return ofs[1] - ofs[0];
207 }
208
getCatOfs() const209 Mat getCatOfs() const CV_OVERRIDE { return catOfs; }
getCatMap() const210 Mat getCatMap() const CV_OVERRIDE { return catMap; }
211
getDefaultSubstValues() const212 Mat getDefaultSubstValues() const CV_OVERRIDE { return missingSubst; }
213
closeFile()214 void closeFile() { if(file) fclose(file); file=0; }
clear()215 void clear()
216 {
217 closeFile();
218 samples.release();
219 missing.release();
220 varType.release();
221 varSymbolFlags.release();
222 responses.release();
223 sampleIdx.release();
224 trainSampleIdx.release();
225 testSampleIdx.release();
226 normCatResponses.release();
227 classLabels.release();
228 classCounters.release();
229 catMap.release();
230 catOfs.release();
231 nameMap = MapType();
232 layout = ROW_SAMPLE;
233 }
234
235 typedef std::map<int, int> CatMapHash;
236
setData(InputArray _samples,int _layout,InputArray _responses,InputArray _varIdx,InputArray _sampleIdx,InputArray _sampleWeights,InputArray _varType,InputArray _missing)237 void setData(InputArray _samples, int _layout, InputArray _responses,
238 InputArray _varIdx, InputArray _sampleIdx, InputArray _sampleWeights,
239 InputArray _varType, InputArray _missing)
240 {
241 clear();
242
243 CV_Assert(_layout == ROW_SAMPLE || _layout == COL_SAMPLE );
244 samples = _samples.getMat();
245 layout = _layout;
246 responses = _responses.getMat();
247 varIdx = _varIdx.getMat();
248 sampleIdx = _sampleIdx.getMat();
249 sampleWeights = _sampleWeights.getMat();
250 varType = _varType.getMat();
251 missing = _missing.getMat();
252
253 int nsamples = layout == ROW_SAMPLE ? samples.rows : samples.cols;
254 int ninputvars = layout == ROW_SAMPLE ? samples.cols : samples.rows;
255 int i, noutputvars = 0;
256
257 CV_Assert( samples.type() == CV_32F || samples.type() == CV_32S );
258
259 if( !sampleIdx.empty() )
260 {
261 CV_Assert( (sampleIdx.checkVector(1, CV_32S, true) > 0 &&
262 checkRange(sampleIdx, true, 0, 0, nsamples)) ||
263 sampleIdx.checkVector(1, CV_8U, true) == nsamples );
264 if( sampleIdx.type() == CV_8U )
265 sampleIdx = convertMaskToIdx(sampleIdx);
266 }
267
268 if( !sampleWeights.empty() )
269 {
270 CV_Assert( sampleWeights.checkVector(1, CV_32F, true) == nsamples );
271 }
272 else
273 {
274 sampleWeights = Mat::ones(nsamples, 1, CV_32F);
275 }
276
277 if( !varIdx.empty() )
278 {
279 CV_Assert( (varIdx.checkVector(1, CV_32S, true) > 0 &&
280 checkRange(varIdx, true, 0, 0, ninputvars)) ||
281 varIdx.checkVector(1, CV_8U, true) == ninputvars );
282 if( varIdx.type() == CV_8U )
283 varIdx = convertMaskToIdx(varIdx);
284 varIdx = varIdx.clone();
285 std::sort(varIdx.ptr<int>(), varIdx.ptr<int>() + varIdx.total());
286 }
287
288 if( !responses.empty() )
289 {
290 CV_Assert( responses.type() == CV_32F || responses.type() == CV_32S );
291 if( (responses.cols == 1 || responses.rows == 1) && (int)responses.total() == nsamples )
292 noutputvars = 1;
293 else
294 {
295 CV_Assert( (layout == ROW_SAMPLE && responses.rows == nsamples) ||
296 (layout == COL_SAMPLE && responses.cols == nsamples) );
297 noutputvars = layout == ROW_SAMPLE ? responses.cols : responses.rows;
298 }
299 if( !responses.isContinuous() || (layout == COL_SAMPLE && noutputvars > 1) )
300 {
301 Mat temp;
302 transpose(responses, temp);
303 responses = temp;
304 }
305 }
306
307 int nvars = ninputvars + noutputvars;
308
309 if( !varType.empty() )
310 {
311 CV_Assert( varType.checkVector(1, CV_8U, true) == nvars &&
312 checkRange(varType, true, 0, VAR_ORDERED, VAR_CATEGORICAL+1) );
313 }
314 else
315 {
316 varType.create(1, nvars, CV_8U);
317 varType = Scalar::all(VAR_ORDERED);
318 if( noutputvars == 1 )
319 varType.at<uchar>(ninputvars) = (uchar)(responses.type() < CV_32F ? VAR_CATEGORICAL : VAR_ORDERED);
320 }
321
322 if( noutputvars > 1 )
323 {
324 for( i = 0; i < noutputvars; i++ )
325 CV_Assert( varType.at<uchar>(ninputvars + i) == VAR_ORDERED );
326 }
327
328 catOfs = Mat::zeros(1, nvars, CV_32SC2);
329 missingSubst = Mat::zeros(1, nvars, CV_32F);
330
331 vector<int> labels, counters, sortbuf, tempCatMap;
332 vector<Vec2i> tempCatOfs;
333 CatMapHash ofshash;
334
335 AutoBuffer<uchar> buf(nsamples);
336 Mat non_missing(layout == ROW_SAMPLE ? Size(1, nsamples) : Size(nsamples, 1), CV_8U, buf.data());
337 bool haveMissing = !missing.empty();
338 if( haveMissing )
339 {
340 CV_Assert( missing.size() == samples.size() && missing.type() == CV_8U );
341 }
342
343 // we iterate through all the variables. For each categorical variable we build a map
344 // in order to convert input values of the variable into normalized values (0..catcount_vi-1)
345 // often many categorical variables are similar, so we compress the map - try to re-use
346 // maps for different variables if they are identical
347 for( i = 0; i < ninputvars; i++ )
348 {
349 Mat values_i = layout == ROW_SAMPLE ? samples.col(i) : samples.row(i);
350
351 if( varType.at<uchar>(i) == VAR_CATEGORICAL )
352 {
353 preprocessCategorical(values_i, 0, labels, 0, sortbuf);
354 missingSubst.at<float>(i) = -1.f;
355 int j, m = (int)labels.size();
356 CV_Assert( m > 0 );
357 int a = labels.front(), b = labels.back();
358 const int* currmap = &labels[0];
359 int hashval = ((unsigned)a*127 + (unsigned)b)*127 + m;
360 CatMapHash::iterator it = ofshash.find(hashval);
361 if( it != ofshash.end() )
362 {
363 int vi = it->second;
364 Vec2i ofs0 = tempCatOfs[vi];
365 int m0 = ofs0[1] - ofs0[0];
366 const int* map0 = &tempCatMap[ofs0[0]];
367 if( m0 == m && map0[0] == a && map0[m0-1] == b )
368 {
369 for( j = 0; j < m; j++ )
370 if( map0[j] != currmap[j] )
371 break;
372 if( j == m )
373 {
374 // re-use the map
375 tempCatOfs.push_back(ofs0);
376 continue;
377 }
378 }
379 }
380 else
381 ofshash[hashval] = i;
382 Vec2i ofs;
383 ofs[0] = (int)tempCatMap.size();
384 ofs[1] = ofs[0] + m;
385 tempCatOfs.push_back(ofs);
386 std::copy(labels.begin(), labels.end(), std::back_inserter(tempCatMap));
387 }
388 else
389 {
390 tempCatOfs.push_back(Vec2i(0, 0));
391 /*Mat missing_i = layout == ROW_SAMPLE ? missing.col(i) : missing.row(i);
392 compare(missing_i, Scalar::all(0), non_missing, CMP_EQ);
393 missingSubst.at<float>(i) = (float)(mean(values_i, non_missing)[0]);*/
394 missingSubst.at<float>(i) = 0.f;
395 }
396 }
397
398 if( !tempCatOfs.empty() )
399 {
400 Mat(tempCatOfs).copyTo(catOfs);
401 Mat(tempCatMap).copyTo(catMap);
402 }
403
404 if( noutputvars > 0 && varType.at<uchar>(ninputvars) == VAR_CATEGORICAL )
405 {
406 preprocessCategorical(responses, &normCatResponses, labels, &counters, sortbuf);
407 Mat(labels).copyTo(classLabels);
408 Mat(counters).copyTo(classCounters);
409 }
410 }
411
convertMaskToIdx(const Mat & mask)412 Mat convertMaskToIdx(const Mat& mask)
413 {
414 int i, j, nz = countNonZero(mask), n = mask.cols + mask.rows - 1;
415 Mat idx(1, nz, CV_32S);
416 for( i = j = 0; i < n; i++ )
417 if( mask.at<uchar>(i) )
418 idx.at<int>(j++) = i;
419 return idx;
420 }
421
422 struct CmpByIdx
423 {
CmpByIdxcv::ml::CV_FINAL::CmpByIdx424 CmpByIdx(const int* _data, int _step) : data(_data), step(_step) {}
operator ()cv::ml::CV_FINAL::CmpByIdx425 bool operator ()(int i, int j) const { return data[i*step] < data[j*step]; }
426 const int* data;
427 int step;
428 };
429
preprocessCategorical(const Mat & data,Mat * normdata,vector<int> & labels,vector<int> * counters,vector<int> & sortbuf)430 void preprocessCategorical(const Mat& data, Mat* normdata, vector<int>& labels,
431 vector<int>* counters, vector<int>& sortbuf)
432 {
433 CV_Assert((data.cols == 1 || data.rows == 1) && (data.type() == CV_32S || data.type() == CV_32F));
434 int* odata = 0;
435 int ostep = 0;
436
437 if(normdata)
438 {
439 normdata->create(data.size(), CV_32S);
440 odata = normdata->ptr<int>();
441 ostep = normdata->isContinuous() ? 1 : (int)normdata->step1();
442 }
443
444 int i, n = data.cols + data.rows - 1;
445 sortbuf.resize(n*2);
446 int* idx = &sortbuf[0];
447 int* idata = (int*)data.ptr<int>();
448 int istep = data.isContinuous() ? 1 : (int)data.step1();
449
450 if( data.type() == CV_32F )
451 {
452 idata = idx + n;
453 const float* fdata = data.ptr<float>();
454 for( i = 0; i < n; i++ )
455 {
456 if( fdata[i*istep] == MISSED_VAL )
457 idata[i] = -1;
458 else
459 {
460 idata[i] = cvRound(fdata[i*istep]);
461 CV_Assert( (float)idata[i] == fdata[i*istep] );
462 }
463 }
464 istep = 1;
465 }
466
467 for( i = 0; i < n; i++ )
468 idx[i] = i;
469
470 std::sort(idx, idx + n, CmpByIdx(idata, istep));
471
472 int clscount = 1;
473 for( i = 1; i < n; i++ )
474 clscount += idata[idx[i]*istep] != idata[idx[i-1]*istep];
475
476 int clslabel = -1;
477 int prev = ~idata[idx[0]*istep];
478 int previdx = 0;
479
480 labels.resize(clscount);
481 if(counters)
482 counters->resize(clscount);
483
484 for( i = 0; i < n; i++ )
485 {
486 int l = idata[idx[i]*istep];
487 if( l != prev )
488 {
489 clslabel++;
490 labels[clslabel] = l;
491 int k = i - previdx;
492 if( clslabel > 0 && counters )
493 counters->at(clslabel-1) = k;
494 prev = l;
495 previdx = i;
496 }
497 if(odata)
498 odata[idx[i]*ostep] = clslabel;
499 }
500 if(counters)
501 counters->at(clslabel) = i - previdx;
502 }
503
loadCSV(const String & filename,int headerLines,int responseStartIdx,int responseEndIdx,const String & varTypeSpec,char delimiter,char missch)504 bool loadCSV(const String& filename, int headerLines,
505 int responseStartIdx, int responseEndIdx,
506 const String& varTypeSpec, char delimiter, char missch)
507 {
508 const int M = 1000000;
509 const char delimiters[3] = { ' ', delimiter, '\0' };
510 int nvars = 0;
511 bool varTypesSet = false;
512
513 clear();
514
515 file = fopen( filename.c_str(), "rt" );
516
517 if( !file )
518 return false;
519
520 std::vector<char> _buf(M);
521 std::vector<float> allresponses;
522 std::vector<float> rowvals;
523 std::vector<uchar> vtypes, rowtypes;
524 std::vector<uchar> vsymbolflags;
525 bool haveMissed = false;
526 char* buf = &_buf[0];
527
528 int i, ridx0 = responseStartIdx, ridx1 = responseEndIdx;
529 int ninputvars = 0, noutputvars = 0;
530
531 Mat tempSamples, tempMissing, tempResponses;
532 MapType tempNameMap;
533 int catCounter = 1;
534
535 // skip header lines
536 int lineno = 0;
537 for(;;lineno++)
538 {
539 if( !fgets(buf, M, file) )
540 break;
541 if(lineno < headerLines )
542 continue;
543 // trim trailing spaces
544 int idx = (int)strlen(buf)-1;
545 while( idx >= 0 && isspace(buf[idx]) )
546 buf[idx--] = '\0';
547 // skip spaces in the beginning
548 char* ptr = buf;
549 while( *ptr != '\0' && isspace(*ptr) )
550 ptr++;
551 // skip commented off lines
552 if(*ptr == '#')
553 continue;
554 rowvals.clear();
555 rowtypes.clear();
556
557 char* token = strtok(buf, delimiters);
558 if (!token)
559 break;
560
561 for(;;)
562 {
563 float val=0.f; int tp = 0;
564 decodeElem( token, val, tp, missch, tempNameMap, catCounter );
565 if( tp == VAR_MISSED )
566 haveMissed = true;
567 rowvals.push_back(val);
568 rowtypes.push_back((uchar)tp);
569 token = strtok(NULL, delimiters);
570 if (!token)
571 break;
572 }
573
574 if( nvars == 0 )
575 {
576 if( rowvals.empty() )
577 CV_Error(CV_StsBadArg, "invalid CSV format; no data found");
578 nvars = (int)rowvals.size();
579 if( !varTypeSpec.empty() && varTypeSpec.size() > 0 )
580 {
581 setVarTypes(varTypeSpec, nvars, vtypes);
582 varTypesSet = true;
583 }
584 else
585 vtypes = rowtypes;
586 vsymbolflags.resize(nvars);
587 for( i = 0; i < nvars; i++ )
588 vsymbolflags[i] = (uchar)(rowtypes[i] == VAR_CATEGORICAL);
589
590 ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1;
591 ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1;
592 CV_Assert(ridx1 > ridx0);
593 noutputvars = ridx0 >= 0 ? ridx1 - ridx0 : 0;
594 ninputvars = nvars - noutputvars;
595 }
596 else
597 CV_Assert( nvars == (int)rowvals.size() );
598
599 // check var types
600 for( i = 0; i < nvars; i++ )
601 {
602 CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) ||
603 (varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) );
604 uchar sflag = (uchar)(rowtypes[i] == VAR_CATEGORICAL);
605 if( vsymbolflags[i] == VAR_MISSED )
606 vsymbolflags[i] = sflag;
607 else
608 CV_Assert(vsymbolflags[i] == sflag || rowtypes[i] == VAR_MISSED);
609 }
610
611 if( ridx0 >= 0 )
612 {
613 for( i = ridx1; i < nvars; i++ )
614 std::swap(rowvals[i], rowvals[i-noutputvars]);
615 for( i = ninputvars; i < nvars; i++ )
616 allresponses.push_back(rowvals[i]);
617 rowvals.pop_back();
618 }
619 Mat rmat(1, ninputvars, CV_32F, &rowvals[0]);
620 tempSamples.push_back(rmat);
621 }
622
623 closeFile();
624
625 int nsamples = tempSamples.rows;
626 if( nsamples == 0 )
627 return false;
628
629 if( haveMissed )
630 compare(tempSamples, MISSED_VAL, tempMissing, CMP_EQ);
631
632 if( ridx0 >= 0 )
633 {
634 for( i = ridx1; i < nvars; i++ )
635 std::swap(vtypes[i], vtypes[i-noutputvars]);
636 if( noutputvars > 1 )
637 {
638 for( i = ninputvars; i < nvars; i++ )
639 if( vtypes[i] == VAR_CATEGORICAL )
640 CV_Error(CV_StsBadArg,
641 "If responses are vector values, not scalars, they must be marked as ordered responses");
642 }
643 }
644
645 if( !varTypesSet && noutputvars == 1 && vtypes[ninputvars] == VAR_ORDERED )
646 {
647 for( i = 0; i < nsamples; i++ )
648 if( allresponses[i] != cvRound(allresponses[i]) )
649 break;
650 if( i == nsamples )
651 vtypes[ninputvars] = VAR_CATEGORICAL;
652 }
653
654 //If there are responses in the csv file, save them. If not, responses matrix will contain just zeros
655 if (noutputvars != 0){
656 Mat(nsamples, noutputvars, CV_32F, &allresponses[0]).copyTo(tempResponses);
657 setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
658 noArray(), Mat(vtypes).clone(), tempMissing);
659 }
660 else{
661 Mat zero_mat(nsamples, 1, CV_32F, Scalar(0));
662 zero_mat.copyTo(tempResponses);
663 setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
664 noArray(), noArray(), tempMissing);
665 }
666 bool ok = !samples.empty();
667 if(ok)
668 {
669 std::swap(tempNameMap, nameMap);
670 Mat(vsymbolflags).copyTo(varSymbolFlags);
671 }
672 return ok;
673 }
674
decodeElem(const char * token,float & elem,int & type,char missch,MapType & namemap,int & counter) const675 void decodeElem( const char* token, float& elem, int& type,
676 char missch, MapType& namemap, int& counter ) const
677 {
678 char* stopstring = NULL;
679 elem = (float)strtod( token, &stopstring );
680 if( *stopstring == missch && strlen(stopstring) == 1 ) // missed value
681 {
682 elem = MISSED_VAL;
683 type = VAR_MISSED;
684 }
685 else if( *stopstring != '\0' )
686 {
687 MapType::iterator it = namemap.find(token);
688 if( it == namemap.end() )
689 {
690 elem = (float)counter;
691 namemap[token] = counter++;
692 }
693 else
694 elem = (float)it->second;
695 type = VAR_CATEGORICAL;
696 }
697 else
698 type = VAR_ORDERED;
699 }
700
setVarTypes(const String & s,int nvars,std::vector<uchar> & vtypes) const701 void setVarTypes( const String& s, int nvars, std::vector<uchar>& vtypes ) const
702 {
703 const char* errmsg = "type spec is not correct; it should have format \"cat\", \"ord\" or "
704 "\"ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\", where n's and m's are 0-based variable indices";
705 const char* str = s.c_str();
706 int specCounter = 0;
707
708 vtypes.resize(nvars);
709
710 for( int k = 0; k < 2; k++ )
711 {
712 const char* ptr = strstr(str, k == 0 ? "ord" : "cat");
713 int tp = k == 0 ? VAR_ORDERED : VAR_CATEGORICAL;
714 if( ptr ) // parse ord/cat str
715 {
716 char* stopstring = NULL;
717
718 if( ptr[3] == '\0' )
719 {
720 for( int i = 0; i < nvars; i++ )
721 vtypes[i] = (uchar)tp;
722 specCounter = nvars;
723 break;
724 }
725
726 if ( ptr[3] != '[')
727 CV_Error( CV_StsBadArg, errmsg );
728
729 ptr += 4; // pass "ord["
730 do
731 {
732 int b1 = (int)strtod( ptr, &stopstring );
733 if( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
734 CV_Error( CV_StsBadArg, errmsg );
735 ptr = stopstring + 1;
736 if( (stopstring[0] == ',') || (stopstring[0] == ']'))
737 {
738 CV_Assert( 0 <= b1 && b1 < nvars );
739 vtypes[b1] = (uchar)tp;
740 specCounter++;
741 }
742 else
743 {
744 if( stopstring[0] == '-')
745 {
746 int b2 = (int)strtod( ptr, &stopstring);
747 if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
748 CV_Error( CV_StsBadArg, errmsg );
749 ptr = stopstring + 1;
750 CV_Assert( 0 <= b1 && b1 <= b2 && b2 < nvars );
751 for (int i = b1; i <= b2; i++)
752 vtypes[i] = (uchar)tp;
753 specCounter += b2 - b1 + 1;
754 }
755 else
756 CV_Error( CV_StsBadArg, errmsg );
757
758 }
759 }
760 while(*stopstring != ']');
761 }
762 }
763
764 if( specCounter != nvars )
765 CV_Error( CV_StsBadArg, "type of some variables is not specified" );
766 }
767
setTrainTestSplitRatio(double ratio,bool shuffle)768 void setTrainTestSplitRatio(double ratio, bool shuffle) CV_OVERRIDE
769 {
770 CV_Assert( 0. <= ratio && ratio <= 1. );
771 setTrainTestSplit(cvRound(getNSamples()*ratio), shuffle);
772 }
773
setTrainTestSplit(int count,bool shuffle)774 void setTrainTestSplit(int count, bool shuffle) CV_OVERRIDE
775 {
776 int i, nsamples = getNSamples();
777 CV_Assert( 0 <= count && count < nsamples );
778
779 trainSampleIdx.release();
780 testSampleIdx.release();
781
782 if( count == 0 )
783 trainSampleIdx = sampleIdx;
784 else if( count == nsamples )
785 testSampleIdx = sampleIdx;
786 else
787 {
788 Mat mask(1, nsamples, CV_8U);
789 uchar* mptr = mask.ptr();
790 for( i = 0; i < nsamples; i++ )
791 mptr[i] = (uchar)(i < count);
792 trainSampleIdx.create(1, count, CV_32S);
793 testSampleIdx.create(1, nsamples - count, CV_32S);
794 int j0 = 0, j1 = 0;
795 const int* sptr = !sampleIdx.empty() ? sampleIdx.ptr<int>() : 0;
796 int* trainptr = trainSampleIdx.ptr<int>();
797 int* testptr = testSampleIdx.ptr<int>();
798 for( i = 0; i < nsamples; i++ )
799 {
800 int idx = sptr ? sptr[i] : i;
801 if( mptr[i] )
802 trainptr[j0++] = idx;
803 else
804 testptr[j1++] = idx;
805 }
806 if( shuffle )
807 shuffleTrainTest();
808 }
809 }
810
shuffleTrainTest()811 void shuffleTrainTest() CV_OVERRIDE
812 {
813 if( !trainSampleIdx.empty() && !testSampleIdx.empty() )
814 {
815 int i, nsamples = getNSamples(), ntrain = getNTrainSamples(), ntest = getNTestSamples();
816 int* trainIdx = trainSampleIdx.ptr<int>();
817 int* testIdx = testSampleIdx.ptr<int>();
818 RNG& rng = theRNG();
819
820 for( i = 0; i < nsamples; i++)
821 {
822 int a = rng.uniform(0, nsamples);
823 int b = rng.uniform(0, nsamples);
824 int* ptra = trainIdx;
825 int* ptrb = trainIdx;
826 if( a >= ntrain )
827 {
828 ptra = testIdx;
829 a -= ntrain;
830 CV_Assert( a < ntest );
831 }
832 if( b >= ntrain )
833 {
834 ptrb = testIdx;
835 b -= ntrain;
836 CV_Assert( b < ntest );
837 }
838 std::swap(ptra[a], ptrb[b]);
839 }
840 }
841 }
842
getTrainSamples(int _layout,bool compressSamples,bool compressVars) const843 Mat getTrainSamples(int _layout,
844 bool compressSamples,
845 bool compressVars) const CV_OVERRIDE
846 {
847 if( samples.empty() )
848 return samples;
849
850 if( (!compressSamples || (trainSampleIdx.empty() && sampleIdx.empty())) &&
851 (!compressVars || varIdx.empty()) &&
852 layout == _layout )
853 return samples;
854
855 int drows = getNTrainSamples(), dcols = getNVars();
856 Mat sidx = getTrainSampleIdx(), vidx = getVarIdx();
857 const float* src0 = samples.ptr<float>();
858 const int* sptr = !sidx.empty() ? sidx.ptr<int>() : 0;
859 const int* vptr = !vidx.empty() ? vidx.ptr<int>() : 0;
860 size_t sstep0 = samples.step/samples.elemSize();
861 size_t sstep = layout == ROW_SAMPLE ? sstep0 : 1;
862 size_t vstep = layout == ROW_SAMPLE ? 1 : sstep0;
863
864 if( _layout == COL_SAMPLE )
865 {
866 std::swap(drows, dcols);
867 std::swap(sptr, vptr);
868 std::swap(sstep, vstep);
869 }
870
871 Mat dsamples(drows, dcols, CV_32F);
872
873 for( int i = 0; i < drows; i++ )
874 {
875 const float* src = src0 + (sptr ? sptr[i] : i)*sstep;
876 float* dst = dsamples.ptr<float>(i);
877
878 for( int j = 0; j < dcols; j++ )
879 dst[j] = src[(vptr ? vptr[j] : j)*vstep];
880 }
881
882 return dsamples;
883 }
884
getValues(int vi,InputArray _sidx,float * values) const885 void getValues( int vi, InputArray _sidx, float* values ) const CV_OVERRIDE
886 {
887 Mat sidx = _sidx.getMat();
888 int i, n = sidx.checkVector(1, CV_32S), nsamples = getNSamples();
889 CV_Assert( 0 <= vi && vi < getNAllVars() );
890 CV_Assert( n >= 0 );
891 const int* s = n > 0 ? sidx.ptr<int>() : 0;
892 if( n == 0 )
893 n = nsamples;
894
895 size_t step = samples.step/samples.elemSize();
896 size_t sstep = layout == ROW_SAMPLE ? step : 1;
897 size_t vstep = layout == ROW_SAMPLE ? 1 : step;
898
899 const float* src = samples.ptr<float>() + vi*vstep;
900 float subst = missingSubst.at<float>(vi);
901 for( i = 0; i < n; i++ )
902 {
903 int j = i;
904 if( s )
905 {
906 j = s[i];
907 CV_Assert( 0 <= j && j < nsamples );
908 }
909 values[i] = src[j*sstep];
910 if( values[i] == MISSED_VAL )
911 values[i] = subst;
912 }
913 }
914
getNormCatValues(int vi,InputArray _sidx,int * values) const915 void getNormCatValues( int vi, InputArray _sidx, int* values ) const CV_OVERRIDE
916 {
917 float* fvalues = (float*)values;
918 getValues(vi, _sidx, fvalues);
919 int i, n = (int)_sidx.total();
920 Vec2i ofs = catOfs.at<Vec2i>(vi);
921 int m = ofs[1] - ofs[0];
922
923 CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
924 const int* cmap = &catMap.at<int>(ofs[0]);
925 bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
926
927 if( fastMap )
928 {
929 for( i = 0; i < n; i++ )
930 {
931 int val = cvRound(fvalues[i]);
932 int idx = val - cmap[0];
933 CV_Assert(cmap[idx] == val);
934 values[i] = idx;
935 }
936 }
937 else
938 {
939 for( i = 0; i < n; i++ )
940 {
941 int val = cvRound(fvalues[i]);
942 int a = 0, b = m, c = -1;
943
944 while( a < b )
945 {
946 c = (a + b) >> 1;
947 if( val < cmap[c] )
948 b = c;
949 else if( val > cmap[c] )
950 a = c+1;
951 else
952 break;
953 }
954
955 CV_DbgAssert( c >= 0 && val == cmap[c] );
956 values[i] = c;
957 }
958 }
959 }
960
getSample(InputArray _vidx,int sidx,float * buf) const961 void getSample(InputArray _vidx, int sidx, float* buf) const CV_OVERRIDE
962 {
963 CV_Assert(buf != 0 && 0 <= sidx && sidx < getNSamples());
964 Mat vidx = _vidx.getMat();
965 int i, n = vidx.checkVector(1, CV_32S), nvars = getNAllVars();
966 CV_Assert( n >= 0 );
967 const int* vptr = n > 0 ? vidx.ptr<int>() : 0;
968 if( n == 0 )
969 n = nvars;
970
971 size_t step = samples.step/samples.elemSize();
972 size_t sstep = layout == ROW_SAMPLE ? step : 1;
973 size_t vstep = layout == ROW_SAMPLE ? 1 : step;
974
975 const float* src = samples.ptr<float>() + sidx*sstep;
976 for( i = 0; i < n; i++ )
977 {
978 int j = i;
979 if( vptr )
980 {
981 j = vptr[i];
982 CV_Assert( 0 <= j && j < nvars );
983 }
984 buf[i] = src[j*vstep];
985 }
986 }
987
getNames(std::vector<String> & names) const988 void getNames(std::vector<String>& names) const CV_OVERRIDE
989 {
990 size_t n = nameMap.size();
991 TrainDataImpl::MapType::const_iterator it = nameMap.begin(),
992 it_end = nameMap.end();
993 names.resize(n+1);
994 names[0] = "?";
995 for( ; it != it_end; ++it )
996 {
997 String s = it->first;
998 int label = it->second;
999 CV_Assert( label > 0 && label <= (int)n );
1000 names[label] = s;
1001 }
1002 }
1003
getVarSymbolFlags() const1004 Mat getVarSymbolFlags() const CV_OVERRIDE
1005 {
1006 return varSymbolFlags;
1007 }
1008
1009 FILE* file;
1010 int layout;
1011 Mat samples, missing, varType, varIdx, varSymbolFlags, responses, missingSubst;
1012 Mat sampleIdx, trainSampleIdx, testSampleIdx;
1013 Mat sampleWeights, catMap, catOfs;
1014 Mat normCatResponses, classLabels, classCounters;
1015 MapType nameMap;
1016 };
1017
1018
loadFromCSV(const String & filename,int headerLines,int responseStartIdx,int responseEndIdx,const String & varTypeSpec,char delimiter,char missch)1019 Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
1020 int headerLines,
1021 int responseStartIdx,
1022 int responseEndIdx,
1023 const String& varTypeSpec,
1024 char delimiter, char missch)
1025 {
1026 CV_TRACE_FUNCTION_SKIP_NESTED();
1027 Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
1028 if(!td->loadCSV(filename, headerLines, responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch))
1029 td.release();
1030 return td;
1031 }
1032
create(InputArray samples,int layout,InputArray responses,InputArray varIdx,InputArray sampleIdx,InputArray sampleWeights,InputArray varType)1033 Ptr<TrainData> TrainData::create(InputArray samples, int layout, InputArray responses,
1034 InputArray varIdx, InputArray sampleIdx, InputArray sampleWeights,
1035 InputArray varType)
1036 {
1037 CV_TRACE_FUNCTION_SKIP_NESTED();
1038 Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
1039 td->setData(samples, layout, responses, varIdx, sampleIdx, sampleWeights, varType, noArray());
1040 return td;
1041 }
1042
1043 }}
1044
1045 /* End of file. */
1046