1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.spark.ml.feature; 19 20 import java.util.Arrays; 21 22 import static org.apache.spark.sql.types.DataTypes.*; 23 24 import org.junit.Assert; 25 import org.junit.Test; 26 27 import org.apache.spark.SharedSparkSession; 28 import org.apache.spark.ml.linalg.Vector; 29 import org.apache.spark.ml.linalg.VectorUDT; 30 import org.apache.spark.ml.linalg.Vectors; 31 import org.apache.spark.sql.Dataset; 32 import org.apache.spark.sql.Row; 33 import org.apache.spark.sql.RowFactory; 34 import org.apache.spark.sql.types.StructField; 35 import org.apache.spark.sql.types.StructType; 36 37 public class JavaVectorAssemblerSuite extends SharedSparkSession { 38 39 @Test testVectorAssembler()40 public void testVectorAssembler() { 41 StructType schema = createStructType(new StructField[]{ 42 createStructField("id", IntegerType, false), 43 createStructField("x", DoubleType, false), 44 createStructField("y", new VectorUDT(), false), 45 createStructField("name", StringType, false), 46 createStructField("z", new VectorUDT(), false), 47 createStructField("n", LongType, false) 48 }); 49 Row row = RowFactory.create( 50 0, 0.0, Vectors.dense(1.0, 2.0), "a", 51 Vectors.sparse(2, new int[]{1}, new double[]{3.0}), 10L); 52 Dataset<Row> dataset = spark.createDataFrame(Arrays.asList(row), schema); 53 VectorAssembler assembler = new VectorAssembler() 54 .setInputCols(new String[]{"x", "y", "z", "n"}) 55 .setOutputCol("features"); 56 Dataset<Row> output = assembler.transform(dataset); 57 Assert.assertEquals( 58 Vectors.sparse(6, new int[]{1, 2, 4, 5}, new double[]{1.0, 2.0, 3.0, 10.0}), 59 output.select("features").first().<Vector>getAs(0)); 60 } 61 } 62