1 /*!
2  * Copyright 2014-2020 by Contributors
3  * \file updater_prune.cc
4  * \brief prune a tree given the statistics
5  * \author Tianqi Chen
6  */
7 #include <rabit/rabit.h>
8 #include <xgboost/tree_updater.h>
9 
10 #include <string>
11 #include <memory>
12 
13 #include "xgboost/base.h"
14 #include "xgboost/json.h"
15 #include "./param.h"
16 #include "../common/io.h"
17 #include "../common/timer.h"
18 namespace xgboost {
19 namespace tree {
20 
21 DMLC_REGISTRY_FILE_TAG(updater_prune);
22 
23 /*! \brief pruner that prunes a tree after growing finishes */
24 class TreePruner: public TreeUpdater {
25  public:
TreePruner()26   TreePruner() {
27     syncher_.reset(TreeUpdater::Create("sync", tparam_));
28     pruner_monitor_.Init("TreePruner");
29   }
Name() const30   char const* Name() const override {
31     return "prune";
32   }
33 
34   // set training parameter
Configure(const Args & args)35   void Configure(const Args& args) override {
36     param_.UpdateAllowUnknown(args);
37     syncher_->Configure(args);
38   }
39 
LoadConfig(Json const & in)40   void LoadConfig(Json const& in) override {
41     auto const& config = get<Object const>(in);
42     FromJson(config.at("train_param"), &this->param_);
43   }
SaveConfig(Json * p_out) const44   void SaveConfig(Json* p_out) const override {
45     auto& out = *p_out;
46     out["train_param"] = ToJson(param_);
47   }
CanModifyTree() const48   bool CanModifyTree() const override {
49     return true;
50   }
51 
52   // update the tree, do pruning
Update(HostDeviceVector<GradientPair> * gpair,DMatrix * p_fmat,const std::vector<RegTree * > & trees)53   void Update(HostDeviceVector<GradientPair> *gpair,
54               DMatrix *p_fmat,
55               const std::vector<RegTree*> &trees) override {
56     pruner_monitor_.Start("PrunerUpdate");
57     // rescale learning rate according to size of trees
58     float lr = param_.learning_rate;
59     param_.learning_rate = lr / trees.size();
60     for (auto tree : trees) {
61       this->DoPrune(tree);
62     }
63     param_.learning_rate = lr;
64     syncher_->Update(gpair, p_fmat, trees);
65     pruner_monitor_.Stop("PrunerUpdate");
66   }
67 
68  private:
69   // try to prune off current leaf
TryPruneLeaf(RegTree & tree,int nid,int depth,int npruned)70   bst_node_t TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) { // NOLINT(*)
71     CHECK(tree[nid].IsLeaf());
72     if (tree[nid].IsRoot()) {
73       return npruned;
74     }
75     bst_node_t pid = tree[nid].Parent();
76     CHECK(!tree[pid].IsLeaf());
77     RTreeNodeStat const &s = tree.Stat(pid);
78     // Only prune when both child are leaf.
79     auto left = tree[pid].LeftChild();
80     auto right = tree[pid].RightChild();
81     bool balanced = tree[left].IsLeaf() &&
82                     right != RegTree::kInvalidNodeId && tree[right].IsLeaf();
83     if (balanced && param_.NeedPrune(s.loss_chg, depth)) {
84       // need to be pruned
85       tree.ChangeToLeaf(pid, param_.learning_rate * s.base_weight);
86       // tail recursion
87       return this->TryPruneLeaf(tree, pid, depth - 1, npruned + 2);
88     } else {
89       return npruned;
90     }
91   }
92   /*! \brief do pruning of a tree */
DoPrune(RegTree * p_tree)93   void DoPrune(RegTree* p_tree) {
94     auto& tree = *p_tree;
95     bst_node_t npruned = 0;
96     for (int nid = 0; nid < tree.param.num_nodes; ++nid) {
97       if (tree[nid].IsLeaf() && !tree[nid].IsDeleted()) {
98         npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned);
99       }
100     }
101     LOG(INFO) << "tree pruning end, "
102               << tree.NumExtraNodes() << " extra nodes, " << npruned
103               << " pruned nodes, max_depth=" << tree.MaxDepth();
104   }
105 
106  private:
107   // synchronizer
108   std::unique_ptr<TreeUpdater> syncher_;
109   // training parameter
110   TrainParam param_;
111   common::Monitor pruner_monitor_;
112 };
113 
114 XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
115 .describe("Pruner that prune the tree according to statistics.")
__anonca808c770102() 116 .set_body([]() {
117     return new TreePruner();
118   });
119 }  // namespace tree
120 }  // namespace xgboost
121