1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.catalyst.optimizer
19
20import org.apache.spark.sql.catalyst.dsl._
21import org.apache.spark.sql.catalyst.dsl.expressions._
22import org.apache.spark.sql.catalyst.dsl.plans._
23import org.apache.spark.sql.catalyst.expressions._
24import org.apache.spark.sql.catalyst.plans.PlanTest
25import org.apache.spark.sql.catalyst.plans.logical._
26import org.apache.spark.sql.catalyst.rules.RuleExecutor
27import org.apache.spark.sql.types._
28
29class SimplifyCastsSuite extends PlanTest {
30
31  object Optimize extends RuleExecutor[LogicalPlan] {
32    val batches = Batch("SimplifyCasts", FixedPoint(50), SimplifyCasts) :: Nil
33  }
34
35  test("non-nullable element array to nullable element array cast") {
36    val input = LocalRelation('a.array(ArrayType(IntegerType, false)))
37    val plan = input.select('a.cast(ArrayType(IntegerType, true)).as("casted")).analyze
38    val optimized = Optimize.execute(plan)
39    val expected = input.select('a.as("casted")).analyze
40    comparePlans(optimized, expected)
41  }
42
43  test("nullable element to non-nullable element array cast") {
44    val input = LocalRelation('a.array(ArrayType(IntegerType, true)))
45    val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze
46    val optimized = Optimize.execute(plan)
47    comparePlans(optimized, plan)
48  }
49
50  test("non-nullable value map to nullable value map cast") {
51    val input = LocalRelation('m.map(MapType(StringType, StringType, false)))
52    val plan = input.select('m.cast(MapType(StringType, StringType, true))
53      .as("casted")).analyze
54    val optimized = Optimize.execute(plan)
55    val expected = input.select('m.as("casted")).analyze
56    comparePlans(optimized, expected)
57  }
58
59  test("nullable value map to non-nullable value map cast") {
60    val input = LocalRelation('m.map(MapType(StringType, StringType, true)))
61    val plan = input.select('m.cast(MapType(StringType, StringType, false))
62      .as("casted")).analyze
63    val optimized = Optimize.execute(plan)
64    comparePlans(optimized, plan)
65  }
66}
67
68