1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40 
41 #include "precomp.hpp"
42 #include <ctype.h>
43 #include <algorithm>
44 #include <iterator>
45 
46 #include <opencv2/core/utils/logger.hpp>
47 
48 namespace cv { namespace ml {
49 
50 static const float MISSED_VAL = TrainData::missingValue();
51 static const int VAR_MISSED = VAR_ORDERED;
52 
~TrainData()53 TrainData::~TrainData() {}
54 
getSubVector(const Mat & vec,const Mat & idx)55 Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
56 {
57     if (!(vec.cols == 1 || vec.rows == 1))
58         CV_LOG_WARNING(NULL, "'getSubVector(const Mat& vec, const Mat& idx)' call with non-1D input is deprecated. It is not designed to work with 2D matrixes (especially with 'cv::ml::COL_SAMPLE' layout).");
59     return getSubMatrix(vec, idx, vec.rows == 1 ? cv::ml::COL_SAMPLE : cv::ml::ROW_SAMPLE);
60 }
61 
62 template<typename T>
getSubMatrixImpl(const Mat & m,const Mat & idx,int layout)63 Mat getSubMatrixImpl(const Mat& m, const Mat& idx, int layout)
64 {
65     int nidx = idx.checkVector(1, CV_32S);
66     int dims = m.cols, nsamples = m.rows;
67 
68     Mat subm;
69     if (layout == COL_SAMPLE)
70     {
71         std::swap(dims, nsamples);
72         subm.create(dims, nidx, m.type());
73     }
74     else
75     {
76         subm.create(nidx, dims, m.type());
77     }
78 
79     for (int i = 0; i < nidx; i++)
80     {
81         int k = idx.at<int>(i); CV_CheckGE(k, 0, "Bad idx"); CV_CheckLT(k, nsamples, "Bad idx or layout");
82         if (dims == 1)
83         {
84             subm.at<T>(i) = m.at<T>(k);  // at() has "transparent" access for 1D col-based / row-based vectors.
85         }
86         else if (layout == COL_SAMPLE)
87         {
88             for (int j = 0; j < dims; j++)
89                 subm.at<T>(j, i) = m.at<T>(j, k);
90         }
91         else
92         {
93             for (int j = 0; j < dims; j++)
94                 subm.at<T>(i, j) = m.at<T>(k, j);
95         }
96     }
97     return subm;
98 }
99 
getSubMatrix(const Mat & m,const Mat & idx,int layout)100 Mat TrainData::getSubMatrix(const Mat& m, const Mat& idx, int layout)
101 {
102     if (idx.empty())
103         return m;
104     int type = m.type();
105     CV_CheckType(type, type == CV_32S || type == CV_32F || type == CV_64F, "");
106     if (type == CV_32S || type == CV_32F)  // 32-bit
107         return getSubMatrixImpl<int>(m, idx, layout);
108     if (type == CV_64F)  // 64-bit
109         return getSubMatrixImpl<double>(m, idx, layout);
110     CV_Error(Error::StsInternal, "");
111 }
112 
113 
114 class TrainDataImpl CV_FINAL : public TrainData
115 {
116 public:
117     typedef std::map<String, int> MapType;
118 
TrainDataImpl()119     TrainDataImpl()
120     {
121         file = 0;
122         clear();
123     }
124 
~TrainDataImpl()125     virtual ~TrainDataImpl() { closeFile(); }
126 
getLayout() const127     int getLayout() const CV_OVERRIDE { return layout; }
getNSamples() const128     int getNSamples() const CV_OVERRIDE
129     {
130         return !sampleIdx.empty() ? (int)sampleIdx.total() :
131                layout == ROW_SAMPLE ? samples.rows : samples.cols;
132     }
getNTrainSamples() const133     int getNTrainSamples() const CV_OVERRIDE
134     {
135         return !trainSampleIdx.empty() ? (int)trainSampleIdx.total() : getNSamples();
136     }
getNTestSamples() const137     int getNTestSamples() const CV_OVERRIDE
138     {
139         return !testSampleIdx.empty() ? (int)testSampleIdx.total() : 0;
140     }
getNVars() const141     int getNVars() const CV_OVERRIDE
142     {
143         return !varIdx.empty() ? (int)varIdx.total() : getNAllVars();
144     }
getNAllVars() const145     int getNAllVars() const CV_OVERRIDE
146     {
147         return layout == ROW_SAMPLE ? samples.cols : samples.rows;
148     }
149 
getTestSamples() const150     Mat getTestSamples() const CV_OVERRIDE
151     {
152         Mat idx = getTestSampleIdx();
153         return idx.empty() ? Mat() : getSubMatrix(samples, idx, getLayout());
154     }
155 
getSamples() const156     Mat getSamples() const CV_OVERRIDE { return samples; }
getResponses() const157     Mat getResponses() const CV_OVERRIDE { return responses; }
getMissing() const158     Mat getMissing() const CV_OVERRIDE { return missing; }
getVarIdx() const159     Mat getVarIdx() const CV_OVERRIDE { return varIdx; }
getVarType() const160     Mat getVarType() const CV_OVERRIDE { return varType; }
getResponseType() const161     int getResponseType() const CV_OVERRIDE
162     {
163         return classLabels.empty() ? VAR_ORDERED : VAR_CATEGORICAL;
164     }
getTrainSampleIdx() const165     Mat getTrainSampleIdx() const CV_OVERRIDE { return !trainSampleIdx.empty() ? trainSampleIdx : sampleIdx; }
getTestSampleIdx() const166     Mat getTestSampleIdx() const CV_OVERRIDE { return testSampleIdx; }
getSampleWeights() const167     Mat getSampleWeights() const CV_OVERRIDE
168     {
169         return sampleWeights;
170     }
getTrainSampleWeights() const171     Mat getTrainSampleWeights() const CV_OVERRIDE
172     {
173         return getSubVector(sampleWeights, getTrainSampleIdx());  // 1D-vector
174     }
getTestSampleWeights() const175     Mat getTestSampleWeights() const CV_OVERRIDE
176     {
177         Mat idx = getTestSampleIdx();
178         return idx.empty() ? Mat() : getSubVector(sampleWeights, idx);  // 1D-vector
179     }
getTrainResponses() const180     Mat getTrainResponses() const CV_OVERRIDE
181     {
182         return getSubMatrix(responses, getTrainSampleIdx(), cv::ml::ROW_SAMPLE);  // col-based responses are transposed in setData()
183     }
getTrainNormCatResponses() const184     Mat getTrainNormCatResponses() const CV_OVERRIDE
185     {
186         return getSubMatrix(normCatResponses, getTrainSampleIdx(), cv::ml::ROW_SAMPLE);  // like 'responses'
187     }
getTestResponses() const188     Mat getTestResponses() const CV_OVERRIDE
189     {
190         Mat idx = getTestSampleIdx();
191         return idx.empty() ? Mat() : getSubMatrix(responses, idx, cv::ml::ROW_SAMPLE);  // col-based responses are transposed in setData()
192     }
getTestNormCatResponses() const193     Mat getTestNormCatResponses() const CV_OVERRIDE
194     {
195         Mat idx = getTestSampleIdx();
196         return idx.empty() ? Mat() : getSubMatrix(normCatResponses, idx, cv::ml::ROW_SAMPLE);  // like 'responses'
197     }
getNormCatResponses() const198     Mat getNormCatResponses() const CV_OVERRIDE { return normCatResponses; }
getClassLabels() const199     Mat getClassLabels() const CV_OVERRIDE { return classLabels; }
getClassCounters() const200     Mat getClassCounters() const { return classCounters; }
getCatCount(int vi) const201     int getCatCount(int vi) const CV_OVERRIDE
202     {
203         int n = (int)catOfs.total();
204         CV_Assert( 0 <= vi && vi < n );
205         Vec2i ofs = catOfs.at<Vec2i>(vi);
206         return ofs[1] - ofs[0];
207     }
208 
getCatOfs() const209     Mat getCatOfs() const CV_OVERRIDE { return catOfs; }
getCatMap() const210     Mat getCatMap() const CV_OVERRIDE { return catMap; }
211 
getDefaultSubstValues() const212     Mat getDefaultSubstValues() const CV_OVERRIDE { return missingSubst; }
213 
closeFile()214     void closeFile() { if(file) fclose(file); file=0; }
clear()215     void clear()
216     {
217         closeFile();
218         samples.release();
219         missing.release();
220         varType.release();
221         varSymbolFlags.release();
222         responses.release();
223         sampleIdx.release();
224         trainSampleIdx.release();
225         testSampleIdx.release();
226         normCatResponses.release();
227         classLabels.release();
228         classCounters.release();
229         catMap.release();
230         catOfs.release();
231         nameMap = MapType();
232         layout = ROW_SAMPLE;
233     }
234 
235     typedef std::map<int, int> CatMapHash;
236 
setData(InputArray _samples,int _layout,InputArray _responses,InputArray _varIdx,InputArray _sampleIdx,InputArray _sampleWeights,InputArray _varType,InputArray _missing)237     void setData(InputArray _samples, int _layout, InputArray _responses,
238                  InputArray _varIdx, InputArray _sampleIdx, InputArray _sampleWeights,
239                  InputArray _varType, InputArray _missing)
240     {
241         clear();
242 
243         CV_Assert(_layout == ROW_SAMPLE || _layout == COL_SAMPLE );
244         samples = _samples.getMat();
245         layout = _layout;
246         responses = _responses.getMat();
247         varIdx = _varIdx.getMat();
248         sampleIdx = _sampleIdx.getMat();
249         sampleWeights = _sampleWeights.getMat();
250         varType = _varType.getMat();
251         missing = _missing.getMat();
252 
253         int nsamples = layout == ROW_SAMPLE ? samples.rows : samples.cols;
254         int ninputvars = layout == ROW_SAMPLE ? samples.cols : samples.rows;
255         int i, noutputvars = 0;
256 
257         CV_Assert( samples.type() == CV_32F || samples.type() == CV_32S );
258 
259         if( !sampleIdx.empty() )
260         {
261             CV_Assert( (sampleIdx.checkVector(1, CV_32S, true) > 0 &&
262                        checkRange(sampleIdx, true, 0, 0, nsamples)) ||
263                        sampleIdx.checkVector(1, CV_8U, true) == nsamples );
264             if( sampleIdx.type() == CV_8U )
265                 sampleIdx = convertMaskToIdx(sampleIdx);
266         }
267 
268         if( !sampleWeights.empty() )
269         {
270             CV_Assert( sampleWeights.checkVector(1, CV_32F, true) == nsamples );
271         }
272         else
273         {
274             sampleWeights = Mat::ones(nsamples, 1, CV_32F);
275         }
276 
277         if( !varIdx.empty() )
278         {
279             CV_Assert( (varIdx.checkVector(1, CV_32S, true) > 0 &&
280                        checkRange(varIdx, true, 0, 0, ninputvars)) ||
281                        varIdx.checkVector(1, CV_8U, true) == ninputvars );
282             if( varIdx.type() == CV_8U )
283                 varIdx = convertMaskToIdx(varIdx);
284             varIdx = varIdx.clone();
285             std::sort(varIdx.ptr<int>(), varIdx.ptr<int>() + varIdx.total());
286         }
287 
288         if( !responses.empty() )
289         {
290             CV_Assert( responses.type() == CV_32F || responses.type() == CV_32S );
291             if( (responses.cols == 1 || responses.rows == 1) && (int)responses.total() == nsamples )
292                 noutputvars = 1;
293             else
294             {
295                 CV_Assert( (layout == ROW_SAMPLE && responses.rows == nsamples) ||
296                            (layout == COL_SAMPLE && responses.cols == nsamples) );
297                 noutputvars = layout == ROW_SAMPLE ? responses.cols : responses.rows;
298             }
299             if( !responses.isContinuous() || (layout == COL_SAMPLE && noutputvars > 1) )
300             {
301                 Mat temp;
302                 transpose(responses, temp);
303                 responses = temp;
304             }
305         }
306 
307         int nvars = ninputvars + noutputvars;
308 
309         if( !varType.empty() )
310         {
311             CV_Assert( varType.checkVector(1, CV_8U, true) == nvars &&
312                        checkRange(varType, true, 0, VAR_ORDERED, VAR_CATEGORICAL+1) );
313         }
314         else
315         {
316             varType.create(1, nvars, CV_8U);
317             varType = Scalar::all(VAR_ORDERED);
318             if( noutputvars == 1 )
319                 varType.at<uchar>(ninputvars) = (uchar)(responses.type() < CV_32F ? VAR_CATEGORICAL : VAR_ORDERED);
320         }
321 
322         if( noutputvars > 1 )
323         {
324             for( i = 0; i < noutputvars; i++ )
325                 CV_Assert( varType.at<uchar>(ninputvars + i) == VAR_ORDERED );
326         }
327 
328         catOfs = Mat::zeros(1, nvars, CV_32SC2);
329         missingSubst = Mat::zeros(1, nvars, CV_32F);
330 
331         vector<int> labels, counters, sortbuf, tempCatMap;
332         vector<Vec2i> tempCatOfs;
333         CatMapHash ofshash;
334 
335         AutoBuffer<uchar> buf(nsamples);
336         Mat non_missing(layout == ROW_SAMPLE ? Size(1, nsamples) : Size(nsamples, 1), CV_8U, buf.data());
337         bool haveMissing = !missing.empty();
338         if( haveMissing )
339         {
340             CV_Assert( missing.size() == samples.size() && missing.type() == CV_8U );
341         }
342 
343         // we iterate through all the variables. For each categorical variable we build a map
344         // in order to convert input values of the variable into normalized values (0..catcount_vi-1)
345         // often many categorical variables are similar, so we compress the map - try to re-use
346         // maps for different variables if they are identical
347         for( i = 0; i < ninputvars; i++ )
348         {
349             Mat values_i = layout == ROW_SAMPLE ? samples.col(i) : samples.row(i);
350 
351             if( varType.at<uchar>(i) == VAR_CATEGORICAL )
352             {
353                 preprocessCategorical(values_i, 0, labels, 0, sortbuf);
354                 missingSubst.at<float>(i) = -1.f;
355                 int j, m = (int)labels.size();
356                 CV_Assert( m > 0 );
357                 int a = labels.front(), b = labels.back();
358                 const int* currmap = &labels[0];
359                 int hashval = ((unsigned)a*127 + (unsigned)b)*127 + m;
360                 CatMapHash::iterator it = ofshash.find(hashval);
361                 if( it != ofshash.end() )
362                 {
363                     int vi = it->second;
364                     Vec2i ofs0 = tempCatOfs[vi];
365                     int m0 = ofs0[1] - ofs0[0];
366                     const int* map0 = &tempCatMap[ofs0[0]];
367                     if( m0 == m && map0[0] == a && map0[m0-1] == b )
368                     {
369                         for( j = 0; j < m; j++ )
370                             if( map0[j] != currmap[j] )
371                                 break;
372                         if( j == m )
373                         {
374                             // re-use the map
375                             tempCatOfs.push_back(ofs0);
376                             continue;
377                         }
378                     }
379                 }
380                 else
381                     ofshash[hashval] = i;
382                 Vec2i ofs;
383                 ofs[0] = (int)tempCatMap.size();
384                 ofs[1] = ofs[0] + m;
385                 tempCatOfs.push_back(ofs);
386                 std::copy(labels.begin(), labels.end(), std::back_inserter(tempCatMap));
387             }
388             else
389             {
390                 tempCatOfs.push_back(Vec2i(0, 0));
391                 /*Mat missing_i = layout == ROW_SAMPLE ? missing.col(i) : missing.row(i);
392                 compare(missing_i, Scalar::all(0), non_missing, CMP_EQ);
393                 missingSubst.at<float>(i) = (float)(mean(values_i, non_missing)[0]);*/
394                 missingSubst.at<float>(i) = 0.f;
395             }
396         }
397 
398         if( !tempCatOfs.empty() )
399         {
400             Mat(tempCatOfs).copyTo(catOfs);
401             Mat(tempCatMap).copyTo(catMap);
402         }
403 
404         if( noutputvars > 0 && varType.at<uchar>(ninputvars) == VAR_CATEGORICAL )
405         {
406             preprocessCategorical(responses, &normCatResponses, labels, &counters, sortbuf);
407             Mat(labels).copyTo(classLabels);
408             Mat(counters).copyTo(classCounters);
409         }
410     }
411 
convertMaskToIdx(const Mat & mask)412     Mat convertMaskToIdx(const Mat& mask)
413     {
414         int i, j, nz = countNonZero(mask), n = mask.cols + mask.rows - 1;
415         Mat idx(1, nz, CV_32S);
416         for( i = j = 0; i < n; i++ )
417             if( mask.at<uchar>(i) )
418                 idx.at<int>(j++) = i;
419         return idx;
420     }
421 
422     struct CmpByIdx
423     {
CmpByIdxcv::ml::CV_FINAL::CmpByIdx424         CmpByIdx(const int* _data, int _step) : data(_data), step(_step) {}
operator ()cv::ml::CV_FINAL::CmpByIdx425         bool operator ()(int i, int j) const { return data[i*step] < data[j*step]; }
426         const int* data;
427         int step;
428     };
429 
preprocessCategorical(const Mat & data,Mat * normdata,vector<int> & labels,vector<int> * counters,vector<int> & sortbuf)430     void preprocessCategorical(const Mat& data, Mat* normdata, vector<int>& labels,
431                                vector<int>* counters, vector<int>& sortbuf)
432     {
433         CV_Assert((data.cols == 1 || data.rows == 1) && (data.type() == CV_32S || data.type() == CV_32F));
434         int* odata = 0;
435         int ostep = 0;
436 
437         if(normdata)
438         {
439             normdata->create(data.size(), CV_32S);
440             odata = normdata->ptr<int>();
441             ostep = normdata->isContinuous() ? 1 : (int)normdata->step1();
442         }
443 
444         int i, n = data.cols + data.rows - 1;
445         sortbuf.resize(n*2);
446         int* idx = &sortbuf[0];
447         int* idata = (int*)data.ptr<int>();
448         int istep = data.isContinuous() ? 1 : (int)data.step1();
449 
450         if( data.type() == CV_32F )
451         {
452             idata = idx + n;
453             const float* fdata = data.ptr<float>();
454             for( i = 0; i < n; i++ )
455             {
456                 if( fdata[i*istep] == MISSED_VAL )
457                     idata[i] = -1;
458                 else
459                 {
460                     idata[i] = cvRound(fdata[i*istep]);
461                     CV_Assert( (float)idata[i] == fdata[i*istep] );
462                 }
463             }
464             istep = 1;
465         }
466 
467         for( i = 0; i < n; i++ )
468             idx[i] = i;
469 
470         std::sort(idx, idx + n, CmpByIdx(idata, istep));
471 
472         int clscount = 1;
473         for( i = 1; i < n; i++ )
474             clscount += idata[idx[i]*istep] != idata[idx[i-1]*istep];
475 
476         int clslabel = -1;
477         int prev = ~idata[idx[0]*istep];
478         int previdx = 0;
479 
480         labels.resize(clscount);
481         if(counters)
482             counters->resize(clscount);
483 
484         for( i = 0; i < n; i++ )
485         {
486             int l = idata[idx[i]*istep];
487             if( l != prev )
488             {
489                 clslabel++;
490                 labels[clslabel] = l;
491                 int k = i - previdx;
492                 if( clslabel > 0 && counters )
493                     counters->at(clslabel-1) = k;
494                 prev = l;
495                 previdx = i;
496             }
497             if(odata)
498                 odata[idx[i]*ostep] = clslabel;
499         }
500         if(counters)
501             counters->at(clslabel) = i - previdx;
502     }
503 
loadCSV(const String & filename,int headerLines,int responseStartIdx,int responseEndIdx,const String & varTypeSpec,char delimiter,char missch)504     bool loadCSV(const String& filename, int headerLines,
505                  int responseStartIdx, int responseEndIdx,
506                  const String& varTypeSpec, char delimiter, char missch)
507     {
508         const int M = 1000000;
509         const char delimiters[3] = { ' ', delimiter, '\0' };
510         int nvars = 0;
511         bool varTypesSet = false;
512 
513         clear();
514 
515         file = fopen( filename.c_str(), "rt" );
516 
517         if( !file )
518             return false;
519 
520         std::vector<char> _buf(M);
521         std::vector<float> allresponses;
522         std::vector<float> rowvals;
523         std::vector<uchar> vtypes, rowtypes;
524         std::vector<uchar> vsymbolflags;
525         bool haveMissed = false;
526         char* buf = &_buf[0];
527 
528         int i, ridx0 = responseStartIdx, ridx1 = responseEndIdx;
529         int ninputvars = 0, noutputvars = 0;
530 
531         Mat tempSamples, tempMissing, tempResponses;
532         MapType tempNameMap;
533         int catCounter = 1;
534 
535         // skip header lines
536         int lineno = 0;
537         for(;;lineno++)
538         {
539             if( !fgets(buf, M, file) )
540                 break;
541             if(lineno < headerLines )
542                 continue;
543             // trim trailing spaces
544             int idx = (int)strlen(buf)-1;
545             while( idx >= 0 && isspace(buf[idx]) )
546                 buf[idx--] = '\0';
547             // skip spaces in the beginning
548             char* ptr = buf;
549             while( *ptr != '\0' && isspace(*ptr) )
550                 ptr++;
551             // skip commented off lines
552             if(*ptr == '#')
553                 continue;
554             rowvals.clear();
555             rowtypes.clear();
556 
557             char* token = strtok(buf, delimiters);
558             if (!token)
559                 break;
560 
561             for(;;)
562             {
563                 float val=0.f; int tp = 0;
564                 decodeElem( token, val, tp, missch, tempNameMap, catCounter );
565                 if( tp == VAR_MISSED )
566                     haveMissed = true;
567                 rowvals.push_back(val);
568                 rowtypes.push_back((uchar)tp);
569                 token = strtok(NULL, delimiters);
570                 if (!token)
571                     break;
572             }
573 
574             if( nvars == 0 )
575             {
576                 if( rowvals.empty() )
577                     CV_Error(CV_StsBadArg, "invalid CSV format; no data found");
578                 nvars = (int)rowvals.size();
579                 if( !varTypeSpec.empty() && varTypeSpec.size() > 0 )
580                 {
581                     setVarTypes(varTypeSpec, nvars, vtypes);
582                     varTypesSet = true;
583                 }
584                 else
585                     vtypes = rowtypes;
586                 vsymbolflags.resize(nvars);
587                 for( i = 0; i < nvars; i++ )
588                     vsymbolflags[i] = (uchar)(rowtypes[i] == VAR_CATEGORICAL);
589 
590                 ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1;
591                 ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1;
592                 CV_Assert(ridx1 > ridx0);
593                 noutputvars = ridx0 >= 0 ? ridx1 - ridx0 : 0;
594                 ninputvars = nvars - noutputvars;
595             }
596             else
597                 CV_Assert( nvars == (int)rowvals.size() );
598 
599             // check var types
600             for( i = 0; i < nvars; i++ )
601             {
602                 CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) ||
603                            (varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) );
604                 uchar sflag = (uchar)(rowtypes[i] == VAR_CATEGORICAL);
605                 if( vsymbolflags[i] == VAR_MISSED )
606                     vsymbolflags[i] = sflag;
607                 else
608                     CV_Assert(vsymbolflags[i] == sflag || rowtypes[i] == VAR_MISSED);
609             }
610 
611             if( ridx0 >= 0 )
612             {
613                 for( i = ridx1; i < nvars; i++ )
614                     std::swap(rowvals[i], rowvals[i-noutputvars]);
615                 for( i = ninputvars; i < nvars; i++ )
616                     allresponses.push_back(rowvals[i]);
617                 rowvals.pop_back();
618             }
619             Mat rmat(1, ninputvars, CV_32F, &rowvals[0]);
620             tempSamples.push_back(rmat);
621         }
622 
623         closeFile();
624 
625         int nsamples = tempSamples.rows;
626         if( nsamples == 0 )
627             return false;
628 
629         if( haveMissed )
630             compare(tempSamples, MISSED_VAL, tempMissing, CMP_EQ);
631 
632         if( ridx0 >= 0 )
633         {
634             for( i = ridx1; i < nvars; i++ )
635                 std::swap(vtypes[i], vtypes[i-noutputvars]);
636             if( noutputvars > 1 )
637             {
638                 for( i = ninputvars; i < nvars; i++ )
639                     if( vtypes[i] == VAR_CATEGORICAL )
640                         CV_Error(CV_StsBadArg,
641                                  "If responses are vector values, not scalars, they must be marked as ordered responses");
642             }
643         }
644 
645         if( !varTypesSet && noutputvars == 1 && vtypes[ninputvars] == VAR_ORDERED )
646         {
647             for( i = 0; i < nsamples; i++ )
648                 if( allresponses[i] != cvRound(allresponses[i]) )
649                     break;
650             if( i == nsamples )
651                 vtypes[ninputvars] = VAR_CATEGORICAL;
652         }
653 
654         //If there are responses in the csv file, save them. If not, responses matrix will contain just zeros
655         if (noutputvars != 0){
656             Mat(nsamples, noutputvars, CV_32F, &allresponses[0]).copyTo(tempResponses);
657             setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
658                     noArray(), Mat(vtypes).clone(), tempMissing);
659         }
660         else{
661             Mat zero_mat(nsamples, 1, CV_32F, Scalar(0));
662             zero_mat.copyTo(tempResponses);
663             setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
664                     noArray(), noArray(), tempMissing);
665         }
666         bool ok = !samples.empty();
667         if(ok)
668         {
669             std::swap(tempNameMap, nameMap);
670             Mat(vsymbolflags).copyTo(varSymbolFlags);
671         }
672         return ok;
673     }
674 
decodeElem(const char * token,float & elem,int & type,char missch,MapType & namemap,int & counter) const675     void decodeElem( const char* token, float& elem, int& type,
676                      char missch, MapType& namemap, int& counter ) const
677     {
678         char* stopstring = NULL;
679         elem = (float)strtod( token, &stopstring );
680         if( *stopstring == missch && strlen(stopstring) == 1 ) // missed value
681         {
682             elem = MISSED_VAL;
683             type = VAR_MISSED;
684         }
685         else if( *stopstring != '\0' )
686         {
687             MapType::iterator it = namemap.find(token);
688             if( it == namemap.end() )
689             {
690                 elem = (float)counter;
691                 namemap[token] = counter++;
692             }
693             else
694                 elem = (float)it->second;
695             type = VAR_CATEGORICAL;
696         }
697         else
698             type = VAR_ORDERED;
699     }
700 
setVarTypes(const String & s,int nvars,std::vector<uchar> & vtypes) const701     void setVarTypes( const String& s, int nvars, std::vector<uchar>& vtypes ) const
702     {
703         const char* errmsg = "type spec is not correct; it should have format \"cat\", \"ord\" or "
704           "\"ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\", where n's and m's are 0-based variable indices";
705         const char* str = s.c_str();
706         int specCounter = 0;
707 
708         vtypes.resize(nvars);
709 
710         for( int k = 0; k < 2; k++ )
711         {
712             const char* ptr = strstr(str, k == 0 ? "ord" : "cat");
713             int tp = k == 0 ? VAR_ORDERED : VAR_CATEGORICAL;
714             if( ptr ) // parse ord/cat str
715             {
716                 char* stopstring = NULL;
717 
718                 if( ptr[3] == '\0' )
719                 {
720                     for( int i = 0; i < nvars; i++ )
721                         vtypes[i] = (uchar)tp;
722                     specCounter = nvars;
723                     break;
724                 }
725 
726                 if ( ptr[3] != '[')
727                     CV_Error( CV_StsBadArg, errmsg );
728 
729                 ptr += 4; // pass "ord["
730                 do
731                 {
732                     int b1 = (int)strtod( ptr, &stopstring );
733                     if( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
734                         CV_Error( CV_StsBadArg, errmsg );
735                     ptr = stopstring + 1;
736                     if( (stopstring[0] == ',') || (stopstring[0] == ']'))
737                     {
738                         CV_Assert( 0 <= b1 && b1 < nvars );
739                         vtypes[b1] = (uchar)tp;
740                         specCounter++;
741                     }
742                     else
743                     {
744                         if( stopstring[0] == '-')
745                         {
746                             int b2 = (int)strtod( ptr, &stopstring);
747                             if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
748                                 CV_Error( CV_StsBadArg, errmsg );
749                             ptr = stopstring + 1;
750                             CV_Assert( 0 <= b1 && b1 <= b2 && b2 < nvars );
751                             for (int i = b1; i <= b2; i++)
752                                 vtypes[i] = (uchar)tp;
753                             specCounter += b2 - b1 + 1;
754                         }
755                         else
756                             CV_Error( CV_StsBadArg, errmsg );
757 
758                     }
759                 }
760                 while(*stopstring != ']');
761             }
762         }
763 
764         if( specCounter != nvars )
765             CV_Error( CV_StsBadArg, "type of some variables is not specified" );
766     }
767 
setTrainTestSplitRatio(double ratio,bool shuffle)768     void setTrainTestSplitRatio(double ratio, bool shuffle) CV_OVERRIDE
769     {
770         CV_Assert( 0. <= ratio && ratio <= 1. );
771         setTrainTestSplit(cvRound(getNSamples()*ratio), shuffle);
772     }
773 
setTrainTestSplit(int count,bool shuffle)774     void setTrainTestSplit(int count, bool shuffle) CV_OVERRIDE
775     {
776         int i, nsamples = getNSamples();
777         CV_Assert( 0 <= count && count < nsamples );
778 
779         trainSampleIdx.release();
780         testSampleIdx.release();
781 
782         if( count == 0 )
783             trainSampleIdx = sampleIdx;
784         else if( count == nsamples )
785             testSampleIdx = sampleIdx;
786         else
787         {
788             Mat mask(1, nsamples, CV_8U);
789             uchar* mptr = mask.ptr();
790             for( i = 0; i < nsamples; i++ )
791                 mptr[i] = (uchar)(i < count);
792             trainSampleIdx.create(1, count, CV_32S);
793             testSampleIdx.create(1, nsamples - count, CV_32S);
794             int j0 = 0, j1 = 0;
795             const int* sptr = !sampleIdx.empty() ? sampleIdx.ptr<int>() : 0;
796             int* trainptr = trainSampleIdx.ptr<int>();
797             int* testptr = testSampleIdx.ptr<int>();
798             for( i = 0; i < nsamples; i++ )
799             {
800                 int idx = sptr ? sptr[i] : i;
801                 if( mptr[i] )
802                     trainptr[j0++] = idx;
803                 else
804                     testptr[j1++] = idx;
805             }
806             if( shuffle )
807                 shuffleTrainTest();
808         }
809     }
810 
shuffleTrainTest()811     void shuffleTrainTest() CV_OVERRIDE
812     {
813         if( !trainSampleIdx.empty() && !testSampleIdx.empty() )
814         {
815             int i, nsamples = getNSamples(), ntrain = getNTrainSamples(), ntest = getNTestSamples();
816             int* trainIdx = trainSampleIdx.ptr<int>();
817             int* testIdx = testSampleIdx.ptr<int>();
818             RNG& rng = theRNG();
819 
820             for( i = 0; i < nsamples; i++)
821             {
822                 int a = rng.uniform(0, nsamples);
823                 int b = rng.uniform(0, nsamples);
824                 int* ptra = trainIdx;
825                 int* ptrb = trainIdx;
826                 if( a >= ntrain )
827                 {
828                     ptra = testIdx;
829                     a -= ntrain;
830                     CV_Assert( a < ntest );
831                 }
832                 if( b >= ntrain )
833                 {
834                     ptrb = testIdx;
835                     b -= ntrain;
836                     CV_Assert( b < ntest );
837                 }
838                 std::swap(ptra[a], ptrb[b]);
839             }
840         }
841     }
842 
getTrainSamples(int _layout,bool compressSamples,bool compressVars) const843     Mat getTrainSamples(int _layout,
844                         bool compressSamples,
845                         bool compressVars) const CV_OVERRIDE
846     {
847         if( samples.empty() )
848             return samples;
849 
850         if( (!compressSamples || (trainSampleIdx.empty() && sampleIdx.empty())) &&
851             (!compressVars || varIdx.empty()) &&
852             layout == _layout )
853             return samples;
854 
855         int drows = getNTrainSamples(), dcols = getNVars();
856         Mat sidx = getTrainSampleIdx(), vidx = getVarIdx();
857         const float* src0 = samples.ptr<float>();
858         const int* sptr = !sidx.empty() ? sidx.ptr<int>() : 0;
859         const int* vptr = !vidx.empty() ? vidx.ptr<int>() : 0;
860         size_t sstep0 = samples.step/samples.elemSize();
861         size_t sstep = layout == ROW_SAMPLE ? sstep0 : 1;
862         size_t vstep = layout == ROW_SAMPLE ? 1 : sstep0;
863 
864         if( _layout == COL_SAMPLE )
865         {
866             std::swap(drows, dcols);
867             std::swap(sptr, vptr);
868             std::swap(sstep, vstep);
869         }
870 
871         Mat dsamples(drows, dcols, CV_32F);
872 
873         for( int i = 0; i < drows; i++ )
874         {
875             const float* src = src0 + (sptr ? sptr[i] : i)*sstep;
876             float* dst = dsamples.ptr<float>(i);
877 
878             for( int j = 0; j < dcols; j++ )
879                 dst[j] = src[(vptr ? vptr[j] : j)*vstep];
880         }
881 
882         return dsamples;
883     }
884 
getValues(int vi,InputArray _sidx,float * values) const885     void getValues( int vi, InputArray _sidx, float* values ) const CV_OVERRIDE
886     {
887         Mat sidx = _sidx.getMat();
888         int i, n = sidx.checkVector(1, CV_32S), nsamples = getNSamples();
889         CV_Assert( 0 <= vi && vi < getNAllVars() );
890         CV_Assert( n >= 0 );
891         const int* s = n > 0 ? sidx.ptr<int>() : 0;
892         if( n == 0 )
893             n = nsamples;
894 
895         size_t step = samples.step/samples.elemSize();
896         size_t sstep = layout == ROW_SAMPLE ? step : 1;
897         size_t vstep = layout == ROW_SAMPLE ? 1 : step;
898 
899         const float* src = samples.ptr<float>() + vi*vstep;
900         float subst = missingSubst.at<float>(vi);
901         for( i = 0; i < n; i++ )
902         {
903             int j = i;
904             if( s )
905             {
906                 j = s[i];
907                 CV_Assert( 0 <= j && j < nsamples );
908             }
909             values[i] = src[j*sstep];
910             if( values[i] == MISSED_VAL )
911                 values[i] = subst;
912         }
913     }
914 
getNormCatValues(int vi,InputArray _sidx,int * values) const915     void getNormCatValues( int vi, InputArray _sidx, int* values ) const CV_OVERRIDE
916     {
917         float* fvalues = (float*)values;
918         getValues(vi, _sidx, fvalues);
919         int i, n = (int)_sidx.total();
920         Vec2i ofs = catOfs.at<Vec2i>(vi);
921         int m = ofs[1] - ofs[0];
922 
923         CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
924         const int* cmap = &catMap.at<int>(ofs[0]);
925         bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
926 
927         if( fastMap )
928         {
929             for( i = 0; i < n; i++ )
930             {
931                 int val = cvRound(fvalues[i]);
932                 int idx = val - cmap[0];
933                 CV_Assert(cmap[idx] == val);
934                 values[i] = idx;
935             }
936         }
937         else
938         {
939             for( i = 0; i < n; i++ )
940             {
941                 int val = cvRound(fvalues[i]);
942                 int a = 0, b = m, c = -1;
943 
944                 while( a < b )
945                 {
946                     c = (a + b) >> 1;
947                     if( val < cmap[c] )
948                         b = c;
949                     else if( val > cmap[c] )
950                         a = c+1;
951                     else
952                         break;
953                 }
954 
955                 CV_DbgAssert( c >= 0 && val == cmap[c] );
956                 values[i] = c;
957             }
958         }
959     }
960 
getSample(InputArray _vidx,int sidx,float * buf) const961     void getSample(InputArray _vidx, int sidx, float* buf) const CV_OVERRIDE
962     {
963         CV_Assert(buf != 0 && 0 <= sidx && sidx < getNSamples());
964         Mat vidx = _vidx.getMat();
965         int i, n = vidx.checkVector(1, CV_32S), nvars = getNAllVars();
966         CV_Assert( n >= 0 );
967         const int* vptr = n > 0 ? vidx.ptr<int>() : 0;
968         if( n == 0 )
969             n = nvars;
970 
971         size_t step = samples.step/samples.elemSize();
972         size_t sstep = layout == ROW_SAMPLE ? step : 1;
973         size_t vstep = layout == ROW_SAMPLE ? 1 : step;
974 
975         const float* src = samples.ptr<float>() + sidx*sstep;
976         for( i = 0; i < n; i++ )
977         {
978             int j = i;
979             if( vptr )
980             {
981                 j = vptr[i];
982                 CV_Assert( 0 <= j && j < nvars );
983             }
984             buf[i] = src[j*vstep];
985         }
986     }
987 
getNames(std::vector<String> & names) const988     void getNames(std::vector<String>& names) const CV_OVERRIDE
989     {
990         size_t n = nameMap.size();
991         TrainDataImpl::MapType::const_iterator it = nameMap.begin(),
992                                                it_end = nameMap.end();
993         names.resize(n+1);
994         names[0] = "?";
995         for( ; it != it_end; ++it )
996         {
997             String s = it->first;
998             int label = it->second;
999             CV_Assert( label > 0 && label <= (int)n );
1000             names[label] = s;
1001         }
1002     }
1003 
getVarSymbolFlags() const1004     Mat getVarSymbolFlags() const CV_OVERRIDE
1005     {
1006         return varSymbolFlags;
1007     }
1008 
1009     FILE* file;
1010     int layout;
1011     Mat samples, missing, varType, varIdx, varSymbolFlags, responses, missingSubst;
1012     Mat sampleIdx, trainSampleIdx, testSampleIdx;
1013     Mat sampleWeights, catMap, catOfs;
1014     Mat normCatResponses, classLabels, classCounters;
1015     MapType nameMap;
1016 };
1017 
1018 
loadFromCSV(const String & filename,int headerLines,int responseStartIdx,int responseEndIdx,const String & varTypeSpec,char delimiter,char missch)1019 Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
1020                                       int headerLines,
1021                                       int responseStartIdx,
1022                                       int responseEndIdx,
1023                                       const String& varTypeSpec,
1024                                       char delimiter, char missch)
1025 {
1026     CV_TRACE_FUNCTION_SKIP_NESTED();
1027     Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
1028     if(!td->loadCSV(filename, headerLines, responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch))
1029         td.release();
1030     return td;
1031 }
1032 
create(InputArray samples,int layout,InputArray responses,InputArray varIdx,InputArray sampleIdx,InputArray sampleWeights,InputArray varType)1033 Ptr<TrainData> TrainData::create(InputArray samples, int layout, InputArray responses,
1034                                  InputArray varIdx, InputArray sampleIdx, InputArray sampleWeights,
1035                                  InputArray varType)
1036 {
1037     CV_TRACE_FUNCTION_SKIP_NESTED();
1038     Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
1039     td->setData(samples, layout, responses, varIdx, sampleIdx, sampleWeights, varType, noArray());
1040     return td;
1041 }
1042 
1043 }}
1044 
1045 /* End of file. */
1046