1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.ml.param
19
20import java.io.{ByteArrayOutputStream, ObjectOutputStream}
21
22import org.apache.spark.SparkFunSuite
23import org.apache.spark.ml.linalg.{Vector, Vectors}
24import org.apache.spark.ml.util.MyParams
25
26class ParamsSuite extends SparkFunSuite {
27
28  test("json encode/decode") {
29    val dummy = new Params {
30      override def copy(extra: ParamMap): Params = defaultCopy(extra)
31
32      override val uid: String = "dummy"
33    }
34
35    { // BooleanParam
36      val param = new BooleanParam(dummy, "name", "doc")
37      for (value <- Seq(true, false)) {
38        val json = param.jsonEncode(value)
39        assert(param.jsonDecode(json) === value)
40      }
41    }
42
43    { // IntParam
44      val param = new IntParam(dummy, "name", "doc")
45      for (value <- Seq(Int.MinValue, -1, 0, 1, Int.MaxValue)) {
46        val json = param.jsonEncode(value)
47        assert(param.jsonDecode(json) === value)
48      }
49    }
50
51    { // LongParam
52      val param = new LongParam(dummy, "name", "doc")
53      for (value <- Seq(Long.MinValue, -1L, 0L, 1L, Long.MaxValue)) {
54        val json = param.jsonEncode(value)
55        assert(param.jsonDecode(json) === value)
56      }
57    }
58
59    { // FloatParam
60      val param = new FloatParam(dummy, "name", "doc")
61      for (value <- Seq(Float.NaN, Float.NegativeInfinity, Float.MinValue, -1.0f, -0.5f, 0.0f,
62        Float.MinPositiveValue, 0.5f, 1.0f, Float.MaxValue, Float.PositiveInfinity)) {
63        val json = param.jsonEncode(value)
64        val decoded = param.jsonDecode(json)
65        if (value.isNaN) {
66          assert(decoded.isNaN)
67        } else {
68          assert(decoded === value)
69        }
70      }
71    }
72
73    { // DoubleParam
74      val param = new DoubleParam(dummy, "name", "doc")
75      for (value <- Seq(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, -0.5, 0.0,
76          Double.MinPositiveValue, 0.5, 1.0, Double.MaxValue, Double.PositiveInfinity)) {
77        val json = param.jsonEncode(value)
78        val decoded = param.jsonDecode(json)
79        if (value.isNaN) {
80          assert(decoded.isNaN)
81        } else {
82          assert(decoded === value)
83        }
84      }
85    }
86
87    { // Param[String]
88      val param = new Param[String](dummy, "name", "doc")
89      // Currently we do not support null.
90      for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) {
91        val json = param.jsonEncode(value)
92        assert(param.jsonDecode(json) === value)
93      }
94    }
95
96    { // Param[Vector]
97      val param = new Param[Vector](dummy, "name", "doc")
98      val values = Seq(
99        Vectors.dense(Array.empty[Double]),
100        Vectors.dense(0.0, 2.0),
101        Vectors.sparse(0, Array.empty, Array.empty),
102        Vectors.sparse(2, Array(1), Array(2.0)))
103      for (value <- values) {
104        val json = param.jsonEncode(value)
105        assert(param.jsonDecode(json) === value)
106      }
107    }
108
109    { // IntArrayParam
110      val param = new IntArrayParam(dummy, "name", "doc")
111      val values: Seq[Array[Int]] = Seq(
112        Array(),
113        Array(1),
114        Array(Int.MinValue, 0, Int.MaxValue))
115      for (value <- values) {
116        val json = param.jsonEncode(value)
117        assert(param.jsonDecode(json) === value)
118      }
119    }
120
121    { // DoubleArrayParam
122      val param = new DoubleArrayParam(dummy, "name", "doc")
123      val values: Seq[Array[Double]] = Seq(
124         Array(),
125         Array(1.0),
126         Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
127           Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity))
128      for (value <- values) {
129        val json = param.jsonEncode(value)
130        val decoded = param.jsonDecode(json)
131        assert(decoded.length === value.length)
132        decoded.zip(value).foreach { case (actual, expected) =>
133          if (expected.isNaN) {
134            assert(actual.isNaN)
135          } else {
136            assert(actual === expected)
137          }
138        }
139      }
140    }
141
142    { // StringArrayParam
143      val param = new StringArrayParam(dummy, "name", "doc")
144      val values: Seq[Array[String]] = Seq(
145        Array(),
146        Array(""),
147        Array("", "1", "abc", "quote\"", "newline\n"))
148      for (value <- values) {
149        val json = param.jsonEncode(value)
150        assert(param.jsonDecode(json) === value)
151      }
152    }
153  }
154
155  test("param") {
156    val solver = new TestParams()
157    val uid = solver.uid
158    import solver.{inputCol, maxIter}
159
160    assert(maxIter.name === "maxIter")
161    assert(maxIter.doc === "maximum number of iterations (>= 0)")
162    assert(maxIter.parent === uid)
163    assert(maxIter.toString === s"${uid}__maxIter")
164    assert(!maxIter.isValid(-1))
165    assert(maxIter.isValid(0))
166    assert(maxIter.isValid(1))
167
168    solver.setMaxIter(5)
169    assert(solver.explainParam(maxIter) ===
170      "maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)")
171
172    assert(inputCol.toString === s"${uid}__inputCol")
173
174    intercept[java.util.NoSuchElementException] {
175      solver.getOrDefault(solver.handleInvalid)
176    }
177
178    intercept[IllegalArgumentException] {
179      solver.setMaxIter(-1)
180    }
181  }
182
183  test("param pair") {
184    val solver = new TestParams()
185    import solver.maxIter
186
187    val pair0 = maxIter -> 5
188    val pair1 = maxIter.w(5)
189    val pair2 = ParamPair(maxIter, 5)
190    for (pair <- Seq(pair0, pair1, pair2)) {
191      assert(pair.param.eq(maxIter))
192      assert(pair.value === 5)
193    }
194    intercept[IllegalArgumentException] {
195      val pair = maxIter -> -1
196    }
197  }
198
199  test("param map") {
200    val solver = new TestParams()
201    import solver.{inputCol, maxIter}
202
203    val map0 = ParamMap.empty
204
205    assert(!map0.contains(maxIter))
206    map0.put(maxIter, 10)
207    assert(map0.contains(maxIter))
208    assert(map0(maxIter) === 10)
209    intercept[IllegalArgumentException] {
210      map0.put(maxIter, -1)
211    }
212
213    assert(!map0.contains(inputCol))
214    intercept[NoSuchElementException] {
215      map0(inputCol)
216    }
217    map0.put(inputCol -> "input")
218    assert(map0.contains(inputCol))
219    assert(map0(inputCol) === "input")
220
221    val map1 = map0.copy
222    val map2 = ParamMap(maxIter -> 10, inputCol -> "input")
223    val map3 = new ParamMap()
224      .put(maxIter, 10)
225      .put(inputCol, "input")
226    val map4 = ParamMap.empty ++ map0
227    val map5 = ParamMap.empty
228    map5 ++= map0
229
230    for (m <- Seq(map1, map2, map3, map4, map5)) {
231      assert(m.contains(maxIter))
232      assert(m(maxIter) === 10)
233      assert(m.contains(inputCol))
234      assert(m(inputCol) === "input")
235    }
236  }
237
238  test("params") {
239    val solver = new TestParams()
240    import solver.{handleInvalid, inputCol, maxIter}
241
242    val params = solver.params
243    assert(params.length === 3)
244    assert(params(0).eq(handleInvalid), "params must be ordered by name")
245    assert(params(1).eq(inputCol), "params must be ordered by name")
246    assert(params(2).eq(maxIter))
247
248    assert(!solver.isSet(maxIter))
249    assert(solver.isDefined(maxIter))
250    assert(solver.getMaxIter === 10)
251    solver.setMaxIter(100)
252    assert(solver.isSet(maxIter))
253    assert(solver.getMaxIter === 100)
254    assert(!solver.isSet(inputCol))
255    assert(!solver.isDefined(inputCol))
256    intercept[NoSuchElementException](solver.getInputCol)
257
258    assert(solver.explainParam(maxIter) ===
259      "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)")
260    assert(solver.explainParams() ===
261      Seq(handleInvalid, inputCol, maxIter).map(solver.explainParam).mkString("\n"))
262
263    assert(solver.getParam("inputCol").eq(inputCol))
264    assert(solver.getParam("maxIter").eq(maxIter))
265    assert(solver.hasParam("inputCol"))
266    assert(!solver.hasParam("abc"))
267    intercept[NoSuchElementException] {
268      solver.getParam("abc")
269    }
270
271    solver.setInputCol("input")
272    assert(solver.isSet(inputCol))
273    assert(solver.isDefined(inputCol))
274    assert(solver.getInputCol === "input")
275    intercept[IllegalArgumentException] {
276      ParamMap(maxIter -> -10)
277    }
278    intercept[IllegalArgumentException] {
279      solver.setMaxIter(-10)
280    }
281
282    solver.clearMaxIter()
283    assert(!solver.isSet(maxIter))
284
285    // Re-set and clear maxIter using the generic clear API
286    solver.setMaxIter(10)
287    solver.clear(maxIter)
288    assert(!solver.isSet(maxIter))
289
290    val copied = solver.copy(ParamMap(solver.maxIter -> 50))
291    assert(copied.uid === solver.uid)
292    assert(copied.getInputCol === solver.getInputCol)
293    assert(copied.getMaxIter === 50)
294  }
295
296  test("ParamValidate") {
297    val alwaysTrue = ParamValidators.alwaysTrue[Int]
298    assert(alwaysTrue(1))
299
300    val gt1Int = ParamValidators.gt[Int](1)
301    assert(!gt1Int(1) && gt1Int(2))
302    val gt1Double = ParamValidators.gt[Double](1)
303    assert(!gt1Double(1.0) && gt1Double(1.1))
304
305    val gtEq1Int = ParamValidators.gtEq[Int](1)
306    assert(!gtEq1Int(0) && gtEq1Int(1))
307    val gtEq1Double = ParamValidators.gtEq[Double](1)
308    assert(!gtEq1Double(0.9) && gtEq1Double(1.0))
309
310    val lt1Int = ParamValidators.lt[Int](1)
311    assert(lt1Int(0) && !lt1Int(1))
312    val lt1Double = ParamValidators.lt[Double](1)
313    assert(lt1Double(0.9) && !lt1Double(1.0))
314
315    val ltEq1Int = ParamValidators.ltEq[Int](1)
316    assert(ltEq1Int(1) && !ltEq1Int(2))
317    val ltEq1Double = ParamValidators.ltEq[Double](1)
318    assert(ltEq1Double(1.0) && !ltEq1Double(1.1))
319
320    val inRange02IntInclusive = ParamValidators.inRange[Int](0, 2)
321    assert(inRange02IntInclusive(0) && inRange02IntInclusive(1) && inRange02IntInclusive(2) &&
322      !inRange02IntInclusive(-1) && !inRange02IntInclusive(3))
323    val inRange02IntExclusive =
324      ParamValidators.inRange[Int](0, 2, lowerInclusive = false, upperInclusive = false)
325    assert(!inRange02IntExclusive(0) && inRange02IntExclusive(1) && !inRange02IntExclusive(2))
326
327    val inRange02DoubleInclusive = ParamValidators.inRange[Double](0, 2)
328    assert(inRange02DoubleInclusive(0) && inRange02DoubleInclusive(1) &&
329      inRange02DoubleInclusive(2) &&
330      !inRange02DoubleInclusive(-0.1) && !inRange02DoubleInclusive(2.1))
331    val inRange02DoubleExclusive =
332      ParamValidators.inRange[Double](0, 2, lowerInclusive = false, upperInclusive = false)
333    assert(!inRange02DoubleExclusive(0) && inRange02DoubleExclusive(1) &&
334      !inRange02DoubleExclusive(2))
335
336    val inArray = ParamValidators.inArray[Int](Array(1, 2))
337    assert(inArray(1) && inArray(2) && !inArray(0))
338
339    val arrayLengthGt = ParamValidators.arrayLengthGt[Int](2.0)
340    assert(arrayLengthGt(Array(0, 1, 2)) && !arrayLengthGt(Array(0, 1)))
341  }
342
343  test("Params.copyValues") {
344    val t = new TestParams()
345    val t2 = t.copy(ParamMap.empty)
346    assert(!t2.isSet(t2.maxIter))
347    val t3 = t.copy(ParamMap(t.maxIter -> 20))
348    assert(t3.isSet(t3.maxIter))
349  }
350
351  test("Filtering ParamMap") {
352    val params1 = new MyParams("my_params1")
353    val params2 = new MyParams("my_params2")
354    val paramMap = ParamMap(
355      params1.intParam -> 1,
356      params2.intParam -> 1,
357      params1.doubleParam -> 0.2,
358      params2.doubleParam -> 0.2)
359    val filteredParamMap = paramMap.filter(params1)
360
361    assert(filteredParamMap.size === 2)
362    filteredParamMap.toSeq.foreach {
363      case ParamPair(p, _) =>
364        assert(p.parent === params1.uid)
365    }
366
367    // At the previous implementation of ParamMap#filter,
368    // mutable.Map#filterKeys was used internally but
369    // the return type of the method is not serializable (see SI-6654).
370    // Now mutable.Map#filter is used instead of filterKeys and the return type is serializable.
371    // So let's ensure serializability.
372    val objOut = new ObjectOutputStream(new ByteArrayOutputStream())
373    objOut.writeObject(filteredParamMap)
374  }
375}
376
377object ParamsSuite extends SparkFunSuite {
378
379  /**
380   * Checks common requirements for [[Params.params]]:
381   *   - params are ordered by names
382   *   - param parent has the same UID as the object's UID
383   *   - param name is the same as the param method name
384   *   - obj.copy should return the same type as the obj
385   */
386  def checkParams(obj: Params): Unit = {
387    val clazz = obj.getClass
388
389    val params = obj.params
390    val paramNames = params.map(_.name)
391    require(paramNames === paramNames.sorted, "params must be ordered by names")
392    params.foreach { p =>
393      assert(p.parent === obj.uid)
394      assert(obj.getParam(p.name) === p)
395      // TODO: Check that setters return self, which needs special handling for generic types.
396    }
397
398    val copyMethod = clazz.getMethod("copy", classOf[ParamMap])
399    val copyReturnType = copyMethod.getReturnType
400    require(copyReturnType === obj.getClass,
401      s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.")
402  }
403}
404