1 /************************************************************************/
2 /*                                                                      */
3 /*    Copyright 2009,2014, 2015 by Sven Peter, Philip Schill,           */
4 /*                                 Rahul Nair and Ullrich Koethe        */
5 /*                                                                      */
6 /*    This file is part of the VIGRA computer vision library.           */
7 /*    The VIGRA Website is                                              */
8 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
9 /*    Please direct questions, bug reports, and contributions to        */
10 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
11 /*        vigra@informatik.uni-hamburg.de                               */
12 /*                                                                      */
13 /*    Permission is hereby granted, free of charge, to any person       */
14 /*    obtaining a copy of this software and associated documentation    */
15 /*    files (the "Software"), to deal in the Software without           */
16 /*    restriction, including without limitation the rights to use,      */
17 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
18 /*    sell copies of the Software, and to permit persons to whom the    */
19 /*    Software is furnished to do so, subject to the following          */
20 /*    conditions:                                                       */
21 /*                                                                      */
22 /*    The above copyright notice and this permission notice shall be    */
23 /*    included in all copies or substantial portions of the             */
24 /*    Software.                                                         */
25 /*                                                                      */
26 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
27 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
28 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
29 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
30 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
31 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
32 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
33 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
34 /*                                                                      */
35 /************************************************************************/
36 
37 #ifndef VIGRA_RF3_IMPEX_HDF5_HXX
38 #define VIGRA_RF3_IMPEX_HDF5_HXX
39 
40 #include <string>
41 #include <sstream>
42 #include <iomanip>
43 #include <stack>
44 
45 #include "config.hxx"
46 #include "random_forest_3/random_forest.hxx"
47 #include "random_forest_3/random_forest_common.hxx"
48 #include "random_forest_3/random_forest_visitors.hxx"
49 #include "hdf5impex.hxx"
50 
51 namespace vigra
52 {
53 namespace rf3
54 {
55 
56 // needs to be in sync with random_forest_hdf5_impex for backwards compatibility
57 static const char *const rf_hdf5_ext_param     = "_ext_param";
58 static const char *const rf_hdf5_options       = "_options";
59 static const char *const rf_hdf5_topology      = "topology";
60 static const char *const rf_hdf5_parameters    = "parameters";
61 static const char *const rf_hdf5_tree          = "Tree_";
62 static const char *const rf_hdf5_version_group = ".";
63 static const char *const rf_hdf5_version_tag   = "vigra_random_forest_version";
64 static const double      rf_hdf5_version       =  0.1;
65 
66 // keep in sync with include/vigra/random_forest/rf_nodeproxy.hxx
67 enum NodeTags
68 {
69     rf_UnFilledNode        = 42,
70     rf_AllColumns          = 0x00000000,
71     rf_ToBePrunedTag       = 0x80000000,
72     rf_LeafNodeTag         = 0x40000000,
73 
74     rf_i_ThresholdNode     = 0,
75     rf_i_HyperplaneNode    = 1,
76     rf_i_HypersphereNode   = 2,
77     rf_e_ConstProbNode     = 0 | rf_LeafNodeTag,
78     rf_e_LogRegProbNode    = 1 | rf_LeafNodeTag
79 };
80 
81 static const unsigned int rf_tag_mask = 0xf0000000;
82 static const unsigned int rf_type_mask = 0x00000003;
83 static const unsigned int rf_zero_mask = 0xffffffff & ~rf_tag_mask & ~rf_type_mask;
84 
85 namespace detail
86 {
get_cwd(HDF5File & h5context)87     inline std::string get_cwd(HDF5File & h5context)
88     {
89         return h5context.get_absolute_path(h5context.pwd());
90     }
91 }
92 
93 template <typename FEATURES, typename LABELS>
94 typename DefaultRF<FEATURES, LABELS>::type
random_forest_import_HDF5(HDF5File & h5ctx,std::string const & pathname="")95 random_forest_import_HDF5(HDF5File & h5ctx, std::string const & pathname = "")
96 {
97     typedef typename DefaultRF<FEATURES, LABELS>::type RF;
98     typedef typename RF::Graph Graph;
99     typedef typename RF::Node Node;
100     typedef typename RF::SplitTests SplitTest;
101     typedef typename LABELS::value_type LabelType;
102     typedef typename RF::AccInputType AccInputType;
103     typedef typename AccInputType::value_type AccValueType;
104 
105     std::string cwd;
106 
107     if (pathname.size()) {
108         cwd = detail::get_cwd(h5ctx);
109         h5ctx.cd(pathname);
110     }
111 
112     if (h5ctx.existsAttribute(rf_hdf5_version_group, rf_hdf5_version_tag)) {
113         double version;
114         h5ctx.readAttribute(rf_hdf5_version_group, rf_hdf5_version_tag, version);
115         vigra_precondition(version <= rf_hdf5_version, "random_forest_import_HDF5(): unexpected file format version.");
116     }
117 
118     // Read ext params.
119     size_t actual_mtry;
120     size_t num_instances;
121     size_t num_features;
122     size_t num_classes;
123     size_t msample;
124     int is_weighted_int;
125     MultiArray<1, LabelType> distinct_labels_marray;
126     MultiArray<1, double> class_weights_marray;
127 
128     h5ctx.cd(rf_hdf5_ext_param);
129     h5ctx.read("actual_msample_", msample);
130     h5ctx.read("actual_mtry_", actual_mtry);
131     h5ctx.read("class_count_", num_classes);
132     h5ctx.readAndResize("class_weights_", class_weights_marray);
133     h5ctx.read("column_count_", num_features);
134     h5ctx.read("is_weighted_", is_weighted_int);
135     h5ctx.readAndResize("labels", distinct_labels_marray);
136     h5ctx.read("row_count_", num_instances);
137     h5ctx.cd_up();
138 
139     bool is_weighted = is_weighted_int == 1 ? true : false;
140 
141     // Read options.
142     size_t min_num_instances;
143     int mtry;
144     int mtry_switch_int;
145     int bootstrap_sampling_int;
146     int tree_count;
147     h5ctx.cd(rf_hdf5_options);
148     h5ctx.read("min_split_node_size_", min_num_instances);
149     h5ctx.read("mtry_", mtry);
150     h5ctx.read("mtry_switch_", mtry_switch_int);
151     h5ctx.read("sample_with_replacement_", bootstrap_sampling_int);
152     h5ctx.read("tree_count_", tree_count);
153     h5ctx.cd_up();
154 
155     RandomForestOptionTags mtry_switch = (RandomForestOptionTags)mtry_switch_int;
156     bool bootstrap_sampling = bootstrap_sampling_int == 1 ? true : false;
157 
158     std::vector<LabelType> const distinct_labels(distinct_labels_marray.begin(), distinct_labels_marray.end());
159     std::vector<double> const class_weights(class_weights_marray.begin(), class_weights_marray.end());
160 
161     auto const pspec = ProblemSpec<LabelType>()
162                                .num_features(num_features)
163                                .num_instances(num_instances)
164                                .num_classes(num_classes)
165                                .distinct_classes(distinct_labels)
166                                .actual_mtry(actual_mtry)
167                                .actual_msample(msample);
168 
169     auto options = RandomForestOptions()
170                             .min_num_instances(min_num_instances)
171                             .bootstrap_sampling(bootstrap_sampling)
172                             .tree_count(tree_count);
173     options.features_per_node_switch_ = mtry_switch;
174     options.features_per_node_ = mtry;
175     if (is_weighted)
176         options.class_weights(class_weights);
177 
178     Graph gr;
179     typename RF::template NodeMap<SplitTest>::type split_tests;
180     typename RF::template NodeMap<AccInputType>::type leaf_responses;
181 
182     auto const groups = h5ctx.ls();
183     for (auto const & groupname : groups) {
184         if (groupname.substr(0, std::char_traits<char>::length(rf_hdf5_tree)).compare(rf_hdf5_tree) != 0) {
185             continue;
186         }
187 
188         MultiArray<1, unsigned int> topology;
189         MultiArray<1, double> parameters;
190         h5ctx.cd(groupname);
191         h5ctx.readAndResize(rf_hdf5_topology, topology);
192         h5ctx.readAndResize(rf_hdf5_parameters, parameters);
193         h5ctx.cd_up();
194 
195         vigra_precondition(topology[0] == num_features, "random_forest_import_HDF5(): number of features mismatch.");
196         vigra_precondition(topology[1] == num_classes, "random_forest_import_HDF5(): number of classes mismatch.");
197 
198         Node const n = gr.addNode();
199 
200         std::queue<std::pair<unsigned int, Node> > q;
201         q.emplace(2, n);
202         while (!q.empty()) {
203             auto const el = q.front();
204 
205             unsigned int const index = el.first;
206             Node const parent = el.second;
207 
208             vigra_precondition((topology[index] & rf_zero_mask) == 0, "random_forest_import_HDF5(): unexpected node type: type & zero_mask > 0");
209 
210             if (topology[index] & rf_LeafNodeTag) {
211                 unsigned int const probs_start = topology[index+1] + 1;
212 
213                 vigra_precondition((topology[index] & rf_tag_mask) == rf_LeafNodeTag, "random_forest_import_HDF5(): unexpected node type: additional tags in leaf node");
214 
215                 std::vector<AccValueType> node_response;
216 
217                 for (unsigned int i = 0; i < num_classes; ++i) {
218                     node_response.push_back(parameters[probs_start + i]);
219                 }
220 
221                 leaf_responses.insert(parent, node_response);
222 
223             } else {
224                 vigra_precondition(topology[index] == rf_i_ThresholdNode, "random_forest_import_HDF5(): unexpected node type.");
225 
226                 Node const left = gr.addNode();
227                 Node const right = gr.addNode();
228 
229                 gr.addArc(parent, left);
230                 gr.addArc(parent, right);
231 
232                 split_tests.insert(parent, SplitTest(topology[index+4], parameters[topology[index+1]+1]));
233 
234                 q.push(std::make_pair(topology[index+2], left));
235                 q.push(std::make_pair(topology[index+3], right));
236             }
237 
238             q.pop();
239         }
240     }
241 
242     if (cwd.size()) {
243         h5ctx.cd(cwd);
244     }
245 
246     RF rf(gr, split_tests, leaf_responses, pspec);
247     rf.options_ = options;
248     return rf;
249 }
250 
251 namespace detail
252 {
253     class PaddedNumberString
254     {
255     public:
256 
PaddedNumberString(int n)257         PaddedNumberString(int n)
258         {
259             ss_ << (n-1);
260             width_ = ss_.str().size();
261         }
262 
operator ()(int k) const263         std::string operator()(int k) const
264         {
265             ss_.str("");
266             ss_ << std::setw(width_) << std::setfill('0') << k;
267             return ss_.str();
268         }
269 
270     private:
271 
272         mutable std::ostringstream ss_;
273         unsigned int width_;
274     };
275 }
276 
277 template <typename RF>
random_forest_export_HDF5(RF const & rf,HDF5File & h5context,std::string const & pathname="")278 void random_forest_export_HDF5(
279         RF const & rf,
280         HDF5File & h5context,
281         std::string const & pathname = ""
282 ){
283     typedef typename RF::LabelType LabelType;
284     typedef typename RF::Node Node;
285 
286     std::string cwd;
287     if (pathname.size()) {
288         cwd = detail::get_cwd(h5context);
289         h5context.cd_mk(pathname);
290     }
291 
292     // version attribute
293     h5context.writeAttribute(rf_hdf5_version_group, rf_hdf5_version_tag,
294                              rf_hdf5_version);
295 
296 
297     auto const & p = rf.problem_spec_;
298     auto const & opts = rf.options_;
299     MultiArray<1, LabelType> distinct_classes(Shape1(p.distinct_classes_.size()), p.distinct_classes_.data());
300     MultiArray<1, double> class_weights(Shape1(p.num_classes_), 1.0);
301     int is_weighted = 0;
302     if (opts.class_weights_.size() > 0)
303     {
304         is_weighted = 1;
305         for (size_t i = 0; i < opts.class_weights_.size(); ++i)
306             class_weights(i) = opts.class_weights_[i];
307     }
308 
309     // Save external parameters.
310     h5context.cd_mk(rf_hdf5_ext_param);
311     h5context.write("column_count_", p.num_features_);
312     h5context.write("row_count_", p.num_instances_);
313     h5context.write("class_count_", p.num_classes_);
314     h5context.write("actual_mtry_", p.actual_mtry_);
315     h5context.write("actual_msample_", p.actual_msample_);
316     h5context.write("labels", distinct_classes);
317     h5context.write("is_weighted_", is_weighted);
318     h5context.write("class_weights_", class_weights);
319     h5context.write("precision_", 0.0);
320     h5context.write("problem_type_", 1.0);
321     h5context.write("response_size_", 1.0);
322     h5context.write("used_", 1.0);
323     h5context.cd_up();
324 
325     // Save the options.
326     h5context.cd_mk(rf_hdf5_options);
327     h5context.write("min_split_node_size_", opts.min_num_instances_);
328     h5context.write("mtry_", opts.features_per_node_);
329     h5context.write("mtry_func_", 0.0);
330     h5context.write("mtry_switch_", opts.features_per_node_switch_);
331     h5context.write("predict_weighted_", 0.0);
332     h5context.write("prepare_online_learning_", 0.0);
333     h5context.write("sample_with_replacement_", opts.bootstrap_sampling_ ? 1 : 0);
334     h5context.write("stratification_method_", 3.0);
335     h5context.write("training_set_calc_switch_", 1.0);
336     h5context.write("training_set_func_", 0.0);
337     h5context.write("training_set_proportion_", 1.0);
338     h5context.write("training_set_size_", 0.0);
339     h5context.write("tree_count_", opts.tree_count_);
340     h5context.cd_up();
341 
342     // Save the trees.
343     detail::PaddedNumberString tree_number(rf.num_trees());
344     for (size_t i = 0; i < rf.num_trees(); ++i)
345     {
346         // Create the topology and parameters arrays.
347         std::vector<UInt32> topology;
348         std::vector<double> parameters;
349         topology.push_back(p.num_features_);
350         topology.push_back(p.num_classes_);
351 
352         auto const & probs = rf.node_responses_;
353         auto const & splits = rf.split_tests_;
354         auto const & gr = rf.graph_;
355         auto const root = gr.getRoot(i);
356 
357         // Write the tree nodes using a depth-first search.
358         // When a node is created, the indices of the child nodes are unknown.
359         // Therefore, they have to be updated once the child nodes are created.
360         // The stack holds the node and the topology-index that must be updated.
361         std::stack<std::pair<Node, std::ptrdiff_t> > stack;
362         stack.emplace(root, -1);
363         while (!stack.empty())
364         {
365             auto const n = stack.top().first; // the node descriptor
366             auto const i = stack.top().second; // index from the parent node that must be updated
367             stack.pop();
368 
369             // Update the index in the parent node.
370             if (i != -1)
371                 topology[i] = topology.size();
372 
373             if (gr.numChildren(n) == 0)
374             {
375                 // The node is a leaf.
376                 // Topology: leaf node tag, index of weight in parameters array.
377                 // Parameters: node weight, class probabilities.
378                 topology.push_back(rf_LeafNodeTag);
379                 topology.push_back(parameters.size());
380                 auto const & prob = probs.at(n);
381                 auto const weight = std::accumulate(prob.begin(), prob.end(), 0.0);
382                 parameters.push_back(weight);
383                 parameters.insert(parameters.end(), prob.begin(), prob.end());
384             }
385             else
386             {
387                 // The node is an inner node.
388                 // Topology: threshold tag, index of weight in parameters array, index of left child, index of right child, split dimension.
389                 // Parameters: node weight, split value.
390                 topology.push_back(rf_i_ThresholdNode);
391                 topology.push_back(parameters.size());
392                 topology.push_back(-1); // index of left children (currently unknown, will be updated when the child node is taken from the stack)
393                 topology.push_back(-1); // index of right children (see above)
394                 topology.push_back(splits.at(n).dim_);
395                 parameters.push_back(1.0); // inner nodes have the weight 1.
396                 parameters.push_back(splits.at(n).val_);
397 
398                 // Place the children on the stack.
399                 stack.emplace(gr.getChild(n, 0), topology.size()-3);
400                 stack.emplace(gr.getChild(n, 1), topology.size()-2);
401             }
402         }
403 
404         // Convert the vectors to multi arrays.
405         MultiArray<1, UInt32> topo(Shape1(topology.size()), topology.data());
406         MultiArray<1, double> para(Shape1(parameters.size()), parameters.data());
407 
408         auto const name = rf_hdf5_tree + tree_number(i);
409         h5context.cd_mk(name);
410         h5context.write(rf_hdf5_topology, topo);
411         h5context.write(rf_hdf5_parameters, para);
412         h5context.cd_up();
413     }
414 
415     if (pathname.size())
416         h5context.cd(cwd);
417 }
418 
419 
420 
421 } // namespace rf3
422 } // namespace vigra
423 
424 #endif // VIGRA_NEW_RANDOM_FOREST_IMPEX_HDF5_HXX
425