1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.streaming
19
20import java.{util => ju}
21import java.text.SimpleDateFormat
22import java.util.Date
23
24import org.scalatest.BeforeAndAfter
25
26import org.apache.spark.internal.Logging
27import org.apache.spark.sql.AnalysisException
28import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
29import org.apache.spark.sql.execution.streaming._
30import org.apache.spark.sql.functions.{count, window}
31import org.apache.spark.sql.streaming.OutputMode._
32
33class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Logging {
34
35  import testImplicits._
36
37  after {
38    sqlContext.streams.active.foreach(_.stop())
39  }
40
41  test("error on bad column") {
42    val inputData = MemoryStream[Int].toDF()
43    val e = intercept[AnalysisException] {
44      inputData.withWatermark("badColumn", "1 minute")
45    }
46    assert(e.getMessage contains "badColumn")
47  }
48
49  test("error on wrong type") {
50    val inputData = MemoryStream[Int].toDF()
51    val e = intercept[AnalysisException] {
52      inputData.withWatermark("value", "1 minute")
53    }
54    assert(e.getMessage contains "value")
55    assert(e.getMessage contains "int")
56  }
57
58  test("event time and watermark metrics") {
59    // No event time metrics when there is no watermarking
60    val inputData1 = MemoryStream[Int]
61    val aggWithoutWatermark = inputData1.toDF()
62      .withColumn("eventTime", $"value".cast("timestamp"))
63      .groupBy(window($"eventTime", "5 seconds") as 'window)
64      .agg(count("*") as 'count)
65      .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
66
67    testStream(aggWithoutWatermark, outputMode = Complete)(
68      AddData(inputData1, 15),
69      CheckAnswer((15, 1)),
70      assertEventStats { e => assert(e.isEmpty) },
71      AddData(inputData1, 10, 12, 14),
72      CheckAnswer((10, 3), (15, 1)),
73      assertEventStats { e => assert(e.isEmpty) }
74    )
75
76    // All event time metrics where watermarking is set
77    val inputData2 = MemoryStream[Int]
78    val aggWithWatermark = inputData2.toDF()
79        .withColumn("eventTime", $"value".cast("timestamp"))
80        .withWatermark("eventTime", "10 seconds")
81        .groupBy(window($"eventTime", "5 seconds") as 'window)
82        .agg(count("*") as 'count)
83        .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
84
85    testStream(aggWithWatermark)(
86      AddData(inputData2, 15),
87      CheckAnswer(),
88      assertEventStats { e =>
89        assert(e.get("max") === formatTimestamp(15))
90        assert(e.get("min") === formatTimestamp(15))
91        assert(e.get("avg") === formatTimestamp(15))
92        assert(e.get("watermark") === formatTimestamp(0))
93      },
94      AddData(inputData2, 10, 12, 14),
95      CheckAnswer(),
96      assertEventStats { e =>
97        assert(e.get("max") === formatTimestamp(14))
98        assert(e.get("min") === formatTimestamp(10))
99        assert(e.get("avg") === formatTimestamp(12))
100        assert(e.get("watermark") === formatTimestamp(5))
101      },
102      AddData(inputData2, 25),
103      CheckAnswer(),
104      assertEventStats { e =>
105        assert(e.get("max") === formatTimestamp(25))
106        assert(e.get("min") === formatTimestamp(25))
107        assert(e.get("avg") === formatTimestamp(25))
108        assert(e.get("watermark") === formatTimestamp(5))
109      },
110      AddData(inputData2, 25),
111      CheckAnswer((10, 3)),
112      assertEventStats { e =>
113        assert(e.get("max") === formatTimestamp(25))
114        assert(e.get("min") === formatTimestamp(25))
115        assert(e.get("avg") === formatTimestamp(25))
116        assert(e.get("watermark") === formatTimestamp(15))
117      }
118    )
119  }
120
121  test("append mode") {
122    val inputData = MemoryStream[Int]
123
124    val windowedAggregation = inputData.toDF()
125      .withColumn("eventTime", $"value".cast("timestamp"))
126      .withWatermark("eventTime", "10 seconds")
127      .groupBy(window($"eventTime", "5 seconds") as 'window)
128      .agg(count("*") as 'count)
129      .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
130
131    testStream(windowedAggregation)(
132      AddData(inputData, 10, 11, 12, 13, 14, 15),
133      CheckLastBatch(),
134      AddData(inputData, 25),   // Advance watermark to 15 seconds
135      CheckLastBatch(),
136      assertNumStateRows(3),
137      AddData(inputData, 25),   // Emit items less than watermark and drop their state
138      CheckLastBatch((10, 5)),
139      assertNumStateRows(2),
140      AddData(inputData, 10),   // Should not emit anything as data less than watermark
141      CheckLastBatch(),
142      assertNumStateRows(2)
143    )
144  }
145
146  test("update mode") {
147    val inputData = MemoryStream[Int]
148    spark.conf.set("spark.sql.shuffle.partitions", "10")
149
150    val windowedAggregation = inputData.toDF()
151      .withColumn("eventTime", $"value".cast("timestamp"))
152      .withWatermark("eventTime", "10 seconds")
153      .groupBy(window($"eventTime", "5 seconds") as 'window)
154      .agg(count("*") as 'count)
155      .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
156
157    testStream(windowedAggregation, OutputMode.Update)(
158      AddData(inputData, 10, 11, 12, 13, 14, 15),
159      CheckLastBatch((10, 5), (15, 1)),
160      AddData(inputData, 25),     // Advance watermark to 15 seconds
161      CheckLastBatch((25, 1)),
162      assertNumStateRows(3),
163      AddData(inputData, 10, 25), // Ignore 10 as its less than watermark
164      CheckLastBatch((25, 2)),
165      assertNumStateRows(2),
166      AddData(inputData, 10),     // Should not emit anything as data less than watermark
167      CheckLastBatch(),
168      assertNumStateRows(2)
169    )
170  }
171
172  test("delay in months and years handled correctly") {
173    val currentTimeMs = System.currentTimeMillis
174    val currentTime = new Date(currentTimeMs)
175
176    val input = MemoryStream[Long]
177    val aggWithWatermark = input.toDF()
178      .withColumn("eventTime", $"value".cast("timestamp"))
179      .withWatermark("eventTime", "2 years 5 months")
180      .groupBy(window($"eventTime", "5 seconds") as 'window)
181      .agg(count("*") as 'count)
182      .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
183
184    def monthsSinceEpoch(date: Date): Int = { date.getYear * 12 + date.getMonth }
185
186    testStream(aggWithWatermark)(
187      AddData(input, currentTimeMs / 1000),
188      CheckAnswer(),
189      AddData(input, currentTimeMs / 1000),
190      CheckAnswer(),
191      assertEventStats { e =>
192        assert(timestampFormat.parse(e.get("max")).getTime === (currentTimeMs / 1000) * 1000)
193        val watermarkTime = timestampFormat.parse(e.get("watermark"))
194        val monthDiff = monthsSinceEpoch(currentTime) - monthsSinceEpoch(watermarkTime)
195        // monthsSinceEpoch is like `math.floor(num)`, so monthDiff has two possible values.
196        assert(monthDiff === 29 || monthDiff === 30,
197          s"currentTime: $currentTime, watermarkTime: $watermarkTime")
198      }
199    )
200  }
201
202  test("recovery") {
203    val inputData = MemoryStream[Int]
204    val df = inputData.toDF()
205      .withColumn("eventTime", $"value".cast("timestamp"))
206      .withWatermark("eventTime", "10 seconds")
207      .groupBy(window($"eventTime", "5 seconds") as 'window)
208      .agg(count("*") as 'count)
209      .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
210
211    testStream(df)(
212      AddData(inputData, 10, 11, 12, 13, 14, 15),
213      CheckLastBatch(),
214      AddData(inputData, 25), // Advance watermark to 15 seconds
215      StopStream,
216      StartStream(),
217      CheckLastBatch(),
218      AddData(inputData, 25), // Evict items less than previous watermark.
219      CheckLastBatch((10, 5)),
220      StopStream,
221      AssertOnQuery { q => // clear the sink
222        q.sink.asInstanceOf[MemorySink].clear()
223        true
224      },
225      StartStream(),
226      CheckLastBatch((10, 5)), // Recompute last batch and re-evict timestamp 10
227      AddData(inputData, 30), // Advance watermark to 20 seconds
228      CheckLastBatch(),
229      StopStream,
230      StartStream(), // Watermark should still be 15 seconds
231      AddData(inputData, 17),
232      CheckLastBatch(), // We still do not see next batch
233      AddData(inputData, 30), // Advance watermark to 20 seconds
234      CheckLastBatch(),
235      AddData(inputData, 30), // Evict items less than previous watermark.
236      CheckLastBatch((15, 2)) // Ensure we see next window
237    )
238  }
239
240  test("dropping old data") {
241    val inputData = MemoryStream[Int]
242
243    val windowedAggregation = inputData.toDF()
244        .withColumn("eventTime", $"value".cast("timestamp"))
245        .withWatermark("eventTime", "10 seconds")
246        .groupBy(window($"eventTime", "5 seconds") as 'window)
247        .agg(count("*") as 'count)
248        .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
249
250    testStream(windowedAggregation)(
251      AddData(inputData, 10, 11, 12),
252      CheckAnswer(),
253      AddData(inputData, 25),     // Advance watermark to 15 seconds
254      CheckAnswer(),
255      AddData(inputData, 25),     // Evict items less than previous watermark.
256      CheckAnswer((10, 3)),
257      AddData(inputData, 10),     // 10 is later than 15 second watermark
258      CheckAnswer((10, 3)),
259      AddData(inputData, 25),
260      CheckAnswer((10, 3))        // Should not emit an incorrect partial result.
261    )
262  }
263
264  test("complete mode") {
265    val inputData = MemoryStream[Int]
266
267    val windowedAggregation = inputData.toDF()
268        .withColumn("eventTime", $"value".cast("timestamp"))
269        .withWatermark("eventTime", "10 seconds")
270        .groupBy(window($"eventTime", "5 seconds") as 'window)
271        .agg(count("*") as 'count)
272        .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
273
274    // No eviction when asked to compute complete results.
275    testStream(windowedAggregation, OutputMode.Complete)(
276      AddData(inputData, 10, 11, 12),
277      CheckAnswer((10, 3)),
278      AddData(inputData, 25),
279      CheckAnswer((10, 3), (25, 1)),
280      AddData(inputData, 25),
281      CheckAnswer((10, 3), (25, 2)),
282      AddData(inputData, 10),
283      CheckAnswer((10, 4), (25, 2)),
284      AddData(inputData, 25),
285      CheckAnswer((10, 4), (25, 3))
286    )
287  }
288
289  test("group by on raw timestamp") {
290    val inputData = MemoryStream[Int]
291
292    val windowedAggregation = inputData.toDF()
293        .withColumn("eventTime", $"value".cast("timestamp"))
294        .withWatermark("eventTime", "10 seconds")
295        .groupBy($"eventTime")
296        .agg(count("*") as 'count)
297        .select($"eventTime".cast("long").as[Long], $"count".as[Long])
298
299    testStream(windowedAggregation)(
300      AddData(inputData, 10),
301      CheckAnswer(),
302      AddData(inputData, 25), // Advance watermark to 15 seconds
303      CheckAnswer(),
304      AddData(inputData, 25), // Evict items less than previous watermark.
305      CheckAnswer((10, 1))
306    )
307  }
308
309  test("delay threshold should not be negative.") {
310    val inputData = MemoryStream[Int].toDF()
311    var e = intercept[IllegalArgumentException] {
312      inputData.withWatermark("value", "-1 year")
313    }
314    assert(e.getMessage contains "should not be negative.")
315
316    e = intercept[IllegalArgumentException] {
317      inputData.withWatermark("value", "1 year -13 months")
318    }
319    assert(e.getMessage contains "should not be negative.")
320
321    e = intercept[IllegalArgumentException] {
322      inputData.withWatermark("value", "1 month -40 days")
323    }
324    assert(e.getMessage contains "should not be negative.")
325
326    e = intercept[IllegalArgumentException] {
327      inputData.withWatermark("value", "-10 seconds")
328    }
329    assert(e.getMessage contains "should not be negative.")
330  }
331
332  test("the new watermark should override the old one") {
333    val df = MemoryStream[(Long, Long)].toDF()
334      .withColumn("first", $"_1".cast("timestamp"))
335      .withColumn("second", $"_2".cast("timestamp"))
336      .withWatermark("first", "1 minute")
337      .withWatermark("second", "2 minutes")
338
339    val eventTimeColumns = df.logicalPlan.output
340      .filter(_.metadata.contains(EventTimeWatermark.delayKey))
341    assert(eventTimeColumns.size === 1)
342    assert(eventTimeColumns(0).name === "second")
343  }
344
345  private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q =>
346    val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
347    assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows)
348    true
349  }
350
351  private def assertEventStats(body: ju.Map[String, String] => Unit): AssertOnQuery = {
352    AssertOnQuery { q =>
353      body(q.recentProgress.filter(_.numInputRows > 0).lastOption.get.eventTime)
354      true
355    }
356  }
357
358  private val timestampFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601
359  timestampFormat.setTimeZone(ju.TimeZone.getTimeZone("UTC"))
360
361  private def formatTimestamp(sec: Long): String = {
362    timestampFormat.format(new ju.Date(sec * 1000))
363  }
364}
365