1 /*!
2  * Copyright 2020-2021 by XGBoost Contributors
3  * \file categorical.h
4  */
5 #ifndef XGBOOST_COMMON_CATEGORICAL_H_
6 #define XGBOOST_COMMON_CATEGORICAL_H_
7 
8 #include "xgboost/base.h"
9 #include "xgboost/data.h"
10 #include "xgboost/span.h"
11 #include "xgboost/parameter.h"
12 #include "bitfield.h"
13 
14 namespace xgboost {
15 namespace common {
16 // Cast the categorical type.
17 template <typename T>
AsCat(T const & v)18 XGBOOST_DEVICE bst_cat_t AsCat(T const& v) {
19   return static_cast<bst_cat_t>(v);
20 }
21 
22 /* \brief Whether is fidx a categorical feature.
23  *
24  * \param ft   Feature type for all features.
25  * \param fidx Feature index.
26  * \return Whether feature pointed by fidx is categorical feature.
27  */
IsCat(Span<FeatureType const> ft,bst_feature_t fidx)28 inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx) {
29   return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
30 }
31 
32 /* \brief Whether should it traverse to left branch of a tree.
33  *
34  *  For one hot split, go to left if it's NOT the matching category.
35  */
Decision(common::Span<uint32_t const> cats,bst_cat_t cat)36 inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t cat) {
37   auto pos = CLBitField32::ToBitPos(cat);
38   if (pos.int_pos >= cats.size()) {
39     return true;
40   }
41   CLBitField32 const s_cats(cats);
42   return !s_cats.Check(cat);
43 }
44 
CheckCat(bst_cat_t cat)45 inline void CheckCat(bst_cat_t cat) {
46   CHECK_GE(cat, 0) << "Invalid categorical value detected.  Categorical value "
47                       "should be non-negative.";
48 }
49 
50 struct IsCatOp {
operatorIsCatOp51   XGBOOST_DEVICE bool operator()(FeatureType ft) {
52     return ft == FeatureType::kCategorical;
53   }
54 };
55 
56 using CatBitField = LBitField32;
57 using KCatBitField = CLBitField32;
58 }  // namespace common
59 }  // namespace xgboost
60 
61 #endif  // XGBOOST_COMMON_CATEGORICAL_H_
62