1 // Copyright (C) 2010 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 #ifndef DLIB_LIBSVM_iO_Hh_ 4 #define DLIB_LIBSVM_iO_Hh_ 5 6 #include "libsvm_io_abstract.h" 7 8 #include <fstream> 9 #include <string> 10 #include <utility> 11 #include "../algs.h" 12 #include "../matrix.h" 13 #include "../string.h" 14 #include "../svm/sparse_vector.h" 15 #include <vector> 16 17 namespace dlib 18 { 19 struct sample_data_io_error : public error 20 { sample_data_io_errorsample_data_io_error21 sample_data_io_error(const std::string& message): error(message) {} 22 }; 23 24 // ---------------------------------------------------------------------------------------- 25 26 template <typename sample_type, typename label_type, typename alloc1, typename alloc2> load_libsvm_formatted_data(const std::string & file_name,std::vector<sample_type,alloc1> & samples,std::vector<label_type,alloc2> & labels)27 void load_libsvm_formatted_data ( 28 const std::string& file_name, 29 std::vector<sample_type, alloc1>& samples, 30 std::vector<label_type, alloc2>& labels 31 ) 32 { 33 using namespace std; 34 typedef typename sample_type::value_type pair_type; 35 typedef typename basic_type<typename pair_type::first_type>::type key_type; 36 typedef typename pair_type::second_type value_type; 37 38 // You must use unsigned integral key types in your sparse vectors 39 COMPILE_TIME_ASSERT(is_unsigned_type<key_type>::value); 40 41 samples.clear(); 42 labels.clear(); 43 44 ifstream fin(file_name.c_str()); 45 46 if (!fin) 47 throw sample_data_io_error("Unable to open file " + file_name); 48 49 string line; 50 istringstream sin; 51 key_type key; 52 value_type value; 53 label_type label; 54 sample_type sample; 55 long line_num = 0; 56 while (fin.peek() != EOF) 57 { 58 ++line_num; 59 getline(fin, line); 60 61 string::size_type pos = line.find_first_not_of(" \t\r\n"); 62 63 // ignore empty lines or comment lines 64 if (pos == string::npos || line[pos] == '#') 65 continue; 66 67 sin.clear(); 68 sin.str(line); 69 sample.clear(); 70 71 sin >> label; 72 73 if (!sin) 74 throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name ); 75 76 // eat whitespace 77 sin >> ws; 78 79 while (sin.peek() != EOF && sin.peek() != '#') 80 { 81 82 sin >> key >> ws; 83 84 // ignore what should be a : character 85 if (sin.get() != ':') 86 throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name); 87 88 sin >> value; 89 90 if (sin && value != 0) 91 { 92 sample.insert(sample.end(), make_pair(key, value)); 93 } 94 95 sin >> ws; 96 } 97 98 samples.push_back(sample); 99 labels.push_back(label); 100 } 101 102 } 103 104 // ---------------------------------------------------------------------------------------- 105 // ---------------------------------------------------------------------------------------- 106 107 template <typename sample_type, typename alloc> 108 typename enable_if<is_const_type<typename sample_type::value_type::first_type> >::type fix_nonzero_indexing(std::vector<sample_type,alloc> & samples)109 fix_nonzero_indexing ( 110 std::vector<sample_type,alloc>& samples 111 ) 112 { 113 typedef typename sample_type::value_type pair_type; 114 typedef typename basic_type<typename pair_type::first_type>::type key_type; 115 116 if (samples.size() == 0) 117 return; 118 119 // figure out the min index value 120 key_type min_idx = samples[0].begin()->first; 121 for (unsigned long i = 0; i < samples.size(); ++i) 122 min_idx = std::min(min_idx, samples[i].begin()->first); 123 124 // Now adjust all the samples so that their min index value is zero. 125 if (min_idx != 0) 126 { 127 sample_type temp; 128 for (unsigned long i = 0; i < samples.size(); ++i) 129 { 130 // copy samples[i] into temp but make sure it has a min index of zero. 131 temp.clear(); 132 typename sample_type::iterator j; 133 for (j = samples[i].begin(); j != samples[i].end(); ++j) 134 { 135 temp.insert(temp.end(), std::make_pair(j->first-min_idx, j->second)); 136 } 137 138 // replace the current sample with temp. 139 samples[i].swap(temp); 140 } 141 } 142 } 143 144 // ---------------------------------------------------------------------------------------- 145 146 // If the "first" values in the std::pair objects are not const then we can modify them 147 // directly and that is what this version of fix_nonzero_indexing() does. 148 template <typename sample_type, typename alloc> 149 typename disable_if<is_const_type<typename sample_type::value_type::first_type> >::type fix_nonzero_indexing(std::vector<sample_type,alloc> & samples)150 fix_nonzero_indexing ( 151 std::vector<sample_type,alloc>& samples 152 ) 153 { 154 typedef typename sample_type::value_type pair_type; 155 typedef typename basic_type<typename pair_type::first_type>::type key_type; 156 157 if (samples.size() == 0) 158 return; 159 160 // figure out the min index value 161 key_type min_idx = samples[0].begin()->first; 162 for (unsigned long i = 0; i < samples.size(); ++i) 163 min_idx = std::min(min_idx, samples[i].begin()->first); 164 165 // Now adjust all the samples so that their min index value is zero. 166 if (min_idx != 0) 167 { 168 for (unsigned long i = 0; i < samples.size(); ++i) 169 { 170 typename sample_type::iterator j; 171 for (j = samples[i].begin(); j != samples[i].end(); ++j) 172 { 173 j->first -= min_idx; 174 } 175 } 176 } 177 } 178 179 // ---------------------------------------------------------------------------------------- 180 // ---------------------------------------------------------------------------------------- 181 182 // This is an overload for sparse vectors 183 template <typename sample_type, typename label_type, typename alloc1, typename alloc2> save_libsvm_formatted_data(const std::string & file_name,const std::vector<sample_type,alloc1> & samples,const std::vector<label_type,alloc2> & labels)184 typename disable_if<is_matrix<sample_type>,void>::type save_libsvm_formatted_data ( 185 const std::string& file_name, 186 const std::vector<sample_type, alloc1>& samples, 187 const std::vector<label_type, alloc2>& labels 188 ) 189 { 190 typedef typename sample_type::value_type pair_type; 191 typedef typename basic_type<typename pair_type::first_type>::type key_type; 192 193 // You must use unsigned integral key types in your sparse vectors 194 COMPILE_TIME_ASSERT(is_unsigned_type<key_type>::value); 195 196 // make sure requires clause is not broken 197 DLIB_ASSERT(samples.size() == labels.size(), 198 "\t void save_libsvm_formatted_data()" 199 << "\n\t You have to have labels for each sample and vice versa" 200 << "\n\t samples.size(): " << samples.size() 201 << "\n\t labels.size(): " << labels.size() 202 ); 203 204 205 using namespace std; 206 ofstream fout(file_name.c_str()); 207 fout.precision(14); 208 209 if (!fout) 210 throw sample_data_io_error("Unable to open file " + file_name); 211 212 for (unsigned long i = 0; i < samples.size(); ++i) 213 { 214 fout << labels[i]; 215 216 for (typename sample_type::const_iterator j = samples[i].begin(); j != samples[i].end(); ++j) 217 { 218 if (j->second != 0) 219 fout << " " << j->first << ":" << j->second; 220 } 221 fout << "\n"; 222 223 if (!fout) 224 throw sample_data_io_error("Error while writing to file " + file_name); 225 } 226 227 } 228 229 // ---------------------------------------------------------------------------------------- 230 231 // This is an overload for dense vectors 232 template <typename sample_type, typename label_type, typename alloc1, typename alloc2> save_libsvm_formatted_data(const std::string & file_name,const std::vector<sample_type,alloc1> & samples,const std::vector<label_type,alloc2> & labels)233 typename enable_if<is_matrix<sample_type>,void>::type save_libsvm_formatted_data ( 234 const std::string& file_name, 235 const std::vector<sample_type, alloc1>& samples, 236 const std::vector<label_type, alloc2>& labels 237 ) 238 { 239 // make sure requires clause is not broken 240 DLIB_ASSERT(samples.size() == labels.size(), 241 "\t void save_libsvm_formatted_data()" 242 << "\n\t You have to have labels for each sample and vice versa" 243 << "\n\t samples.size(): " << samples.size() 244 << "\n\t labels.size(): " << labels.size() 245 ); 246 247 using namespace std; 248 ofstream fout(file_name.c_str()); 249 fout.precision(14); 250 251 if (!fout) 252 throw sample_data_io_error("Unable to open file " + file_name); 253 254 for (unsigned long i = 0; i < samples.size(); ++i) 255 { 256 fout << labels[i]; 257 258 for (long j = 0; j < samples[i].size(); ++j) 259 { 260 if (samples[i](j) != 0) 261 fout << " " << j << ":" << samples[i](j); 262 } 263 fout << "\n"; 264 265 if (!fout) 266 throw sample_data_io_error("Error while writing to file " + file_name); 267 } 268 269 } 270 271 // ---------------------------------------------------------------------------------------- 272 273 } 274 275 #endif // DLIB_LIBSVM_iO_Hh_ 276 277