1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql.catalyst.plans 19 20import org.apache.spark.SparkFunSuite 21import org.apache.spark.sql.catalyst.analysis._ 22import org.apache.spark.sql.catalyst.dsl.expressions._ 23import org.apache.spark.sql.catalyst.dsl.plans._ 24import org.apache.spark.sql.catalyst.expressions._ 25import org.apache.spark.sql.catalyst.plans.logical._ 26import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType} 27 28class ConstraintPropagationSuite extends SparkFunSuite { 29 30 private def resolveColumn(tr: LocalRelation, columnName: String): Expression = 31 resolveColumn(tr.analyze, columnName) 32 33 private def resolveColumn(plan: LogicalPlan, columnName: String): Expression = 34 plan.resolveQuoted(columnName, caseInsensitiveResolution).get 35 36 private def verifyConstraints(found: ExpressionSet, expected: ExpressionSet): Unit = { 37 val missing = expected -- found 38 val extra = found -- expected 39 if (missing.nonEmpty || extra.nonEmpty) { 40 fail( 41 s""" 42 |== FAIL: Constraints do not match === 43 |Found: ${found.mkString(",")} 44 |Expected: ${expected.mkString(",")} 45 |== Result == 46 |Missing: ${if (missing.isEmpty) "N/A" else missing.mkString(",")} 47 |Found but not expected: ${if (extra.isEmpty) "N/A" else extra.mkString(",")} 48 """.stripMargin) 49 } 50 } 51 52 test("propagating constraints in filters") { 53 val tr = LocalRelation('a.int, 'b.string, 'c.int) 54 55 assert(tr.analyze.constraints.isEmpty) 56 57 assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) 58 59 verifyConstraints(tr 60 .where('a.attr > 10) 61 .analyze.constraints, 62 ExpressionSet(Seq(resolveColumn(tr, "a") > 10, 63 IsNotNull(resolveColumn(tr, "a"))))) 64 65 verifyConstraints(tr 66 .where('a.attr > 10) 67 .select('c.attr, 'a.attr) 68 .where('c.attr =!= 100) 69 .analyze.constraints, 70 ExpressionSet(Seq(resolveColumn(tr, "a") > 10, 71 resolveColumn(tr, "c") =!= 100, 72 IsNotNull(resolveColumn(tr, "a")), 73 IsNotNull(resolveColumn(tr, "c"))))) 74 } 75 76 test("propagating constraints in aggregate") { 77 val tr = LocalRelation('a.int, 'b.string, 'c.int) 78 79 assert(tr.analyze.constraints.isEmpty) 80 81 val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) 82 .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3).analyze 83 84 // SPARK-16644: aggregate expression count(a) should not appear in the constraints. 85 verifyConstraints(aliasedRelation.analyze.constraints, 86 ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "c1") > 10, 87 IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")), 88 resolveColumn(aliasedRelation.analyze, "a") < 5, 89 IsNotNull(resolveColumn(aliasedRelation.analyze, "a")), 90 IsNotNull(resolveColumn(aliasedRelation.analyze, "a3"))))) 91 } 92 93 test("propagating constraints in expand") { 94 val tr = LocalRelation('a.int, 'b.int, 'c.int) 95 96 assert(tr.analyze.constraints.isEmpty) 97 98 // We add IsNotNull constraints for 'a, 'b and 'c into LocalRelation 99 // by creating notNullRelation. 100 val notNullRelation = tr.where('c.attr > 10 && 'a.attr < 5 && 'b.attr > 2) 101 verifyConstraints(notNullRelation.analyze.constraints, 102 ExpressionSet(Seq(resolveColumn(notNullRelation.analyze, "c") > 10, 103 IsNotNull(resolveColumn(notNullRelation.analyze, "c")), 104 resolveColumn(notNullRelation.analyze, "a") < 5, 105 IsNotNull(resolveColumn(notNullRelation.analyze, "a")), 106 resolveColumn(notNullRelation.analyze, "b") > 2, 107 IsNotNull(resolveColumn(notNullRelation.analyze, "b"))))) 108 109 val expand = Expand( 110 Seq( 111 Seq('c, Literal.create(null, StringType), 1), 112 Seq('c, 'a, 2)), 113 Seq('c, 'a, 'gid.int), 114 Project(Seq('a, 'c), 115 notNullRelation)) 116 verifyConstraints(expand.analyze.constraints, 117 ExpressionSet(Seq.empty[Expression])) 118 } 119 120 test("propagating constraints in aliases") { 121 val tr = LocalRelation('a.int, 'b.string, 'c.int) 122 123 assert(tr.where('c.attr > 10).select('a.as('x), 'b.as('y)).analyze.constraints.isEmpty) 124 125 val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z)) 126 127 verifyConstraints(aliasedRelation.analyze.constraints, 128 ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, 129 IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), 130 resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), 131 resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"), 132 resolveColumn(aliasedRelation.analyze, "z") > 10, 133 IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) 134 135 val multiAlias = tr.where('a === 'c + 10).select('a.as('x), 'c.as('y)) 136 verifyConstraints(multiAlias.analyze.constraints, 137 ExpressionSet(Seq(IsNotNull(resolveColumn(multiAlias.analyze, "x")), 138 IsNotNull(resolveColumn(multiAlias.analyze, "y")), 139 resolveColumn(multiAlias.analyze, "x") === resolveColumn(multiAlias.analyze, "y") + 10)) 140 ) 141 } 142 143 test("propagating constraints in union") { 144 val tr1 = LocalRelation('a.int, 'b.int, 'c.int) 145 val tr2 = LocalRelation('d.int, 'e.int, 'f.int) 146 val tr3 = LocalRelation('g.int, 'h.int, 'i.int) 147 148 assert(tr1 149 .where('a.attr > 10) 150 .union(tr2.where('e.attr > 10) 151 .union(tr3.where('i.attr > 10))) 152 .analyze.constraints.isEmpty) 153 154 verifyConstraints(tr1 155 .where('a.attr > 10) 156 .union(tr2.where('d.attr > 10) 157 .union(tr3.where('g.attr > 10))) 158 .analyze.constraints, 159 ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, 160 IsNotNull(resolveColumn(tr1, "a"))))) 161 162 val a = resolveColumn(tr1, "a") 163 verifyConstraints(tr1 164 .where('a.attr > 10) 165 .union(tr2.where('d.attr > 11)) 166 .analyze.constraints, 167 ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a)))) 168 169 val b = resolveColumn(tr1, "b") 170 verifyConstraints(tr1 171 .where('a.attr > 10 && 'b.attr < 10) 172 .union(tr2.where('d.attr > 11 && 'e.attr < 11)) 173 .analyze.constraints, 174 ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b)))) 175 } 176 177 test("propagating constraints in intersect") { 178 val tr1 = LocalRelation('a.int, 'b.int, 'c.int) 179 val tr2 = LocalRelation('a.int, 'b.int, 'c.int) 180 181 verifyConstraints(tr1 182 .where('a.attr > 10) 183 .intersect(tr2.where('b.attr < 100)) 184 .analyze.constraints, 185 ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, 186 resolveColumn(tr1, "b") < 100, 187 IsNotNull(resolveColumn(tr1, "a")), 188 IsNotNull(resolveColumn(tr1, "b"))))) 189 } 190 191 test("propagating constraints in except") { 192 val tr1 = LocalRelation('a.int, 'b.int, 'c.int) 193 val tr2 = LocalRelation('a.int, 'b.int, 'c.int) 194 verifyConstraints(tr1 195 .where('a.attr > 10) 196 .except(tr2.where('b.attr < 100)) 197 .analyze.constraints, 198 ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, 199 IsNotNull(resolveColumn(tr1, "a"))))) 200 } 201 202 test("propagating constraints in inner join") { 203 val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) 204 val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) 205 verifyConstraints(tr1 206 .where('a.attr > 10) 207 .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) 208 .analyze.constraints, 209 ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, 210 tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, 211 tr1.resolveQuoted("a", caseInsensitiveResolution).get === 212 tr2.resolveQuoted("a", caseInsensitiveResolution).get, 213 tr2.resolveQuoted("a", caseInsensitiveResolution).get > 10, 214 IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), 215 IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), 216 IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))) 217 } 218 219 test("propagating constraints in left-semi join") { 220 val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) 221 val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) 222 verifyConstraints(tr1 223 .where('a.attr > 10) 224 .join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr)) 225 .analyze.constraints, 226 ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, 227 IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))) 228 } 229 230 test("propagating constraints in left-outer join") { 231 val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) 232 val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) 233 verifyConstraints(tr1 234 .where('a.attr > 10) 235 .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) 236 .analyze.constraints, 237 ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, 238 IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))) 239 } 240 241 test("propagating constraints in right-outer join") { 242 val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) 243 val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) 244 verifyConstraints(tr1 245 .where('a.attr > 10) 246 .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) 247 .analyze.constraints, 248 ExpressionSet(Seq(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, 249 IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))) 250 } 251 252 test("propagating constraints in full-outer join") { 253 val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) 254 val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) 255 assert(tr1.where('a.attr > 10) 256 .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr)) 257 .analyze.constraints.isEmpty) 258 } 259 260 test("infer additional constraints in filters") { 261 val tr = LocalRelation('a.int, 'b.int, 'c.int) 262 263 verifyConstraints(tr 264 .where('a.attr > 10 && 'a.attr === 'b.attr) 265 .analyze.constraints, 266 ExpressionSet(Seq(resolveColumn(tr, "a") > 10, 267 resolveColumn(tr, "b") > 10, 268 resolveColumn(tr, "a") === resolveColumn(tr, "b"), 269 IsNotNull(resolveColumn(tr, "a")), 270 IsNotNull(resolveColumn(tr, "b"))))) 271 } 272 273 test("infer constraints on cast") { 274 val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) 275 verifyConstraints( 276 tr.where('a.attr === 'b.attr && 277 'c.attr + 100 > 'd.attr && 278 IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints, 279 ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"), 280 Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"), 281 IsNotNull(resolveColumn(tr, "a")), 282 IsNotNull(resolveColumn(tr, "b")), 283 IsNotNull(resolveColumn(tr, "c")), 284 IsNotNull(resolveColumn(tr, "d")), 285 IsNotNull(resolveColumn(tr, "e")), 286 IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))))) 287 } 288 289 test("infer isnotnull constraints from compound expressions") { 290 val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) 291 verifyConstraints( 292 tr.where('a.attr + 'b.attr === 'c.attr && 293 IsNotNull( 294 Cast( 295 Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints, 296 ExpressionSet(Seq( 297 Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") === 298 Cast(resolveColumn(tr, "c"), LongType), 299 IsNotNull(resolveColumn(tr, "a")), 300 IsNotNull(resolveColumn(tr, "b")), 301 IsNotNull(resolveColumn(tr, "c")), 302 IsNotNull(resolveColumn(tr, "e")), 303 IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))))) 304 305 verifyConstraints( 306 tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints, 307 ExpressionSet(Seq( 308 Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) === 309 Cast(resolveColumn(tr, "c"), LongType), 310 Cast(resolveColumn(tr, "d"), DoubleType) / 311 Cast(10, DoubleType) === 312 Cast(resolveColumn(tr, "e"), DoubleType), 313 IsNotNull(resolveColumn(tr, "a")), 314 IsNotNull(resolveColumn(tr, "b")), 315 IsNotNull(resolveColumn(tr, "c")), 316 IsNotNull(resolveColumn(tr, "d")), 317 IsNotNull(resolveColumn(tr, "e"))))) 318 319 verifyConstraints( 320 tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints, 321 ExpressionSet(Seq( 322 Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >= 323 Cast(resolveColumn(tr, "c"), LongType), 324 Cast(resolveColumn(tr, "d"), DoubleType) / 325 Cast(10, DoubleType) < 326 Cast(resolveColumn(tr, "e"), DoubleType), 327 IsNotNull(resolveColumn(tr, "a")), 328 IsNotNull(resolveColumn(tr, "b")), 329 IsNotNull(resolveColumn(tr, "c")), 330 IsNotNull(resolveColumn(tr, "d")), 331 IsNotNull(resolveColumn(tr, "e"))))) 332 333 verifyConstraints( 334 tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints, 335 ExpressionSet(Seq( 336 (Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) - 337 (Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) > 338 Cast(resolveColumn(tr, "e") * 1000, LongType), 339 IsNotNull(resolveColumn(tr, "a")), 340 IsNotNull(resolveColumn(tr, "b")), 341 IsNotNull(resolveColumn(tr, "c")), 342 IsNotNull(resolveColumn(tr, "d")), 343 IsNotNull(resolveColumn(tr, "e"))))) 344 345 // The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null. 346 verifyConstraints( 347 tr.where('a.attr === 'c.attr && 348 IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints, 349 ExpressionSet(Seq( 350 resolveColumn(tr, "a") === resolveColumn(tr, "c"), 351 IsNotNull(IsNotNull(resolveColumn(tr, "b"))), 352 IsNotNull(resolveColumn(tr, "a")), 353 IsNotNull(resolveColumn(tr, "c"))))) 354 355 verifyConstraints( 356 tr.where('a.attr === 1 && IsNotNull(resolveColumn(tr, "b")) && 357 IsNotNull(resolveColumn(tr, "c"))).analyze.constraints, 358 ExpressionSet(Seq( 359 resolveColumn(tr, "a") === 1, 360 IsNotNull(resolveColumn(tr, "c")), 361 IsNotNull(resolveColumn(tr, "a")), 362 IsNotNull(resolveColumn(tr, "b"))))) 363 } 364 365 test("infer IsNotNull constraints from non-nullable attributes") { 366 val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(), 367 AttributeReference("c", StringType, nullable = false)()) 368 369 verifyConstraints(tr.analyze.constraints, 370 ExpressionSet(Seq(IsNotNull(resolveColumn(tr, "b")), IsNotNull(resolveColumn(tr, "c"))))) 371 } 372 373 test("not infer non-deterministic constraints") { 374 val tr = LocalRelation('a.int, 'b.string, 'c.int) 375 376 verifyConstraints(tr 377 .where('a.attr === Rand(0)) 378 .analyze.constraints, 379 ExpressionSet(Seq(IsNotNull(resolveColumn(tr, "a"))))) 380 381 verifyConstraints(tr 382 .where('a.attr === InputFileName()) 383 .where('a.attr =!= 'c.attr) 384 .analyze.constraints, 385 ExpressionSet(Seq(resolveColumn(tr, "a") =!= resolveColumn(tr, "c"), 386 IsNotNull(resolveColumn(tr, "a")), 387 IsNotNull(resolveColumn(tr, "c"))))) 388 } 389} 390