1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql.catalyst.analysis 19 20import org.scalatest.ShouldMatchers 21 22import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} 23import org.apache.spark.sql.catalyst.dsl.expressions._ 24import org.apache.spark.sql.catalyst.dsl.plans._ 25import org.apache.spark.sql.catalyst.expressions._ 26import org.apache.spark.sql.catalyst.plans.{Cross, Inner} 27import org.apache.spark.sql.catalyst.plans.logical._ 28import org.apache.spark.sql.types._ 29 30 31class AnalysisSuite extends AnalysisTest with ShouldMatchers { 32 import org.apache.spark.sql.catalyst.analysis.TestRelations._ 33 34 test("union project *") { 35 val plan = (1 to 120) 36 .map(_ => testRelation) 37 .fold[LogicalPlan](testRelation) { (a, b) => 38 a.select(UnresolvedStar(None)).select('a).union(b.select(UnresolvedStar(None))) 39 } 40 41 assertAnalysisSuccess(plan) 42 } 43 44 test("check project's resolved") { 45 assert(Project(testRelation.output, testRelation).resolved) 46 47 assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved) 48 49 val explode = Explode(AttributeReference("a", IntegerType, nullable = true)()) 50 assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved) 51 52 assert(!Project(Seq(Alias(count(Literal(1)), "count")()), testRelation).resolved) 53 } 54 55 test("analyze project") { 56 checkAnalysis( 57 Project(Seq(UnresolvedAttribute("a")), testRelation), 58 Project(testRelation.output, testRelation)) 59 60 checkAnalysis( 61 Project(Seq(UnresolvedAttribute("TbL.a")), 62 UnresolvedRelation(TableIdentifier("TaBlE"), Some("TbL"))), 63 Project(testRelation.output, testRelation)) 64 65 assertAnalysisError( 66 Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation( 67 TableIdentifier("TaBlE"), Some("TbL"))), 68 Seq("cannot resolve")) 69 70 checkAnalysis( 71 Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation( 72 TableIdentifier("TaBlE"), Some("TbL"))), 73 Project(testRelation.output, testRelation), 74 caseSensitive = false) 75 76 checkAnalysis( 77 Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation( 78 TableIdentifier("TaBlE"), Some("TbL"))), 79 Project(testRelation.output, testRelation), 80 caseSensitive = false) 81 } 82 83 test("resolve sort references - filter/limit") { 84 val a = testRelation2.output(0) 85 val b = testRelation2.output(1) 86 val c = testRelation2.output(2) 87 88 // Case 1: one missing attribute is in the leaf node and another is in the unary node 89 val plan1 = testRelation2 90 .where('a > "str").select('a, 'b) 91 .where('b > "str").select('a) 92 .sortBy('b.asc, 'c.desc) 93 val expected1 = testRelation2 94 .where(a > "str").select(a, b, c) 95 .where(b > "str").select(a, b, c) 96 .sortBy(b.asc, c.desc) 97 .select(a) 98 checkAnalysis(plan1, expected1) 99 100 // Case 2: all the missing attributes are in the leaf node 101 val plan2 = testRelation2 102 .where('a > "str").select('a) 103 .where('a > "str").select('a) 104 .sortBy('b.asc, 'c.desc) 105 val expected2 = testRelation2 106 .where(a > "str").select(a, b, c) 107 .where(a > "str").select(a, b, c) 108 .sortBy(b.asc, c.desc) 109 .select(a) 110 checkAnalysis(plan2, expected2) 111 } 112 113 test("resolve sort references - join") { 114 val a = testRelation2.output(0) 115 val b = testRelation2.output(1) 116 val c = testRelation2.output(2) 117 val h = testRelation3.output(3) 118 119 // Case: join itself can resolve all the missing attributes 120 val plan = testRelation2.join(testRelation3) 121 .where('a > "str").select('a, 'b) 122 .sortBy('c.desc, 'h.asc) 123 val expected = testRelation2.join(testRelation3) 124 .where(a > "str").select(a, b, c, h) 125 .sortBy(c.desc, h.asc) 126 .select(a, b) 127 checkAnalysis(plan, expected) 128 } 129 130 test("resolve sort references - aggregate") { 131 val a = testRelation2.output(0) 132 val b = testRelation2.output(1) 133 val c = testRelation2.output(2) 134 val alias_a3 = count(a).as("a3") 135 val alias_b = b.as("aggOrder") 136 137 // Case 1: when the child of Sort is not Aggregate, 138 // the sort reference is handled by the rule ResolveSortReferences 139 val plan1 = testRelation2 140 .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) 141 .select('a, 'c, 'a3) 142 .orderBy('b.asc) 143 144 val expected1 = testRelation2 145 .groupBy(a, c, b)(a, c, alias_a3, b) 146 .select(a, c, alias_a3.toAttribute, b) 147 .orderBy(b.asc) 148 .select(a, c, alias_a3.toAttribute) 149 150 checkAnalysis(plan1, expected1) 151 152 // Case 2: when the child of Sort is Aggregate, 153 // the sort reference is handled by the rule ResolveAggregateFunctions 154 val plan2 = testRelation2 155 .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) 156 .orderBy('b.asc) 157 158 val expected2 = testRelation2 159 .groupBy(a, c, b)(a, c, alias_a3, alias_b) 160 .orderBy(alias_b.toAttribute.asc) 161 .select(a, c, alias_a3.toAttribute) 162 163 checkAnalysis(plan2, expected2) 164 } 165 166 test("resolve relations") { 167 assertAnalysisError(UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq()) 168 checkAnalysis(UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation) 169 checkAnalysis( 170 UnresolvedRelation(TableIdentifier("tAbLe"), None), testRelation, caseSensitive = false) 171 checkAnalysis( 172 UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation, caseSensitive = false) 173 } 174 175 test("divide should be casted into fractional types") { 176 val plan = caseInsensitiveAnalyzer.execute( 177 testRelation2.select( 178 'a / Literal(2) as 'div1, 179 'a / 'b as 'div2, 180 'a / 'c as 'div3, 181 'a / 'd as 'div4, 182 'e / 'e as 'div5)) 183 val pl = plan.asInstanceOf[Project].projectList 184 185 assert(pl(0).dataType == DoubleType) 186 assert(pl(1).dataType == DoubleType) 187 assert(pl(2).dataType == DoubleType) 188 assert(pl(3).dataType == DoubleType) 189 assert(pl(4).dataType == DoubleType) 190 } 191 192 test("pull out nondeterministic expressions from RepartitionByExpression") { 193 val plan = RepartitionByExpression(Seq(Rand(33)), testRelation) 194 val projected = Alias(Rand(33), "_nondeterministic")() 195 val expected = 196 Project(testRelation.output, 197 RepartitionByExpression(Seq(projected.toAttribute), 198 Project(testRelation.output :+ projected, testRelation))) 199 checkAnalysis(plan, expected) 200 } 201 202 test("pull out nondeterministic expressions from Sort") { 203 val plan = Sort(Seq(SortOrder(Rand(33), Ascending)), false, testRelation) 204 val projected = Alias(Rand(33), "_nondeterministic")() 205 val expected = 206 Project(testRelation.output, 207 Sort(Seq(SortOrder(projected.toAttribute, Ascending)), false, 208 Project(testRelation.output :+ projected, testRelation))) 209 checkAnalysis(plan, expected) 210 } 211 212 test("SPARK-9634: cleanup unnecessary Aliases in LogicalPlan") { 213 val a = testRelation.output.head 214 var plan = testRelation.select(((a + 1).as("a+1") + 2).as("col")) 215 var expected = testRelation.select((a + 1 + 2).as("col")) 216 checkAnalysis(plan, expected) 217 218 plan = testRelation.groupBy(a.as("a1").as("a2"))((min(a).as("min_a") + 1).as("col")) 219 expected = testRelation.groupBy(a)((min(a) + 1).as("col")) 220 checkAnalysis(plan, expected) 221 222 // CreateStruct is a special case that we should not trim Alias for it. 223 plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col")) 224 expected = testRelation.select(CreateNamedStruct(Seq( 225 Literal(a.name), a, 226 Literal("a+1"), (a + 1))).as("col")) 227 checkAnalysis(plan, expected) 228 } 229 230 test("Analysis may leave unnecassary aliases") { 231 val att1 = testRelation.output.head 232 var plan = testRelation.select( 233 CreateStruct(Seq(att1, ((att1.as("aa")) + 1).as("a_plus_1"))).as("col"), 234 att1 235 ) 236 val prevPlan = getAnalyzer(true).execute(plan) 237 plan = prevPlan.select(CreateArray(Seq( 238 CreateStruct(Seq(att1, (att1 + 1).as("a_plus_1"))).as("col1"), 239 /** alias should be eliminated by [[CleanupAliases]] */ 240 "col".attr.as("col2") 241 )).as("arr")) 242 plan = getAnalyzer(true).execute(plan) 243 244 val expectedPlan = prevPlan.select( 245 CreateArray(Seq( 246 CreateNamedStruct(Seq( 247 Literal(att1.name), att1, 248 Literal("a_plus_1"), (att1 + 1))), 249 'col.struct(prevPlan.output(0).dataType.asInstanceOf[StructType]).notNull 250 )).as("arr") 251 ) 252 253 checkAnalysis(plan, expectedPlan) 254 } 255 256 test("SPARK-10534: resolve attribute references in order by clause") { 257 val a = testRelation2.output(0) 258 val c = testRelation2.output(2) 259 260 val plan = testRelation2.select('c).orderBy(Floor('a).asc) 261 val expected = testRelation2.select(c, a).orderBy(Floor(a.cast(DoubleType)).asc).select(c) 262 263 checkAnalysis(plan, expected) 264 } 265 266 test("self intersect should resolve duplicate expression IDs") { 267 val plan = testRelation.intersect(testRelation) 268 assertAnalysisSuccess(plan) 269 } 270 271 test("SPARK-8654: invalid CAST in NULL IN(...) expression") { 272 val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, 273 LocalRelation() 274 ) 275 assertAnalysisSuccess(plan) 276 } 277 278 test("SPARK-8654: different types in inlist but can be converted to a common type") { 279 val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, 280 LocalRelation() 281 ) 282 assertAnalysisSuccess(plan) 283 } 284 285 test("SPARK-8654: check type compatibility error") { 286 val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil, 287 LocalRelation() 288 ) 289 assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) 290 } 291 292 test("SPARK-11725: correctly handle null inputs for ScalaUDF") { 293 val string = testRelation2.output(0) 294 val double = testRelation2.output(2) 295 val short = testRelation2.output(4) 296 val nullResult = Literal.create(null, StringType) 297 298 def checkUDF(udf: Expression, transformed: Expression): Unit = { 299 checkAnalysis( 300 Project(Alias(udf, "")() :: Nil, testRelation2), 301 Project(Alias(transformed, "")() :: Nil, testRelation2) 302 ) 303 } 304 305 // non-primitive parameters do not need special null handling 306 val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil) 307 val expected1 = udf1 308 checkUDF(udf1, expected1) 309 310 // only primitive parameter needs special null handling 311 val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) 312 val expected2 = If(IsNull(double), nullResult, udf2) 313 checkUDF(udf2, expected2) 314 315 // special null handling should apply to all primitive parameters 316 val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil) 317 val expected3 = If( 318 IsNull(short) || IsNull(double), 319 nullResult, 320 udf3) 321 checkUDF(udf3, expected3) 322 323 // we can skip special null handling for primitive parameters that are not nullable 324 // TODO: this is disabled for now as we can not completely trust `nullable`. 325 val udf4 = ScalaUDF( 326 (s: Short, d: Double) => "x", 327 StringType, 328 short :: double.withNullability(false) :: Nil) 329 val expected4 = If( 330 IsNull(short), 331 nullResult, 332 udf4) 333 // checkUDF(udf4, expected4) 334 } 335 336 test("SPARK-11863 mixture of aliases and real columns in order by clause - tpcds 19,55,71") { 337 val a = testRelation2.output(0) 338 val c = testRelation2.output(2) 339 val alias1 = a.as("a1") 340 val alias2 = c.as("a2") 341 val alias3 = count(a).as("a3") 342 343 val plan = testRelation2 344 .groupBy('a, 'c)('a.as("a1"), 'c.as("a2"), count('a).as("a3")) 345 .orderBy('a1.asc, 'c.asc) 346 347 val expected = testRelation2 348 .groupBy(a, c)(alias1, alias2, alias3) 349 .orderBy(alias1.toAttribute.asc, alias2.toAttribute.asc) 350 .select(alias1.toAttribute, alias2.toAttribute, alias3.toAttribute) 351 checkAnalysis(plan, expected) 352 } 353 354 test("Eliminate the unnecessary union") { 355 val plan = Union(testRelation :: Nil) 356 val expected = testRelation 357 checkAnalysis(plan, expected) 358 } 359 360 test("SPARK-12102: Ignore nullablity when comparing two sides of case") { 361 val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false))) 362 val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val")) 363 assertAnalysisSuccess(plan) 364 } 365 366 test("Keep attribute qualifiers after dedup") { 367 val input = LocalRelation('key.int, 'value.string) 368 369 val query = 370 Project(Seq($"x.key", $"y.key"), 371 Join( 372 Project(Seq($"x.key"), SubqueryAlias("x", input, None)), 373 Project(Seq($"y.key"), SubqueryAlias("y", input, None)), 374 Cross, None)) 375 376 assertAnalysisSuccess(query) 377 } 378 379 private def assertExpressionType( 380 expression: Expression, 381 expectedDataType: DataType): Unit = { 382 val afterAnalyze = 383 Project(Seq(Alias(expression, "a")()), OneRowRelation).analyze.expressions.head 384 if (!afterAnalyze.dataType.equals(expectedDataType)) { 385 fail( 386 s""" 387 |data type of expression $expression doesn't match expected: 388 |Actual data type: 389 |${afterAnalyze.dataType} 390 | 391 |Expected data type: 392 |${expectedDataType} 393 """.stripMargin) 394 } 395 } 396 397 test("SPARK-15776: test whether Divide expression's data type can be deduced correctly by " + 398 "analyzer") { 399 assertExpressionType(sum(Divide(1, 2)), DoubleType) 400 assertExpressionType(sum(Divide(1.0, 2)), DoubleType) 401 assertExpressionType(sum(Divide(1, 2.0)), DoubleType) 402 assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType) 403 assertExpressionType(sum(Divide(1, 2.0f)), DoubleType) 404 assertExpressionType(sum(Divide(1.0f, 2)), DoubleType) 405 assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11)) 406 assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11)) 407 assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) 408 assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) 409 } 410 411 test("SPARK-18058: union and set operations shall not care about the nullability" + 412 " when comparing column types") { 413 val firstTable = LocalRelation( 414 AttributeReference("a", 415 StructType(Seq(StructField("a", IntegerType, nullable = true))), nullable = false)()) 416 val secondTable = LocalRelation( 417 AttributeReference("a", 418 StructType(Seq(StructField("a", IntegerType, nullable = false))), nullable = false)()) 419 420 val unionPlan = Union(firstTable, secondTable) 421 assertAnalysisSuccess(unionPlan) 422 423 val r1 = Except(firstTable, secondTable) 424 val r2 = Intersect(firstTable, secondTable) 425 426 assertAnalysisSuccess(r1) 427 assertAnalysisSuccess(r2) 428 } 429} 430