1/*
2 Copyright (c) 2014 by Contributors
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17package ml.dmlc.xgboost4j.scala.spark.params
18
19import scala.collection.immutable.HashSet
20
21import org.apache.spark.ml.param.{DoubleParam, IntParam, BooleanParam, Param, Params}
22
23private[spark] trait BoosterParams extends Params {
24
25  /**
26   * step size shrinkage used in update to prevents overfitting. After each boosting step, we
27   * can directly get the weights of new features and eta actually shrinks the feature weights
28   * to make the boosting process more conservative. [default=0.3] range: [0,1]
29   */
30  final val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" +
31    " overfitting. After each boosting step, we can directly get the weights of new features." +
32    " and eta actually shrinks the feature weights to make the boosting process more conservative.",
33    (value: Double) => value >= 0 && value <= 1)
34
35  final def getEta: Double = $(eta)
36
37  /**
38   * minimum loss reduction required to make a further partition on a leaf node of the tree.
39   * the larger, the more conservative the algorithm will be. [default=0] range: [0,
40   * Double.MaxValue]
41   */
42  final val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a " +
43    "further partition on a leaf node of the tree. the larger, the more conservative the " +
44    "algorithm will be.", (value: Double) => value >= 0)
45
46  final def getGamma: Double = $(gamma)
47
48  /**
49   * maximum depth of a tree, increase this value will make model more complex / likely to be
50   * overfitting. [default=6] range: [1, Int.MaxValue]
51   */
52  final val maxDepth = new IntParam(this, "maxDepth", "maximum depth of a tree, increase this " +
53    "value will make model more complex/likely to be overfitting.", (value: Int) => value >= 0)
54
55  final def getMaxDepth: Int = $(maxDepth)
56
57
58  /**
59   * Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.
60   */
61  final val maxLeaves = new IntParam(this, "maxLeaves",
62    "Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.",
63    (value: Int) => value >= 0)
64
65  final def getMaxLeaves: Int = $(maxLeaves)
66
67
68  /**
69   * minimum sum of instance weight(hessian) needed in a child. If the tree partition step results
70   * in a leaf node with the sum of instance weight less than min_child_weight, then the building
71   * process will give up further partitioning. In linear regression mode, this simply corresponds
72   * to minimum number of instances needed to be in each node. The larger, the more conservative
73   * the algorithm will be. [default=1] range: [0, Double.MaxValue]
74   */
75  final val minChildWeight = new DoubleParam(this, "minChildWeight", "minimum sum of instance" +
76    " weight(hessian) needed in a child. If the tree partition step results in a leaf node with" +
77    " the sum of instance weight less than min_child_weight, then the building process will" +
78    " give up further partitioning. In linear regression mode, this simply corresponds to minimum" +
79    " number of instances needed to be in each node. The larger, the more conservative" +
80    " the algorithm will be.", (value: Double) => value >= 0)
81
82  final def getMinChildWeight: Double = $(minChildWeight)
83
84  /**
85   * Maximum delta step we allow each tree's weight estimation to be. If the value is set to 0, it
86   * means there is no constraint. If it is set to a positive value, it can help making the update
87   * step more conservative. Usually this parameter is not needed, but it might help in logistic
88   * regression when class is extremely imbalanced. Set it to value of 1-10 might help control the
89   * update. [default=0] range: [0, Double.MaxValue]
90   */
91  final val maxDeltaStep = new DoubleParam(this, "maxDeltaStep", "Maximum delta step we allow " +
92    "each tree's weight" +
93    " estimation to be. If the value is set to 0, it means there is no constraint. If it is set" +
94    " to a positive value, it can help making the update step more conservative. Usually this" +
95    " parameter is not needed, but it might help in logistic regression when class is extremely" +
96    " imbalanced. Set it to value of 1-10 might help control the update",
97    (value: Double) => value >= 0)
98
99  final def getMaxDeltaStep: Double = $(maxDeltaStep)
100
101  /**
102   * subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly
103   * collected half of the data instances to grow trees and this will prevent overfitting.
104   * [default=1] range:(0,1]
105   */
106  final val subsample = new DoubleParam(this, "subsample", "subsample ratio of the training " +
107    "instance. Setting it to 0.5 means that XGBoost randomly collected half of the data " +
108    "instances to grow trees and this will prevent overfitting.",
109    (value: Double) => value <= 1 && value > 0)
110
111  final def getSubsample: Double = $(subsample)
112
113  /**
114   * subsample ratio of columns when constructing each tree. [default=1] range: (0,1]
115   */
116  final val colsampleBytree = new DoubleParam(this, "colsampleBytree", "subsample ratio of " +
117    "columns when constructing each tree.", (value: Double) => value <= 1 && value > 0)
118
119  final def getColsampleBytree: Double = $(colsampleBytree)
120
121  /**
122   * subsample ratio of columns for each split, in each level. [default=1] range: (0,1]
123   */
124  final val colsampleBylevel = new DoubleParam(this, "colsampleBylevel", "subsample ratio of " +
125    "columns for each split, in each level.", (value: Double) => value <= 1 && value > 0)
126
127  final def getColsampleBylevel: Double = $(colsampleBylevel)
128
129  /**
130   * L2 regularization term on weights, increase this value will make model more conservative.
131   * [default=1]
132   */
133  final val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, " +
134    "increase this value will make model more conservative.", (value: Double) => value >= 0)
135
136  final def getLambda: Double = $(lambda)
137
138  /**
139   * L1 regularization term on weights, increase this value will make model more conservative.
140   * [default=0]
141   */
142  final val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase " +
143    "this value will make model more conservative.", (value: Double) => value >= 0)
144
145  final def getAlpha: Double = $(alpha)
146
147  /**
148   * The tree construction algorithm used in XGBoost. options:
149   * {'auto', 'exact', 'approx','gpu_hist'} [default='auto']
150   */
151  final val treeMethod = new Param[String](this, "treeMethod",
152    "The tree construction algorithm used in XGBoost, options: " +
153      "{'auto', 'exact', 'approx', 'hist', 'gpu_hist'}",
154    (value: String) => BoosterParams.supportedTreeMethods.contains(value))
155
156  final def getTreeMethod: String = $(treeMethod)
157
158  /**
159   * growth policy for fast histogram algorithm
160   */
161  final val growPolicy = new Param[String](this, "growPolicy",
162    "Controls a way new nodes are added to the tree. Currently supported only if" +
163      " tree_method is set to hist. Choices: depthwise, lossguide. depthwise: split at nodes" +
164      " closest to the root. lossguide: split at nodes with highest loss change.",
165    (value: String) => BoosterParams.supportedGrowthPolicies.contains(value))
166
167  final def getGrowPolicy: String = $(growPolicy)
168
169  /**
170   * maximum number of bins in histogram
171   */
172  final val maxBins = new IntParam(this, "maxBin", "maximum number of bins in histogram",
173    (value: Int) => value > 0)
174
175  final def getMaxBins: Int = $(maxBins)
176
177  /**
178   * whether to build histograms using single precision floating point values
179   */
180  final val singlePrecisionHistogram = new BooleanParam(this, "singlePrecisionHistogram",
181    "whether to use single precision to build histograms")
182
183  final def getSinglePrecisionHistogram: Boolean = $(singlePrecisionHistogram)
184
185  /**
186   * This is only used for approximate greedy algorithm.
187   * This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select
188   * number of bins, this comes with theoretical guarantee with sketch accuracy.
189   * [default=0.03] range: (0, 1)
190   */
191  final val sketchEps = new DoubleParam(this, "sketchEps",
192    "This is only used for approximate greedy algorithm. This roughly translated into" +
193      " O(1 / sketch_eps) number of bins. Compared to directly select number of bins, this comes" +
194      " with theoretical guarantee with sketch accuracy.",
195    (value: Double) => value < 1 && value > 0)
196
197  final def getSketchEps: Double = $(sketchEps)
198
199  /**
200   * Control the balance of positive and negative weights, useful for unbalanced classes. A typical
201   * value to consider: sum(negative cases) / sum(positive cases).   [default=1]
202   */
203  final val scalePosWeight = new DoubleParam(this, "scalePosWeight", "Control the balance of " +
204    "positive and negative weights, useful for unbalanced classes. A typical value to consider:" +
205    " sum(negative cases) / sum(positive cases)")
206
207  final def getScalePosWeight: Double = $(scalePosWeight)
208
209  // Dart boosters
210
211  /**
212   * Parameter for Dart booster.
213   * Type of sampling algorithm. "uniform": dropped trees are selected uniformly.
214   * "weighted": dropped trees are selected in proportion to weight. [default="uniform"]
215   */
216  final val sampleType = new Param[String](this, "sampleType", "type of sampling algorithm, " +
217    "options: {'uniform', 'weighted'}",
218    (value: String) => BoosterParams.supportedSampleType.contains(value))
219
220  final def getSampleType: String = $(sampleType)
221
222  /**
223   * Parameter of Dart booster.
224   * type of normalization algorithm, options: {'tree', 'forest'}. [default="tree"]
225   */
226  final val normalizeType = new Param[String](this, "normalizeType", "type of normalization" +
227    " algorithm, options: {'tree', 'forest'}",
228    (value: String) => BoosterParams.supportedNormalizeType.contains(value))
229
230  final def getNormalizeType: String = $(normalizeType)
231
232  /**
233   * Parameter of Dart booster.
234   * dropout rate. [default=0.0] range: [0.0, 1.0]
235   */
236  final val rateDrop = new DoubleParam(this, "rateDrop", "dropout rate", (value: Double) =>
237    value >= 0 && value <= 1)
238
239  final def getRateDrop: Double = $(rateDrop)
240
241  /**
242   * Parameter of Dart booster.
243   * probability of skip dropout. If a dropout is skipped, new trees are added in the same manner
244   * as gbtree. [default=0.0] range: [0.0, 1.0]
245   */
246  final val skipDrop = new DoubleParam(this, "skipDrop", "probability of skip dropout. If" +
247    " a dropout is skipped, new trees are added in the same manner as gbtree.",
248    (value: Double) => value >= 0 && value <= 1)
249
250  final def getSkipDrop: Double = $(skipDrop)
251
252  // linear booster
253  /**
254   * Parameter of linear booster
255   * L2 regularization term on bias, default 0(no L1 reg on bias because it is not important)
256   */
257  final val lambdaBias = new DoubleParam(this, "lambdaBias", "L2 regularization term on bias, " +
258    "default 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0)
259
260  final def getLambdaBias: Double = $(lambdaBias)
261
262  final val treeLimit = new IntParam(this, name = "treeLimit",
263    doc = "number of trees used in the prediction; defaults to 0 (use all trees).")
264
265  final def getTreeLimit: Int = $(treeLimit)
266
267  final val monotoneConstraints = new Param[String](this, name = "monotoneConstraints",
268    doc = "a list in length of number of features, 1 indicate monotonic increasing, - 1 means " +
269      "decreasing, 0 means no constraint. If it is shorter than number of features, 0 will be " +
270      "padded ")
271
272  final def getMonotoneConstraints: String = $(monotoneConstraints)
273
274  final val interactionConstraints = new Param[String](this,
275    name = "interactionConstraints",
276    doc = "Constraints for interaction representing permitted interactions. The constraints" +
277      " must be specified in the form of a nest list, e.g. [[0, 1], [2, 3, 4]]," +
278      " where each inner list is a group of indices of features that are allowed to interact" +
279      " with each other. See tutorial for more information")
280
281  final def getInteractionConstraints: String = $(interactionConstraints)
282
283  setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6,
284    minChildWeight -> 1, maxDeltaStep -> 0,
285    growPolicy -> "depthwise", maxBins -> 256,
286    subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
287    lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
288    scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
289    rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0)
290}
291
292private[spark] object BoosterParams {
293
294  val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
295
296  val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist", "gpu_hist")
297
298  val supportedGrowthPolicies = HashSet("depthwise", "lossguide")
299
300  val supportedSampleType = HashSet("uniform", "weighted")
301
302  val supportedNormalizeType = HashSet("tree", "forest")
303}
304