1 /************************************************************************/
2 /* */
3 /* Copyright 2009,2014, 2015 by Sven Peter, Philip Schill, */
4 /* Rahul Nair and Ullrich Koethe */
5 /* */
6 /* This file is part of the VIGRA computer vision library. */
7 /* The VIGRA Website is */
8 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
9 /* Please direct questions, bug reports, and contributions to */
10 /* ullrich.koethe@iwr.uni-heidelberg.de or */
11 /* vigra@informatik.uni-hamburg.de */
12 /* */
13 /* Permission is hereby granted, free of charge, to any person */
14 /* obtaining a copy of this software and associated documentation */
15 /* files (the "Software"), to deal in the Software without */
16 /* restriction, including without limitation the rights to use, */
17 /* copy, modify, merge, publish, distribute, sublicense, and/or */
18 /* sell copies of the Software, and to permit persons to whom the */
19 /* Software is furnished to do so, subject to the following */
20 /* conditions: */
21 /* */
22 /* The above copyright notice and this permission notice shall be */
23 /* included in all copies or substantial portions of the */
24 /* Software. */
25 /* */
26 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
27 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
28 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
29 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
30 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
31 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
32 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
33 /* OTHER DEALINGS IN THE SOFTWARE. */
34 /* */
35 /************************************************************************/
36
37 #ifndef VIGRA_RF3_IMPEX_HDF5_HXX
38 #define VIGRA_RF3_IMPEX_HDF5_HXX
39
40 #include <string>
41 #include <sstream>
42 #include <iomanip>
43 #include <stack>
44
45 #include "config.hxx"
46 #include "random_forest_3/random_forest.hxx"
47 #include "random_forest_3/random_forest_common.hxx"
48 #include "random_forest_3/random_forest_visitors.hxx"
49 #include "hdf5impex.hxx"
50
51 namespace vigra
52 {
53 namespace rf3
54 {
55
56 // needs to be in sync with random_forest_hdf5_impex for backwards compatibility
57 static const char *const rf_hdf5_ext_param = "_ext_param";
58 static const char *const rf_hdf5_options = "_options";
59 static const char *const rf_hdf5_topology = "topology";
60 static const char *const rf_hdf5_parameters = "parameters";
61 static const char *const rf_hdf5_tree = "Tree_";
62 static const char *const rf_hdf5_version_group = ".";
63 static const char *const rf_hdf5_version_tag = "vigra_random_forest_version";
64 static const double rf_hdf5_version = 0.1;
65
66 // keep in sync with include/vigra/random_forest/rf_nodeproxy.hxx
67 enum NodeTags
68 {
69 rf_UnFilledNode = 42,
70 rf_AllColumns = 0x00000000,
71 rf_ToBePrunedTag = 0x80000000,
72 rf_LeafNodeTag = 0x40000000,
73
74 rf_i_ThresholdNode = 0,
75 rf_i_HyperplaneNode = 1,
76 rf_i_HypersphereNode = 2,
77 rf_e_ConstProbNode = 0 | rf_LeafNodeTag,
78 rf_e_LogRegProbNode = 1 | rf_LeafNodeTag
79 };
80
81 static const unsigned int rf_tag_mask = 0xf0000000;
82 static const unsigned int rf_type_mask = 0x00000003;
83 static const unsigned int rf_zero_mask = 0xffffffff & ~rf_tag_mask & ~rf_type_mask;
84
85 namespace detail
86 {
get_cwd(HDF5File & h5context)87 inline std::string get_cwd(HDF5File & h5context)
88 {
89 return h5context.get_absolute_path(h5context.pwd());
90 }
91 }
92
93 template <typename FEATURES, typename LABELS>
94 typename DefaultRF<FEATURES, LABELS>::type
random_forest_import_HDF5(HDF5File & h5ctx,std::string const & pathname="")95 random_forest_import_HDF5(HDF5File & h5ctx, std::string const & pathname = "")
96 {
97 typedef typename DefaultRF<FEATURES, LABELS>::type RF;
98 typedef typename RF::Graph Graph;
99 typedef typename RF::Node Node;
100 typedef typename RF::SplitTests SplitTest;
101 typedef typename LABELS::value_type LabelType;
102 typedef typename RF::AccInputType AccInputType;
103 typedef typename AccInputType::value_type AccValueType;
104
105 std::string cwd;
106
107 if (pathname.size()) {
108 cwd = detail::get_cwd(h5ctx);
109 h5ctx.cd(pathname);
110 }
111
112 if (h5ctx.existsAttribute(rf_hdf5_version_group, rf_hdf5_version_tag)) {
113 double version;
114 h5ctx.readAttribute(rf_hdf5_version_group, rf_hdf5_version_tag, version);
115 vigra_precondition(version <= rf_hdf5_version, "random_forest_import_HDF5(): unexpected file format version.");
116 }
117
118 // Read ext params.
119 size_t actual_mtry;
120 size_t num_instances;
121 size_t num_features;
122 size_t num_classes;
123 size_t msample;
124 int is_weighted_int;
125 MultiArray<1, LabelType> distinct_labels_marray;
126 MultiArray<1, double> class_weights_marray;
127
128 h5ctx.cd(rf_hdf5_ext_param);
129 h5ctx.read("actual_msample_", msample);
130 h5ctx.read("actual_mtry_", actual_mtry);
131 h5ctx.read("class_count_", num_classes);
132 h5ctx.readAndResize("class_weights_", class_weights_marray);
133 h5ctx.read("column_count_", num_features);
134 h5ctx.read("is_weighted_", is_weighted_int);
135 h5ctx.readAndResize("labels", distinct_labels_marray);
136 h5ctx.read("row_count_", num_instances);
137 h5ctx.cd_up();
138
139 bool is_weighted = is_weighted_int == 1 ? true : false;
140
141 // Read options.
142 size_t min_num_instances;
143 int mtry;
144 int mtry_switch_int;
145 int bootstrap_sampling_int;
146 int tree_count;
147 h5ctx.cd(rf_hdf5_options);
148 h5ctx.read("min_split_node_size_", min_num_instances);
149 h5ctx.read("mtry_", mtry);
150 h5ctx.read("mtry_switch_", mtry_switch_int);
151 h5ctx.read("sample_with_replacement_", bootstrap_sampling_int);
152 h5ctx.read("tree_count_", tree_count);
153 h5ctx.cd_up();
154
155 RandomForestOptionTags mtry_switch = (RandomForestOptionTags)mtry_switch_int;
156 bool bootstrap_sampling = bootstrap_sampling_int == 1 ? true : false;
157
158 std::vector<LabelType> const distinct_labels(distinct_labels_marray.begin(), distinct_labels_marray.end());
159 std::vector<double> const class_weights(class_weights_marray.begin(), class_weights_marray.end());
160
161 auto const pspec = ProblemSpec<LabelType>()
162 .num_features(num_features)
163 .num_instances(num_instances)
164 .num_classes(num_classes)
165 .distinct_classes(distinct_labels)
166 .actual_mtry(actual_mtry)
167 .actual_msample(msample);
168
169 auto options = RandomForestOptions()
170 .min_num_instances(min_num_instances)
171 .bootstrap_sampling(bootstrap_sampling)
172 .tree_count(tree_count);
173 options.features_per_node_switch_ = mtry_switch;
174 options.features_per_node_ = mtry;
175 if (is_weighted)
176 options.class_weights(class_weights);
177
178 Graph gr;
179 typename RF::template NodeMap<SplitTest>::type split_tests;
180 typename RF::template NodeMap<AccInputType>::type leaf_responses;
181
182 auto const groups = h5ctx.ls();
183 for (auto const & groupname : groups) {
184 if (groupname.substr(0, std::char_traits<char>::length(rf_hdf5_tree)).compare(rf_hdf5_tree) != 0) {
185 continue;
186 }
187
188 MultiArray<1, unsigned int> topology;
189 MultiArray<1, double> parameters;
190 h5ctx.cd(groupname);
191 h5ctx.readAndResize(rf_hdf5_topology, topology);
192 h5ctx.readAndResize(rf_hdf5_parameters, parameters);
193 h5ctx.cd_up();
194
195 vigra_precondition(topology[0] == num_features, "random_forest_import_HDF5(): number of features mismatch.");
196 vigra_precondition(topology[1] == num_classes, "random_forest_import_HDF5(): number of classes mismatch.");
197
198 Node const n = gr.addNode();
199
200 std::queue<std::pair<unsigned int, Node> > q;
201 q.emplace(2, n);
202 while (!q.empty()) {
203 auto const el = q.front();
204
205 unsigned int const index = el.first;
206 Node const parent = el.second;
207
208 vigra_precondition((topology[index] & rf_zero_mask) == 0, "random_forest_import_HDF5(): unexpected node type: type & zero_mask > 0");
209
210 if (topology[index] & rf_LeafNodeTag) {
211 unsigned int const probs_start = topology[index+1] + 1;
212
213 vigra_precondition((topology[index] & rf_tag_mask) == rf_LeafNodeTag, "random_forest_import_HDF5(): unexpected node type: additional tags in leaf node");
214
215 std::vector<AccValueType> node_response;
216
217 for (unsigned int i = 0; i < num_classes; ++i) {
218 node_response.push_back(parameters[probs_start + i]);
219 }
220
221 leaf_responses.insert(parent, node_response);
222
223 } else {
224 vigra_precondition(topology[index] == rf_i_ThresholdNode, "random_forest_import_HDF5(): unexpected node type.");
225
226 Node const left = gr.addNode();
227 Node const right = gr.addNode();
228
229 gr.addArc(parent, left);
230 gr.addArc(parent, right);
231
232 split_tests.insert(parent, SplitTest(topology[index+4], parameters[topology[index+1]+1]));
233
234 q.push(std::make_pair(topology[index+2], left));
235 q.push(std::make_pair(topology[index+3], right));
236 }
237
238 q.pop();
239 }
240 }
241
242 if (cwd.size()) {
243 h5ctx.cd(cwd);
244 }
245
246 RF rf(gr, split_tests, leaf_responses, pspec);
247 rf.options_ = options;
248 return rf;
249 }
250
251 namespace detail
252 {
253 class PaddedNumberString
254 {
255 public:
256
PaddedNumberString(int n)257 PaddedNumberString(int n)
258 {
259 ss_ << (n-1);
260 width_ = ss_.str().size();
261 }
262
operator ()(int k) const263 std::string operator()(int k) const
264 {
265 ss_.str("");
266 ss_ << std::setw(width_) << std::setfill('0') << k;
267 return ss_.str();
268 }
269
270 private:
271
272 mutable std::ostringstream ss_;
273 unsigned int width_;
274 };
275 }
276
277 template <typename RF>
random_forest_export_HDF5(RF const & rf,HDF5File & h5context,std::string const & pathname="")278 void random_forest_export_HDF5(
279 RF const & rf,
280 HDF5File & h5context,
281 std::string const & pathname = ""
282 ){
283 typedef typename RF::LabelType LabelType;
284 typedef typename RF::Node Node;
285
286 std::string cwd;
287 if (pathname.size()) {
288 cwd = detail::get_cwd(h5context);
289 h5context.cd_mk(pathname);
290 }
291
292 // version attribute
293 h5context.writeAttribute(rf_hdf5_version_group, rf_hdf5_version_tag,
294 rf_hdf5_version);
295
296
297 auto const & p = rf.problem_spec_;
298 auto const & opts = rf.options_;
299 MultiArray<1, LabelType> distinct_classes(Shape1(p.distinct_classes_.size()), p.distinct_classes_.data());
300 MultiArray<1, double> class_weights(Shape1(p.num_classes_), 1.0);
301 int is_weighted = 0;
302 if (opts.class_weights_.size() > 0)
303 {
304 is_weighted = 1;
305 for (size_t i = 0; i < opts.class_weights_.size(); ++i)
306 class_weights(i) = opts.class_weights_[i];
307 }
308
309 // Save external parameters.
310 h5context.cd_mk(rf_hdf5_ext_param);
311 h5context.write("column_count_", p.num_features_);
312 h5context.write("row_count_", p.num_instances_);
313 h5context.write("class_count_", p.num_classes_);
314 h5context.write("actual_mtry_", p.actual_mtry_);
315 h5context.write("actual_msample_", p.actual_msample_);
316 h5context.write("labels", distinct_classes);
317 h5context.write("is_weighted_", is_weighted);
318 h5context.write("class_weights_", class_weights);
319 h5context.write("precision_", 0.0);
320 h5context.write("problem_type_", 1.0);
321 h5context.write("response_size_", 1.0);
322 h5context.write("used_", 1.0);
323 h5context.cd_up();
324
325 // Save the options.
326 h5context.cd_mk(rf_hdf5_options);
327 h5context.write("min_split_node_size_", opts.min_num_instances_);
328 h5context.write("mtry_", opts.features_per_node_);
329 h5context.write("mtry_func_", 0.0);
330 h5context.write("mtry_switch_", opts.features_per_node_switch_);
331 h5context.write("predict_weighted_", 0.0);
332 h5context.write("prepare_online_learning_", 0.0);
333 h5context.write("sample_with_replacement_", opts.bootstrap_sampling_ ? 1 : 0);
334 h5context.write("stratification_method_", 3.0);
335 h5context.write("training_set_calc_switch_", 1.0);
336 h5context.write("training_set_func_", 0.0);
337 h5context.write("training_set_proportion_", 1.0);
338 h5context.write("training_set_size_", 0.0);
339 h5context.write("tree_count_", opts.tree_count_);
340 h5context.cd_up();
341
342 // Save the trees.
343 detail::PaddedNumberString tree_number(rf.num_trees());
344 for (size_t i = 0; i < rf.num_trees(); ++i)
345 {
346 // Create the topology and parameters arrays.
347 std::vector<UInt32> topology;
348 std::vector<double> parameters;
349 topology.push_back(p.num_features_);
350 topology.push_back(p.num_classes_);
351
352 auto const & probs = rf.node_responses_;
353 auto const & splits = rf.split_tests_;
354 auto const & gr = rf.graph_;
355 auto const root = gr.getRoot(i);
356
357 // Write the tree nodes using a depth-first search.
358 // When a node is created, the indices of the child nodes are unknown.
359 // Therefore, they have to be updated once the child nodes are created.
360 // The stack holds the node and the topology-index that must be updated.
361 std::stack<std::pair<Node, std::ptrdiff_t> > stack;
362 stack.emplace(root, -1);
363 while (!stack.empty())
364 {
365 auto const n = stack.top().first; // the node descriptor
366 auto const i = stack.top().second; // index from the parent node that must be updated
367 stack.pop();
368
369 // Update the index in the parent node.
370 if (i != -1)
371 topology[i] = topology.size();
372
373 if (gr.numChildren(n) == 0)
374 {
375 // The node is a leaf.
376 // Topology: leaf node tag, index of weight in parameters array.
377 // Parameters: node weight, class probabilities.
378 topology.push_back(rf_LeafNodeTag);
379 topology.push_back(parameters.size());
380 auto const & prob = probs.at(n);
381 auto const weight = std::accumulate(prob.begin(), prob.end(), 0.0);
382 parameters.push_back(weight);
383 parameters.insert(parameters.end(), prob.begin(), prob.end());
384 }
385 else
386 {
387 // The node is an inner node.
388 // Topology: threshold tag, index of weight in parameters array, index of left child, index of right child, split dimension.
389 // Parameters: node weight, split value.
390 topology.push_back(rf_i_ThresholdNode);
391 topology.push_back(parameters.size());
392 topology.push_back(-1); // index of left children (currently unknown, will be updated when the child node is taken from the stack)
393 topology.push_back(-1); // index of right children (see above)
394 topology.push_back(splits.at(n).dim_);
395 parameters.push_back(1.0); // inner nodes have the weight 1.
396 parameters.push_back(splits.at(n).val_);
397
398 // Place the children on the stack.
399 stack.emplace(gr.getChild(n, 0), topology.size()-3);
400 stack.emplace(gr.getChild(n, 1), topology.size()-2);
401 }
402 }
403
404 // Convert the vectors to multi arrays.
405 MultiArray<1, UInt32> topo(Shape1(topology.size()), topology.data());
406 MultiArray<1, double> para(Shape1(parameters.size()), parameters.data());
407
408 auto const name = rf_hdf5_tree + tree_number(i);
409 h5context.cd_mk(name);
410 h5context.write(rf_hdf5_topology, topo);
411 h5context.write(rf_hdf5_parameters, para);
412 h5context.cd_up();
413 }
414
415 if (pathname.size())
416 h5context.cd(cwd);
417 }
418
419
420
421 } // namespace rf3
422 } // namespace vigra
423
424 #endif // VIGRA_NEW_RANDOM_FOREST_IMPEX_HDF5_HXX
425