1 // Copyright (C) 2010  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_LIBSVM_iO_Hh_
4 #define DLIB_LIBSVM_iO_Hh_
5 
6 #include "libsvm_io_abstract.h"
7 
8 #include <fstream>
9 #include <string>
10 #include <utility>
11 #include "../algs.h"
12 #include "../matrix.h"
13 #include "../string.h"
14 #include "../svm/sparse_vector.h"
15 #include <vector>
16 
17 namespace dlib
18 {
19     struct sample_data_io_error : public error
20     {
sample_data_io_errorsample_data_io_error21         sample_data_io_error(const std::string& message): error(message) {}
22     };
23 
24 // ----------------------------------------------------------------------------------------
25 
26     template <typename sample_type, typename label_type, typename alloc1, typename alloc2>
load_libsvm_formatted_data(const std::string & file_name,std::vector<sample_type,alloc1> & samples,std::vector<label_type,alloc2> & labels)27     void load_libsvm_formatted_data (
28         const std::string& file_name,
29         std::vector<sample_type, alloc1>& samples,
30         std::vector<label_type, alloc2>& labels
31     )
32     {
33         using namespace std;
34         typedef typename sample_type::value_type pair_type;
35         typedef typename basic_type<typename pair_type::first_type>::type key_type;
36         typedef typename pair_type::second_type value_type;
37 
38         // You must use unsigned integral key types in your sparse vectors
39         COMPILE_TIME_ASSERT(is_unsigned_type<key_type>::value);
40 
41         samples.clear();
42         labels.clear();
43 
44         ifstream fin(file_name.c_str());
45 
46         if (!fin)
47             throw sample_data_io_error("Unable to open file " + file_name);
48 
49         string line;
50         istringstream sin;
51         key_type key;
52         value_type value;
53         label_type label;
54         sample_type sample;
55         long line_num = 0;
56         while (fin.peek() != EOF)
57         {
58             ++line_num;
59             getline(fin, line);
60 
61             string::size_type pos = line.find_first_not_of(" \t\r\n");
62 
63             // ignore empty lines or comment lines
64             if (pos == string::npos || line[pos] == '#')
65                 continue;
66 
67             sin.clear();
68             sin.str(line);
69             sample.clear();
70 
71             sin >> label;
72 
73             if (!sin)
74                 throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name );
75 
76             // eat whitespace
77             sin >> ws;
78 
79             while (sin.peek() != EOF && sin.peek() != '#')
80             {
81 
82                 sin >> key >> ws;
83 
84                 // ignore what should be a : character
85                 if (sin.get() != ':')
86                     throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name);
87 
88                 sin >> value;
89 
90                 if (sin && value != 0)
91                 {
92                     sample.insert(sample.end(), make_pair(key, value));
93                 }
94 
95                 sin >> ws;
96             }
97 
98             samples.push_back(sample);
99             labels.push_back(label);
100         }
101 
102     }
103 
104 // ----------------------------------------------------------------------------------------
105 // ----------------------------------------------------------------------------------------
106 
107     template <typename sample_type, typename alloc>
108     typename enable_if<is_const_type<typename sample_type::value_type::first_type> >::type
fix_nonzero_indexing(std::vector<sample_type,alloc> & samples)109     fix_nonzero_indexing (
110         std::vector<sample_type,alloc>& samples
111     )
112     {
113         typedef typename sample_type::value_type pair_type;
114         typedef typename basic_type<typename pair_type::first_type>::type key_type;
115 
116         if (samples.size() == 0)
117             return;
118 
119         // figure out the min index value
120         key_type min_idx = samples[0].begin()->first;
121         for (unsigned long i = 0; i < samples.size(); ++i)
122             min_idx = std::min(min_idx, samples[i].begin()->first);
123 
124         // Now adjust all the samples so that their min index value is zero.
125         if (min_idx != 0)
126         {
127             sample_type temp;
128             for (unsigned long i = 0; i < samples.size(); ++i)
129             {
130                 // copy samples[i] into temp but make sure it has a min index of zero.
131                 temp.clear();
132                 typename sample_type::iterator j;
133                 for (j = samples[i].begin(); j != samples[i].end(); ++j)
134                 {
135                     temp.insert(temp.end(), std::make_pair(j->first-min_idx, j->second));
136                 }
137 
138                 // replace the current sample with temp.
139                 samples[i].swap(temp);
140             }
141         }
142     }
143 
144 // ----------------------------------------------------------------------------------------
145 
146 // If the "first" values in the std::pair objects are not const then we can modify them
147 // directly and that is what this version of fix_nonzero_indexing() does.
148     template <typename sample_type, typename alloc>
149     typename disable_if<is_const_type<typename sample_type::value_type::first_type> >::type
fix_nonzero_indexing(std::vector<sample_type,alloc> & samples)150     fix_nonzero_indexing (
151         std::vector<sample_type,alloc>& samples
152     )
153     {
154         typedef typename sample_type::value_type pair_type;
155         typedef typename basic_type<typename pair_type::first_type>::type key_type;
156 
157         if (samples.size() == 0)
158             return;
159 
160         // figure out the min index value
161         key_type min_idx = samples[0].begin()->first;
162         for (unsigned long i = 0; i < samples.size(); ++i)
163             min_idx = std::min(min_idx, samples[i].begin()->first);
164 
165         // Now adjust all the samples so that their min index value is zero.
166         if (min_idx != 0)
167         {
168             for (unsigned long i = 0; i < samples.size(); ++i)
169             {
170                 typename sample_type::iterator j;
171                 for (j = samples[i].begin(); j != samples[i].end(); ++j)
172                 {
173                     j->first -= min_idx;
174                 }
175             }
176         }
177     }
178 
179 // ----------------------------------------------------------------------------------------
180 // ----------------------------------------------------------------------------------------
181 
182 // This is an overload for sparse vectors
183     template <typename sample_type, typename label_type, typename alloc1, typename alloc2>
save_libsvm_formatted_data(const std::string & file_name,const std::vector<sample_type,alloc1> & samples,const std::vector<label_type,alloc2> & labels)184     typename disable_if<is_matrix<sample_type>,void>::type save_libsvm_formatted_data (
185         const std::string& file_name,
186         const std::vector<sample_type, alloc1>& samples,
187         const std::vector<label_type, alloc2>& labels
188     )
189     {
190         typedef typename sample_type::value_type pair_type;
191         typedef typename basic_type<typename pair_type::first_type>::type key_type;
192 
193         // You must use unsigned integral key types in your sparse vectors
194         COMPILE_TIME_ASSERT(is_unsigned_type<key_type>::value);
195 
196         // make sure requires clause is not broken
197         DLIB_ASSERT(samples.size() == labels.size(),
198             "\t void save_libsvm_formatted_data()"
199             << "\n\t You have to have labels for each sample and vice versa"
200             << "\n\t samples.size(): " << samples.size()
201             << "\n\t labels.size():  " << labels.size()
202             );
203 
204 
205         using namespace std;
206         ofstream fout(file_name.c_str());
207         fout.precision(14);
208 
209         if (!fout)
210             throw sample_data_io_error("Unable to open file " + file_name);
211 
212         for (unsigned long i = 0; i < samples.size(); ++i)
213         {
214             fout << labels[i];
215 
216             for (typename sample_type::const_iterator j = samples[i].begin(); j != samples[i].end(); ++j)
217             {
218                 if (j->second != 0)
219                     fout << " " << j->first << ":" << j->second;
220             }
221             fout << "\n";
222 
223             if (!fout)
224                 throw sample_data_io_error("Error while writing to file " + file_name);
225         }
226 
227     }
228 
229 // ----------------------------------------------------------------------------------------
230 
231 // This is an overload for dense vectors
232     template <typename sample_type, typename label_type, typename alloc1, typename alloc2>
save_libsvm_formatted_data(const std::string & file_name,const std::vector<sample_type,alloc1> & samples,const std::vector<label_type,alloc2> & labels)233     typename enable_if<is_matrix<sample_type>,void>::type save_libsvm_formatted_data (
234         const std::string& file_name,
235         const std::vector<sample_type, alloc1>& samples,
236         const std::vector<label_type, alloc2>& labels
237     )
238     {
239         // make sure requires clause is not broken
240         DLIB_ASSERT(samples.size() == labels.size(),
241             "\t void save_libsvm_formatted_data()"
242             << "\n\t You have to have labels for each sample and vice versa"
243             << "\n\t samples.size(): " << samples.size()
244             << "\n\t labels.size():  " << labels.size()
245             );
246 
247         using namespace std;
248         ofstream fout(file_name.c_str());
249         fout.precision(14);
250 
251         if (!fout)
252             throw sample_data_io_error("Unable to open file " + file_name);
253 
254         for (unsigned long i = 0; i < samples.size(); ++i)
255         {
256             fout << labels[i];
257 
258             for (long j = 0; j < samples[i].size(); ++j)
259             {
260                 if (samples[i](j) != 0)
261                     fout << " " << j << ":" << samples[i](j);
262             }
263             fout << "\n";
264 
265             if (!fout)
266                 throw sample_data_io_error("Error while writing to file " + file_name);
267         }
268 
269     }
270 
271 // ----------------------------------------------------------------------------------------
272 
273 }
274 
275 #endif // DLIB_LIBSVM_iO_Hh_
276 
277