1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.ml.param 19 20import java.io.{ByteArrayOutputStream, ObjectOutputStream} 21 22import org.apache.spark.SparkFunSuite 23import org.apache.spark.ml.linalg.{Vector, Vectors} 24import org.apache.spark.ml.util.MyParams 25 26class ParamsSuite extends SparkFunSuite { 27 28 test("json encode/decode") { 29 val dummy = new Params { 30 override def copy(extra: ParamMap): Params = defaultCopy(extra) 31 32 override val uid: String = "dummy" 33 } 34 35 { // BooleanParam 36 val param = new BooleanParam(dummy, "name", "doc") 37 for (value <- Seq(true, false)) { 38 val json = param.jsonEncode(value) 39 assert(param.jsonDecode(json) === value) 40 } 41 } 42 43 { // IntParam 44 val param = new IntParam(dummy, "name", "doc") 45 for (value <- Seq(Int.MinValue, -1, 0, 1, Int.MaxValue)) { 46 val json = param.jsonEncode(value) 47 assert(param.jsonDecode(json) === value) 48 } 49 } 50 51 { // LongParam 52 val param = new LongParam(dummy, "name", "doc") 53 for (value <- Seq(Long.MinValue, -1L, 0L, 1L, Long.MaxValue)) { 54 val json = param.jsonEncode(value) 55 assert(param.jsonDecode(json) === value) 56 } 57 } 58 59 { // FloatParam 60 val param = new FloatParam(dummy, "name", "doc") 61 for (value <- Seq(Float.NaN, Float.NegativeInfinity, Float.MinValue, -1.0f, -0.5f, 0.0f, 62 Float.MinPositiveValue, 0.5f, 1.0f, Float.MaxValue, Float.PositiveInfinity)) { 63 val json = param.jsonEncode(value) 64 val decoded = param.jsonDecode(json) 65 if (value.isNaN) { 66 assert(decoded.isNaN) 67 } else { 68 assert(decoded === value) 69 } 70 } 71 } 72 73 { // DoubleParam 74 val param = new DoubleParam(dummy, "name", "doc") 75 for (value <- Seq(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, -0.5, 0.0, 76 Double.MinPositiveValue, 0.5, 1.0, Double.MaxValue, Double.PositiveInfinity)) { 77 val json = param.jsonEncode(value) 78 val decoded = param.jsonDecode(json) 79 if (value.isNaN) { 80 assert(decoded.isNaN) 81 } else { 82 assert(decoded === value) 83 } 84 } 85 } 86 87 { // Param[String] 88 val param = new Param[String](dummy, "name", "doc") 89 // Currently we do not support null. 90 for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) { 91 val json = param.jsonEncode(value) 92 assert(param.jsonDecode(json) === value) 93 } 94 } 95 96 { // Param[Vector] 97 val param = new Param[Vector](dummy, "name", "doc") 98 val values = Seq( 99 Vectors.dense(Array.empty[Double]), 100 Vectors.dense(0.0, 2.0), 101 Vectors.sparse(0, Array.empty, Array.empty), 102 Vectors.sparse(2, Array(1), Array(2.0))) 103 for (value <- values) { 104 val json = param.jsonEncode(value) 105 assert(param.jsonDecode(json) === value) 106 } 107 } 108 109 { // IntArrayParam 110 val param = new IntArrayParam(dummy, "name", "doc") 111 val values: Seq[Array[Int]] = Seq( 112 Array(), 113 Array(1), 114 Array(Int.MinValue, 0, Int.MaxValue)) 115 for (value <- values) { 116 val json = param.jsonEncode(value) 117 assert(param.jsonDecode(json) === value) 118 } 119 } 120 121 { // DoubleArrayParam 122 val param = new DoubleArrayParam(dummy, "name", "doc") 123 val values: Seq[Array[Double]] = Seq( 124 Array(), 125 Array(1.0), 126 Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0, 127 Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity)) 128 for (value <- values) { 129 val json = param.jsonEncode(value) 130 val decoded = param.jsonDecode(json) 131 assert(decoded.length === value.length) 132 decoded.zip(value).foreach { case (actual, expected) => 133 if (expected.isNaN) { 134 assert(actual.isNaN) 135 } else { 136 assert(actual === expected) 137 } 138 } 139 } 140 } 141 142 { // StringArrayParam 143 val param = new StringArrayParam(dummy, "name", "doc") 144 val values: Seq[Array[String]] = Seq( 145 Array(), 146 Array(""), 147 Array("", "1", "abc", "quote\"", "newline\n")) 148 for (value <- values) { 149 val json = param.jsonEncode(value) 150 assert(param.jsonDecode(json) === value) 151 } 152 } 153 } 154 155 test("param") { 156 val solver = new TestParams() 157 val uid = solver.uid 158 import solver.{inputCol, maxIter} 159 160 assert(maxIter.name === "maxIter") 161 assert(maxIter.doc === "maximum number of iterations (>= 0)") 162 assert(maxIter.parent === uid) 163 assert(maxIter.toString === s"${uid}__maxIter") 164 assert(!maxIter.isValid(-1)) 165 assert(maxIter.isValid(0)) 166 assert(maxIter.isValid(1)) 167 168 solver.setMaxIter(5) 169 assert(solver.explainParam(maxIter) === 170 "maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)") 171 172 assert(inputCol.toString === s"${uid}__inputCol") 173 174 intercept[java.util.NoSuchElementException] { 175 solver.getOrDefault(solver.handleInvalid) 176 } 177 178 intercept[IllegalArgumentException] { 179 solver.setMaxIter(-1) 180 } 181 } 182 183 test("param pair") { 184 val solver = new TestParams() 185 import solver.maxIter 186 187 val pair0 = maxIter -> 5 188 val pair1 = maxIter.w(5) 189 val pair2 = ParamPair(maxIter, 5) 190 for (pair <- Seq(pair0, pair1, pair2)) { 191 assert(pair.param.eq(maxIter)) 192 assert(pair.value === 5) 193 } 194 intercept[IllegalArgumentException] { 195 val pair = maxIter -> -1 196 } 197 } 198 199 test("param map") { 200 val solver = new TestParams() 201 import solver.{inputCol, maxIter} 202 203 val map0 = ParamMap.empty 204 205 assert(!map0.contains(maxIter)) 206 map0.put(maxIter, 10) 207 assert(map0.contains(maxIter)) 208 assert(map0(maxIter) === 10) 209 intercept[IllegalArgumentException] { 210 map0.put(maxIter, -1) 211 } 212 213 assert(!map0.contains(inputCol)) 214 intercept[NoSuchElementException] { 215 map0(inputCol) 216 } 217 map0.put(inputCol -> "input") 218 assert(map0.contains(inputCol)) 219 assert(map0(inputCol) === "input") 220 221 val map1 = map0.copy 222 val map2 = ParamMap(maxIter -> 10, inputCol -> "input") 223 val map3 = new ParamMap() 224 .put(maxIter, 10) 225 .put(inputCol, "input") 226 val map4 = ParamMap.empty ++ map0 227 val map5 = ParamMap.empty 228 map5 ++= map0 229 230 for (m <- Seq(map1, map2, map3, map4, map5)) { 231 assert(m.contains(maxIter)) 232 assert(m(maxIter) === 10) 233 assert(m.contains(inputCol)) 234 assert(m(inputCol) === "input") 235 } 236 } 237 238 test("params") { 239 val solver = new TestParams() 240 import solver.{handleInvalid, inputCol, maxIter} 241 242 val params = solver.params 243 assert(params.length === 3) 244 assert(params(0).eq(handleInvalid), "params must be ordered by name") 245 assert(params(1).eq(inputCol), "params must be ordered by name") 246 assert(params(2).eq(maxIter)) 247 248 assert(!solver.isSet(maxIter)) 249 assert(solver.isDefined(maxIter)) 250 assert(solver.getMaxIter === 10) 251 solver.setMaxIter(100) 252 assert(solver.isSet(maxIter)) 253 assert(solver.getMaxIter === 100) 254 assert(!solver.isSet(inputCol)) 255 assert(!solver.isDefined(inputCol)) 256 intercept[NoSuchElementException](solver.getInputCol) 257 258 assert(solver.explainParam(maxIter) === 259 "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)") 260 assert(solver.explainParams() === 261 Seq(handleInvalid, inputCol, maxIter).map(solver.explainParam).mkString("\n")) 262 263 assert(solver.getParam("inputCol").eq(inputCol)) 264 assert(solver.getParam("maxIter").eq(maxIter)) 265 assert(solver.hasParam("inputCol")) 266 assert(!solver.hasParam("abc")) 267 intercept[NoSuchElementException] { 268 solver.getParam("abc") 269 } 270 271 solver.setInputCol("input") 272 assert(solver.isSet(inputCol)) 273 assert(solver.isDefined(inputCol)) 274 assert(solver.getInputCol === "input") 275 intercept[IllegalArgumentException] { 276 ParamMap(maxIter -> -10) 277 } 278 intercept[IllegalArgumentException] { 279 solver.setMaxIter(-10) 280 } 281 282 solver.clearMaxIter() 283 assert(!solver.isSet(maxIter)) 284 285 // Re-set and clear maxIter using the generic clear API 286 solver.setMaxIter(10) 287 solver.clear(maxIter) 288 assert(!solver.isSet(maxIter)) 289 290 val copied = solver.copy(ParamMap(solver.maxIter -> 50)) 291 assert(copied.uid === solver.uid) 292 assert(copied.getInputCol === solver.getInputCol) 293 assert(copied.getMaxIter === 50) 294 } 295 296 test("ParamValidate") { 297 val alwaysTrue = ParamValidators.alwaysTrue[Int] 298 assert(alwaysTrue(1)) 299 300 val gt1Int = ParamValidators.gt[Int](1) 301 assert(!gt1Int(1) && gt1Int(2)) 302 val gt1Double = ParamValidators.gt[Double](1) 303 assert(!gt1Double(1.0) && gt1Double(1.1)) 304 305 val gtEq1Int = ParamValidators.gtEq[Int](1) 306 assert(!gtEq1Int(0) && gtEq1Int(1)) 307 val gtEq1Double = ParamValidators.gtEq[Double](1) 308 assert(!gtEq1Double(0.9) && gtEq1Double(1.0)) 309 310 val lt1Int = ParamValidators.lt[Int](1) 311 assert(lt1Int(0) && !lt1Int(1)) 312 val lt1Double = ParamValidators.lt[Double](1) 313 assert(lt1Double(0.9) && !lt1Double(1.0)) 314 315 val ltEq1Int = ParamValidators.ltEq[Int](1) 316 assert(ltEq1Int(1) && !ltEq1Int(2)) 317 val ltEq1Double = ParamValidators.ltEq[Double](1) 318 assert(ltEq1Double(1.0) && !ltEq1Double(1.1)) 319 320 val inRange02IntInclusive = ParamValidators.inRange[Int](0, 2) 321 assert(inRange02IntInclusive(0) && inRange02IntInclusive(1) && inRange02IntInclusive(2) && 322 !inRange02IntInclusive(-1) && !inRange02IntInclusive(3)) 323 val inRange02IntExclusive = 324 ParamValidators.inRange[Int](0, 2, lowerInclusive = false, upperInclusive = false) 325 assert(!inRange02IntExclusive(0) && inRange02IntExclusive(1) && !inRange02IntExclusive(2)) 326 327 val inRange02DoubleInclusive = ParamValidators.inRange[Double](0, 2) 328 assert(inRange02DoubleInclusive(0) && inRange02DoubleInclusive(1) && 329 inRange02DoubleInclusive(2) && 330 !inRange02DoubleInclusive(-0.1) && !inRange02DoubleInclusive(2.1)) 331 val inRange02DoubleExclusive = 332 ParamValidators.inRange[Double](0, 2, lowerInclusive = false, upperInclusive = false) 333 assert(!inRange02DoubleExclusive(0) && inRange02DoubleExclusive(1) && 334 !inRange02DoubleExclusive(2)) 335 336 val inArray = ParamValidators.inArray[Int](Array(1, 2)) 337 assert(inArray(1) && inArray(2) && !inArray(0)) 338 339 val arrayLengthGt = ParamValidators.arrayLengthGt[Int](2.0) 340 assert(arrayLengthGt(Array(0, 1, 2)) && !arrayLengthGt(Array(0, 1))) 341 } 342 343 test("Params.copyValues") { 344 val t = new TestParams() 345 val t2 = t.copy(ParamMap.empty) 346 assert(!t2.isSet(t2.maxIter)) 347 val t3 = t.copy(ParamMap(t.maxIter -> 20)) 348 assert(t3.isSet(t3.maxIter)) 349 } 350 351 test("Filtering ParamMap") { 352 val params1 = new MyParams("my_params1") 353 val params2 = new MyParams("my_params2") 354 val paramMap = ParamMap( 355 params1.intParam -> 1, 356 params2.intParam -> 1, 357 params1.doubleParam -> 0.2, 358 params2.doubleParam -> 0.2) 359 val filteredParamMap = paramMap.filter(params1) 360 361 assert(filteredParamMap.size === 2) 362 filteredParamMap.toSeq.foreach { 363 case ParamPair(p, _) => 364 assert(p.parent === params1.uid) 365 } 366 367 // At the previous implementation of ParamMap#filter, 368 // mutable.Map#filterKeys was used internally but 369 // the return type of the method is not serializable (see SI-6654). 370 // Now mutable.Map#filter is used instead of filterKeys and the return type is serializable. 371 // So let's ensure serializability. 372 val objOut = new ObjectOutputStream(new ByteArrayOutputStream()) 373 objOut.writeObject(filteredParamMap) 374 } 375} 376 377object ParamsSuite extends SparkFunSuite { 378 379 /** 380 * Checks common requirements for [[Params.params]]: 381 * - params are ordered by names 382 * - param parent has the same UID as the object's UID 383 * - param name is the same as the param method name 384 * - obj.copy should return the same type as the obj 385 */ 386 def checkParams(obj: Params): Unit = { 387 val clazz = obj.getClass 388 389 val params = obj.params 390 val paramNames = params.map(_.name) 391 require(paramNames === paramNames.sorted, "params must be ordered by names") 392 params.foreach { p => 393 assert(p.parent === obj.uid) 394 assert(obj.getParam(p.name) === p) 395 // TODO: Check that setters return self, which needs special handling for generic types. 396 } 397 398 val copyMethod = clazz.getMethod("copy", classOf[ParamMap]) 399 val copyReturnType = copyMethod.getReturnType 400 require(copyReturnType === obj.getClass, 401 s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.") 402 } 403} 404