1/* 2 Copyright (c) 2014 by Contributors 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17package ml.dmlc.xgboost4j.scala.spark.params 18 19import scala.collection.immutable.HashSet 20 21import org.apache.spark.ml.param.{DoubleParam, IntParam, BooleanParam, Param, Params} 22 23private[spark] trait BoosterParams extends Params { 24 25 /** 26 * step size shrinkage used in update to prevents overfitting. After each boosting step, we 27 * can directly get the weights of new features and eta actually shrinks the feature weights 28 * to make the boosting process more conservative. [default=0.3] range: [0,1] 29 */ 30 final val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" + 31 " overfitting. After each boosting step, we can directly get the weights of new features." + 32 " and eta actually shrinks the feature weights to make the boosting process more conservative.", 33 (value: Double) => value >= 0 && value <= 1) 34 35 final def getEta: Double = $(eta) 36 37 /** 38 * minimum loss reduction required to make a further partition on a leaf node of the tree. 39 * the larger, the more conservative the algorithm will be. [default=0] range: [0, 40 * Double.MaxValue] 41 */ 42 final val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a " + 43 "further partition on a leaf node of the tree. the larger, the more conservative the " + 44 "algorithm will be.", (value: Double) => value >= 0) 45 46 final def getGamma: Double = $(gamma) 47 48 /** 49 * maximum depth of a tree, increase this value will make model more complex / likely to be 50 * overfitting. [default=6] range: [1, Int.MaxValue] 51 */ 52 final val maxDepth = new IntParam(this, "maxDepth", "maximum depth of a tree, increase this " + 53 "value will make model more complex/likely to be overfitting.", (value: Int) => value >= 0) 54 55 final def getMaxDepth: Int = $(maxDepth) 56 57 58 /** 59 * Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set. 60 */ 61 final val maxLeaves = new IntParam(this, "maxLeaves", 62 "Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.", 63 (value: Int) => value >= 0) 64 65 final def getMaxLeaves: Int = $(maxLeaves) 66 67 68 /** 69 * minimum sum of instance weight(hessian) needed in a child. If the tree partition step results 70 * in a leaf node with the sum of instance weight less than min_child_weight, then the building 71 * process will give up further partitioning. In linear regression mode, this simply corresponds 72 * to minimum number of instances needed to be in each node. The larger, the more conservative 73 * the algorithm will be. [default=1] range: [0, Double.MaxValue] 74 */ 75 final val minChildWeight = new DoubleParam(this, "minChildWeight", "minimum sum of instance" + 76 " weight(hessian) needed in a child. If the tree partition step results in a leaf node with" + 77 " the sum of instance weight less than min_child_weight, then the building process will" + 78 " give up further partitioning. In linear regression mode, this simply corresponds to minimum" + 79 " number of instances needed to be in each node. The larger, the more conservative" + 80 " the algorithm will be.", (value: Double) => value >= 0) 81 82 final def getMinChildWeight: Double = $(minChildWeight) 83 84 /** 85 * Maximum delta step we allow each tree's weight estimation to be. If the value is set to 0, it 86 * means there is no constraint. If it is set to a positive value, it can help making the update 87 * step more conservative. Usually this parameter is not needed, but it might help in logistic 88 * regression when class is extremely imbalanced. Set it to value of 1-10 might help control the 89 * update. [default=0] range: [0, Double.MaxValue] 90 */ 91 final val maxDeltaStep = new DoubleParam(this, "maxDeltaStep", "Maximum delta step we allow " + 92 "each tree's weight" + 93 " estimation to be. If the value is set to 0, it means there is no constraint. If it is set" + 94 " to a positive value, it can help making the update step more conservative. Usually this" + 95 " parameter is not needed, but it might help in logistic regression when class is extremely" + 96 " imbalanced. Set it to value of 1-10 might help control the update", 97 (value: Double) => value >= 0) 98 99 final def getMaxDeltaStep: Double = $(maxDeltaStep) 100 101 /** 102 * subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly 103 * collected half of the data instances to grow trees and this will prevent overfitting. 104 * [default=1] range:(0,1] 105 */ 106 final val subsample = new DoubleParam(this, "subsample", "subsample ratio of the training " + 107 "instance. Setting it to 0.5 means that XGBoost randomly collected half of the data " + 108 "instances to grow trees and this will prevent overfitting.", 109 (value: Double) => value <= 1 && value > 0) 110 111 final def getSubsample: Double = $(subsample) 112 113 /** 114 * subsample ratio of columns when constructing each tree. [default=1] range: (0,1] 115 */ 116 final val colsampleBytree = new DoubleParam(this, "colsampleBytree", "subsample ratio of " + 117 "columns when constructing each tree.", (value: Double) => value <= 1 && value > 0) 118 119 final def getColsampleBytree: Double = $(colsampleBytree) 120 121 /** 122 * subsample ratio of columns for each split, in each level. [default=1] range: (0,1] 123 */ 124 final val colsampleBylevel = new DoubleParam(this, "colsampleBylevel", "subsample ratio of " + 125 "columns for each split, in each level.", (value: Double) => value <= 1 && value > 0) 126 127 final def getColsampleBylevel: Double = $(colsampleBylevel) 128 129 /** 130 * L2 regularization term on weights, increase this value will make model more conservative. 131 * [default=1] 132 */ 133 final val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, " + 134 "increase this value will make model more conservative.", (value: Double) => value >= 0) 135 136 final def getLambda: Double = $(lambda) 137 138 /** 139 * L1 regularization term on weights, increase this value will make model more conservative. 140 * [default=0] 141 */ 142 final val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase " + 143 "this value will make model more conservative.", (value: Double) => value >= 0) 144 145 final def getAlpha: Double = $(alpha) 146 147 /** 148 * The tree construction algorithm used in XGBoost. options: 149 * {'auto', 'exact', 'approx','gpu_hist'} [default='auto'] 150 */ 151 final val treeMethod = new Param[String](this, "treeMethod", 152 "The tree construction algorithm used in XGBoost, options: " + 153 "{'auto', 'exact', 'approx', 'hist', 'gpu_hist'}", 154 (value: String) => BoosterParams.supportedTreeMethods.contains(value)) 155 156 final def getTreeMethod: String = $(treeMethod) 157 158 /** 159 * growth policy for fast histogram algorithm 160 */ 161 final val growPolicy = new Param[String](this, "growPolicy", 162 "Controls a way new nodes are added to the tree. Currently supported only if" + 163 " tree_method is set to hist. Choices: depthwise, lossguide. depthwise: split at nodes" + 164 " closest to the root. lossguide: split at nodes with highest loss change.", 165 (value: String) => BoosterParams.supportedGrowthPolicies.contains(value)) 166 167 final def getGrowPolicy: String = $(growPolicy) 168 169 /** 170 * maximum number of bins in histogram 171 */ 172 final val maxBins = new IntParam(this, "maxBin", "maximum number of bins in histogram", 173 (value: Int) => value > 0) 174 175 final def getMaxBins: Int = $(maxBins) 176 177 /** 178 * whether to build histograms using single precision floating point values 179 */ 180 final val singlePrecisionHistogram = new BooleanParam(this, "singlePrecisionHistogram", 181 "whether to use single precision to build histograms") 182 183 final def getSinglePrecisionHistogram: Boolean = $(singlePrecisionHistogram) 184 185 /** 186 * This is only used for approximate greedy algorithm. 187 * This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select 188 * number of bins, this comes with theoretical guarantee with sketch accuracy. 189 * [default=0.03] range: (0, 1) 190 */ 191 final val sketchEps = new DoubleParam(this, "sketchEps", 192 "This is only used for approximate greedy algorithm. This roughly translated into" + 193 " O(1 / sketch_eps) number of bins. Compared to directly select number of bins, this comes" + 194 " with theoretical guarantee with sketch accuracy.", 195 (value: Double) => value < 1 && value > 0) 196 197 final def getSketchEps: Double = $(sketchEps) 198 199 /** 200 * Control the balance of positive and negative weights, useful for unbalanced classes. A typical 201 * value to consider: sum(negative cases) / sum(positive cases). [default=1] 202 */ 203 final val scalePosWeight = new DoubleParam(this, "scalePosWeight", "Control the balance of " + 204 "positive and negative weights, useful for unbalanced classes. A typical value to consider:" + 205 " sum(negative cases) / sum(positive cases)") 206 207 final def getScalePosWeight: Double = $(scalePosWeight) 208 209 // Dart boosters 210 211 /** 212 * Parameter for Dart booster. 213 * Type of sampling algorithm. "uniform": dropped trees are selected uniformly. 214 * "weighted": dropped trees are selected in proportion to weight. [default="uniform"] 215 */ 216 final val sampleType = new Param[String](this, "sampleType", "type of sampling algorithm, " + 217 "options: {'uniform', 'weighted'}", 218 (value: String) => BoosterParams.supportedSampleType.contains(value)) 219 220 final def getSampleType: String = $(sampleType) 221 222 /** 223 * Parameter of Dart booster. 224 * type of normalization algorithm, options: {'tree', 'forest'}. [default="tree"] 225 */ 226 final val normalizeType = new Param[String](this, "normalizeType", "type of normalization" + 227 " algorithm, options: {'tree', 'forest'}", 228 (value: String) => BoosterParams.supportedNormalizeType.contains(value)) 229 230 final def getNormalizeType: String = $(normalizeType) 231 232 /** 233 * Parameter of Dart booster. 234 * dropout rate. [default=0.0] range: [0.0, 1.0] 235 */ 236 final val rateDrop = new DoubleParam(this, "rateDrop", "dropout rate", (value: Double) => 237 value >= 0 && value <= 1) 238 239 final def getRateDrop: Double = $(rateDrop) 240 241 /** 242 * Parameter of Dart booster. 243 * probability of skip dropout. If a dropout is skipped, new trees are added in the same manner 244 * as gbtree. [default=0.0] range: [0.0, 1.0] 245 */ 246 final val skipDrop = new DoubleParam(this, "skipDrop", "probability of skip dropout. If" + 247 " a dropout is skipped, new trees are added in the same manner as gbtree.", 248 (value: Double) => value >= 0 && value <= 1) 249 250 final def getSkipDrop: Double = $(skipDrop) 251 252 // linear booster 253 /** 254 * Parameter of linear booster 255 * L2 regularization term on bias, default 0(no L1 reg on bias because it is not important) 256 */ 257 final val lambdaBias = new DoubleParam(this, "lambdaBias", "L2 regularization term on bias, " + 258 "default 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0) 259 260 final def getLambdaBias: Double = $(lambdaBias) 261 262 final val treeLimit = new IntParam(this, name = "treeLimit", 263 doc = "number of trees used in the prediction; defaults to 0 (use all trees).") 264 265 final def getTreeLimit: Int = $(treeLimit) 266 267 final val monotoneConstraints = new Param[String](this, name = "monotoneConstraints", 268 doc = "a list in length of number of features, 1 indicate monotonic increasing, - 1 means " + 269 "decreasing, 0 means no constraint. If it is shorter than number of features, 0 will be " + 270 "padded ") 271 272 final def getMonotoneConstraints: String = $(monotoneConstraints) 273 274 final val interactionConstraints = new Param[String](this, 275 name = "interactionConstraints", 276 doc = "Constraints for interaction representing permitted interactions. The constraints" + 277 " must be specified in the form of a nest list, e.g. [[0, 1], [2, 3, 4]]," + 278 " where each inner list is a group of indices of features that are allowed to interact" + 279 " with each other. See tutorial for more information") 280 281 final def getInteractionConstraints: String = $(interactionConstraints) 282 283 setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6, 284 minChildWeight -> 1, maxDeltaStep -> 0, 285 growPolicy -> "depthwise", maxBins -> 256, 286 subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1, 287 lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03, 288 scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree", 289 rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0) 290} 291 292private[spark] object BoosterParams { 293 294 val supportedBoosters = HashSet("gbtree", "gblinear", "dart") 295 296 val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist", "gpu_hist") 297 298 val supportedGrowthPolicies = HashSet("depthwise", "lossguide") 299 300 val supportedSampleType = HashSet("uniform", "weighted") 301 302 val supportedNormalizeType = HashSet("tree", "forest") 303} 304