1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.streaming
19
20import java.util.TimeZone
21
22import org.scalatest.BeforeAndAfterAll
23
24import org.apache.spark.SparkException
25import org.apache.spark.sql.AnalysisException
26import org.apache.spark.sql.catalyst.util.DateTimeUtils
27import org.apache.spark.sql.execution.SparkPlan
28import org.apache.spark.sql.execution.streaming._
29import org.apache.spark.sql.execution.streaming.state.StateStore
30import org.apache.spark.sql.expressions.scalalang.typed
31import org.apache.spark.sql.functions._
32import org.apache.spark.sql.streaming.OutputMode._
33
34object FailureSinglton {
35  var firstTime = true
36}
37
38class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll {
39
40  override def afterAll(): Unit = {
41    super.afterAll()
42    StateStore.stop()
43  }
44
45  import testImplicits._
46
47  test("simple count, update mode") {
48    val inputData = MemoryStream[Int]
49
50    val aggregated =
51      inputData.toDF()
52        .groupBy($"value")
53        .agg(count("*"))
54        .as[(Int, Long)]
55
56    testStream(aggregated, Update)(
57      AddData(inputData, 3),
58      CheckLastBatch((3, 1)),
59      AddData(inputData, 3, 2),
60      CheckLastBatch((3, 2), (2, 1)),
61      StopStream,
62      StartStream(),
63      AddData(inputData, 3, 2, 1),
64      CheckLastBatch((3, 3), (2, 2), (1, 1)),
65      // By default we run in new tuple mode.
66      AddData(inputData, 4, 4, 4, 4),
67      CheckLastBatch((4, 4))
68    )
69  }
70
71  test("simple count, complete mode") {
72    val inputData = MemoryStream[Int]
73
74    val aggregated =
75      inputData.toDF()
76        .groupBy($"value")
77        .agg(count("*"))
78        .as[(Int, Long)]
79
80    testStream(aggregated, Complete)(
81      AddData(inputData, 3),
82      CheckLastBatch((3, 1)),
83      AddData(inputData, 2),
84      CheckLastBatch((3, 1), (2, 1)),
85      StopStream,
86      StartStream(),
87      AddData(inputData, 3, 2, 1),
88      CheckLastBatch((3, 2), (2, 2), (1, 1)),
89      AddData(inputData, 4, 4, 4, 4),
90      CheckLastBatch((4, 4), (3, 2), (2, 2), (1, 1))
91    )
92  }
93
94  test("simple count, append mode") {
95    val inputData = MemoryStream[Int]
96
97    val aggregated =
98      inputData.toDF()
99        .groupBy($"value")
100        .agg(count("*"))
101        .as[(Int, Long)]
102
103    val e = intercept[AnalysisException] {
104      testStream(aggregated, Append)()
105    }
106    Seq("append", "not supported").foreach { m =>
107      assert(e.getMessage.toLowerCase.contains(m.toLowerCase))
108    }
109  }
110
111  test("sort after aggregate in complete mode") {
112    val inputData = MemoryStream[Int]
113
114    val aggregated =
115      inputData.toDF()
116        .groupBy($"value")
117        .agg(count("*"))
118        .toDF("value", "count")
119        .orderBy($"count".desc)
120        .as[(Int, Long)]
121
122    testStream(aggregated, Complete)(
123      AddData(inputData, 3),
124      CheckLastBatch(isSorted = true, (3, 1)),
125      AddData(inputData, 2, 3),
126      CheckLastBatch(isSorted = true, (3, 2), (2, 1)),
127      StopStream,
128      StartStream(),
129      AddData(inputData, 3, 2, 1),
130      CheckLastBatch(isSorted = true, (3, 3), (2, 2), (1, 1)),
131      AddData(inputData, 4, 4, 4, 4),
132      CheckLastBatch(isSorted = true, (4, 4), (3, 3), (2, 2), (1, 1))
133    )
134  }
135
136  test("state metrics") {
137    val inputData = MemoryStream[Int]
138
139    val aggregated =
140      inputData.toDS()
141        .flatMap(x => Seq(x, x + 1))
142        .toDF("value")
143        .groupBy($"value")
144        .agg(count("*"))
145        .as[(Int, Long)]
146
147    implicit class RichStreamExecution(query: StreamExecution) {
148      def stateNodes: Seq[SparkPlan] = {
149        query.lastExecution.executedPlan.collect {
150          case p if p.isInstanceOf[StateStoreSaveExec] => p
151        }
152      }
153    }
154
155    // Test with Update mode
156    testStream(aggregated, Update)(
157      AddData(inputData, 1),
158      CheckLastBatch((1, 1), (2, 1)),
159      AssertOnQuery { _.stateNodes.size === 1 },
160      AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 2 },
161      AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 2 },
162      AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 2 },
163      AddData(inputData, 2, 3),
164      CheckLastBatch((2, 2), (3, 2), (4, 1)),
165      AssertOnQuery { _.stateNodes.size === 1 },
166      AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 3 },
167      AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 3 },
168      AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 4 }
169    )
170
171    // Test with Complete mode
172    inputData.reset()
173    testStream(aggregated, Complete)(
174      AddData(inputData, 1),
175      CheckLastBatch((1, 1), (2, 1)),
176      AssertOnQuery { _.stateNodes.size === 1 },
177      AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 2 },
178      AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 2 },
179      AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 2 },
180      AddData(inputData, 2, 3),
181      CheckLastBatch((1, 1), (2, 2), (3, 2), (4, 1)),
182      AssertOnQuery { _.stateNodes.size === 1 },
183      AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 4 },
184      AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 3 },
185      AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 4 }
186    )
187  }
188
189  test("multiple keys") {
190    val inputData = MemoryStream[Int]
191
192    val aggregated =
193      inputData.toDF()
194        .groupBy($"value", $"value" + 1)
195        .agg(count("*"))
196        .as[(Int, Int, Long)]
197
198    testStream(aggregated, Update)(
199      AddData(inputData, 1, 2),
200      CheckLastBatch((1, 2, 1), (2, 3, 1)),
201      AddData(inputData, 1, 2),
202      CheckLastBatch((1, 2, 2), (2, 3, 2))
203    )
204  }
205
206  testQuietly("midbatch failure") {
207    val inputData = MemoryStream[Int]
208    FailureSinglton.firstTime = true
209    val aggregated =
210      inputData.toDS()
211          .map { i =>
212            if (i == 4 && FailureSinglton.firstTime) {
213              FailureSinglton.firstTime = false
214              sys.error("injected failure")
215            }
216
217            i
218          }
219          .groupBy($"value")
220          .agg(count("*"))
221          .as[(Int, Long)]
222
223    testStream(aggregated, Update)(
224      StartStream(),
225      AddData(inputData, 1, 2, 3, 4),
226      ExpectFailure[SparkException](),
227      StartStream(),
228      CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1))
229    )
230  }
231
232  test("typed aggregators") {
233    val inputData = MemoryStream[(String, Int)]
234    val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2))
235
236    testStream(aggregated, Update)(
237      AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)),
238      CheckLastBatch(("a", 30), ("b", 3), ("c", 1))
239    )
240  }
241
242  test("prune results by current_time, complete mode") {
243    import testImplicits._
244    val clock = new StreamManualClock
245    val inputData = MemoryStream[Long]
246    val aggregated =
247      inputData.toDF()
248        .groupBy($"value")
249        .agg(count("*"))
250        .where('value >= current_timestamp().cast("long") - 10L)
251
252    testStream(aggregated, Complete)(
253      StartStream(ProcessingTime("10 seconds"), triggerClock = clock),
254
255      // advance clock to 10 seconds, all keys retained
256      AddData(inputData, 0L, 5L, 5L, 10L),
257      AdvanceManualClock(10 * 1000),
258      CheckLastBatch((0L, 1), (5L, 2), (10L, 1)),
259
260      // advance clock to 20 seconds, should retain keys >= 10
261      AddData(inputData, 15L, 15L, 20L),
262      AdvanceManualClock(10 * 1000),
263      CheckLastBatch((10L, 1), (15L, 2), (20L, 1)),
264
265      // advance clock to 30 seconds, should retain keys >= 20
266      AddData(inputData, 0L, 85L),
267      AdvanceManualClock(10 * 1000),
268      CheckLastBatch((20L, 1), (85L, 1)),
269
270      // bounce stream and ensure correct batch timestamp is used
271      // i.e., we don't take it from the clock, which is at 90 seconds.
272      StopStream,
273      AssertOnQuery { q => // clear the sink
274        q.sink.asInstanceOf[MemorySink].clear()
275        // advance by a minute i.e., 90 seconds total
276        clock.advance(60 * 1000L)
277        true
278      },
279      StartStream(ProcessingTime("10 seconds"), triggerClock = clock),
280      CheckLastBatch((20L, 1), (85L, 1)),
281      AssertOnQuery { q =>
282        clock.getTimeMillis() == 90000L
283      },
284
285      // advance clock to 100 seconds, should retain keys >= 90
286      AddData(inputData, 85L, 90L, 100L, 105L),
287      AdvanceManualClock(10 * 1000),
288      CheckLastBatch((90L, 1), (100L, 1), (105L, 1))
289    )
290  }
291
292  test("prune results by current_date, complete mode") {
293    import testImplicits._
294    val clock = new StreamManualClock
295    val tz = TimeZone.getDefault.getID
296    val inputData = MemoryStream[Long]
297    val aggregated =
298      inputData.toDF()
299        .select(to_utc_timestamp(from_unixtime('value * DateTimeUtils.SECONDS_PER_DAY), tz))
300        .toDF("value")
301        .groupBy($"value")
302        .agg(count("*"))
303        .where($"value".cast("date") >= date_sub(current_date(), 10))
304        .select(($"value".cast("long") / DateTimeUtils.SECONDS_PER_DAY).cast("long"), $"count(1)")
305    testStream(aggregated, Complete)(
306      StartStream(ProcessingTime("10 day"), triggerClock = clock),
307      // advance clock to 10 days, should retain all keys
308      AddData(inputData, 0L, 5L, 5L, 10L),
309      AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10),
310      CheckLastBatch((0L, 1), (5L, 2), (10L, 1)),
311      // advance clock to 20 days, should retain keys >= 10
312      AddData(inputData, 15L, 15L, 20L),
313      AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10),
314      CheckLastBatch((10L, 1), (15L, 2), (20L, 1)),
315      // advance clock to 30 days, should retain keys >= 20
316      AddData(inputData, 85L),
317      AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10),
318      CheckLastBatch((20L, 1), (85L, 1)),
319
320      // bounce stream and ensure correct batch timestamp is used
321      // i.e., we don't take it from the clock, which is at 90 days.
322      StopStream,
323      AssertOnQuery { q => // clear the sink
324        q.sink.asInstanceOf[MemorySink].clear()
325        // advance by 60 days i.e., 90 days total
326        clock.advance(DateTimeUtils.MILLIS_PER_DAY * 60)
327        true
328      },
329      StartStream(ProcessingTime("10 day"), triggerClock = clock),
330      CheckLastBatch((20L, 1), (85L, 1)),
331
332      // advance clock to 100 days, should retain keys >= 90
333      AddData(inputData, 85L, 90L, 100L, 105L),
334      AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10),
335      CheckLastBatch((90L, 1), (100L, 1), (105L, 1))
336    )
337  }
338}
339