1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.ml.feature 19 20import org.apache.spark.{SparkException, SparkFunSuite} 21import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} 22import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} 23import org.apache.spark.ml.param.ParamsSuite 24import org.apache.spark.ml.util.DefaultReadWriteTest 25import org.apache.spark.mllib.util.MLlibTestSparkContext 26import org.apache.spark.sql.Row 27import org.apache.spark.sql.functions.col 28 29class VectorAssemblerSuite 30 extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { 31 32 import testImplicits._ 33 34 test("params") { 35 ParamsSuite.checkParams(new VectorAssembler) 36 } 37 38 test("assemble") { 39 import org.apache.spark.ml.feature.VectorAssembler.assemble 40 assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) 41 assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0))) 42 val dv = Vectors.dense(2.0, 0.0) 43 assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) 44 val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0)) 45 assert(assemble(0.0, dv, 1.0, sv) === 46 Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0))) 47 for (v <- Seq(1, "a", null)) { 48 intercept[SparkException](assemble(v)) 49 intercept[SparkException](assemble(1.0, v)) 50 } 51 } 52 53 test("assemble should compress vectors") { 54 import org.apache.spark.ml.feature.VectorAssembler.assemble 55 val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0)) 56 assert(v1.isInstanceOf[SparseVector]) 57 val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0))) 58 assert(v2.isInstanceOf[DenseVector]) 59 } 60 61 test("VectorAssembler") { 62 val df = Seq( 63 (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) 64 ).toDF("id", "x", "y", "name", "z", "n") 65 val assembler = new VectorAssembler() 66 .setInputCols(Array("x", "y", "z", "n")) 67 .setOutputCol("features") 68 assembler.transform(df).select("features").collect().foreach { 69 case Row(v: Vector) => 70 assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0))) 71 } 72 } 73 74 test("transform should throw an exception in case of unsupported type") { 75 val df = Seq(("a", "b", "c")).toDF("a", "b", "c") 76 val assembler = new VectorAssembler() 77 .setInputCols(Array("a", "b", "c")) 78 .setOutputCol("features") 79 val thrown = intercept[IllegalArgumentException] { 80 assembler.transform(df) 81 } 82 assert(thrown.getMessage contains "Data type StringType is not supported") 83 } 84 85 test("ML attributes") { 86 val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") 87 val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) 88 val user = new AttributeGroup("user", Array( 89 NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"), 90 NumericAttribute.defaultAttr.withName("salary"))) 91 val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0))) 92 val df = Seq(row).toDF("browser", "hour", "count", "user", "ad") 93 .select( 94 col("browser").as("browser", browser.toMetadata()), 95 col("hour").as("hour", hour.toMetadata()), 96 col("count"), // "count" is an integer column without ML attribute 97 col("user").as("user", user.toMetadata()), 98 col("ad")) // "ad" is a vector column without ML attribute 99 val assembler = new VectorAssembler() 100 .setInputCols(Array("browser", "hour", "count", "user", "ad")) 101 .setOutputCol("features") 102 val output = assembler.transform(df) 103 val schema = output.schema 104 val features = AttributeGroup.fromStructField(schema("features")) 105 assert(features.size === 7) 106 val browserOut = features.getAttr(0) 107 assert(browserOut === browser.withIndex(0).withName("browser")) 108 val hourOut = features.getAttr(1) 109 assert(hourOut === hour.withIndex(1).withName("hour")) 110 val countOut = features.getAttr(2) 111 assert(countOut === NumericAttribute.defaultAttr.withName("count").withIndex(2)) 112 val userGenderOut = features.getAttr(3) 113 assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3)) 114 val userSalaryOut = features.getAttr(4) 115 assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4)) 116 assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5).withName("ad_0")) 117 assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6).withName("ad_1")) 118 } 119 120 test("read/write") { 121 val t = new VectorAssembler() 122 .setInputCols(Array("myInputCol", "myInputCol2")) 123 .setOutputCol("myOutputCol") 124 testDefaultReadWrite(t) 125 } 126} 127