1 /*!
2  * Copyright 2021 by XGBoost Contributors
3  */
4 #include <thrust/scan.h>
5 #include <cub/cub.cuh>
6 #include <cassert>
7 #include <limits>
8 #include <memory>
9 #include <utility>
10 #include <tuple>
11 
12 #include "rabit/rabit.h"
13 #include "xgboost/span.h"
14 #include "xgboost/data.h"
15 #include "auc.h"
16 #include "../common/device_helpers.cuh"
17 #include "../common/ranking_utils.cuh"
18 
19 namespace xgboost {
20 namespace metric {
21 namespace {
22 template <typename T>
23 using Discard = thrust::discard_iterator<T>;
24 
25 struct GetWeightOp {
26   common::Span<float const> weights;
27   common::Span<size_t const> sorted_idx;
28 
operator ()xgboost::metric::__anonf0f48b580111::GetWeightOp29   __device__ float operator()(size_t i) const {
30     return weights.empty() ? 1.0f : weights[sorted_idx[i]];
31   }
32 };
33 }  // namespace
34 
35 /**
36  * A cache to GPU data to avoid reallocating memory.
37  */
38 struct DeviceAUCCache {
39   // Pair of FP/TP
40   using Pair = thrust::pair<float, float>;
41   // index sorted by prediction value
42   dh::device_vector<size_t> sorted_idx;
43   // track FP/TP for computation on trapesoid area
44   dh::device_vector<Pair> fptp;
45   // track FP_PREV/TP_PREV for computation on trapesoid area
46   dh::device_vector<Pair> neg_pos;
47   // index of unique prediction values.
48   dh::device_vector<size_t> unique_idx;
49   // p^T: transposed prediction matrix, used by MultiClassAUC
50   dh::device_vector<float> predts_t;
51   std::unique_ptr<dh::AllReducer> reducer;
52 
Initxgboost::metric::DeviceAUCCache53   void Init(common::Span<float const> predts, bool is_multi, int32_t device) {
54     if (sorted_idx.size() != predts.size()) {
55       sorted_idx.resize(predts.size());
56       fptp.resize(sorted_idx.size());
57       unique_idx.resize(sorted_idx.size());
58       neg_pos.resize(sorted_idx.size());
59       if (is_multi) {
60         predts_t.resize(sorted_idx.size());
61       }
62     }
63     if (is_multi && !reducer) {
64       reducer.reset(new dh::AllReducer);
65       reducer->Init(device);
66     }
67   }
68 };
69 
70 /**
71  * The GPU implementation uses same calculation as CPU with a few more steps to distribute
72  * work across threads:
73  *
74  * - Run scan to obtain TP/FP values, which are right coordinates of trapesoid.
75  * - Find distinct prediction values and get the corresponding FP_PREV/TP_PREV value,
76  *   which are left coordinates of trapesoids.
77  * - Reduce the scan array into 1 AUC value.
78  */
79 std::tuple<float, float, float>
GPUBinaryAUC(common::Span<float const> predts,MetaInfo const & info,int32_t device,std::shared_ptr<DeviceAUCCache> * p_cache)80 GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
81              int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
82   auto& cache = *p_cache;
83   if (!cache) {
84     cache.reset(new DeviceAUCCache);
85   }
86   cache->Init(predts, false, device);
87 
88   auto labels = info.labels_.ConstDeviceSpan();
89   auto weights = info.weights_.ConstDeviceSpan();
90   dh::safe_cuda(cudaSetDevice(device));
91 
92   CHECK(!labels.empty());
93   CHECK_EQ(labels.size(), predts.size());
94 
95   /**
96    * Create sorted index for each class
97    */
98   auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
99   dh::ArgSort<false>(predts, d_sorted_idx);
100 
101   /**
102    * Linear scan
103    */
104   auto get_weight = GetWeightOp{weights, d_sorted_idx};
105   using Pair = thrust::pair<float, float>;
106   auto get_fp_tp = [=]__device__(size_t i) {
107     size_t idx = d_sorted_idx[i];
108 
109     float label = labels[idx];
110     float w = get_weight(i);
111 
112     float fp = (1.0 - label) * w;
113     float tp = label * w;
114 
115     return thrust::make_pair(fp, tp);
116   };  // NOLINT
117   auto d_fptp = dh::ToSpan(cache->fptp);
118   dh::LaunchN(d_sorted_idx.size(),
119               [=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); });
120 
121   dh::XGBDeviceAllocator<char> alloc;
122   auto d_unique_idx = dh::ToSpan(cache->unique_idx);
123   dh::Iota(d_unique_idx);
124 
125   auto uni_key = dh::MakeTransformIterator<float>(
126       thrust::make_counting_iterator(0),
127       [=] __device__(size_t i) { return predts[d_sorted_idx[i]]; });
128   auto end_unique = thrust::unique_by_key_copy(
129       thrust::cuda::par(alloc), uni_key, uni_key + d_sorted_idx.size(),
130       dh::tbegin(d_unique_idx), thrust::make_discard_iterator(),
131       dh::tbegin(d_unique_idx));
132   d_unique_idx = d_unique_idx.subspan(0, end_unique.second - dh::tbegin(d_unique_idx));
133 
134   dh::InclusiveScan(
135       dh::tbegin(d_fptp), dh::tbegin(d_fptp),
136       [=] __device__(Pair const &l, Pair const &r) {
137         return thrust::make_pair(l.first + r.first, l.second + r.second);
138       },
139       d_fptp.size());
140 
141   auto d_neg_pos = dh::ToSpan(cache->neg_pos);
142   // scatter unique negaive/positive values
143   // shift to right by 1 with initial value being 0
144   dh::LaunchN(d_unique_idx.size(), [=] __device__(size_t i) {
145     if (d_unique_idx[i] == 0) {  // first unique index is 0
146       assert(i == 0);
147       d_neg_pos[0] = {0, 0};
148       return;
149     }
150     d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1];
151     if (i == d_unique_idx.size() - 1) {
152       // last one needs to be included, may override above assignment if the last
153       // prediction value is distinct from previous one.
154       d_neg_pos.back() = d_fptp[d_unique_idx[i] - 1];
155       return;
156     }
157   });
158 
159   auto in = dh::MakeTransformIterator<float>(
160       thrust::make_counting_iterator(0), [=] __device__(size_t i) {
161         float fp, tp;
162         float fp_prev, tp_prev;
163         if (i == 0) {
164           // handle the last element
165           thrust::tie(fp, tp) = d_fptp.back();
166           thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx.back()];
167         } else {
168           thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1];
169           thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]];
170         }
171         return TrapesoidArea(fp_prev, fp, tp_prev, tp);
172       });
173 
174   Pair last = cache->fptp.back();
175   float auc = thrust::reduce(thrust::cuda::par(alloc), in, in + d_unique_idx.size());
176   return std::make_tuple(last.first, last.second, auc);
177 }
178 
Transpose(common::Span<float const> in,common::Span<float> out,size_t m,size_t n,int32_t device)179 void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
180                size_t n, int32_t device) {
181   CHECK_EQ(in.size(), out.size());
182   CHECK_EQ(in.size(), m * n);
183   dh::LaunchN(in.size(), [=] __device__(size_t i) {
184     size_t col = i / m;
185     size_t row = i % m;
186     size_t idx = row * n + col;
187     out[i] = in[idx];
188   });
189 }
190 
191 /**
192  * Last index of a group in a CSR style of index pointer.
193  */
194 template <typename Idx>
LastOf(size_t group,common::Span<Idx> indptr)195 XGBOOST_DEVICE size_t LastOf(size_t group, common::Span<Idx> indptr) {
196   return indptr[group + 1] - 1;
197 }
198 
199 
ScaleClasses(common::Span<float> results,common::Span<float> local_area,common::Span<float> fp,common::Span<float> tp,common::Span<float> auc,std::shared_ptr<DeviceAUCCache> cache,size_t n_classes)200 float ScaleClasses(common::Span<float> results, common::Span<float> local_area,
201                    common::Span<float> fp, common::Span<float> tp,
202                    common::Span<float> auc, std::shared_ptr<DeviceAUCCache> cache,
203                    size_t n_classes) {
204   dh::XGBDeviceAllocator<char> alloc;
205   if (rabit::IsDistributed()) {
206     CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice());
207     cache->reducer->AllReduceSum(results.data(), results.data(), results.size());
208   }
209   auto reduce_in = dh::MakeTransformIterator<thrust::pair<float, float>>(
210       thrust::make_counting_iterator(0), [=] __device__(size_t i) {
211         if (local_area[i] > 0) {
212           return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]);
213         }
214         return thrust::make_pair(std::numeric_limits<float>::quiet_NaN(), 0.0f);
215       });
216 
217   float tp_sum;
218   float auc_sum;
219   thrust::tie(auc_sum, tp_sum) = thrust::reduce(
220       thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
221       thrust::make_pair(0.0f, 0.0f),
222       [=] __device__(auto const &l, auto const &r) {
223         return thrust::make_pair(l.first + r.first, l.second + r.second);
224       });
225   if (tp_sum != 0 && !std::isnan(auc_sum)) {
226     auc_sum /= tp_sum;
227   } else {
228     return std::numeric_limits<float>::quiet_NaN();
229   }
230   return auc_sum;
231 }
232 
233 /**
234  * MultiClass implementation is similar to binary classification, except we need to split
235  * up each class in all kernels.
236  */
GPUMultiClassAUCOVR(common::Span<float const> predts,MetaInfo const & info,int32_t device,std::shared_ptr<DeviceAUCCache> * p_cache,size_t n_classes)237 float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info,
238                           int32_t device, std::shared_ptr<DeviceAUCCache>* p_cache,
239                           size_t n_classes) {
240   dh::safe_cuda(cudaSetDevice(device));
241   auto& cache = *p_cache;
242   if (!cache) {
243     cache.reset(new DeviceAUCCache);
244   }
245   cache->Init(predts, true, device);
246 
247   auto labels = info.labels_.ConstDeviceSpan();
248   auto weights = info.weights_.ConstDeviceSpan();
249 
250   size_t n_samples = labels.size();
251 
252   if (n_samples == 0) {
253     dh::TemporaryArray<float> resutls(n_classes * 4, 0.0f);
254     auto d_results = dh::ToSpan(resutls);
255     dh::LaunchN(n_classes * 4,
256                 [=] __device__(size_t i) { d_results[i] = 0.0f; });
257     auto local_area = d_results.subspan(0, n_classes);
258     auto fp = d_results.subspan(n_classes, n_classes);
259     auto tp = d_results.subspan(2 * n_classes, n_classes);
260     auto auc = d_results.subspan(3 * n_classes, n_classes);
261     return ScaleClasses(d_results, local_area, fp, tp, auc, cache, n_classes);
262   }
263 
264   /**
265    * Create sorted index for each class
266    */
267   auto d_predts_t = dh::ToSpan(cache->predts_t);
268   Transpose(predts, d_predts_t, n_samples, n_classes, device);
269 
270   dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
271   auto d_class_ptr = dh::ToSpan(class_ptr);
272   dh::LaunchN(n_classes + 1,
273               [=] __device__(size_t i) { d_class_ptr[i] = i * n_samples; });
274   // no out-of-place sort for thrust, cub sort doesn't accept general iterator. So can't
275   // use transform iterator in sorting.
276   auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
277   dh::SegmentedArgSort<false>(d_predts_t, d_class_ptr, d_sorted_idx);
278 
279   /**
280    * Linear scan
281    */
282   dh::caching_device_vector<float> d_auc(n_classes, 0);
283   auto s_d_auc = dh::ToSpan(d_auc);
284   auto get_weight = GetWeightOp{weights, d_sorted_idx};
285   using Pair = thrust::pair<float, float>;
286   auto d_fptp = dh::ToSpan(cache->fptp);
287   auto get_fp_tp = [=]__device__(size_t i) {
288     size_t idx = d_sorted_idx[i];
289 
290     size_t class_id = i / n_samples;
291     // labels is a vector of size n_samples.
292     float label = labels[idx % n_samples] == class_id;
293 
294     float w = weights.empty() ? 1.0f : weights[d_sorted_idx[i] % n_samples];
295     float fp = (1.0 - label) * w;
296     float tp = label * w;
297     return thrust::make_pair(fp, tp);
298   };  // NOLINT
299   dh::LaunchN(d_sorted_idx.size(),
300               [=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); });
301 
302   /**
303    *  Handle duplicated predictions
304    */
305   dh::XGBDeviceAllocator<char> alloc;
306   auto d_unique_idx = dh::ToSpan(cache->unique_idx);
307   dh::Iota(d_unique_idx);
308   auto uni_key = dh::MakeTransformIterator<thrust::pair<uint32_t, float>>(
309       thrust::make_counting_iterator(0), [=] __device__(size_t i) {
310         uint32_t class_id = i / n_samples;
311         float predt = d_predts_t[d_sorted_idx[i]];
312         return thrust::make_pair(class_id, predt);
313       });
314 
315   // unique values are sparse, so we need a CSR style indptr
316   dh::TemporaryArray<uint32_t> unique_class_ptr(class_ptr.size());
317   auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr);
318   auto n_uniques = dh::SegmentedUniqueByKey(
319       thrust::cuda::par(alloc),
320       dh::tbegin(d_class_ptr),
321       dh::tend(d_class_ptr),
322       uni_key,
323       uni_key + d_sorted_idx.size(),
324       dh::tbegin(d_unique_idx),
325       d_unique_class_ptr.data(),
326       dh::tbegin(d_unique_idx),
327       thrust::equal_to<thrust::pair<uint32_t, float>>{});
328   d_unique_idx = d_unique_idx.subspan(0, n_uniques);
329 
330   using Triple = thrust::tuple<uint32_t, float, float>;
331   // expand to tuple to include class id
332   auto fptp_it_in = dh::MakeTransformIterator<Triple>(
333       thrust::make_counting_iterator(0), [=] __device__(size_t i) {
334         return thrust::make_tuple(i, d_fptp[i].first, d_fptp[i].second);
335       });
336   // shrink down to pair
337   auto fptp_it_out = thrust::make_transform_output_iterator(
338       dh::TypedDiscard<Triple>{}, [d_fptp] __device__(Triple const &t) {
339         d_fptp[thrust::get<0>(t)] =
340             thrust::make_pair(thrust::get<1>(t), thrust::get<2>(t));
341         return t;
342       });
343   dh::InclusiveScan(
344       fptp_it_in, fptp_it_out,
345       [=] __device__(Triple const &l, Triple const &r) {
346         uint32_t l_cid = thrust::get<0>(l) / n_samples;
347         uint32_t r_cid = thrust::get<0>(r) / n_samples;
348         if (l_cid != r_cid) {
349           return r;
350         }
351 
352         return Triple(thrust::get<0>(r),
353                       thrust::get<1>(l) + thrust::get<1>(r),   // fp
354                       thrust::get<2>(l) + thrust::get<2>(r));  // tp
355       },
356       d_fptp.size());
357 
358   // scatter unique FP_PREV/TP_PREV values
359   auto d_neg_pos = dh::ToSpan(cache->neg_pos);
360   // When dataset is not empty, each class must have at least 1 (unique) sample
361   // prediction, so no need to handle special case.
362   dh::LaunchN(d_unique_idx.size(), [=] __device__(size_t i) {
363     if (d_unique_idx[i] % n_samples == 0) {  // first unique index is 0
364       assert(d_unique_idx[i] % n_samples == 0);
365       d_neg_pos[d_unique_idx[i]] = {0, 0};   // class_id * n_samples = i
366       return;
367     }
368     uint32_t class_id = d_unique_idx[i] / n_samples;
369     d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1];
370     if (i == LastOf(class_id, d_unique_class_ptr)) {
371       // last one needs to be included.
372       size_t last = d_unique_idx[LastOf(class_id, d_unique_class_ptr)];
373       d_neg_pos[LastOf(class_id, d_class_ptr)] = d_fptp[last - 1];
374       return;
375     }
376   });
377 
378   /**
379    * Reduce the result for each class
380    */
381   auto key_in = dh::MakeTransformIterator<uint32_t>(
382       thrust::make_counting_iterator(0), [=] __device__(size_t i) {
383         size_t class_id = d_unique_idx[i] / n_samples;
384         return class_id;
385       });
386   auto val_in = dh::MakeTransformIterator<float>(
387       thrust::make_counting_iterator(0), [=] __device__(size_t i) {
388         size_t class_id = d_unique_idx[i] / n_samples;
389         float fp, tp;
390         float fp_prev, tp_prev;
391         if (i == d_unique_class_ptr[class_id]) {
392           // first item is ignored, we use this thread to calculate the last item
393           thrust::tie(fp, tp) = d_fptp[class_id * n_samples + (n_samples - 1)];
394           thrust::tie(fp_prev, tp_prev) =
395               d_neg_pos[d_unique_idx[LastOf(class_id, d_unique_class_ptr)]];
396         } else {
397           thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1];
398           thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]];
399         }
400         float auc = TrapesoidArea(fp_prev, fp, tp_prev, tp);
401         return auc;
402       });
403 
404   thrust::reduce_by_key(thrust::cuda::par(alloc), key_in,
405                         key_in + d_unique_idx.size(), val_in,
406                         thrust::make_discard_iterator(), d_auc.begin());
407 
408   /**
409    * Scale the classes with number of samples for each class.
410    */
411   dh::TemporaryArray<float> resutls(n_classes * 4);
412   auto d_results = dh::ToSpan(resutls);
413   auto local_area = d_results.subspan(0, n_classes);
414   auto fp = d_results.subspan(n_classes, n_classes);
415   auto tp = d_results.subspan(2 * n_classes, n_classes);
416   auto auc = d_results.subspan(3 * n_classes, n_classes);
417 
418   dh::LaunchN(n_classes, [=] __device__(size_t c) {
419     auc[c] = s_d_auc[c];
420     auto last = d_fptp[n_samples * c + (n_samples - 1)];
421     fp[c] = last.first;
422     tp[c] = last.second;
423     local_area[c] = last.first * last.second;
424   });
425   return ScaleClasses(d_results, local_area, fp, tp, auc, cache, n_classes);
426 }
427 
428 namespace {
429 struct RankScanItem {
430   size_t idx;
431   float predt;
432   float w;
433   bst_group_t group_id;
434 };
435 }  // anonymous namespace
436 
437 std::pair<float, uint32_t>
GPURankingAUC(common::Span<float const> predts,MetaInfo const & info,int32_t device,std::shared_ptr<DeviceAUCCache> * p_cache)438 GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
439               int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
440   auto& cache = *p_cache;
441   if (!cache) {
442     cache.reset(new DeviceAUCCache);
443   }
444   cache->Init(predts, false, device);
445 
446   dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
447   dh::XGBCachingDeviceAllocator<char> alloc;
448 
449   auto d_group_ptr = dh::ToSpan(group_ptr);
450   /**
451    * Validate the dataset
452    */
453   auto check_it = dh::MakeTransformIterator<size_t>(
454       thrust::make_counting_iterator(0),
455       [=] __device__(size_t i) { return d_group_ptr[i + 1] - d_group_ptr[i]; });
456   size_t n_valid = thrust::count_if(
457       thrust::cuda::par(alloc), check_it, check_it + group_ptr.size() - 1,
458       [=] __device__(size_t len) { return len >= 3; });
459   if (n_valid < info.group_ptr_.size() - 1) {
460     InvalidGroupAUC();
461   }
462   if (n_valid == 0) {
463     return std::make_pair(0.0f, 0);
464   }
465 
466   /**
467    * Sort the labels
468    */
469   auto d_labels = info.labels_.ConstDeviceSpan();
470 
471   auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
472   dh::SegmentedArgSort<false>(d_labels, d_group_ptr, d_sorted_idx);
473 
474   auto d_weights = info.weights_.ConstDeviceSpan();
475 
476   dh::caching_device_vector<size_t> threads_group_ptr(group_ptr.size(), 0);
477   auto d_threads_group_ptr = dh::ToSpan(threads_group_ptr);
478   // Use max to represent triangle
479   auto n_threads = common::SegmentedTrapezoidThreads(
480       d_group_ptr, d_threads_group_ptr, std::numeric_limits<size_t>::max());
481   // get the coordinate in nested summation
482   auto get_i_j = [=]__device__(size_t idx, size_t query_group_idx) {
483     auto data_group_begin = d_group_ptr[query_group_idx];
484     size_t n_samples = d_group_ptr[query_group_idx + 1] - data_group_begin;
485     auto thread_group_begin = d_threads_group_ptr[query_group_idx];
486     auto idx_in_thread_group = idx - thread_group_begin;
487 
488     size_t i, j;
489     common::UnravelTrapeziodIdx(idx_in_thread_group, n_samples, &i, &j);
490     // we use global index among all groups for sorted idx, so i, j should also be global
491     // index.
492     i += data_group_begin;
493     j += data_group_begin;
494     return thrust::make_pair(i, j);
495   };  // NOLINT
496   auto in = dh::MakeTransformIterator<RankScanItem>(
497       thrust::make_counting_iterator(0), [=] __device__(size_t idx) {
498         bst_group_t query_group_idx = dh::SegmentId(d_threads_group_ptr, idx);
499         auto data_group_begin = d_group_ptr[query_group_idx];
500         size_t n_samples = d_group_ptr[query_group_idx + 1] - data_group_begin;
501         if (n_samples < 3) {
502           // at least 3 documents are required.
503           return RankScanItem{idx, 0, 0, query_group_idx};
504         }
505 
506         size_t i, j;
507         thrust::tie(i, j) = get_i_j(idx, query_group_idx);
508 
509         float predt = predts[d_sorted_idx[i]] - predts[d_sorted_idx[j]];
510         float w = common::Sqr(d_weights.empty() ? 1.0f : d_weights[query_group_idx]);
511         if (predt > 0) {
512           predt = 1.0;
513         } else if (predt == 0) {
514           predt = 0.5;
515         } else {
516           predt = 0;
517         }
518         predt *= w;
519         return RankScanItem{idx, predt, w, query_group_idx};
520       });
521 
522   dh::TemporaryArray<float> d_auc(group_ptr.size() - 1);
523   auto s_d_auc = dh::ToSpan(d_auc);
524   auto out = thrust::make_transform_output_iterator(
525       dh::TypedDiscard<RankScanItem>{}, [=] __device__(RankScanItem const &item) -> RankScanItem {
526         auto group_id = item.group_id;
527         assert(group_id < d_group_ptr.size());
528         auto data_group_begin = d_group_ptr[group_id];
529         size_t n_samples = d_group_ptr[group_id + 1] - data_group_begin;
530         // last item of current group
531         if (item.idx == LastOf(group_id, d_threads_group_ptr)) {
532           if (item.w > 0) {
533             s_d_auc[group_id] = item.predt / item.w;
534           } else {
535             s_d_auc[group_id] = 0;
536           }
537         }
538         return {};  // discard
539       });
540   dh::InclusiveScan(
541       in, out,
542       [] __device__(RankScanItem const &l, RankScanItem const &r) {
543         if (l.group_id != r.group_id) {
544           return r;
545         }
546         return RankScanItem{r.idx, l.predt + r.predt, l.w + r.w, l.group_id};
547       },
548       n_threads);
549 
550   /**
551    * Scale the AUC with number of items in each group.
552    */
553   float auc = thrust::reduce(thrust::cuda::par(alloc), dh::tbegin(s_d_auc),
554                              dh::tend(s_d_auc), 0.0f);
555   return std::make_pair(auc, n_valid);
556 }
557 }  // namespace metric
558 }  // namespace xgboost
559