1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // License Agreement
11 // For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2016, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 // * Redistribution's of source code must retain the above copyright notice,
21 // this list of conditions and the following disclaimer.
22 //
23 // * Redistribution's in binary form must reproduce the above copyright notice,
24 // this list of conditions and the following disclaimer in the documentation
25 // and/or other materials provided with the distribution.
26 //
27 // * The name of the copyright holders may not be used to endorse or promote products
28 // derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42
43 #include "precomp.hpp"
44 #include "limits"
45
46 #include <iostream>
47
48 using std::cout;
49 using std::endl;
50
51 /****************************************************************************************\
52 * Stochastic Gradient Descent SVM Classifier *
53 \****************************************************************************************/
54
55 namespace cv
56 {
57 namespace ml
58 {
59
60 class SVMSGDImpl CV_FINAL : public SVMSGD
61 {
62
63 public:
64 SVMSGDImpl();
65
~SVMSGDImpl()66 virtual ~SVMSGDImpl() {}
67
68 virtual bool train(const Ptr<TrainData>& data, int) CV_OVERRIDE;
69
70 virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const CV_OVERRIDE;
71
72 virtual bool isClassifier() const CV_OVERRIDE;
73
74 virtual bool isTrained() const CV_OVERRIDE;
75
76 virtual void clear() CV_OVERRIDE;
77
78 virtual void write(FileStorage &fs) const CV_OVERRIDE;
79
80 virtual void read(const FileNode &fn) CV_OVERRIDE;
81
getWeights()82 virtual Mat getWeights() CV_OVERRIDE { return weights_; }
83
getShift()84 virtual float getShift() CV_OVERRIDE { return shift_; }
85
getVarCount() const86 virtual int getVarCount() const CV_OVERRIDE { return weights_.cols; }
87
getDefaultName() const88 virtual String getDefaultName() const CV_OVERRIDE {return "opencv_ml_svmsgd";}
89
90 virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN) CV_OVERRIDE;
91
getSvmsgdType() const92 inline int getSvmsgdType() const CV_OVERRIDE { return params.svmsgdType; }
setSvmsgdType(int val)93 inline void setSvmsgdType(int val) CV_OVERRIDE { params.svmsgdType = val; }
getMarginType() const94 inline int getMarginType() const CV_OVERRIDE { return params.marginType; }
setMarginType(int val)95 inline void setMarginType(int val) CV_OVERRIDE { params.marginType = val; }
getMarginRegularization() const96 inline float getMarginRegularization() const CV_OVERRIDE { return params.marginRegularization; }
setMarginRegularization(float val)97 inline void setMarginRegularization(float val) CV_OVERRIDE { params.marginRegularization = val; }
getInitialStepSize() const98 inline float getInitialStepSize() const CV_OVERRIDE { return params.initialStepSize; }
setInitialStepSize(float val)99 inline void setInitialStepSize(float val) CV_OVERRIDE { params.initialStepSize = val; }
getStepDecreasingPower() const100 inline float getStepDecreasingPower() const CV_OVERRIDE { return params.stepDecreasingPower; }
setStepDecreasingPower(float val)101 inline void setStepDecreasingPower(float val) CV_OVERRIDE { params.stepDecreasingPower = val; }
getTermCriteria() const102 inline cv::TermCriteria getTermCriteria() const CV_OVERRIDE { return params.termCrit; }
setTermCriteria(const cv::TermCriteria & val)103 inline void setTermCriteria(const cv::TermCriteria& val) CV_OVERRIDE { params.termCrit = val; }
104
105 private:
106 void updateWeights(InputArray sample, bool positive, float stepSize, Mat &weights);
107
108 void writeParams( FileStorage &fs ) const;
109
110 void readParams( const FileNode &fn );
111
isPositive(float val)112 static inline bool isPositive(float val) { return val > 0; }
113
114 static void normalizeSamples(Mat &matrix, Mat &average, float &multiplier);
115
116 float calcShift(InputArray _samples, InputArray _responses) const;
117
118 static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier);
119
120 // Vector with SVM weights
121 Mat weights_;
122 float shift_;
123
124 // Parameters for learning
125 struct SVMSGDParams
126 {
127 float marginRegularization;
128 float initialStepSize;
129 float stepDecreasingPower;
130 TermCriteria termCrit;
131 int svmsgdType;
132 int marginType;
133 };
134
135 SVMSGDParams params;
136 };
137
create()138 Ptr<SVMSGD> SVMSGD::create()
139 {
140 return makePtr<SVMSGDImpl>();
141 }
142
load(const String & filepath,const String & nodeName)143 Ptr<SVMSGD> SVMSGD::load(const String& filepath, const String& nodeName)
144 {
145 return Algorithm::load<SVMSGD>(filepath, nodeName);
146 }
147
148
normalizeSamples(Mat & samples,Mat & average,float & multiplier)149 void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier)
150 {
151 int featuresCount = samples.cols;
152 int samplesCount = samples.rows;
153
154 average = Mat(1, featuresCount, samples.type());
155 CV_Assert(average.type() == CV_32FC1);
156 for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++)
157 {
158 average.at<float>(featureIndex) = static_cast<float>(mean(samples.col(featureIndex))[0]);
159 }
160
161 for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++)
162 {
163 samples.row(sampleIndex) -= average;
164 }
165
166 double normValue = norm(samples);
167
168 multiplier = static_cast<float>(sqrt(static_cast<double>(samples.total())) / normValue);
169
170 samples *= multiplier;
171 }
172
makeExtendedTrainSamples(const Mat & trainSamples,Mat & extendedTrainSamples,Mat & average,float & multiplier)173 void SVMSGDImpl::makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier)
174 {
175 Mat normalizedTrainSamples = trainSamples.clone();
176 int samplesCount = normalizedTrainSamples.rows;
177
178 normalizeSamples(normalizedTrainSamples, average, multiplier);
179
180 Mat onesCol = Mat::ones(samplesCount, 1, CV_32F);
181 cv::hconcat(normalizedTrainSamples, onesCol, extendedTrainSamples);
182 }
183
updateWeights(InputArray _sample,bool positive,float stepSize,Mat & weights)184 void SVMSGDImpl::updateWeights(InputArray _sample, bool positive, float stepSize, Mat& weights)
185 {
186 Mat sample = _sample.getMat();
187
188 int response = positive ? 1 : -1; // ensure that trainResponses are -1 or 1
189
190 if ( sample.dot(weights) * response > 1)
191 {
192 // Not a support vector, only apply weight decay
193 weights *= (1.f - stepSize * params.marginRegularization);
194 }
195 else
196 {
197 // It's a support vector, add it to the weights
198 weights -= (stepSize * params.marginRegularization) * weights - (stepSize * response) * sample;
199 }
200 }
201
calcShift(InputArray _samples,InputArray _responses) const202 float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
203 {
204 float margin[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
205
206 Mat trainSamples = _samples.getMat();
207 int trainSamplesCount = trainSamples.rows;
208
209 Mat trainResponses = _responses.getMat();
210
211 CV_Assert(trainResponses.type() == CV_32FC1);
212 for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++)
213 {
214 Mat currentSample = trainSamples.row(samplesIndex);
215 float dotProduct = static_cast<float>(currentSample.dot(weights_));
216
217 bool positive = isPositive(trainResponses.at<float>(samplesIndex));
218 int index = positive ? 0 : 1;
219 float signToMul = positive ? 1.f : -1.f;
220 float curMargin = dotProduct * signToMul;
221
222 if (curMargin < margin[index])
223 {
224 margin[index] = curMargin;
225 }
226 }
227
228 return -(margin[0] - margin[1]) / 2.f;
229 }
230
train(const Ptr<TrainData> & data,int)231 bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
232 {
233 CV_Assert(!data.empty());
234 clear();
235 CV_Assert( isClassifier() ); //toDo: consider
236
237 Mat trainSamples = data->getTrainSamples();
238
239 int featureCount = trainSamples.cols;
240 Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix
241
242 CV_Assert(trainResponses.rows == trainSamples.rows);
243
244 if (trainResponses.empty())
245 {
246 return false;
247 }
248
249 int positiveCount = countNonZero(trainResponses >= 0);
250 int negativeCount = countNonZero(trainResponses < 0);
251
252 if ( positiveCount <= 0 || negativeCount <= 0 )
253 {
254 weights_ = Mat::zeros(1, featureCount, CV_32F);
255 shift_ = (positiveCount > 0) ? 1.f : -1.f;
256 return true;
257 }
258
259 Mat extendedTrainSamples;
260 Mat average;
261 float multiplier = 0;
262 makeExtendedTrainSamples(trainSamples, extendedTrainSamples, average, multiplier);
263
264 int extendedTrainSamplesCount = extendedTrainSamples.rows;
265 int extendedFeatureCount = extendedTrainSamples.cols;
266
267 Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
268 Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
269 Mat averageExtendedWeights;
270 if (params.svmsgdType == ASGD)
271 {
272 averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
273 }
274
275 RNG rng(0);
276
277 CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS);
278 int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX;
279 double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0;
280
281 double err = DBL_MAX;
282 CV_Assert (trainResponses.type() == CV_32FC1);
283 // Stochastic gradient descent SVM
284 for (int iter = 0; (iter < maxCount) && (err > epsilon); iter++)
285 {
286 int randomNumber = rng.uniform(0, extendedTrainSamplesCount); //generate sample number
287
288 Mat currentSample = extendedTrainSamples.row(randomNumber);
289
290 float stepSize = params.initialStepSize * std::pow((1 + params.marginRegularization * params.initialStepSize * (float)iter), (-params.stepDecreasingPower)); //update stepSize
291
292 updateWeights( currentSample, isPositive(trainResponses.at<float>(randomNumber)), stepSize, extendedWeights );
293
294 //average weights (only for ASGD model)
295 if (params.svmsgdType == ASGD)
296 {
297 averageExtendedWeights = ((float)iter/ (1 + (float)iter)) * averageExtendedWeights + extendedWeights / (1 + (float) iter);
298 err = norm(averageExtendedWeights - previousWeights);
299 averageExtendedWeights.copyTo(previousWeights);
300 }
301 else
302 {
303 err = norm(extendedWeights - previousWeights);
304 extendedWeights.copyTo(previousWeights);
305 }
306 }
307
308 if (params.svmsgdType == ASGD)
309 {
310 extendedWeights = averageExtendedWeights;
311 }
312
313 Rect roi(0, 0, featureCount, 1);
314 weights_ = extendedWeights(roi);
315 weights_ *= multiplier;
316
317 CV_Assert((params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN) && (extendedWeights.type() == CV_32FC1));
318
319 if (params.marginType == SOFT_MARGIN)
320 {
321 shift_ = extendedWeights.at<float>(featureCount) - static_cast<float>(weights_.dot(average));
322 }
323 else
324 {
325 shift_ = calcShift(trainSamples, trainResponses);
326 }
327
328 return true;
329 }
330
predict(InputArray _samples,OutputArray _results,int) const331 float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
332 {
333 float result = 0;
334 cv::Mat samples = _samples.getMat();
335 int nSamples = samples.rows;
336 cv::Mat results;
337
338 CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32FC1);
339
340 if( _results.needed() )
341 {
342 _results.create( nSamples, 1, samples.type() );
343 results = _results.getMat();
344 }
345 else
346 {
347 CV_Assert( nSamples == 1 );
348 results = Mat(1, 1, CV_32FC1, &result);
349 }
350
351 for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++)
352 {
353 Mat currentSample = samples.row(sampleIndex);
354 float criterion = static_cast<float>(currentSample.dot(weights_)) + shift_;
355 results.at<float>(sampleIndex) = (criterion >= 0) ? 1.f : -1.f;
356 }
357
358 return result;
359 }
360
isClassifier() const361 bool SVMSGDImpl::isClassifier() const
362 {
363 return (params.svmsgdType == SGD || params.svmsgdType == ASGD)
364 &&
365 (params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN)
366 &&
367 (params.marginRegularization > 0) && (params.initialStepSize > 0) && (params.stepDecreasingPower >= 0);
368 }
369
isTrained() const370 bool SVMSGDImpl::isTrained() const
371 {
372 return !weights_.empty();
373 }
374
write(FileStorage & fs) const375 void SVMSGDImpl::write(FileStorage& fs) const
376 {
377 if( !isTrained() )
378 CV_Error( CV_StsParseError, "SVMSGD model data is invalid, it hasn't been trained" );
379
380 writeFormat(fs);
381 writeParams( fs );
382
383 fs << "weights" << weights_;
384 fs << "shift" << shift_;
385 }
386
writeParams(FileStorage & fs) const387 void SVMSGDImpl::writeParams( FileStorage& fs ) const
388 {
389 String SvmsgdTypeStr;
390
391 switch (params.svmsgdType)
392 {
393 case SGD:
394 SvmsgdTypeStr = "SGD";
395 break;
396 case ASGD:
397 SvmsgdTypeStr = "ASGD";
398 break;
399 default:
400 SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
401 }
402
403 fs << "svmsgdType" << SvmsgdTypeStr;
404
405 String marginTypeStr;
406
407 switch (params.marginType)
408 {
409 case SOFT_MARGIN:
410 marginTypeStr = "SOFT_MARGIN";
411 break;
412 case HARD_MARGIN:
413 marginTypeStr = "HARD_MARGIN";
414 break;
415 default:
416 marginTypeStr = format("Unknown_%d", params.marginType);
417 }
418
419 fs << "marginType" << marginTypeStr;
420
421 fs << "marginRegularization" << params.marginRegularization;
422 fs << "initialStepSize" << params.initialStepSize;
423 fs << "stepDecreasingPower" << params.stepDecreasingPower;
424
425 fs << "term_criteria" << "{:";
426 if( params.termCrit.type & TermCriteria::EPS )
427 fs << "epsilon" << params.termCrit.epsilon;
428 if( params.termCrit.type & TermCriteria::COUNT )
429 fs << "iterations" << params.termCrit.maxCount;
430 fs << "}";
431 }
readParams(const FileNode & fn)432 void SVMSGDImpl::readParams( const FileNode& fn )
433 {
434 String svmsgdTypeStr = (String)fn["svmsgdType"];
435 int svmsgdType =
436 svmsgdTypeStr == "SGD" ? SGD :
437 svmsgdTypeStr == "ASGD" ? ASGD : -1;
438
439 if( svmsgdType < 0 )
440 CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
441
442 params.svmsgdType = svmsgdType;
443
444 String marginTypeStr = (String)fn["marginType"];
445 int marginType =
446 marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
447 marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
448
449 if( marginType < 0 )
450 CV_Error( CV_StsParseError, "Missing or invalid margin type" );
451
452 params.marginType = marginType;
453
454 CV_Assert ( fn["marginRegularization"].isReal() );
455 params.marginRegularization = (float)fn["marginRegularization"];
456
457 CV_Assert ( fn["initialStepSize"].isReal() );
458 params.initialStepSize = (float)fn["initialStepSize"];
459
460 CV_Assert ( fn["stepDecreasingPower"].isReal() );
461 params.stepDecreasingPower = (float)fn["stepDecreasingPower"];
462
463 FileNode tcnode = fn["term_criteria"];
464 CV_Assert(!tcnode.empty());
465 params.termCrit.epsilon = (double)tcnode["epsilon"];
466 params.termCrit.maxCount = (int)tcnode["iterations"];
467 params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
468 (params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
469 CV_Assert ((params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS));
470 }
471
read(const FileNode & fn)472 void SVMSGDImpl::read(const FileNode& fn)
473 {
474 clear();
475
476 readParams(fn);
477
478 fn["weights"] >> weights_;
479 fn["shift"] >> shift_;
480 }
481
clear()482 void SVMSGDImpl::clear()
483 {
484 weights_.release();
485 shift_ = 0;
486 }
487
488
SVMSGDImpl()489 SVMSGDImpl::SVMSGDImpl()
490 {
491 clear();
492 setOptimalParameters();
493 }
494
setOptimalParameters(int svmsgdType,int marginType)495 void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
496 {
497 switch (svmsgdType)
498 {
499 case SGD:
500 params.svmsgdType = SGD;
501 params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
502 (marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
503 params.marginRegularization = 0.0001f;
504 params.initialStepSize = 0.05f;
505 params.stepDecreasingPower = 1.f;
506 params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
507 break;
508
509 case ASGD:
510 params.svmsgdType = ASGD;
511 params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
512 (marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
513 params.marginRegularization = 0.00001f;
514 params.initialStepSize = 0.05f;
515 params.stepDecreasingPower = 0.75f;
516 params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
517 break;
518
519 default:
520 CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
521 }
522 }
523 } //ml
524 } //cv
525