1 /*!
2  * Copyright 2017-2021 XGBoost contributors
3  */
4 #include <thrust/iterator/discard_iterator.h>
5 #include <thrust/iterator/transform_output_iterator.h>
6 #include <thrust/sequence.h>
7 #include <vector>
8 #include "../../common/device_helpers.cuh"
9 #include "row_partitioner.cuh"
10 
11 namespace xgboost {
12 namespace tree {
13 struct IndexFlagTuple {
14   size_t idx;
15   size_t flag;
16 };
17 
18 struct IndexFlagOp {
operator ()xgboost::tree::IndexFlagOp19   __device__ IndexFlagTuple operator()(const IndexFlagTuple& a,
20                                        const IndexFlagTuple& b) const {
21     return {b.idx, a.flag + b.flag};
22   }
23 };
24 
25 struct WriteResultsFunctor {
26   bst_node_t left_nidx;
27   common::Span<bst_node_t> position_in;
28   common::Span<bst_node_t> position_out;
29   common::Span<RowPartitioner::RowIndexT> ridx_in;
30   common::Span<RowPartitioner::RowIndexT> ridx_out;
31   int64_t* d_left_count;
32 
operator ()xgboost::tree::WriteResultsFunctor33   __device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
34     // the ex_scan_result represents how many rows have been assigned to left
35     // node so far during scan.
36     int scatter_address;
37     if (position_in[x.idx] == left_nidx) {
38       scatter_address = x.flag - 1;  // -1 because inclusive scan
39     } else {
40       // current number of rows belong to right node + total number of rows
41       // belong to left node
42       scatter_address = (x.idx - x.flag) + *d_left_count;
43     }
44     // copy the node id to output
45     position_out[scatter_address] = position_in[x.idx];
46     ridx_out[scatter_address] = ridx_in[x.idx];
47 
48     // Discard
49     return {};
50   }
51 };
52 
53 // Implement partitioning via single scan operation using transform output to
54 // write the result
SortPosition(common::Span<bst_node_t> position,common::Span<bst_node_t> position_out,common::Span<RowIndexT> ridx,common::Span<RowIndexT> ridx_out,bst_node_t left_nidx,bst_node_t,int64_t * d_left_count,cudaStream_t stream)55 void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
56                                   common::Span<bst_node_t> position_out,
57                                   common::Span<RowIndexT> ridx,
58                                   common::Span<RowIndexT> ridx_out,
59                                   bst_node_t left_nidx, bst_node_t,
60                                   int64_t* d_left_count, cudaStream_t stream) {
61   WriteResultsFunctor write_results{left_nidx, position, position_out,
62                                     ridx,      ridx_out, d_left_count};
63   auto discard_write_iterator =
64       thrust::make_transform_output_iterator(dh::TypedDiscard<IndexFlagTuple>(), write_results);
65   auto counting = thrust::make_counting_iterator(0llu);
66   auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
67       counting, [=] __device__(size_t idx) {
68         return IndexFlagTuple{idx, static_cast<size_t>(position[idx] == left_nidx)};
69       });
70   size_t temp_bytes = 0;
71   cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator,
72                                  discard_write_iterator, IndexFlagOp(),
73                                  position.size(), stream);
74   dh::TemporaryArray<int8_t> temp(temp_bytes);
75   cub::DeviceScan::InclusiveScan(temp.data().get(), temp_bytes, input_iterator,
76                                  discard_write_iterator, IndexFlagOp(),
77                                  position.size(), stream);
78 }
79 
Reset(int device_idx,common::Span<RowPartitioner::RowIndexT> ridx,common::Span<bst_node_t> position)80 void Reset(int device_idx, common::Span<RowPartitioner::RowIndexT> ridx,
81            common::Span<bst_node_t> position) {
82   CHECK_EQ(ridx.size(), position.size());
83   dh::LaunchN(ridx.size(), [=] __device__(size_t idx) {
84     ridx[idx] = idx;
85     position[idx] = 0;
86   });
87 }
88 
RowPartitioner(int device_idx,size_t num_rows)89 RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
90     : device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows),
91       ridx_b_(num_rows), position_b_(num_rows) {
92   dh::safe_cuda(cudaSetDevice(device_idx_));
93   ridx_ = dh::DoubleBuffer<RowIndexT>{&ridx_a_, &ridx_b_};
94   position_ = dh::DoubleBuffer<bst_node_t>{&position_a_, &position_b_};
95   ridx_segments_.emplace_back(Segment(0, num_rows));
96 
97   Reset(device_idx, ridx_.CurrentSpan(), position_.CurrentSpan());
98   left_counts_.resize(256);
99   thrust::fill(left_counts_.begin(), left_counts_.end(), 0);
100   streams_.resize(2);
101   for (auto& stream : streams_) {
102     dh::safe_cuda(cudaStreamCreate(&stream));
103   }
104 }
~RowPartitioner()105 RowPartitioner::~RowPartitioner() {
106   dh::safe_cuda(cudaSetDevice(device_idx_));
107   for (auto& stream : streams_) {
108     dh::safe_cuda(cudaStreamDestroy(stream));
109   }
110 }
111 
GetRows(bst_node_t nidx)112 common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(
113     bst_node_t nidx) {
114   auto segment = ridx_segments_.at(nidx);
115   // Return empty span here as a valid result
116   // Will error if we try to construct a span from a pointer with size 0
117   if (segment.Size() == 0) {
118     return {};
119   }
120   return ridx_.CurrentSpan().subspan(segment.begin, segment.Size());
121 }
122 
GetRows()123 common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() {
124   return ridx_.CurrentSpan();
125 }
126 
GetPosition()127 common::Span<const bst_node_t> RowPartitioner::GetPosition() {
128   return position_.CurrentSpan();
129 }
GetRowsHost(bst_node_t nidx)130 std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(
131     bst_node_t nidx) {
132   auto span = GetRows(nidx);
133   std::vector<RowIndexT> rows(span.size());
134   dh::CopyDeviceSpanToVector(&rows, span);
135   return rows;
136 }
137 
GetPositionHost()138 std::vector<bst_node_t> RowPartitioner::GetPositionHost() {
139   auto span = GetPosition();
140   std::vector<bst_node_t> position(span.size());
141   dh::CopyDeviceSpanToVector(&position, span);
142   return position;
143 }
144 
SortPositionAndCopy(const Segment & segment,bst_node_t left_nidx,bst_node_t right_nidx,int64_t * d_left_count,cudaStream_t stream)145 void RowPartitioner::SortPositionAndCopy(const Segment& segment,
146                                          bst_node_t left_nidx,
147                                          bst_node_t right_nidx,
148                                          int64_t* d_left_count,
149                                          cudaStream_t stream) {
150   SortPosition(
151       // position_in
152       common::Span<bst_node_t>(position_.Current() + segment.begin,
153                                segment.Size()),
154       // position_out
155       common::Span<bst_node_t>(position_.Other() + segment.begin,
156                                segment.Size()),
157       // row index in
158       common::Span<RowIndexT>(ridx_.Current() + segment.begin, segment.Size()),
159       // row index out
160       common::Span<RowIndexT>(ridx_.Other() + segment.begin, segment.Size()),
161       left_nidx, right_nidx, d_left_count, stream);
162   // Copy back key/value
163   const auto d_position_current = position_.Current() + segment.begin;
164   const auto d_position_other = position_.Other() + segment.begin;
165   const auto d_ridx_current = ridx_.Current() + segment.begin;
166   const auto d_ridx_other = ridx_.Other() + segment.begin;
167   dh::LaunchN(segment.Size(), stream, [=] __device__(size_t idx) {
168     d_position_current[idx] = d_position_other[idx];
169     d_ridx_current[idx] = d_ridx_other[idx];
170   });
171 }
172 };  // namespace tree
173 };  // namespace xgboost
174