1 /*!
2 * Copyright 2014-2020 by Contributors
3 * \file updater_prune.cc
4 * \brief prune a tree given the statistics
5 * \author Tianqi Chen
6 */
7 #include <rabit/rabit.h>
8 #include <xgboost/tree_updater.h>
9
10 #include <string>
11 #include <memory>
12
13 #include "xgboost/base.h"
14 #include "xgboost/json.h"
15 #include "./param.h"
16 #include "../common/io.h"
17 #include "../common/timer.h"
18 namespace xgboost {
19 namespace tree {
20
21 DMLC_REGISTRY_FILE_TAG(updater_prune);
22
23 /*! \brief pruner that prunes a tree after growing finishes */
24 class TreePruner: public TreeUpdater {
25 public:
TreePruner()26 TreePruner() {
27 syncher_.reset(TreeUpdater::Create("sync", tparam_));
28 pruner_monitor_.Init("TreePruner");
29 }
Name() const30 char const* Name() const override {
31 return "prune";
32 }
33
34 // set training parameter
Configure(const Args & args)35 void Configure(const Args& args) override {
36 param_.UpdateAllowUnknown(args);
37 syncher_->Configure(args);
38 }
39
LoadConfig(Json const & in)40 void LoadConfig(Json const& in) override {
41 auto const& config = get<Object const>(in);
42 FromJson(config.at("train_param"), &this->param_);
43 }
SaveConfig(Json * p_out) const44 void SaveConfig(Json* p_out) const override {
45 auto& out = *p_out;
46 out["train_param"] = ToJson(param_);
47 }
CanModifyTree() const48 bool CanModifyTree() const override {
49 return true;
50 }
51
52 // update the tree, do pruning
Update(HostDeviceVector<GradientPair> * gpair,DMatrix * p_fmat,const std::vector<RegTree * > & trees)53 void Update(HostDeviceVector<GradientPair> *gpair,
54 DMatrix *p_fmat,
55 const std::vector<RegTree*> &trees) override {
56 pruner_monitor_.Start("PrunerUpdate");
57 // rescale learning rate according to size of trees
58 float lr = param_.learning_rate;
59 param_.learning_rate = lr / trees.size();
60 for (auto tree : trees) {
61 this->DoPrune(tree);
62 }
63 param_.learning_rate = lr;
64 syncher_->Update(gpair, p_fmat, trees);
65 pruner_monitor_.Stop("PrunerUpdate");
66 }
67
68 private:
69 // try to prune off current leaf
TryPruneLeaf(RegTree & tree,int nid,int depth,int npruned)70 bst_node_t TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) { // NOLINT(*)
71 CHECK(tree[nid].IsLeaf());
72 if (tree[nid].IsRoot()) {
73 return npruned;
74 }
75 bst_node_t pid = tree[nid].Parent();
76 CHECK(!tree[pid].IsLeaf());
77 RTreeNodeStat const &s = tree.Stat(pid);
78 // Only prune when both child are leaf.
79 auto left = tree[pid].LeftChild();
80 auto right = tree[pid].RightChild();
81 bool balanced = tree[left].IsLeaf() &&
82 right != RegTree::kInvalidNodeId && tree[right].IsLeaf();
83 if (balanced && param_.NeedPrune(s.loss_chg, depth)) {
84 // need to be pruned
85 tree.ChangeToLeaf(pid, param_.learning_rate * s.base_weight);
86 // tail recursion
87 return this->TryPruneLeaf(tree, pid, depth - 1, npruned + 2);
88 } else {
89 return npruned;
90 }
91 }
92 /*! \brief do pruning of a tree */
DoPrune(RegTree * p_tree)93 void DoPrune(RegTree* p_tree) {
94 auto& tree = *p_tree;
95 bst_node_t npruned = 0;
96 for (int nid = 0; nid < tree.param.num_nodes; ++nid) {
97 if (tree[nid].IsLeaf() && !tree[nid].IsDeleted()) {
98 npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned);
99 }
100 }
101 LOG(INFO) << "tree pruning end, "
102 << tree.NumExtraNodes() << " extra nodes, " << npruned
103 << " pruned nodes, max_depth=" << tree.MaxDepth();
104 }
105
106 private:
107 // synchronizer
108 std::unique_ptr<TreeUpdater> syncher_;
109 // training parameter
110 TrainParam param_;
111 common::Monitor pruner_monitor_;
112 };
113
114 XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
115 .describe("Pruner that prune the tree according to statistics.")
__anonca808c770102() 116 .set_body([]() {
117 return new TreePruner();
118 });
119 } // namespace tree
120 } // namespace xgboost
121