1 // Copyright by Contributors
2 #include <gtest/gtest.h>
3 #include "../helpers.h"
4 #include "dmlc/filesystem.h"
5 #include "xgboost/json_io.h"
6 #include "xgboost/tree_model.h"
7 #include "../../../src/common/bitfield.h"
8 #include "../../../src/common/categorical.h"
9 
10 namespace xgboost {
11 #if DMLC_IO_NO_ENDIAN_SWAP  // skip on big-endian machines
12 // Manually construct tree in binary format
13 // Do not use structs in case they change
14 // We want to preserve backwards compatibility
TEST(Tree,Load)15 TEST(Tree, Load) {
16   dmlc::TemporaryDirectory tempdir;
17   const std::string tmp_file = tempdir.path + "/tree.model";
18   std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(tmp_file.c_str(), "w"));
19 
20   // Write params
21   EXPECT_EQ(sizeof(TreeParam), (31 + 6) * sizeof(int));
22   int num_roots = 1;
23   int num_nodes = 2;
24   int num_deleted = 0;
25   int max_depth = 1;
26   int num_feature = 0;
27   int size_leaf_vector = 0;
28   int reserved[31];
29   fo->Write(&num_roots, sizeof(int));
30   fo->Write(&num_nodes, sizeof(int));
31   fo->Write(&num_deleted, sizeof(int));
32   fo->Write(&max_depth, sizeof(int));
33   fo->Write(&num_feature, sizeof(int));
34   fo->Write(&size_leaf_vector, sizeof(int));
35   fo->Write(reserved, sizeof(int) * 31);
36 
37   // Write 2 nodes
38   EXPECT_EQ(sizeof(RegTree::Node),
39             3 * sizeof(int) + 1 * sizeof(unsigned) + sizeof(float));
40   int parent = -1;
41   int cleft = 1;
42   int cright = -1;
43   unsigned sindex = 5;
44   float split_or_weight = 0.5;
45   fo->Write(&parent, sizeof(int));
46   fo->Write(&cleft, sizeof(int));
47   fo->Write(&cright, sizeof(int));
48   fo->Write(&sindex, sizeof(unsigned));
49   fo->Write(&split_or_weight, sizeof(float));
50   parent = 0;
51   cleft = -1;
52   cright = -1;
53   sindex = 2;
54   split_or_weight = 0.1;
55   fo->Write(&parent, sizeof(int));
56   fo->Write(&cleft, sizeof(int));
57   fo->Write(&cright, sizeof(int));
58   fo->Write(&sindex, sizeof(unsigned));
59   fo->Write(&split_or_weight, sizeof(float));
60 
61   // Write 2x node stats
62   EXPECT_EQ(sizeof(RTreeNodeStat), 3 * sizeof(float) + sizeof(int));
63   bst_float loss_chg = 5.0;
64   bst_float sum_hess = 1.0;
65   bst_float base_weight = 3.0;
66   int leaf_child_cnt = 0;
67   fo->Write(&loss_chg, sizeof(float));
68   fo->Write(&sum_hess, sizeof(float));
69   fo->Write(&base_weight, sizeof(float));
70   fo->Write(&leaf_child_cnt, sizeof(int));
71 
72   loss_chg = 50.0;
73   sum_hess = 10.0;
74   base_weight = 30.0;
75   leaf_child_cnt = 0;
76   fo->Write(&loss_chg, sizeof(float));
77   fo->Write(&sum_hess, sizeof(float));
78   fo->Write(&base_weight, sizeof(float));
79   fo->Write(&leaf_child_cnt, sizeof(int));
80   fo.reset();
81   std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(tmp_file.c_str(), "r"));
82 
83   xgboost::RegTree tree;
84   tree.Load(fi.get());
85   EXPECT_EQ(tree.GetDepth(1), 1);
86   EXPECT_EQ(tree[0].SplitCond(), 0.5f);
87   EXPECT_EQ(tree[0].SplitIndex(), 5ul);
88   EXPECT_EQ(tree[1].LeafValue(), 0.1f);
89   EXPECT_TRUE(tree[1].IsLeaf());
90 }
91 #endif  // DMLC_IO_NO_ENDIAN_SWAP
92 
TEST(Tree,AllocateNode)93 TEST(Tree, AllocateNode) {
94   RegTree tree;
95   tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
96                   /*left_sum=*/0.0f, /*right_sum=*/0.0f);
97   tree.CollapseToLeaf(0, 0);
98   ASSERT_EQ(tree.NumExtraNodes(), 0);
99 
100   tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
101                   /*left_sum=*/0.0f, /*right_sum=*/0.0f);
102   ASSERT_EQ(tree.NumExtraNodes(), 2);
103 
104   auto& nodes = tree.GetNodes();
105   ASSERT_FALSE(nodes.at(1).IsDeleted());
106   ASSERT_TRUE(nodes.at(1).IsLeaf());
107   ASSERT_TRUE(nodes.at(2).IsLeaf());
108 }
109 
TEST(Tree,ExpandCategoricalFeature)110 TEST(Tree, ExpandCategoricalFeature) {
111   {
112     RegTree tree;
113     tree.ExpandCategorical(0, 0, {}, true, 1.0, 2.0, 3.0, 11.0, 2.0,
114                            /*left_sum=*/3.0, /*right_sum=*/4.0);
115     ASSERT_EQ(tree.GetNodes().size(), 3ul);
116     ASSERT_EQ(tree.GetNumLeaves(), 2);
117     ASSERT_EQ(tree.GetSplitTypes().size(), 3ul);
118     ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical);
119     ASSERT_EQ(tree.GetSplitTypes()[1], FeatureType::kNumerical);
120     ASSERT_EQ(tree.GetSplitTypes()[2], FeatureType::kNumerical);
121     ASSERT_EQ(tree.GetSplitCategories().size(), 0ul);
122     ASSERT_TRUE(std::isnan(tree[0].SplitCond()));
123   }
124   {
125     RegTree tree;
126     bst_cat_t cat = 33;
127     std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(cat+1));
128     LBitField32 bitset {split_cats};
129     bitset.Set(cat);
130     tree.ExpandCategorical(0, 0, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
131                            /*left_sum=*/3.0, /*right_sum=*/4.0);
132     auto categories = tree.GetSplitCategories();
133     auto segments = tree.GetSplitCategoriesPtr();
134     auto got = categories.subspan(segments[0].beg, segments[0].size);
135     ASSERT_TRUE(std::equal(got.cbegin(), got.cend(), split_cats.cbegin()));
136 
137     Json out{Object()};
138     tree.SaveModel(&out);
139 
140     RegTree loaded_tree;
141     loaded_tree.LoadModel(out);
142 
143     auto const& cat_ptr = loaded_tree.GetSplitCategoriesPtr();
144     ASSERT_EQ(cat_ptr.size(), 3ul);
145     ASSERT_EQ(cat_ptr[0].beg, 0ul);
146     ASSERT_EQ(cat_ptr[0].size, 2ul);
147 
148     auto loaded_categories = loaded_tree.GetSplitCategories();
149     auto loaded_root = loaded_categories.subspan(cat_ptr[0].beg, cat_ptr[0].size);
150     ASSERT_TRUE(std::equal(loaded_root.begin(), loaded_root.end(), split_cats.begin()));
151   }
152 }
153 
GrowTree(RegTree * p_tree)154 void GrowTree(RegTree* p_tree) {
155   SimpleLCG lcg;
156   size_t n_expands = 10;
157   constexpr size_t kCols = 256;
158   SimpleRealUniformDistribution<double> coin(0.0, 1.0);
159   SimpleRealUniformDistribution<double> feat(0.0, kCols);
160   SimpleRealUniformDistribution<double> split_cat(0.0, 128.0);
161   SimpleRealUniformDistribution<double> split_value(0.0, kCols);
162 
163   std::stack<bst_node_t> stack;
164   stack.push(RegTree::kRoot);
165   auto& tree = *p_tree;
166 
167   for (size_t i = 0; i < n_expands; ++i) {
168     auto is_cat = coin(&lcg) <= 0.5;
169     bst_node_t node = stack.top();
170     stack.pop();
171 
172     bst_feature_t f = feat(&lcg);
173     if (is_cat) {
174       bst_cat_t cat = common::AsCat(split_cat(&lcg));
175       std::vector<uint32_t> split_cats(
176           LBitField32::ComputeStorageSize(cat + 1));
177       LBitField32 bitset{split_cats};
178       bitset.Set(cat);
179       tree.ExpandCategorical(node, f, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
180                              /*left_sum=*/3.0, /*right_sum=*/4.0);
181     } else {
182       auto split = split_value(&lcg);
183       tree.ExpandNode(node, f, split, true, 1.0, 2.0, 3.0, 11.0, 2.0,
184                       /*left_sum=*/3.0, /*right_sum=*/4.0);
185     }
186 
187     stack.push(tree[node].LeftChild());
188     stack.push(tree[node].RightChild());
189   }
190 }
191 
CheckReload(RegTree const & tree)192 void CheckReload(RegTree const &tree) {
193   Json out{Object()};
194   tree.SaveModel(&out);
195 
196   RegTree loaded_tree;
197   loaded_tree.LoadModel(out);
198   Json saved{Object()};
199   loaded_tree.SaveModel(&saved);
200 
201   auto same = out == saved;
202   ASSERT_TRUE(same);
203 }
204 
TEST(Tree,CategoricalIO)205 TEST(Tree, CategoricalIO) {
206   {
207     RegTree tree;
208     bst_cat_t cat = 32;
209     std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(cat + 1));
210     LBitField32 bitset{split_cats};
211     bitset.Set(cat);
212     tree.ExpandCategorical(0, 0, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
213                            /*left_sum=*/3.0, /*right_sum=*/4.0);
214 
215     CheckReload(tree);
216   }
217 
218   {
219     RegTree tree;
220     GrowTree(&tree);
221     CheckReload(tree);
222   }
223 }
224 
225 namespace {
ConstructTree()226 RegTree ConstructTree() {
227   RegTree tree;
228   tree.ExpandNode(
229       /*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f,
230       /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
231       /*right_sum=*/0.0f);
232   auto left = tree[0].LeftChild();
233   auto right = tree[0].RightChild();
234   tree.ExpandNode(
235       /*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
236       /*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
237       /*right_sum=*/0.0f);
238   tree.ExpandNode(
239       /*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f,
240       /*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
241       /*right_sum=*/0.0f);
242   return tree;
243 }
244 
ConstructTreeCat(std::vector<bst_cat_t> * cond)245 RegTree ConstructTreeCat(std::vector<bst_cat_t>* cond) {
246   RegTree tree;
247   std::vector<uint32_t> cats_storage(common::CatBitField::ComputeStorageSize(33), 0);
248   common::CatBitField split_cats(cats_storage);
249   split_cats.Set(0);
250   split_cats.Set(14);
251   split_cats.Set(32);
252 
253   cond->push_back(0);
254   cond->push_back(14);
255   cond->push_back(32);
256 
257   tree.ExpandCategorical(0, /*split_index=*/0, cats_storage, true, 0.0f, 2.0,
258                          3.00, 11.0, 2.0, 3.0, 4.0);
259   auto left = tree[0].LeftChild();
260   auto right = tree[0].RightChild();
261   tree.ExpandNode(
262       /*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
263       /*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
264       /*right_sum=*/0.0f);
265   tree.ExpandCategorical(right, /*split_index=*/0, cats_storage, true, 0.0f,
266                          2.0, 3.00, 11.0, 2.0, 3.0, 4.0);
267   return tree;
268 }
269 
TestCategoricalTreeDump(std::string format,std::string sep)270 void TestCategoricalTreeDump(std::string format, std::string sep) {
271   std::vector<bst_cat_t> cond;
272   auto tree = ConstructTreeCat(&cond);
273 
274   FeatureMap fmap;
275   auto str = tree.DumpModel(fmap, true, format);
276   std::string cond_str;
277   for (size_t c = 0; c < cond.size(); ++c) {
278     cond_str += std::to_string(cond[c]);
279     if (c != cond.size() - 1) {
280       cond_str += sep;
281     }
282   }
283   auto pos = str.find(cond_str);
284   ASSERT_NE(pos, std::string::npos);
285   pos = str.find(cond_str, pos + 1);
286   ASSERT_NE(pos, std::string::npos);
287 
288   fmap.PushBack(0, "feat_0", "c");
289   fmap.PushBack(1, "feat_1", "q");
290   fmap.PushBack(2, "feat_2", "int");
291 
292   str = tree.DumpModel(fmap, true, format);
293   pos = str.find(cond_str);
294   ASSERT_NE(pos, std::string::npos);
295   pos = str.find(cond_str, pos + 1);
296   ASSERT_NE(pos, std::string::npos);
297 
298   if (format == "json") {
299     // Make sure it's valid JSON
300     Json::Load(StringView{str});
301   }
302 }
303 }  // anonymous namespace
304 
TEST(Tree,DumpJson)305 TEST(Tree, DumpJson) {
306   auto tree = ConstructTree();
307   FeatureMap fmap;
308   auto str = tree.DumpModel(fmap, true, "json");
309   size_t n_leaves = 0;
310   size_t iter = 0;
311   while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
312     n_leaves++;
313   }
314   ASSERT_EQ(n_leaves, 4ul);
315 
316   size_t n_conditions = 0;
317   iter = 0;
318   while ((iter = str.find("split_condition", iter + 1)) != std::string::npos) {
319     n_conditions++;
320   }
321   ASSERT_EQ(n_conditions, 3ul);
322 
323   fmap.PushBack(0, "feat_0", "i");
324   fmap.PushBack(1, "feat_1", "q");
325   fmap.PushBack(2, "feat_2", "int");
326 
327   str = tree.DumpModel(fmap, true, "json");
328   ASSERT_NE(str.find(R"("split": "feat_0")"), std::string::npos);
329   ASSERT_NE(str.find(R"("split": "feat_1")"), std::string::npos);
330   ASSERT_NE(str.find(R"("split": "feat_2")"), std::string::npos);
331 
332   str = tree.DumpModel(fmap, false, "json");
333   ASSERT_EQ(str.find("cover"), std::string::npos);
334 
335 
336   auto j_tree = Json::Load({str.c_str(), str.size()});
337   ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2ul);
338 }
339 
TEST(Tree,DumpJsonCategorical)340 TEST(Tree, DumpJsonCategorical) {
341   TestCategoricalTreeDump("json", ", ");
342 }
343 
TEST(Tree,DumpText)344 TEST(Tree, DumpText) {
345   auto tree = ConstructTree();
346   FeatureMap fmap;
347   auto str = tree.DumpModel(fmap, true, "text");
348   size_t n_leaves = 0;
349   size_t iter = 0;
350   while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
351     n_leaves++;
352   }
353   ASSERT_EQ(n_leaves, 4ul);
354 
355   iter = 0;
356   size_t n_conditions = 0;
357   while ((iter = str.find("gain", iter + 1)) != std::string::npos) {
358     n_conditions++;
359   }
360   ASSERT_EQ(n_conditions, 3ul);
361 
362   ASSERT_NE(str.find("[f0<0]"), std::string::npos);
363   ASSERT_NE(str.find("[f1<1]"), std::string::npos);
364   ASSERT_NE(str.find("[f2<2]"), std::string::npos);
365 
366   fmap.PushBack(0, "feat_0", "i");
367   fmap.PushBack(1, "feat_1", "q");
368   fmap.PushBack(2, "feat_2", "int");
369 
370   str = tree.DumpModel(fmap, true, "text");
371   ASSERT_NE(str.find("[feat_0]"), std::string::npos);
372   ASSERT_NE(str.find("[feat_1<1]"), std::string::npos);
373   ASSERT_NE(str.find("[feat_2<2]"), std::string::npos);
374 
375   str = tree.DumpModel(fmap, false, "text");
376   ASSERT_EQ(str.find("cover"), std::string::npos);
377 }
378 
TEST(Tree,DumpTextCategorical)379 TEST(Tree, DumpTextCategorical) {
380   TestCategoricalTreeDump("text", ",");
381 }
382 
TEST(Tree,DumpDot)383 TEST(Tree, DumpDot) {
384   auto tree = ConstructTree();
385   FeatureMap fmap;
386   auto str = tree.DumpModel(fmap, true, "dot");
387 
388   size_t n_leaves = 0;
389   size_t iter = 0;
390   while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
391     n_leaves++;
392   }
393   ASSERT_EQ(n_leaves, 4ul);
394 
395   size_t n_edges = 0;
396   iter = 0;
397   while ((iter = str.find("->", iter + 1)) != std::string::npos) {
398     n_edges++;
399   }
400   ASSERT_EQ(n_edges, 6ul);
401 
402   fmap.PushBack(0, "feat_0", "i");
403   fmap.PushBack(1, "feat_1", "q");
404   fmap.PushBack(2, "feat_2", "int");
405 
406   str = tree.DumpModel(fmap, true, "dot");
407   ASSERT_NE(str.find(R"("feat_0")"), std::string::npos);
408   ASSERT_NE(str.find(R"(feat_1<1)"), std::string::npos);
409   ASSERT_NE(str.find(R"(feat_2<2)"), std::string::npos);
410 
411   str = tree.DumpModel(fmap, true, R"(dot:{"graph_attrs": {"bgcolor": "#FFFF00"}})");
412   ASSERT_NE(str.find(R"(graph [ bgcolor="#FFFF00" ])"), std::string::npos);
413 
414   // Default left for root.
415   ASSERT_NE(str.find(R"(0 -> 1 [label="yes, missing")"), std::string::npos);
416   // Default right for node 1
417   ASSERT_NE(str.find(R"(1 -> 4 [label="no, missing")"), std::string::npos);
418 }
419 
TEST(Tree,DumpDotCategorical)420 TEST(Tree, DumpDotCategorical) {
421   TestCategoricalTreeDump("dot", ",");
422 }
423 
TEST(Tree,JsonIO)424 TEST(Tree, JsonIO) {
425   RegTree tree;
426   tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
427                   /*left_sum=*/0.0f, /*right_sum=*/0.0f);
428   Json j_tree{Object()};
429   tree.SaveModel(&j_tree);
430 
431   auto tparam = j_tree["tree_param"];
432   ASSERT_EQ(get<String>(tparam["num_feature"]), "0");
433   ASSERT_EQ(get<String>(tparam["num_nodes"]), "3");
434   ASSERT_EQ(get<String>(tparam["size_leaf_vector"]), "0");
435 
436   ASSERT_EQ(get<Array const>(j_tree["left_children"]).size(), 3ul);
437   ASSERT_EQ(get<Array const>(j_tree["right_children"]).size(), 3ul);
438   ASSERT_EQ(get<Array const>(j_tree["parents"]).size(), 3ul);
439   ASSERT_EQ(get<Array const>(j_tree["split_indices"]).size(), 3ul);
440   ASSERT_EQ(get<Array const>(j_tree["split_conditions"]).size(), 3ul);
441   ASSERT_EQ(get<Array const>(j_tree["default_left"]).size(), 3ul);
442 
443   RegTree loaded_tree;
444   loaded_tree.LoadModel(j_tree);
445   ASSERT_EQ(loaded_tree.param.num_nodes, 3);
446 
447   ASSERT_TRUE(loaded_tree == tree);
448 
449   auto left = tree[0].LeftChild();
450   auto right = tree[0].RightChild();
451   tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
452                   /*left_sum=*/0.0f, /*right_sum=*/0.0f);
453   tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
454                   /*left_sum=*/0.0f, /*right_sum=*/0.0f);
455   tree.SaveModel(&j_tree);
456 
457   tree.ChangeToLeaf(1, 1.0f);
458   ASSERT_EQ(tree[1].LeftChild(), -1);
459   ASSERT_EQ(tree[1].RightChild(), -1);
460   tree.SaveModel(&j_tree);
461   loaded_tree.LoadModel(j_tree);
462   ASSERT_EQ(loaded_tree[1].LeftChild(), -1);
463   ASSERT_EQ(loaded_tree[1].RightChild(), -1);
464   ASSERT_TRUE(tree.Equal(loaded_tree));
465 }
466 }  // namespace xgboost
467