1 // Copyright (C) 2014 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 #ifndef DLIB_LDA_Hh_ 4 #define DLIB_LDA_Hh_ 5 6 #include "lda_abstract.h" 7 #include "../algs.h" 8 #include <map> 9 #include "../matrix.h" 10 #include <vector> 11 12 namespace dlib 13 { 14 15 // ---------------------------------------------------------------------------------------- 16 17 namespace impl 18 { 19 make_class_labels(const std::vector<unsigned long> & row_labels)20 inline std::map<unsigned long,unsigned long> make_class_labels( 21 const std::vector<unsigned long>& row_labels 22 ) 23 { 24 std::map<unsigned long,unsigned long> class_labels; 25 for (unsigned long i = 0; i < row_labels.size(); ++i) 26 { 27 const unsigned long next = class_labels.size(); 28 if (class_labels.count(row_labels[i]) == 0) 29 class_labels[row_labels[i]] = next; 30 } 31 return class_labels; 32 } 33 34 // ------------------------------------------------------------------------------------ 35 36 template < 37 typename T 38 > center_matrix(matrix<T> & X)39 matrix<T,0,1> center_matrix ( 40 matrix<T>& X 41 ) 42 { 43 matrix<T,1> mean; 44 for (long r = 0; r < X.nr(); ++r) 45 mean += rowm(X,r); 46 mean /= X.nr(); 47 48 for (long r = 0; r < X.nr(); ++r) 49 set_rowm(X,r) -= mean; 50 51 return trans(mean); 52 } 53 } 54 55 // ---------------------------------------------------------------------------------------- 56 57 template < 58 typename T 59 > 60 void compute_lda_transform ( 61 matrix<T>& X, 62 matrix<T,0,1>& mean, 63 const std::vector<unsigned long>& row_labels, 64 unsigned long lda_dims = 500, 65 unsigned long extra_pca_dims = 200 66 ) 67 { 68 std::map<unsigned long,unsigned long> class_labels = impl::make_class_labels(row_labels); 69 // LDA can only give out at most class_labels.size()-1 dimensions so don't try to 70 // compute more than that. 71 lda_dims = std::min<unsigned long>(lda_dims, class_labels.size()-1); 72 73 // make sure requires clause is not broken 74 DLIB_CASSERT(class_labels.size() > 1, 75 "\t void compute_lda_transform()" 76 << "\n\t You can't call this function if the number of distinct class labels is less than 2." 77 ); 78 DLIB_CASSERT(X.size() != 0 && (long)row_labels.size() == X.nr() && lda_dims != 0, 79 "\t void compute_lda_transform()" 80 << "\n\t Invalid inputs were given to this function." 81 << "\n\t X.size(): " << X.size() 82 << "\n\t row_labels.size(): " << row_labels.size() 83 << "\n\t lda_dims: " << lda_dims 84 ); 85 86 87 mean = impl::center_matrix(X); 88 // Do PCA to reduce dims 89 matrix<T> pu,pw,pv; 90 svd_fast(X, pu, pw, pv, lda_dims+extra_pca_dims, 4); 91 pu.set_size(0,0); // free RAM, we don't need pu. 92 X = X*pv; 93 94 95 matrix<T> class_means(class_labels.size(), X.nc()); 96 class_means = 0; 97 matrix<T,0,1> class_counts(class_labels.size()); 98 class_counts = 0; 99 100 // First compute the means of each class 101 for (unsigned long i = 0; i < row_labels.size(); ++i) 102 { 103 const unsigned long class_idx = class_labels[row_labels[i]]; 104 set_rowm(class_means,class_idx) += rowm(X,i); 105 class_counts(class_idx)++; 106 } 107 class_means = inv(diagm(class_counts))*class_means; 108 // subtract means from the data 109 for (unsigned long i = 0; i < row_labels.size(); ++i) 110 { 111 const unsigned long class_idx = class_labels[row_labels[i]]; 112 set_rowm(X,i) -= rowm(class_means,class_idx); 113 } 114 115 // Note that we are using the formulas from the paper Using Discriminant 116 // Eigenfeatures for Image Retrieval by Swets and Weng. 117 matrix<T> Sw = trans(X)*X; 118 matrix<T> Sb = trans(class_means)*class_means; 119 matrix<T> A, H; 120 matrix<T,0,1> W; 121 svd3(Sw, A, W, H); 122 W = sqrt(W); 123 W = reciprocal(lowerbound(W,max(W)*1e-5)); 124 A = trans(H*diagm(W))*Sb*H*diagm(W); 125 matrix<T> v,s,u; 126 svd3(A, v, s, u); diagm(W)127 matrix<T> tform = H*diagm(W)*u; 128 // pick out only the number of dimensions we are supposed to for the output, unless 129 // we should just keep them all, then don't do anything. 130 if ((long)lda_dims <= tform.nc()) 131 { 132 rsort_columns(tform, s); 133 tform = colm(tform, range(0, lda_dims-1)); 134 } 135 136 X = trans(pv*tform); 137 mean = X*mean; 138 } 139 140 // ---------------------------------------------------------------------------------------- 141 142 struct roc_point 143 { 144 double true_positive_rate; 145 double false_positive_rate; 146 double detection_threshold; 147 }; 148 compute_roc_curve(const std::vector<double> & true_detections,const std::vector<double> & false_detections)149 inline std::vector<roc_point> compute_roc_curve ( 150 const std::vector<double>& true_detections, 151 const std::vector<double>& false_detections 152 ) 153 { 154 DLIB_CASSERT(true_detections.size() != 0); 155 DLIB_CASSERT(false_detections.size() != 0); 156 157 std::vector<std::pair<double,int> > temp; 158 temp.reserve(true_detections.size()+false_detections.size()); 159 // We use -1 for true labels and +1 for false so when we call std::sort() below it will sort 160 // runs with equal detection scores so false come first. This will avoid it seeming like we 161 // can separate true from false when scores are equal in the loop below. 162 const int true_label = -1; 163 const int false_label = +1; 164 for (unsigned long i = 0; i < true_detections.size(); ++i) 165 temp.push_back(std::make_pair(true_detections[i], true_label)); 166 for (unsigned long i = 0; i < false_detections.size(); ++i) 167 temp.push_back(std::make_pair(false_detections[i], false_label)); 168 169 std::sort(temp.rbegin(), temp.rend()); 170 171 172 std::vector<roc_point> roc_curve; 173 roc_curve.reserve(temp.size()); 174 175 double num_false_included = 0; 176 double num_true_included = 0; 177 for (unsigned long i = 0; i < temp.size(); ++i) 178 { 179 if (temp[i].second == true_label) 180 num_true_included++; 181 else 182 num_false_included++; 183 184 roc_point p; 185 p.true_positive_rate = num_true_included/true_detections.size(); 186 p.false_positive_rate = num_false_included/false_detections.size(); 187 p.detection_threshold = temp[i].first; 188 roc_curve.push_back(p); 189 } 190 191 return roc_curve; 192 } 193 194 // ---------------------------------------------------------------------------------------- 195 equal_error_rate(const std::vector<double> & low_vals,const std::vector<double> & high_vals)196 inline std::pair<double,double> equal_error_rate ( 197 const std::vector<double>& low_vals, 198 const std::vector<double>& high_vals 199 ) 200 { 201 if (low_vals.size() == 0 && high_vals.size() == 0) 202 return std::make_pair(0,0); 203 else if (low_vals.size() == 0) 204 return std::make_pair(0, min(mat(high_vals))); 205 else if (high_vals.size() == 0) 206 return std::make_pair(0, max(mat(low_vals))+1); 207 208 // Find the point of equal error rates 209 double best_thresh = 0; 210 double best_error = 0; 211 double best_delta = std::numeric_limits<double>::infinity(); 212 for (const auto& pt : compute_roc_curve(high_vals, low_vals)) 213 { 214 const double false_negative_rate = 1-pt.true_positive_rate; 215 const double delta = std::abs(false_negative_rate - pt.false_positive_rate); 216 if (delta < best_delta) 217 { 218 best_delta = delta; 219 best_error = std::max(false_negative_rate, pt.false_positive_rate); 220 best_thresh = pt.detection_threshold; 221 } 222 } 223 224 return std::make_pair(best_error, best_thresh); 225 } 226 227 // ---------------------------------------------------------------------------------------- 228 229 } 230 231 #endif // DLIB_LDA_Hh_ 232 233