1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.ml.feature
19
20import org.apache.spark.{SparkException, SparkFunSuite}
21import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
22import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
23import org.apache.spark.ml.param.ParamsSuite
24import org.apache.spark.ml.util.DefaultReadWriteTest
25import org.apache.spark.mllib.util.MLlibTestSparkContext
26import org.apache.spark.sql.Row
27import org.apache.spark.sql.functions.col
28
29class VectorAssemblerSuite
30  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
31
32  import testImplicits._
33
34  test("params") {
35    ParamsSuite.checkParams(new VectorAssembler)
36  }
37
38  test("assemble") {
39    import org.apache.spark.ml.feature.VectorAssembler.assemble
40    assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
41    assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0)))
42    val dv = Vectors.dense(2.0, 0.0)
43    assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
44    val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0))
45    assert(assemble(0.0, dv, 1.0, sv) ===
46      Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0)))
47    for (v <- Seq(1, "a", null)) {
48      intercept[SparkException](assemble(v))
49      intercept[SparkException](assemble(1.0, v))
50    }
51  }
52
53  test("assemble should compress vectors") {
54    import org.apache.spark.ml.feature.VectorAssembler.assemble
55    val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0))
56    assert(v1.isInstanceOf[SparseVector])
57    val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0)))
58    assert(v2.isInstanceOf[DenseVector])
59  }
60
61  test("VectorAssembler") {
62    val df = Seq(
63      (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
64    ).toDF("id", "x", "y", "name", "z", "n")
65    val assembler = new VectorAssembler()
66      .setInputCols(Array("x", "y", "z", "n"))
67      .setOutputCol("features")
68    assembler.transform(df).select("features").collect().foreach {
69      case Row(v: Vector) =>
70        assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0)))
71    }
72  }
73
74  test("transform should throw an exception in case of unsupported type") {
75    val df = Seq(("a", "b", "c")).toDF("a", "b", "c")
76    val assembler = new VectorAssembler()
77      .setInputCols(Array("a", "b", "c"))
78      .setOutputCol("features")
79    val thrown = intercept[IllegalArgumentException] {
80      assembler.transform(df)
81    }
82    assert(thrown.getMessage contains "Data type StringType is not supported")
83  }
84
85  test("ML attributes") {
86    val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari")
87    val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0)
88    val user = new AttributeGroup("user", Array(
89      NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"),
90      NumericAttribute.defaultAttr.withName("salary")))
91    val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0)))
92    val df = Seq(row).toDF("browser", "hour", "count", "user", "ad")
93      .select(
94        col("browser").as("browser", browser.toMetadata()),
95        col("hour").as("hour", hour.toMetadata()),
96        col("count"), // "count" is an integer column without ML attribute
97        col("user").as("user", user.toMetadata()),
98        col("ad")) // "ad" is a vector column without ML attribute
99    val assembler = new VectorAssembler()
100      .setInputCols(Array("browser", "hour", "count", "user", "ad"))
101      .setOutputCol("features")
102    val output = assembler.transform(df)
103    val schema = output.schema
104    val features = AttributeGroup.fromStructField(schema("features"))
105    assert(features.size === 7)
106    val browserOut = features.getAttr(0)
107    assert(browserOut === browser.withIndex(0).withName("browser"))
108    val hourOut = features.getAttr(1)
109    assert(hourOut === hour.withIndex(1).withName("hour"))
110    val countOut = features.getAttr(2)
111    assert(countOut === NumericAttribute.defaultAttr.withName("count").withIndex(2))
112    val userGenderOut = features.getAttr(3)
113    assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3))
114    val userSalaryOut = features.getAttr(4)
115    assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4))
116    assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5).withName("ad_0"))
117    assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6).withName("ad_1"))
118  }
119
120  test("read/write") {
121    val t = new VectorAssembler()
122      .setInputCols(Array("myInputCol", "myInputCol2"))
123      .setOutputCol("myOutputCol")
124    testDefaultReadWrite(t)
125  }
126}
127