1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.catalyst.plans
19
20import org.apache.spark.SparkFunSuite
21import org.apache.spark.sql.catalyst.analysis._
22import org.apache.spark.sql.catalyst.dsl.expressions._
23import org.apache.spark.sql.catalyst.dsl.plans._
24import org.apache.spark.sql.catalyst.expressions._
25import org.apache.spark.sql.catalyst.plans.logical._
26import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType}
27
28class ConstraintPropagationSuite extends SparkFunSuite {
29
30  private def resolveColumn(tr: LocalRelation, columnName: String): Expression =
31    resolveColumn(tr.analyze, columnName)
32
33  private def resolveColumn(plan: LogicalPlan, columnName: String): Expression =
34    plan.resolveQuoted(columnName, caseInsensitiveResolution).get
35
36  private def verifyConstraints(found: ExpressionSet, expected: ExpressionSet): Unit = {
37    val missing = expected -- found
38    val extra = found -- expected
39    if (missing.nonEmpty || extra.nonEmpty) {
40      fail(
41        s"""
42           |== FAIL: Constraints do not match ===
43           |Found: ${found.mkString(",")}
44           |Expected: ${expected.mkString(",")}
45           |== Result ==
46           |Missing: ${if (missing.isEmpty) "N/A" else missing.mkString(",")}
47           |Found but not expected: ${if (extra.isEmpty) "N/A" else extra.mkString(",")}
48         """.stripMargin)
49    }
50  }
51
52  test("propagating constraints in filters") {
53    val tr = LocalRelation('a.int, 'b.string, 'c.int)
54
55    assert(tr.analyze.constraints.isEmpty)
56
57    assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty)
58
59    verifyConstraints(tr
60      .where('a.attr > 10)
61      .analyze.constraints,
62      ExpressionSet(Seq(resolveColumn(tr, "a") > 10,
63        IsNotNull(resolveColumn(tr, "a")))))
64
65    verifyConstraints(tr
66      .where('a.attr > 10)
67      .select('c.attr, 'a.attr)
68      .where('c.attr =!= 100)
69      .analyze.constraints,
70      ExpressionSet(Seq(resolveColumn(tr, "a") > 10,
71        resolveColumn(tr, "c") =!= 100,
72        IsNotNull(resolveColumn(tr, "a")),
73        IsNotNull(resolveColumn(tr, "c")))))
74  }
75
76  test("propagating constraints in aggregate") {
77    val tr = LocalRelation('a.int, 'b.string, 'c.int)
78
79    assert(tr.analyze.constraints.isEmpty)
80
81    val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5)
82      .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3).analyze
83
84    // SPARK-16644: aggregate expression count(a) should not appear in the constraints.
85    verifyConstraints(aliasedRelation.analyze.constraints,
86      ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "c1") > 10,
87        IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")),
88        resolveColumn(aliasedRelation.analyze, "a") < 5,
89        IsNotNull(resolveColumn(aliasedRelation.analyze, "a")),
90        IsNotNull(resolveColumn(aliasedRelation.analyze, "a3")))))
91  }
92
93  test("propagating constraints in expand") {
94    val tr = LocalRelation('a.int, 'b.int, 'c.int)
95
96    assert(tr.analyze.constraints.isEmpty)
97
98    // We add IsNotNull constraints for 'a, 'b and 'c into LocalRelation
99    // by creating notNullRelation.
100    val notNullRelation = tr.where('c.attr > 10 && 'a.attr < 5 && 'b.attr > 2)
101    verifyConstraints(notNullRelation.analyze.constraints,
102      ExpressionSet(Seq(resolveColumn(notNullRelation.analyze, "c") > 10,
103        IsNotNull(resolveColumn(notNullRelation.analyze, "c")),
104        resolveColumn(notNullRelation.analyze, "a") < 5,
105        IsNotNull(resolveColumn(notNullRelation.analyze, "a")),
106        resolveColumn(notNullRelation.analyze, "b") > 2,
107        IsNotNull(resolveColumn(notNullRelation.analyze, "b")))))
108
109    val expand = Expand(
110          Seq(
111            Seq('c, Literal.create(null, StringType), 1),
112            Seq('c, 'a, 2)),
113          Seq('c, 'a, 'gid.int),
114          Project(Seq('a, 'c),
115            notNullRelation))
116    verifyConstraints(expand.analyze.constraints,
117      ExpressionSet(Seq.empty[Expression]))
118  }
119
120  test("propagating constraints in aliases") {
121    val tr = LocalRelation('a.int, 'b.string, 'c.int)
122
123    assert(tr.where('c.attr > 10).select('a.as('x), 'b.as('y)).analyze.constraints.isEmpty)
124
125    val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z))
126
127    verifyConstraints(aliasedRelation.analyze.constraints,
128      ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10,
129        IsNotNull(resolveColumn(aliasedRelation.analyze, "x")),
130        resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"),
131        resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"),
132        resolveColumn(aliasedRelation.analyze, "z") > 10,
133        IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))))
134
135    val multiAlias = tr.where('a === 'c + 10).select('a.as('x), 'c.as('y))
136    verifyConstraints(multiAlias.analyze.constraints,
137      ExpressionSet(Seq(IsNotNull(resolveColumn(multiAlias.analyze, "x")),
138        IsNotNull(resolveColumn(multiAlias.analyze, "y")),
139        resolveColumn(multiAlias.analyze, "x") === resolveColumn(multiAlias.analyze, "y") + 10))
140    )
141  }
142
143  test("propagating constraints in union") {
144    val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
145    val tr2 = LocalRelation('d.int, 'e.int, 'f.int)
146    val tr3 = LocalRelation('g.int, 'h.int, 'i.int)
147
148    assert(tr1
149      .where('a.attr > 10)
150      .union(tr2.where('e.attr > 10)
151      .union(tr3.where('i.attr > 10)))
152      .analyze.constraints.isEmpty)
153
154    verifyConstraints(tr1
155      .where('a.attr > 10)
156      .union(tr2.where('d.attr > 10)
157      .union(tr3.where('g.attr > 10)))
158      .analyze.constraints,
159      ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
160        IsNotNull(resolveColumn(tr1, "a")))))
161
162    val a = resolveColumn(tr1, "a")
163    verifyConstraints(tr1
164      .where('a.attr > 10)
165      .union(tr2.where('d.attr > 11))
166      .analyze.constraints,
167      ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a))))
168
169    val b = resolveColumn(tr1, "b")
170    verifyConstraints(tr1
171      .where('a.attr > 10 && 'b.attr < 10)
172      .union(tr2.where('d.attr > 11 && 'e.attr < 11))
173      .analyze.constraints,
174      ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b))))
175  }
176
177  test("propagating constraints in intersect") {
178    val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
179    val tr2 = LocalRelation('a.int, 'b.int, 'c.int)
180
181    verifyConstraints(tr1
182      .where('a.attr > 10)
183      .intersect(tr2.where('b.attr < 100))
184      .analyze.constraints,
185      ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
186        resolveColumn(tr1, "b") < 100,
187        IsNotNull(resolveColumn(tr1, "a")),
188        IsNotNull(resolveColumn(tr1, "b")))))
189  }
190
191  test("propagating constraints in except") {
192    val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
193    val tr2 = LocalRelation('a.int, 'b.int, 'c.int)
194    verifyConstraints(tr1
195      .where('a.attr > 10)
196      .except(tr2.where('b.attr < 100))
197      .analyze.constraints,
198      ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
199        IsNotNull(resolveColumn(tr1, "a")))))
200  }
201
202  test("propagating constraints in inner join") {
203    val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
204    val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
205    verifyConstraints(tr1
206      .where('a.attr > 10)
207      .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr))
208      .analyze.constraints,
209      ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
210        tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
211        tr1.resolveQuoted("a", caseInsensitiveResolution).get ===
212          tr2.resolveQuoted("a", caseInsensitiveResolution).get,
213        tr2.resolveQuoted("a", caseInsensitiveResolution).get > 10,
214        IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
215        IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get),
216        IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))))
217  }
218
219  test("propagating constraints in left-semi join") {
220    val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
221    val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
222    verifyConstraints(tr1
223      .where('a.attr > 10)
224      .join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr))
225      .analyze.constraints,
226      ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
227        IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))))
228  }
229
230  test("propagating constraints in left-outer join") {
231    val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
232    val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
233    verifyConstraints(tr1
234      .where('a.attr > 10)
235      .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr))
236      .analyze.constraints,
237      ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
238        IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))))
239  }
240
241  test("propagating constraints in right-outer join") {
242    val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
243    val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
244    verifyConstraints(tr1
245      .where('a.attr > 10)
246      .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr))
247      .analyze.constraints,
248      ExpressionSet(Seq(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
249        IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))))
250  }
251
252  test("propagating constraints in full-outer join") {
253    val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
254    val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
255    assert(tr1.where('a.attr > 10)
256      .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr))
257      .analyze.constraints.isEmpty)
258  }
259
260  test("infer additional constraints in filters") {
261    val tr = LocalRelation('a.int, 'b.int, 'c.int)
262
263    verifyConstraints(tr
264      .where('a.attr > 10 && 'a.attr === 'b.attr)
265      .analyze.constraints,
266      ExpressionSet(Seq(resolveColumn(tr, "a") > 10,
267        resolveColumn(tr, "b") > 10,
268        resolveColumn(tr, "a") === resolveColumn(tr, "b"),
269        IsNotNull(resolveColumn(tr, "a")),
270        IsNotNull(resolveColumn(tr, "b")))))
271  }
272
273  test("infer constraints on cast") {
274    val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
275    verifyConstraints(
276      tr.where('a.attr === 'b.attr &&
277        'c.attr + 100 > 'd.attr &&
278        IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints,
279      ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"),
280        Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"),
281        IsNotNull(resolveColumn(tr, "a")),
282        IsNotNull(resolveColumn(tr, "b")),
283        IsNotNull(resolveColumn(tr, "c")),
284        IsNotNull(resolveColumn(tr, "d")),
285        IsNotNull(resolveColumn(tr, "e")),
286        IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType)))))
287  }
288
289  test("infer isnotnull constraints from compound expressions") {
290    val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
291    verifyConstraints(
292      tr.where('a.attr + 'b.attr === 'c.attr &&
293        IsNotNull(
294          Cast(
295            Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints,
296      ExpressionSet(Seq(
297        Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") ===
298          Cast(resolveColumn(tr, "c"), LongType),
299        IsNotNull(resolveColumn(tr, "a")),
300        IsNotNull(resolveColumn(tr, "b")),
301        IsNotNull(resolveColumn(tr, "c")),
302        IsNotNull(resolveColumn(tr, "e")),
303        IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType)))))
304
305    verifyConstraints(
306      tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints,
307      ExpressionSet(Seq(
308        Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) ===
309          Cast(resolveColumn(tr, "c"), LongType),
310        Cast(resolveColumn(tr, "d"), DoubleType) /
311          Cast(10, DoubleType) ===
312            Cast(resolveColumn(tr, "e"), DoubleType),
313        IsNotNull(resolveColumn(tr, "a")),
314        IsNotNull(resolveColumn(tr, "b")),
315        IsNotNull(resolveColumn(tr, "c")),
316        IsNotNull(resolveColumn(tr, "d")),
317        IsNotNull(resolveColumn(tr, "e")))))
318
319    verifyConstraints(
320      tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints,
321      ExpressionSet(Seq(
322        Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >=
323          Cast(resolveColumn(tr, "c"), LongType),
324        Cast(resolveColumn(tr, "d"), DoubleType) /
325          Cast(10, DoubleType) <
326            Cast(resolveColumn(tr, "e"), DoubleType),
327        IsNotNull(resolveColumn(tr, "a")),
328        IsNotNull(resolveColumn(tr, "b")),
329        IsNotNull(resolveColumn(tr, "c")),
330        IsNotNull(resolveColumn(tr, "d")),
331        IsNotNull(resolveColumn(tr, "e")))))
332
333    verifyConstraints(
334      tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints,
335      ExpressionSet(Seq(
336        (Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) -
337          (Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) >
338            Cast(resolveColumn(tr, "e") * 1000, LongType),
339        IsNotNull(resolveColumn(tr, "a")),
340        IsNotNull(resolveColumn(tr, "b")),
341        IsNotNull(resolveColumn(tr, "c")),
342        IsNotNull(resolveColumn(tr, "d")),
343        IsNotNull(resolveColumn(tr, "e")))))
344
345    // The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null.
346    verifyConstraints(
347      tr.where('a.attr === 'c.attr &&
348        IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints,
349      ExpressionSet(Seq(
350        resolveColumn(tr, "a") === resolveColumn(tr, "c"),
351        IsNotNull(IsNotNull(resolveColumn(tr, "b"))),
352        IsNotNull(resolveColumn(tr, "a")),
353        IsNotNull(resolveColumn(tr, "c")))))
354
355    verifyConstraints(
356      tr.where('a.attr === 1 && IsNotNull(resolveColumn(tr, "b")) &&
357        IsNotNull(resolveColumn(tr, "c"))).analyze.constraints,
358      ExpressionSet(Seq(
359        resolveColumn(tr, "a") === 1,
360        IsNotNull(resolveColumn(tr, "c")),
361        IsNotNull(resolveColumn(tr, "a")),
362        IsNotNull(resolveColumn(tr, "b")))))
363  }
364
365  test("infer IsNotNull constraints from non-nullable attributes") {
366    val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(),
367      AttributeReference("c", StringType, nullable = false)())
368
369    verifyConstraints(tr.analyze.constraints,
370      ExpressionSet(Seq(IsNotNull(resolveColumn(tr, "b")), IsNotNull(resolveColumn(tr, "c")))))
371  }
372
373  test("not infer non-deterministic constraints") {
374    val tr = LocalRelation('a.int, 'b.string, 'c.int)
375
376    verifyConstraints(tr
377      .where('a.attr === Rand(0))
378      .analyze.constraints,
379      ExpressionSet(Seq(IsNotNull(resolveColumn(tr, "a")))))
380
381    verifyConstraints(tr
382      .where('a.attr === InputFileName())
383      .where('a.attr =!= 'c.attr)
384      .analyze.constraints,
385      ExpressionSet(Seq(resolveColumn(tr, "a") =!= resolveColumn(tr, "c"),
386        IsNotNull(resolveColumn(tr, "a")),
387        IsNotNull(resolveColumn(tr, "c")))))
388  }
389}
390