1 /*! 2 * Copyright 2020-2021 by XGBoost Contributors 3 * \file categorical.h 4 */ 5 #ifndef XGBOOST_COMMON_CATEGORICAL_H_ 6 #define XGBOOST_COMMON_CATEGORICAL_H_ 7 8 #include "xgboost/base.h" 9 #include "xgboost/data.h" 10 #include "xgboost/span.h" 11 #include "xgboost/parameter.h" 12 #include "bitfield.h" 13 14 namespace xgboost { 15 namespace common { 16 // Cast the categorical type. 17 template <typename T> AsCat(T const & v)18XGBOOST_DEVICE bst_cat_t AsCat(T const& v) { 19 return static_cast<bst_cat_t>(v); 20 } 21 22 /* \brief Whether is fidx a categorical feature. 23 * 24 * \param ft Feature type for all features. 25 * \param fidx Feature index. 26 * \return Whether feature pointed by fidx is categorical feature. 27 */ IsCat(Span<FeatureType const> ft,bst_feature_t fidx)28inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx) { 29 return !ft.empty() && ft[fidx] == FeatureType::kCategorical; 30 } 31 32 /* \brief Whether should it traverse to left branch of a tree. 33 * 34 * For one hot split, go to left if it's NOT the matching category. 35 */ Decision(common::Span<uint32_t const> cats,bst_cat_t cat)36inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t cat) { 37 auto pos = CLBitField32::ToBitPos(cat); 38 if (pos.int_pos >= cats.size()) { 39 return true; 40 } 41 CLBitField32 const s_cats(cats); 42 return !s_cats.Check(cat); 43 } 44 CheckCat(bst_cat_t cat)45inline void CheckCat(bst_cat_t cat) { 46 CHECK_GE(cat, 0) << "Invalid categorical value detected. Categorical value " 47 "should be non-negative."; 48 } 49 50 struct IsCatOp { operatorIsCatOp51 XGBOOST_DEVICE bool operator()(FeatureType ft) { 52 return ft == FeatureType::kCategorical; 53 } 54 }; 55 56 using CatBitField = LBitField32; 57 using KCatBitField = CLBitField32; 58 } // namespace common 59 } // namespace xgboost 60 61 #endif // XGBOOST_COMMON_CATEGORICAL_H_ 62