1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 package org.apache.spark.ml.feature;
19 
20 import java.util.Arrays;
21 
22 import static org.apache.spark.sql.types.DataTypes.*;
23 
24 import org.junit.Assert;
25 import org.junit.Test;
26 
27 import org.apache.spark.SharedSparkSession;
28 import org.apache.spark.ml.linalg.Vector;
29 import org.apache.spark.ml.linalg.VectorUDT;
30 import org.apache.spark.ml.linalg.Vectors;
31 import org.apache.spark.sql.Dataset;
32 import org.apache.spark.sql.Row;
33 import org.apache.spark.sql.RowFactory;
34 import org.apache.spark.sql.types.StructField;
35 import org.apache.spark.sql.types.StructType;
36 
37 public class JavaVectorAssemblerSuite extends SharedSparkSession {
38 
39   @Test
testVectorAssembler()40   public void testVectorAssembler() {
41     StructType schema = createStructType(new StructField[]{
42       createStructField("id", IntegerType, false),
43       createStructField("x", DoubleType, false),
44       createStructField("y", new VectorUDT(), false),
45       createStructField("name", StringType, false),
46       createStructField("z", new VectorUDT(), false),
47       createStructField("n", LongType, false)
48     });
49     Row row = RowFactory.create(
50       0, 0.0, Vectors.dense(1.0, 2.0), "a",
51       Vectors.sparse(2, new int[]{1}, new double[]{3.0}), 10L);
52     Dataset<Row> dataset = spark.createDataFrame(Arrays.asList(row), schema);
53     VectorAssembler assembler = new VectorAssembler()
54       .setInputCols(new String[]{"x", "y", "z", "n"})
55       .setOutputCol("features");
56     Dataset<Row> output = assembler.transform(dataset);
57     Assert.assertEquals(
58       Vectors.sparse(6, new int[]{1, 2, 4, 5}, new double[]{1.0, 2.0, 3.0, 10.0}),
59       output.select("features").first().<Vector>getAs(0));
60   }
61 }
62