1 /*!
2 * Copyright 2017-2021 XGBoost contributors
3 */
4 #include <thrust/iterator/discard_iterator.h>
5 #include <thrust/iterator/transform_output_iterator.h>
6 #include <thrust/sequence.h>
7 #include <vector>
8 #include "../../common/device_helpers.cuh"
9 #include "row_partitioner.cuh"
10
11 namespace xgboost {
12 namespace tree {
13 struct IndexFlagTuple {
14 size_t idx;
15 size_t flag;
16 };
17
18 struct IndexFlagOp {
operator ()xgboost::tree::IndexFlagOp19 __device__ IndexFlagTuple operator()(const IndexFlagTuple& a,
20 const IndexFlagTuple& b) const {
21 return {b.idx, a.flag + b.flag};
22 }
23 };
24
25 struct WriteResultsFunctor {
26 bst_node_t left_nidx;
27 common::Span<bst_node_t> position_in;
28 common::Span<bst_node_t> position_out;
29 common::Span<RowPartitioner::RowIndexT> ridx_in;
30 common::Span<RowPartitioner::RowIndexT> ridx_out;
31 int64_t* d_left_count;
32
operator ()xgboost::tree::WriteResultsFunctor33 __device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
34 // the ex_scan_result represents how many rows have been assigned to left
35 // node so far during scan.
36 int scatter_address;
37 if (position_in[x.idx] == left_nidx) {
38 scatter_address = x.flag - 1; // -1 because inclusive scan
39 } else {
40 // current number of rows belong to right node + total number of rows
41 // belong to left node
42 scatter_address = (x.idx - x.flag) + *d_left_count;
43 }
44 // copy the node id to output
45 position_out[scatter_address] = position_in[x.idx];
46 ridx_out[scatter_address] = ridx_in[x.idx];
47
48 // Discard
49 return {};
50 }
51 };
52
53 // Implement partitioning via single scan operation using transform output to
54 // write the result
SortPosition(common::Span<bst_node_t> position,common::Span<bst_node_t> position_out,common::Span<RowIndexT> ridx,common::Span<RowIndexT> ridx_out,bst_node_t left_nidx,bst_node_t,int64_t * d_left_count,cudaStream_t stream)55 void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
56 common::Span<bst_node_t> position_out,
57 common::Span<RowIndexT> ridx,
58 common::Span<RowIndexT> ridx_out,
59 bst_node_t left_nidx, bst_node_t,
60 int64_t* d_left_count, cudaStream_t stream) {
61 WriteResultsFunctor write_results{left_nidx, position, position_out,
62 ridx, ridx_out, d_left_count};
63 auto discard_write_iterator =
64 thrust::make_transform_output_iterator(dh::TypedDiscard<IndexFlagTuple>(), write_results);
65 auto counting = thrust::make_counting_iterator(0llu);
66 auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
67 counting, [=] __device__(size_t idx) {
68 return IndexFlagTuple{idx, static_cast<size_t>(position[idx] == left_nidx)};
69 });
70 size_t temp_bytes = 0;
71 cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator,
72 discard_write_iterator, IndexFlagOp(),
73 position.size(), stream);
74 dh::TemporaryArray<int8_t> temp(temp_bytes);
75 cub::DeviceScan::InclusiveScan(temp.data().get(), temp_bytes, input_iterator,
76 discard_write_iterator, IndexFlagOp(),
77 position.size(), stream);
78 }
79
Reset(int device_idx,common::Span<RowPartitioner::RowIndexT> ridx,common::Span<bst_node_t> position)80 void Reset(int device_idx, common::Span<RowPartitioner::RowIndexT> ridx,
81 common::Span<bst_node_t> position) {
82 CHECK_EQ(ridx.size(), position.size());
83 dh::LaunchN(ridx.size(), [=] __device__(size_t idx) {
84 ridx[idx] = idx;
85 position[idx] = 0;
86 });
87 }
88
RowPartitioner(int device_idx,size_t num_rows)89 RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
90 : device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows),
91 ridx_b_(num_rows), position_b_(num_rows) {
92 dh::safe_cuda(cudaSetDevice(device_idx_));
93 ridx_ = dh::DoubleBuffer<RowIndexT>{&ridx_a_, &ridx_b_};
94 position_ = dh::DoubleBuffer<bst_node_t>{&position_a_, &position_b_};
95 ridx_segments_.emplace_back(Segment(0, num_rows));
96
97 Reset(device_idx, ridx_.CurrentSpan(), position_.CurrentSpan());
98 left_counts_.resize(256);
99 thrust::fill(left_counts_.begin(), left_counts_.end(), 0);
100 streams_.resize(2);
101 for (auto& stream : streams_) {
102 dh::safe_cuda(cudaStreamCreate(&stream));
103 }
104 }
~RowPartitioner()105 RowPartitioner::~RowPartitioner() {
106 dh::safe_cuda(cudaSetDevice(device_idx_));
107 for (auto& stream : streams_) {
108 dh::safe_cuda(cudaStreamDestroy(stream));
109 }
110 }
111
GetRows(bst_node_t nidx)112 common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(
113 bst_node_t nidx) {
114 auto segment = ridx_segments_.at(nidx);
115 // Return empty span here as a valid result
116 // Will error if we try to construct a span from a pointer with size 0
117 if (segment.Size() == 0) {
118 return {};
119 }
120 return ridx_.CurrentSpan().subspan(segment.begin, segment.Size());
121 }
122
GetRows()123 common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() {
124 return ridx_.CurrentSpan();
125 }
126
GetPosition()127 common::Span<const bst_node_t> RowPartitioner::GetPosition() {
128 return position_.CurrentSpan();
129 }
GetRowsHost(bst_node_t nidx)130 std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(
131 bst_node_t nidx) {
132 auto span = GetRows(nidx);
133 std::vector<RowIndexT> rows(span.size());
134 dh::CopyDeviceSpanToVector(&rows, span);
135 return rows;
136 }
137
GetPositionHost()138 std::vector<bst_node_t> RowPartitioner::GetPositionHost() {
139 auto span = GetPosition();
140 std::vector<bst_node_t> position(span.size());
141 dh::CopyDeviceSpanToVector(&position, span);
142 return position;
143 }
144
SortPositionAndCopy(const Segment & segment,bst_node_t left_nidx,bst_node_t right_nidx,int64_t * d_left_count,cudaStream_t stream)145 void RowPartitioner::SortPositionAndCopy(const Segment& segment,
146 bst_node_t left_nidx,
147 bst_node_t right_nidx,
148 int64_t* d_left_count,
149 cudaStream_t stream) {
150 SortPosition(
151 // position_in
152 common::Span<bst_node_t>(position_.Current() + segment.begin,
153 segment.Size()),
154 // position_out
155 common::Span<bst_node_t>(position_.Other() + segment.begin,
156 segment.Size()),
157 // row index in
158 common::Span<RowIndexT>(ridx_.Current() + segment.begin, segment.Size()),
159 // row index out
160 common::Span<RowIndexT>(ridx_.Other() + segment.begin, segment.Size()),
161 left_nidx, right_nidx, d_left_count, stream);
162 // Copy back key/value
163 const auto d_position_current = position_.Current() + segment.begin;
164 const auto d_position_other = position_.Other() + segment.begin;
165 const auto d_ridx_current = ridx_.Current() + segment.begin;
166 const auto d_ridx_other = ridx_.Other() + segment.begin;
167 dh::LaunchN(segment.Size(), stream, [=] __device__(size_t idx) {
168 d_position_current[idx] = d_position_other[idx];
169 d_ridx_current[idx] = d_ridx_other[idx];
170 });
171 }
172 }; // namespace tree
173 }; // namespace xgboost
174