1 /*!
2 * Copyright 2014-2019 by Contributors
3 * \file updater_histmaker.cc
4 * \brief use histogram counting to construct a tree
5 * \author Tianqi Chen
6 */
7 #include <rabit/rabit.h>
8 #include <vector>
9 #include <algorithm>
10
11 #include "xgboost/tree_updater.h"
12 #include "xgboost/base.h"
13 #include "xgboost/logging.h"
14
15 #include "../common/quantile.h"
16 #include "../common/group_data.h"
17 #include "./updater_basemaker-inl.h"
18 #include "constraints.h"
19
20 namespace xgboost {
21 namespace tree {
22
23 DMLC_REGISTRY_FILE_TAG(updater_histmaker);
24
25 class HistMaker: public BaseMaker {
26 public:
Update(HostDeviceVector<GradientPair> * gpair,DMatrix * p_fmat,const std::vector<RegTree * > & trees)27 void Update(HostDeviceVector<GradientPair> *gpair,
28 DMatrix *p_fmat,
29 const std::vector<RegTree*> &trees) override {
30 interaction_constraints_.Configure(param_, p_fmat->Info().num_col_);
31 // rescale learning rate according to size of trees
32 float lr = param_.learning_rate;
33 param_.learning_rate = lr / trees.size();
34 // build tree
35 for (auto tree : trees) {
36 this->UpdateTree(gpair->ConstHostVector(), p_fmat, tree);
37 }
38 param_.learning_rate = lr;
39 }
Name() const40 char const* Name() const override {
41 return "grow_histmaker";
42 }
43
44 protected:
45 /*! \brief a single column of histogram cuts */
46 struct HistUnit {
47 /*! \brief cutting point of histogram, contains maximum point */
48 const float *cut;
49 /*! \brief content of statistics data */
50 GradStats *data;
51 /*! \brief size of histogram */
52 uint32_t size;
53 // default constructor
54 HistUnit() = default;
55 // constructor
HistUnitxgboost::tree::HistMaker::HistUnit56 HistUnit(const float *cut, GradStats *data, uint32_t size)
57 : cut{cut}, data{data}, size{size} {}
58 /*! \brief add a histogram to data */
59 };
60 /*! \brief a set of histograms from different index */
61 struct HistSet {
62 /*! \brief the index pointer of each histunit */
63 const uint32_t *rptr;
64 /*! \brief cutting points in each histunit */
65 const bst_float *cut;
66 /*! \brief data in different hist unit */
67 std::vector<GradStats> data;
68 /*! \brief return a column of histogram cuts */
operator []xgboost::tree::HistMaker::HistSet69 inline HistUnit operator[](size_t fid) {
70 return {cut + rptr[fid], &data[0] + rptr[fid], rptr[fid+1] - rptr[fid]};
71 }
72 };
73 // thread workspace
74 struct ThreadWSpace {
75 /*! \brief actual unit pointer */
76 std::vector<unsigned> rptr;
77 /*! \brief cut field */
78 std::vector<bst_float> cut;
79 // per thread histset
80 std::vector<HistSet> hset;
81 // initialize the hist set
Configurexgboost::tree::HistMaker::ThreadWSpace82 inline void Configure(int nthread) {
83 hset.resize(nthread);
84 // cleanup statistics
85 for (int tid = 0; tid < nthread; ++tid) {
86 for (auto& d : hset[tid].data) { d = GradStats(); }
87 hset[tid].rptr = dmlc::BeginPtr(rptr);
88 hset[tid].cut = dmlc::BeginPtr(cut);
89 hset[tid].data.resize(cut.size(), GradStats());
90 }
91 }
92 /*! \brief clear the workspace */
Clearxgboost::tree::HistMaker::ThreadWSpace93 inline void Clear() {
94 cut.clear(); rptr.resize(1); rptr[0] = 0;
95 }
96 /*! \brief total size */
Sizexgboost::tree::HistMaker::ThreadWSpace97 inline size_t Size() const {
98 return rptr.size() - 1;
99 }
100 };
101 // workspace of thread
102 ThreadWSpace wspace_;
103 // reducer for histogram
104 rabit::Reducer<GradStats, GradStats::Reduce> histred_;
105 // set of working features
106 std::vector<bst_feature_t> selected_features_;
107 // update function implementation
UpdateTree(const std::vector<GradientPair> & gpair,DMatrix * p_fmat,RegTree * p_tree)108 virtual void UpdateTree(const std::vector<GradientPair> &gpair,
109 DMatrix *p_fmat,
110 RegTree *p_tree) {
111 CHECK(param_.max_depth > 0) << "max_depth must be larger than 0";
112 this->InitData(gpair, *p_fmat, *p_tree);
113 this->InitWorkSet(p_fmat, *p_tree, &selected_features_);
114 // mark root node as fresh.
115 (*p_tree)[0].SetLeaf(0.0f, 0);
116
117 for (int depth = 0; depth < param_.max_depth; ++depth) {
118 // reset and propose candidate split
119 this->ResetPosAndPropose(gpair, p_fmat, selected_features_, *p_tree);
120 // create histogram
121 this->CreateHist(gpair, p_fmat, selected_features_, *p_tree);
122 // find split based on histogram statistics
123 this->FindSplit(selected_features_, p_tree);
124 // reset position after split
125 this->ResetPositionAfterSplit(p_fmat, *p_tree);
126 this->UpdateQueueExpand(*p_tree);
127 // if nothing left to be expand, break
128 if (qexpand_.size() == 0) break;
129 }
130 for (int const nid : qexpand_) {
131 (*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
132 }
133 }
134 // this function does two jobs
135 // (1) reset the position in array position, to be the latest leaf id
136 // (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly
137 virtual void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
138 DMatrix *p_fmat,
139 const std::vector <bst_feature_t> &fset,
140 const RegTree &tree) = 0;
141 // initialize the current working set of features in this round
InitWorkSet(DMatrix *,const RegTree & tree,std::vector<bst_feature_t> * p_fset)142 virtual void InitWorkSet(DMatrix *,
143 const RegTree &tree,
144 std::vector<bst_feature_t> *p_fset) {
145 p_fset->resize(tree.param.num_feature);
146 for (size_t i = 0; i < p_fset->size(); ++i) {
147 (*p_fset)[i] = static_cast<unsigned>(i);
148 }
149 }
150 // reset position after split, this is not a must, depending on implementation
ResetPositionAfterSplit(DMatrix * p_fmat,const RegTree & tree)151 virtual void ResetPositionAfterSplit(DMatrix *p_fmat,
152 const RegTree &tree) {
153 }
154 virtual void CreateHist(const std::vector<GradientPair> &gpair,
155 DMatrix *,
156 const std::vector <bst_feature_t> &fset,
157 const RegTree &) = 0;
158
159 private:
EnumerateSplit(const HistUnit & hist,const GradStats & node_sum,bst_uint fid,SplitEntry * best,GradStats * left_sum) const160 void EnumerateSplit(const HistUnit &hist,
161 const GradStats &node_sum,
162 bst_uint fid,
163 SplitEntry *best,
164 GradStats *left_sum) const {
165 if (hist.size == 0) return;
166
167 double root_gain = CalcGain(param_, node_sum.GetGrad(), node_sum.GetHess());
168 GradStats s, c;
169 for (bst_uint i = 0; i < hist.size; ++i) {
170 s.Add(hist.data[i]);
171 if (s.sum_hess >= param_.min_child_weight) {
172 c.SetSubstract(node_sum, s);
173 if (c.sum_hess >= param_.min_child_weight) {
174 double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) +
175 CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain;
176 if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i], false, s, c)) {
177 *left_sum = s;
178 }
179 }
180 }
181 }
182 s = GradStats();
183 for (bst_uint i = hist.size - 1; i != 0; --i) {
184 s.Add(hist.data[i]);
185 if (s.sum_hess >= param_.min_child_weight) {
186 c.SetSubstract(node_sum, s);
187 if (c.sum_hess >= param_.min_child_weight) {
188 double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) +
189 CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain;
190 if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true, c, s)) {
191 *left_sum = c;
192 }
193 }
194 }
195 }
196 }
197
FindSplit(const std::vector<bst_feature_t> & feature_set,RegTree * p_tree)198 void FindSplit(const std::vector <bst_feature_t> &feature_set,
199 RegTree *p_tree) {
200 const size_t num_feature = feature_set.size();
201 // get the best split condition for each node
202 std::vector<SplitEntry> sol(qexpand_.size());
203 std::vector<GradStats> left_sum(qexpand_.size());
204 auto nexpand = static_cast<bst_omp_uint>(qexpand_.size());
205 dmlc::OMPException exc;
206 #pragma omp parallel for schedule(dynamic, 1)
207 for (bst_omp_uint wid = 0; wid < nexpand; ++wid) {
208 exc.Run([&]() {
209 const int nid = qexpand_[wid];
210 CHECK_EQ(node2workindex_[nid], static_cast<int>(wid));
211 SplitEntry &best = sol[wid];
212 GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0];
213 for (size_t i = 0; i < feature_set.size(); ++i) {
214 // Query is thread safe as it's a const function.
215 if (!this->interaction_constraints_.Query(nid, feature_set[i])) {
216 continue;
217 }
218
219 EnumerateSplit(this->wspace_.hset[0][i + wid * (num_feature+1)],
220 node_sum, feature_set[i], &best, &left_sum[wid]);
221 }
222 });
223 }
224 exc.Rethrow();
225 // get the best result, we can synchronize the solution
226 for (bst_omp_uint wid = 0; wid < nexpand; ++wid) {
227 const bst_node_t nid = qexpand_[wid];
228 SplitEntry const& best = sol[wid];
229 const GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0];
230 this->SetStats(p_tree, nid, node_sum);
231 // set up the values
232 p_tree->Stat(nid).loss_chg = best.loss_chg;
233 // now we know the solution in snode[nid], set split
234 if (best.loss_chg > kRtEps) {
235 bst_float base_weight = CalcWeight(param_, node_sum);
236 bst_float left_leaf_weight =
237 CalcWeight(param_, best.left_sum.sum_grad, best.left_sum.sum_hess) *
238 param_.learning_rate;
239 bst_float right_leaf_weight =
240 CalcWeight(param_, best.right_sum.sum_grad,
241 best.right_sum.sum_hess) *
242 param_.learning_rate;
243 p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
244 best.DefaultLeft(), base_weight, left_leaf_weight,
245 right_leaf_weight, best.loss_chg,
246 node_sum.sum_hess,
247 best.left_sum.GetHess(), best.right_sum.GetHess());
248 GradStats right_sum;
249 right_sum.SetSubstract(node_sum, left_sum[wid]);
250 auto left_child = (*p_tree)[nid].LeftChild();
251 auto right_child = (*p_tree)[nid].RightChild();
252 this->SetStats(p_tree, left_child, left_sum[wid]);
253 this->SetStats(p_tree, right_child, right_sum);
254 this->interaction_constraints_.Split(nid, best.SplitIndex(), left_child, right_child);
255 } else {
256 (*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
257 }
258 }
259 }
260
SetStats(RegTree * p_tree,int nid,const GradStats & node_sum)261 inline void SetStats(RegTree *p_tree, int nid, const GradStats &node_sum) {
262 p_tree->Stat(nid).base_weight =
263 static_cast<bst_float>(CalcWeight(param_, node_sum));
264 p_tree->Stat(nid).sum_hess = static_cast<bst_float>(node_sum.sum_hess);
265 }
266 };
267
268 class CQHistMaker: public HistMaker {
269 public:
270 CQHistMaker() = default;
Name() const271 char const* Name() const override {
272 return "grow_local_histmaker";
273 }
274
275 protected:
276 struct HistEntry {
277 HistMaker::HistUnit hist;
278 unsigned istart;
279 /*!
280 * \brief add a histogram to data,
281 * do linear scan, start from istart
282 */
Addxgboost::tree::CQHistMaker::HistEntry283 inline void Add(bst_float fv,
284 const std::vector<GradientPair> &gpair,
285 const bst_uint ridx) {
286 while (istart < hist.size && !(fv < hist.cut[istart])) ++istart;
287 CHECK_NE(istart, hist.size);
288 hist.data[istart].Add(gpair[ridx]);
289 }
290 /*!
291 * \brief add a histogram to data,
292 * do linear scan, start from istart
293 */
Addxgboost::tree::CQHistMaker::HistEntry294 inline void Add(bst_float fv,
295 GradientPair gstats) {
296 if (fv < hist.cut[istart]) {
297 hist.data[istart].Add(gstats);
298 } else {
299 while (istart < hist.size && !(fv < hist.cut[istart])) ++istart;
300 if (istart != hist.size) {
301 hist.data[istart].Add(gstats);
302 } else {
303 LOG(INFO) << "fv=" << fv << ", hist.size=" << hist.size;
304 for (size_t i = 0; i < hist.size; ++i) {
305 LOG(INFO) << "hist[" << i << "]=" << hist.cut[i];
306 }
307 LOG(FATAL) << "fv=" << fv << ", hist.last=" << hist.cut[hist.size - 1];
308 }
309 }
310 }
311 };
312 // sketch type used for this
313 using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
314 // initialize the work set of tree
InitWorkSet(DMatrix * p_fmat,const RegTree & tree,std::vector<bst_feature_t> * p_fset)315 void InitWorkSet(DMatrix *p_fmat,
316 const RegTree &tree,
317 std::vector<bst_feature_t> *p_fset) override {
318 if (p_fmat != cache_dmatrix_) {
319 feat_helper_.InitByCol(p_fmat, tree);
320 cache_dmatrix_ = p_fmat;
321 }
322 feat_helper_.SyncInfo();
323 feat_helper_.SampleCol(this->param_.colsample_bytree, p_fset);
324 }
325 // code to create histogram
CreateHist(const std::vector<GradientPair> & gpair,DMatrix * p_fmat,const std::vector<bst_feature_t> & fset,const RegTree & tree)326 void CreateHist(const std::vector<GradientPair> &gpair,
327 DMatrix *p_fmat,
328 const std::vector<bst_feature_t> &fset,
329 const RegTree &tree) override {
330 const MetaInfo &info = p_fmat->Info();
331 // fill in reverse map
332 feat2workindex_.resize(tree.param.num_feature);
333 std::fill(feat2workindex_.begin(), feat2workindex_.end(), -1);
334 for (size_t i = 0; i < fset.size(); ++i) {
335 feat2workindex_[fset[i]] = static_cast<int>(i);
336 }
337 // start to work
338 this->wspace_.Configure(1);
339 // if it is C++11, use lazy evaluation for Allreduce,
340 // to gain speedup in recovery
341 auto lazy_get_hist = [&]() {
342 thread_hist_.resize(omp_get_max_threads());
343 // start accumulating statistics
344 for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
345 auto page = batch.GetView();
346 // start enumeration
347 const auto nsize = static_cast<bst_omp_uint>(fset.size());
348 dmlc::OMPException exc;
349 #pragma omp parallel for schedule(dynamic, 1)
350 for (bst_omp_uint i = 0; i < nsize; ++i) {
351 exc.Run([&]() {
352 int fid = fset[i];
353 int offset = feat2workindex_[fid];
354 if (offset >= 0) {
355 this->UpdateHistCol(gpair, page[fid], info, tree,
356 fset, offset,
357 &thread_hist_[omp_get_thread_num()]);
358 }
359 });
360 }
361 exc.Rethrow();
362 }
363 // update node statistics.
364 this->GetNodeStats(gpair, *p_fmat, tree,
365 &thread_stats_, &node_stats_);
366 for (int const nid : this->qexpand_) {
367 const int wid = this->node2workindex_[nid];
368 this->wspace_.hset[0][fset.size() + wid * (fset.size() + 1)]
369 .data[0] = node_stats_[nid];
370 }
371 };
372 // sync the histogram
373 this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data),
374 this->wspace_.hset[0].data.size(), lazy_get_hist);
375 }
376
ResetPositionAfterSplit(DMatrix *,const RegTree & tree)377 void ResetPositionAfterSplit(DMatrix *,
378 const RegTree &tree) override {
379 this->GetSplitSet(this->qexpand_, tree, &fsplit_set_);
380 }
ResetPosAndPropose(const std::vector<GradientPair> & gpair,DMatrix * p_fmat,const std::vector<bst_feature_t> & fset,const RegTree & tree)381 void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
382 DMatrix *p_fmat,
383 const std::vector<bst_feature_t> &fset,
384 const RegTree &tree) override {
385 const MetaInfo &info = p_fmat->Info();
386 // fill in reverse map
387 feat2workindex_.resize(tree.param.num_feature);
388 std::fill(feat2workindex_.begin(), feat2workindex_.end(), -1);
389 work_set_.clear();
390 for (auto fidx : fset) {
391 if (feat_helper_.Type(fidx) == 2) {
392 feat2workindex_[fidx] = static_cast<int>(work_set_.size());
393 work_set_.push_back(fidx);
394 } else {
395 feat2workindex_[fidx] = -2;
396 }
397 }
398 const size_t work_set_size = work_set_.size();
399
400 sketchs_.resize(this->qexpand_.size() * work_set_size);
401 for (auto& sketch : sketchs_) {
402 sketch.Init(info.num_row_, this->param_.sketch_eps);
403 }
404 // initialize the summary array
405 summary_array_.resize(sketchs_.size());
406 // setup maximum size
407 unsigned max_size = this->param_.MaxSketchSize();
408 for (size_t i = 0; i < sketchs_.size(); ++i) {
409 summary_array_[i].Reserve(max_size);
410 }
411 {
412 // get summary
413 thread_sketch_.resize(omp_get_max_threads());
414
415 // TWOPASS: use the real set + split set in the column iteration.
416 this->SetDefaultPostion(p_fmat, tree);
417 work_set_.insert(work_set_.end(), fsplit_set_.begin(), fsplit_set_.end());
418 std::sort(work_set_.begin(), work_set_.end());
419 work_set_.resize(std::unique(work_set_.begin(), work_set_.end()) - work_set_.begin());
420
421 // start accumulating statistics
422 for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
423 // TWOPASS: use the real set + split set in the column iteration.
424 this->CorrectNonDefaultPositionByBatch(batch, fsplit_set_, tree);
425 auto page = batch.GetView();
426 // start enumeration
427 const auto nsize = static_cast<bst_omp_uint>(work_set_.size());
428 dmlc::OMPException exc;
429 #pragma omp parallel for schedule(dynamic, 1)
430 for (bst_omp_uint i = 0; i < nsize; ++i) {
431 exc.Run([&]() {
432 int fid = work_set_[i];
433 int offset = feat2workindex_[fid];
434 if (offset >= 0) {
435 this->UpdateSketchCol(gpair, page[fid], tree,
436 work_set_size, offset,
437 &thread_sketch_[omp_get_thread_num()]);
438 }
439 });
440 }
441 exc.Rethrow();
442 }
443 for (size_t i = 0; i < sketchs_.size(); ++i) {
444 common::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
445 sketchs_[i].GetSummary(&out);
446 summary_array_[i].SetPrune(out, max_size);
447 }
448 CHECK_EQ(summary_array_.size(), sketchs_.size());
449 }
450 if (summary_array_.size() != 0) {
451 size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
452 sreducer_.Allreduce(dmlc::BeginPtr(summary_array_), nbytes, summary_array_.size());
453 }
454 // now we get the final result of sketch, setup the cut
455 this->wspace_.cut.clear();
456 this->wspace_.rptr.clear();
457 this->wspace_.rptr.push_back(0);
458 for (size_t wid = 0; wid < this->qexpand_.size(); ++wid) {
459 for (unsigned int i : fset) {
460 int offset = feat2workindex_[i];
461 if (offset >= 0) {
462 const WXQSketch::Summary &a = summary_array_[wid * work_set_size + offset];
463 for (size_t i = 1; i < a.size; ++i) {
464 bst_float cpt = a.data[i].value - kRtEps;
465 if (i == 1 || cpt > this->wspace_.cut.back()) {
466 this->wspace_.cut.push_back(cpt);
467 }
468 }
469 // push a value that is greater than anything
470 if (a.size != 0) {
471 bst_float cpt = a.data[a.size - 1].value;
472 // this must be bigger than last value in a scale
473 bst_float last = cpt + fabs(cpt) + kRtEps;
474 this->wspace_.cut.push_back(last);
475 }
476 this->wspace_.rptr.push_back(static_cast<unsigned>(this->wspace_.cut.size()));
477 } else {
478 CHECK_EQ(offset, -2);
479 bst_float cpt = feat_helper_.MaxValue(i);
480 this->wspace_.cut.push_back(cpt + fabs(cpt) + kRtEps);
481 this->wspace_.rptr.push_back(static_cast<unsigned>(this->wspace_.cut.size()));
482 }
483 }
484 // reserve last value for global statistics
485 this->wspace_.cut.push_back(0.0f);
486 this->wspace_.rptr.push_back(static_cast<unsigned>(this->wspace_.cut.size()));
487 }
488 CHECK_EQ(this->wspace_.rptr.size(),
489 (fset.size() + 1) * this->qexpand_.size() + 1);
490 }
491
UpdateHistCol(const std::vector<GradientPair> & gpair,const SparsePage::Inst & col,const MetaInfo & info,const RegTree & tree,const std::vector<bst_feature_t> & fset,bst_uint fid_offset,std::vector<HistEntry> * p_temp)492 inline void UpdateHistCol(const std::vector<GradientPair> &gpair,
493 const SparsePage::Inst &col,
494 const MetaInfo &info,
495 const RegTree &tree,
496 const std::vector<bst_feature_t> &fset,
497 bst_uint fid_offset,
498 std::vector<HistEntry> *p_temp) {
499 if (col.size() == 0) return;
500 // initialize sbuilder for use
501 std::vector<HistEntry> &hbuilder = *p_temp;
502 hbuilder.resize(tree.param.num_nodes);
503 for (int const nid : this->qexpand_) {
504 const unsigned wid = this->node2workindex_[nid];
505 hbuilder[nid].istart = 0;
506 hbuilder[nid].hist = this->wspace_.hset[0][fid_offset + wid * (fset.size()+1)];
507 }
508 if (this->param_.cache_opt != 0) {
509 constexpr bst_uint kBuffer = 32;
510 bst_uint align_length = col.size() / kBuffer * kBuffer;
511 int buf_position[kBuffer];
512 GradientPair buf_gpair[kBuffer];
513 for (bst_uint j = 0; j < align_length; j += kBuffer) {
514 for (bst_uint i = 0; i < kBuffer; ++i) {
515 bst_uint ridx = col[j + i].index;
516 buf_position[i] = this->position_[ridx];
517 buf_gpair[i] = gpair[ridx];
518 }
519 for (bst_uint i = 0; i < kBuffer; ++i) {
520 const int nid = buf_position[i];
521 if (nid >= 0) {
522 hbuilder[nid].Add(col[j + i].fvalue, buf_gpair[i]);
523 }
524 }
525 }
526 for (bst_uint j = align_length; j < col.size(); ++j) {
527 const bst_uint ridx = col[j].index;
528 const int nid = this->position_[ridx];
529 if (nid >= 0) {
530 hbuilder[nid].Add(col[j].fvalue, gpair[ridx]);
531 }
532 }
533 } else {
534 for (const auto& c : col) {
535 const bst_uint ridx = c.index;
536 const int nid = this->position_[ridx];
537 if (nid >= 0) {
538 hbuilder[nid].Add(c.fvalue, gpair, ridx);
539 }
540 }
541 }
542 }
UpdateSketchCol(const std::vector<GradientPair> & gpair,const SparsePage::Inst & col,const RegTree & tree,size_t work_set_size,bst_uint offset,std::vector<BaseMaker::SketchEntry> * p_temp)543 inline void UpdateSketchCol(const std::vector<GradientPair> &gpair,
544 const SparsePage::Inst &col,
545 const RegTree &tree,
546 size_t work_set_size,
547 bst_uint offset,
548 std::vector<BaseMaker::SketchEntry> *p_temp) {
549 if (col.size() == 0) return;
550 // initialize sbuilder for use
551 std::vector<BaseMaker::SketchEntry> &sbuilder = *p_temp;
552 sbuilder.resize(tree.param.num_nodes);
553 for (int const nid : this->qexpand_) {
554 const unsigned wid = this->node2workindex_[nid];
555 sbuilder[nid].sum_total = 0.0f;
556 sbuilder[nid].sketch = &sketchs_[wid * work_set_size + offset];
557 }
558 // first pass, get sum of weight, TODO, optimization to skip first pass
559 for (const auto& c : col) {
560 const bst_uint ridx = c.index;
561 const int nid = this->position_[ridx];
562 if (nid >= 0) {
563 sbuilder[nid].sum_total += gpair[ridx].GetHess();
564 }
565 }
566 // if only one value, no need to do second pass
567 if (col[0].fvalue == col[col.size()-1].fvalue) {
568 for (int const nid : this->qexpand_) {
569 sbuilder[nid].sketch->Push(
570 col[0].fvalue, static_cast<bst_float>(sbuilder[nid].sum_total));
571 }
572 return;
573 }
574 // two pass scan
575 unsigned max_size = this->param_.MaxSketchSize();
576 for (int const nid : this->qexpand_) {
577 sbuilder[nid].Init(max_size);
578 }
579 // second pass, build the sketch
580 if (this->param_.cache_opt != 0) {
581 constexpr bst_uint kBuffer = 32;
582 bst_uint align_length = col.size() / kBuffer * kBuffer;
583 int buf_position[kBuffer];
584 bst_float buf_hess[kBuffer];
585 for (bst_uint j = 0; j < align_length; j += kBuffer) {
586 for (bst_uint i = 0; i < kBuffer; ++i) {
587 bst_uint ridx = col[j + i].index;
588 buf_position[i] = this->position_[ridx];
589 buf_hess[i] = gpair[ridx].GetHess();
590 }
591 for (bst_uint i = 0; i < kBuffer; ++i) {
592 const int nid = buf_position[i];
593 if (nid >= 0) {
594 sbuilder[nid].Push(col[j + i].fvalue, buf_hess[i], max_size);
595 }
596 }
597 }
598 for (bst_uint j = align_length; j < col.size(); ++j) {
599 const bst_uint ridx = col[j].index;
600 const int nid = this->position_[ridx];
601 if (nid >= 0) {
602 sbuilder[nid].Push(col[j].fvalue, gpair[ridx].GetHess(), max_size);
603 }
604 }
605 } else {
606 for (const auto& c : col) {
607 const bst_uint ridx = c.index;
608 const int nid = this->position_[ridx];
609 if (nid >= 0) {
610 sbuilder[nid].Push(c.fvalue, gpair[ridx].GetHess(), max_size);
611 }
612 }
613 }
614 for (int const nid : this->qexpand_) { sbuilder[nid].Finalize(max_size); }
615 }
616 // cached dmatrix where we initialized the feature on.
617 const DMatrix* cache_dmatrix_{nullptr};
618 // feature helper
619 BaseMaker::FMetaHelper feat_helper_;
620 // temp space to map feature id to working index
621 std::vector<int> feat2workindex_;
622 // set of index from fset that are current work set
623 std::vector<bst_feature_t> work_set_;
624 // set of index from that are split candidates.
625 std::vector<bst_uint> fsplit_set_;
626 // thread temp data
627 std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch_;
628 // used to hold statistics
629 std::vector<std::vector<GradStats> > thread_stats_;
630 // used to hold start pointer
631 std::vector<std::vector<HistEntry> > thread_hist_;
632 // node statistics
633 std::vector<GradStats> node_stats_;
634 // summary array
635 std::vector<WXQSketch::SummaryContainer> summary_array_;
636 // reducer for summary
637 rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer_;
638 // per node, per feature sketch
639 std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs_;
640 };
641
642 // global proposal
643 class GlobalProposalHistMaker: public CQHistMaker {
644 public:
Name() const645 char const* Name() const override {
646 return "grow_histmaker";
647 }
648
649 protected:
ResetPosAndPropose(const std::vector<GradientPair> & gpair,DMatrix * p_fmat,const std::vector<bst_feature_t> & fset,const RegTree & tree)650 void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
651 DMatrix *p_fmat,
652 const std::vector<bst_feature_t> &fset,
653 const RegTree &tree) override {
654 if (this->qexpand_.size() == 1) {
655 cached_rptr_.clear();
656 cached_cut_.clear();
657 }
658 if (cached_rptr_.size() == 0) {
659 CHECK_EQ(this->qexpand_.size(), 1U);
660 CQHistMaker::ResetPosAndPropose(gpair, p_fmat, fset, tree);
661 cached_rptr_ = this->wspace_.rptr;
662 cached_cut_ = this->wspace_.cut;
663 } else {
664 this->wspace_.cut.clear();
665 this->wspace_.rptr.clear();
666 this->wspace_.rptr.push_back(0);
667 for (size_t i = 0; i < this->qexpand_.size(); ++i) {
668 for (size_t j = 0; j < cached_rptr_.size() - 1; ++j) {
669 this->wspace_.rptr.push_back(
670 this->wspace_.rptr.back() + cached_rptr_[j + 1] - cached_rptr_[j]);
671 }
672 this->wspace_.cut.insert(this->wspace_.cut.end(), cached_cut_.begin(), cached_cut_.end());
673 }
674 CHECK_EQ(this->wspace_.rptr.size(),
675 (fset.size() + 1) * this->qexpand_.size() + 1);
676 CHECK_EQ(this->wspace_.rptr.back(), this->wspace_.cut.size());
677 }
678 }
679
680 // code to create histogram
CreateHist(const std::vector<GradientPair> & gpair,DMatrix * p_fmat,const std::vector<bst_feature_t> & fset,const RegTree & tree)681 void CreateHist(const std::vector<GradientPair> &gpair,
682 DMatrix *p_fmat,
683 const std::vector<bst_feature_t> &fset,
684 const RegTree &tree) override {
685 const MetaInfo &info = p_fmat->Info();
686 // fill in reverse map
687 this->feat2workindex_.resize(tree.param.num_feature);
688 this->work_set_ = fset;
689 std::fill(this->feat2workindex_.begin(), this->feat2workindex_.end(), -1);
690 for (size_t i = 0; i < fset.size(); ++i) {
691 this->feat2workindex_[fset[i]] = static_cast<int>(i);
692 }
693 // start to work
694 this->wspace_.Configure(1);
695 // to gain speedup in recovery
696 {
697 this->thread_hist_.resize(omp_get_max_threads());
698
699 // TWOPASS: use the real set + split set in the column iteration.
700 this->SetDefaultPostion(p_fmat, tree);
701 this->work_set_.insert(this->work_set_.end(), this->fsplit_set_.begin(),
702 this->fsplit_set_.end());
703 XGBOOST_PARALLEL_SORT(this->work_set_.begin(), this->work_set_.end(),
704 std::less<>{});
705 this->work_set_.resize(
706 std::unique(this->work_set_.begin(), this->work_set_.end()) - this->work_set_.begin());
707
708 // start accumulating statistics
709 for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
710 // TWOPASS: use the real set + split set in the column iteration.
711 this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree);
712 auto page = batch.GetView();
713
714 // start enumeration
715 const auto nsize = static_cast<bst_omp_uint>(this->work_set_.size());
716 dmlc::OMPException exc;
717 #pragma omp parallel for schedule(dynamic, 1)
718 for (bst_omp_uint i = 0; i < nsize; ++i) {
719 exc.Run([&]() {
720 int fid = this->work_set_[i];
721 int offset = this->feat2workindex_[fid];
722 if (offset >= 0) {
723 this->UpdateHistCol(gpair, page[fid], info, tree,
724 fset, offset,
725 &this->thread_hist_[omp_get_thread_num()]);
726 }
727 });
728 }
729 exc.Rethrow();
730 }
731
732 // update node statistics.
733 this->GetNodeStats(gpair, *p_fmat, tree,
734 &(this->thread_stats_), &(this->node_stats_));
735 for (const int nid : this->qexpand_) {
736 const int wid = this->node2workindex_[nid];
737 this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)]
738 .data[0] = this->node_stats_[nid];
739 }
740 }
741 this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data),
742 this->wspace_.hset[0].data.size());
743 }
744
745 // cached unit pointer
746 std::vector<unsigned> cached_rptr_;
747 // cached cut value.
748 std::vector<bst_float> cached_cut_;
749 };
750
751 XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
752 .describe("Tree constructor that uses approximate histogram construction.")
__anon3e3db4f50602() 753 .set_body([]() {
754 return new CQHistMaker();
755 });
756
757 // The updater for approx tree method.
758 XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
759 .describe("Tree constructor that uses approximate global of histogram construction.")
__anon3e3db4f50702() 760 .set_body([]() {
761 return new GlobalProposalHistMaker();
762 });
763 } // namespace tree
764 } // namespace xgboost
765