1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2016, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42 
43 #include "precomp.hpp"
44 #include "limits"
45 
46 #include <iostream>
47 
48 using std::cout;
49 using std::endl;
50 
51 /****************************************************************************************\
52 *                        Stochastic Gradient Descent SVM Classifier                      *
53 \****************************************************************************************/
54 
55 namespace cv
56 {
57 namespace ml
58 {
59 
60 class SVMSGDImpl CV_FINAL : public SVMSGD
61 {
62 
63 public:
64     SVMSGDImpl();
65 
~SVMSGDImpl()66     virtual ~SVMSGDImpl() {}
67 
68     virtual bool train(const Ptr<TrainData>& data, int) CV_OVERRIDE;
69 
70     virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const CV_OVERRIDE;
71 
72     virtual bool isClassifier() const CV_OVERRIDE;
73 
74     virtual bool isTrained() const CV_OVERRIDE;
75 
76     virtual void clear() CV_OVERRIDE;
77 
78     virtual void write(FileStorage &fs) const CV_OVERRIDE;
79 
80     virtual void read(const FileNode &fn) CV_OVERRIDE;
81 
getWeights()82     virtual Mat getWeights() CV_OVERRIDE { return weights_; }
83 
getShift()84     virtual float getShift() CV_OVERRIDE { return shift_; }
85 
getVarCount() const86     virtual int getVarCount() const CV_OVERRIDE { return weights_.cols; }
87 
getDefaultName() const88     virtual String getDefaultName() const CV_OVERRIDE {return "opencv_ml_svmsgd";}
89 
90     virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN) CV_OVERRIDE;
91 
getSvmsgdType() const92     inline int getSvmsgdType() const CV_OVERRIDE { return params.svmsgdType; }
setSvmsgdType(int val)93     inline void setSvmsgdType(int val) CV_OVERRIDE { params.svmsgdType = val; }
getMarginType() const94     inline int getMarginType() const CV_OVERRIDE { return params.marginType; }
setMarginType(int val)95     inline void setMarginType(int val) CV_OVERRIDE { params.marginType = val; }
getMarginRegularization() const96     inline float getMarginRegularization() const CV_OVERRIDE { return params.marginRegularization; }
setMarginRegularization(float val)97     inline void setMarginRegularization(float val) CV_OVERRIDE { params.marginRegularization = val; }
getInitialStepSize() const98     inline float getInitialStepSize() const CV_OVERRIDE { return params.initialStepSize; }
setInitialStepSize(float val)99     inline void setInitialStepSize(float val) CV_OVERRIDE { params.initialStepSize = val; }
getStepDecreasingPower() const100     inline float getStepDecreasingPower() const CV_OVERRIDE { return params.stepDecreasingPower; }
setStepDecreasingPower(float val)101     inline void setStepDecreasingPower(float val) CV_OVERRIDE { params.stepDecreasingPower = val; }
getTermCriteria() const102     inline cv::TermCriteria getTermCriteria() const CV_OVERRIDE { return params.termCrit; }
setTermCriteria(const cv::TermCriteria & val)103     inline void setTermCriteria(const cv::TermCriteria& val) CV_OVERRIDE { params.termCrit = val; }
104 
105 private:
106     void updateWeights(InputArray sample, bool positive, float stepSize, Mat &weights);
107 
108     void writeParams( FileStorage &fs ) const;
109 
110     void readParams( const FileNode &fn );
111 
isPositive(float val)112     static inline bool isPositive(float val) { return val > 0; }
113 
114     static void normalizeSamples(Mat &matrix, Mat &average, float &multiplier);
115 
116     float calcShift(InputArray _samples, InputArray _responses) const;
117 
118     static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier);
119 
120     // Vector with SVM weights
121     Mat weights_;
122     float shift_;
123 
124     // Parameters for learning
125     struct SVMSGDParams
126     {
127         float marginRegularization;
128         float initialStepSize;
129         float stepDecreasingPower;
130         TermCriteria termCrit;
131         int svmsgdType;
132         int marginType;
133     };
134 
135     SVMSGDParams params;
136 };
137 
create()138 Ptr<SVMSGD> SVMSGD::create()
139 {
140     return makePtr<SVMSGDImpl>();
141 }
142 
load(const String & filepath,const String & nodeName)143 Ptr<SVMSGD> SVMSGD::load(const String& filepath, const String& nodeName)
144 {
145     return Algorithm::load<SVMSGD>(filepath, nodeName);
146 }
147 
148 
normalizeSamples(Mat & samples,Mat & average,float & multiplier)149 void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier)
150 {
151     int featuresCount = samples.cols;
152     int samplesCount = samples.rows;
153 
154     average = Mat(1, featuresCount, samples.type());
155     CV_Assert(average.type() ==  CV_32FC1);
156     for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++)
157     {
158         average.at<float>(featureIndex) = static_cast<float>(mean(samples.col(featureIndex))[0]);
159     }
160 
161     for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++)
162     {
163         samples.row(sampleIndex) -= average;
164     }
165 
166     double normValue = norm(samples);
167 
168     multiplier = static_cast<float>(sqrt(static_cast<double>(samples.total())) / normValue);
169 
170     samples *= multiplier;
171 }
172 
makeExtendedTrainSamples(const Mat & trainSamples,Mat & extendedTrainSamples,Mat & average,float & multiplier)173 void SVMSGDImpl::makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier)
174 {
175     Mat normalizedTrainSamples = trainSamples.clone();
176     int samplesCount = normalizedTrainSamples.rows;
177 
178     normalizeSamples(normalizedTrainSamples, average, multiplier);
179 
180     Mat onesCol = Mat::ones(samplesCount, 1, CV_32F);
181     cv::hconcat(normalizedTrainSamples, onesCol, extendedTrainSamples);
182 }
183 
updateWeights(InputArray _sample,bool positive,float stepSize,Mat & weights)184 void SVMSGDImpl::updateWeights(InputArray _sample, bool positive, float stepSize, Mat& weights)
185 {
186     Mat sample = _sample.getMat();
187 
188     int response = positive ? 1 : -1; // ensure that trainResponses are -1 or 1
189 
190     if ( sample.dot(weights) * response > 1)
191     {
192         // Not a support vector, only apply weight decay
193         weights *= (1.f - stepSize * params.marginRegularization);
194     }
195     else
196     {
197         // It's a support vector, add it to the weights
198         weights -= (stepSize * params.marginRegularization) * weights - (stepSize * response) * sample;
199     }
200 }
201 
calcShift(InputArray _samples,InputArray _responses) const202 float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
203 {
204     float margin[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
205 
206     Mat trainSamples = _samples.getMat();
207     int trainSamplesCount = trainSamples.rows;
208 
209     Mat trainResponses = _responses.getMat();
210 
211     CV_Assert(trainResponses.type() ==  CV_32FC1);
212     for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++)
213     {
214         Mat currentSample = trainSamples.row(samplesIndex);
215         float dotProduct = static_cast<float>(currentSample.dot(weights_));
216 
217         bool positive = isPositive(trainResponses.at<float>(samplesIndex));
218         int index = positive ? 0 : 1;
219         float signToMul = positive ? 1.f : -1.f;
220         float curMargin = dotProduct * signToMul;
221 
222         if (curMargin < margin[index])
223         {
224             margin[index] = curMargin;
225         }
226     }
227 
228     return -(margin[0] - margin[1]) / 2.f;
229 }
230 
train(const Ptr<TrainData> & data,int)231 bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
232 {
233     CV_Assert(!data.empty());
234     clear();
235     CV_Assert( isClassifier() );   //toDo: consider
236 
237     Mat trainSamples = data->getTrainSamples();
238 
239     int featureCount = trainSamples.cols;
240     Mat trainResponses = data->getTrainResponses();        // (trainSamplesCount x 1) matrix
241 
242     CV_Assert(trainResponses.rows == trainSamples.rows);
243 
244     if (trainResponses.empty())
245     {
246         return false;
247     }
248 
249     int positiveCount = countNonZero(trainResponses >= 0);
250     int negativeCount = countNonZero(trainResponses < 0);
251 
252     if ( positiveCount <= 0 || negativeCount <= 0 )
253     {
254         weights_ = Mat::zeros(1, featureCount, CV_32F);
255         shift_ = (positiveCount > 0) ? 1.f : -1.f;
256         return true;
257     }
258 
259     Mat extendedTrainSamples;
260     Mat average;
261     float multiplier = 0;
262     makeExtendedTrainSamples(trainSamples, extendedTrainSamples, average, multiplier);
263 
264     int extendedTrainSamplesCount = extendedTrainSamples.rows;
265     int extendedFeatureCount = extendedTrainSamples.cols;
266 
267     Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
268     Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
269     Mat averageExtendedWeights;
270     if (params.svmsgdType == ASGD)
271     {
272         averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
273     }
274 
275     RNG rng(0);
276 
277     CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS);
278     int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX;
279     double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0;
280 
281     double err = DBL_MAX;
282     CV_Assert (trainResponses.type() == CV_32FC1);
283     // Stochastic gradient descent SVM
284     for (int iter = 0; (iter < maxCount) && (err > epsilon); iter++)
285     {
286         int randomNumber = rng.uniform(0, extendedTrainSamplesCount);             //generate sample number
287 
288         Mat currentSample = extendedTrainSamples.row(randomNumber);
289 
290         float stepSize = params.initialStepSize * std::pow((1 + params.marginRegularization * params.initialStepSize * (float)iter), (-params.stepDecreasingPower));    //update stepSize
291 
292         updateWeights( currentSample, isPositive(trainResponses.at<float>(randomNumber)), stepSize, extendedWeights );
293 
294         //average weights (only for ASGD model)
295         if (params.svmsgdType == ASGD)
296         {
297             averageExtendedWeights = ((float)iter/ (1 + (float)iter)) * averageExtendedWeights  + extendedWeights / (1 + (float) iter);
298             err = norm(averageExtendedWeights - previousWeights);
299             averageExtendedWeights.copyTo(previousWeights);
300         }
301         else
302         {
303             err = norm(extendedWeights - previousWeights);
304             extendedWeights.copyTo(previousWeights);
305         }
306     }
307 
308     if (params.svmsgdType == ASGD)
309     {
310         extendedWeights = averageExtendedWeights;
311     }
312 
313     Rect roi(0, 0, featureCount, 1);
314     weights_ = extendedWeights(roi);
315     weights_ *= multiplier;
316 
317     CV_Assert((params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN) && (extendedWeights.type() ==  CV_32FC1));
318 
319     if (params.marginType == SOFT_MARGIN)
320     {
321         shift_ = extendedWeights.at<float>(featureCount) - static_cast<float>(weights_.dot(average));
322     }
323     else
324     {
325         shift_ = calcShift(trainSamples, trainResponses);
326     }
327 
328     return true;
329 }
330 
predict(InputArray _samples,OutputArray _results,int) const331 float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
332 {
333     float result = 0;
334     cv::Mat samples = _samples.getMat();
335     int nSamples = samples.rows;
336     cv::Mat results;
337 
338     CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32FC1);
339 
340     if( _results.needed() )
341     {
342         _results.create( nSamples, 1, samples.type() );
343         results = _results.getMat();
344     }
345     else
346     {
347         CV_Assert( nSamples == 1 );
348         results = Mat(1, 1, CV_32FC1, &result);
349     }
350 
351     for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++)
352     {
353         Mat currentSample = samples.row(sampleIndex);
354         float criterion = static_cast<float>(currentSample.dot(weights_)) + shift_;
355         results.at<float>(sampleIndex) = (criterion >= 0) ? 1.f : -1.f;
356     }
357 
358     return result;
359 }
360 
isClassifier() const361 bool SVMSGDImpl::isClassifier() const
362 {
363     return (params.svmsgdType == SGD || params.svmsgdType == ASGD)
364             &&
365             (params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN)
366             &&
367             (params.marginRegularization > 0) && (params.initialStepSize > 0) && (params.stepDecreasingPower >= 0);
368 }
369 
isTrained() const370 bool SVMSGDImpl::isTrained() const
371 {
372     return !weights_.empty();
373 }
374 
write(FileStorage & fs) const375 void SVMSGDImpl::write(FileStorage& fs) const
376 {
377     if( !isTrained() )
378         CV_Error( CV_StsParseError, "SVMSGD model data is invalid, it hasn't been trained" );
379 
380     writeFormat(fs);
381     writeParams( fs );
382 
383     fs << "weights" << weights_;
384     fs << "shift" << shift_;
385 }
386 
writeParams(FileStorage & fs) const387 void SVMSGDImpl::writeParams( FileStorage& fs ) const
388 {
389     String SvmsgdTypeStr;
390 
391     switch (params.svmsgdType)
392     {
393     case SGD:
394         SvmsgdTypeStr = "SGD";
395         break;
396     case ASGD:
397         SvmsgdTypeStr = "ASGD";
398         break;
399     default:
400         SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
401     }
402 
403     fs << "svmsgdType" << SvmsgdTypeStr;
404 
405     String marginTypeStr;
406 
407     switch (params.marginType)
408     {
409     case SOFT_MARGIN:
410         marginTypeStr = "SOFT_MARGIN";
411         break;
412     case HARD_MARGIN:
413         marginTypeStr = "HARD_MARGIN";
414         break;
415     default:
416         marginTypeStr = format("Unknown_%d", params.marginType);
417     }
418 
419     fs << "marginType" << marginTypeStr;
420 
421     fs << "marginRegularization" << params.marginRegularization;
422     fs << "initialStepSize" << params.initialStepSize;
423     fs << "stepDecreasingPower" << params.stepDecreasingPower;
424 
425     fs << "term_criteria" << "{:";
426     if( params.termCrit.type & TermCriteria::EPS )
427         fs << "epsilon" << params.termCrit.epsilon;
428     if( params.termCrit.type & TermCriteria::COUNT )
429         fs << "iterations" << params.termCrit.maxCount;
430     fs << "}";
431 }
readParams(const FileNode & fn)432 void SVMSGDImpl::readParams( const FileNode& fn )
433 {
434     String svmsgdTypeStr = (String)fn["svmsgdType"];
435     int svmsgdType =
436             svmsgdTypeStr == "SGD" ? SGD :
437                                      svmsgdTypeStr == "ASGD" ? ASGD : -1;
438 
439     if( svmsgdType < 0 )
440         CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
441 
442     params.svmsgdType = svmsgdType;
443 
444     String marginTypeStr = (String)fn["marginType"];
445     int marginType =
446             marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
447                                              marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
448 
449     if( marginType < 0 )
450         CV_Error( CV_StsParseError, "Missing or invalid margin type" );
451 
452     params.marginType = marginType;
453 
454     CV_Assert ( fn["marginRegularization"].isReal() );
455     params.marginRegularization = (float)fn["marginRegularization"];
456 
457     CV_Assert ( fn["initialStepSize"].isReal() );
458     params.initialStepSize = (float)fn["initialStepSize"];
459 
460     CV_Assert ( fn["stepDecreasingPower"].isReal() );
461     params.stepDecreasingPower = (float)fn["stepDecreasingPower"];
462 
463     FileNode tcnode = fn["term_criteria"];
464     CV_Assert(!tcnode.empty());
465     params.termCrit.epsilon = (double)tcnode["epsilon"];
466     params.termCrit.maxCount = (int)tcnode["iterations"];
467     params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
468             (params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
469     CV_Assert ((params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS));
470 }
471 
read(const FileNode & fn)472 void SVMSGDImpl::read(const FileNode& fn)
473 {
474     clear();
475 
476     readParams(fn);
477 
478     fn["weights"] >> weights_;
479     fn["shift"] >> shift_;
480 }
481 
clear()482 void SVMSGDImpl::clear()
483 {
484     weights_.release();
485     shift_ = 0;
486 }
487 
488 
SVMSGDImpl()489 SVMSGDImpl::SVMSGDImpl()
490 {
491     clear();
492     setOptimalParameters();
493 }
494 
setOptimalParameters(int svmsgdType,int marginType)495 void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
496 {
497     switch (svmsgdType)
498     {
499     case SGD:
500         params.svmsgdType = SGD;
501         params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
502                                                           (marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
503         params.marginRegularization = 0.0001f;
504         params.initialStepSize = 0.05f;
505         params.stepDecreasingPower = 1.f;
506         params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
507         break;
508 
509     case ASGD:
510         params.svmsgdType = ASGD;
511         params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
512                                                           (marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
513         params.marginRegularization = 0.00001f;
514         params.initialStepSize = 0.05f;
515         params.stepDecreasingPower = 0.75f;
516         params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
517         break;
518 
519     default:
520         CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
521     }
522 }
523 }   //ml
524 }   //cv
525