1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql.streaming 19 20import java.util.TimeZone 21 22import org.scalatest.BeforeAndAfterAll 23 24import org.apache.spark.SparkException 25import org.apache.spark.sql.AnalysisException 26import org.apache.spark.sql.catalyst.util.DateTimeUtils 27import org.apache.spark.sql.execution.SparkPlan 28import org.apache.spark.sql.execution.streaming._ 29import org.apache.spark.sql.execution.streaming.state.StateStore 30import org.apache.spark.sql.expressions.scalalang.typed 31import org.apache.spark.sql.functions._ 32import org.apache.spark.sql.streaming.OutputMode._ 33 34object FailureSinglton { 35 var firstTime = true 36} 37 38class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll { 39 40 override def afterAll(): Unit = { 41 super.afterAll() 42 StateStore.stop() 43 } 44 45 import testImplicits._ 46 47 test("simple count, update mode") { 48 val inputData = MemoryStream[Int] 49 50 val aggregated = 51 inputData.toDF() 52 .groupBy($"value") 53 .agg(count("*")) 54 .as[(Int, Long)] 55 56 testStream(aggregated, Update)( 57 AddData(inputData, 3), 58 CheckLastBatch((3, 1)), 59 AddData(inputData, 3, 2), 60 CheckLastBatch((3, 2), (2, 1)), 61 StopStream, 62 StartStream(), 63 AddData(inputData, 3, 2, 1), 64 CheckLastBatch((3, 3), (2, 2), (1, 1)), 65 // By default we run in new tuple mode. 66 AddData(inputData, 4, 4, 4, 4), 67 CheckLastBatch((4, 4)) 68 ) 69 } 70 71 test("simple count, complete mode") { 72 val inputData = MemoryStream[Int] 73 74 val aggregated = 75 inputData.toDF() 76 .groupBy($"value") 77 .agg(count("*")) 78 .as[(Int, Long)] 79 80 testStream(aggregated, Complete)( 81 AddData(inputData, 3), 82 CheckLastBatch((3, 1)), 83 AddData(inputData, 2), 84 CheckLastBatch((3, 1), (2, 1)), 85 StopStream, 86 StartStream(), 87 AddData(inputData, 3, 2, 1), 88 CheckLastBatch((3, 2), (2, 2), (1, 1)), 89 AddData(inputData, 4, 4, 4, 4), 90 CheckLastBatch((4, 4), (3, 2), (2, 2), (1, 1)) 91 ) 92 } 93 94 test("simple count, append mode") { 95 val inputData = MemoryStream[Int] 96 97 val aggregated = 98 inputData.toDF() 99 .groupBy($"value") 100 .agg(count("*")) 101 .as[(Int, Long)] 102 103 val e = intercept[AnalysisException] { 104 testStream(aggregated, Append)() 105 } 106 Seq("append", "not supported").foreach { m => 107 assert(e.getMessage.toLowerCase.contains(m.toLowerCase)) 108 } 109 } 110 111 test("sort after aggregate in complete mode") { 112 val inputData = MemoryStream[Int] 113 114 val aggregated = 115 inputData.toDF() 116 .groupBy($"value") 117 .agg(count("*")) 118 .toDF("value", "count") 119 .orderBy($"count".desc) 120 .as[(Int, Long)] 121 122 testStream(aggregated, Complete)( 123 AddData(inputData, 3), 124 CheckLastBatch(isSorted = true, (3, 1)), 125 AddData(inputData, 2, 3), 126 CheckLastBatch(isSorted = true, (3, 2), (2, 1)), 127 StopStream, 128 StartStream(), 129 AddData(inputData, 3, 2, 1), 130 CheckLastBatch(isSorted = true, (3, 3), (2, 2), (1, 1)), 131 AddData(inputData, 4, 4, 4, 4), 132 CheckLastBatch(isSorted = true, (4, 4), (3, 3), (2, 2), (1, 1)) 133 ) 134 } 135 136 test("state metrics") { 137 val inputData = MemoryStream[Int] 138 139 val aggregated = 140 inputData.toDS() 141 .flatMap(x => Seq(x, x + 1)) 142 .toDF("value") 143 .groupBy($"value") 144 .agg(count("*")) 145 .as[(Int, Long)] 146 147 implicit class RichStreamExecution(query: StreamExecution) { 148 def stateNodes: Seq[SparkPlan] = { 149 query.lastExecution.executedPlan.collect { 150 case p if p.isInstanceOf[StateStoreSaveExec] => p 151 } 152 } 153 } 154 155 // Test with Update mode 156 testStream(aggregated, Update)( 157 AddData(inputData, 1), 158 CheckLastBatch((1, 1), (2, 1)), 159 AssertOnQuery { _.stateNodes.size === 1 }, 160 AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 2 }, 161 AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 2 }, 162 AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 2 }, 163 AddData(inputData, 2, 3), 164 CheckLastBatch((2, 2), (3, 2), (4, 1)), 165 AssertOnQuery { _.stateNodes.size === 1 }, 166 AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 3 }, 167 AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 3 }, 168 AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 4 } 169 ) 170 171 // Test with Complete mode 172 inputData.reset() 173 testStream(aggregated, Complete)( 174 AddData(inputData, 1), 175 CheckLastBatch((1, 1), (2, 1)), 176 AssertOnQuery { _.stateNodes.size === 1 }, 177 AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 2 }, 178 AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 2 }, 179 AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 2 }, 180 AddData(inputData, 2, 3), 181 CheckLastBatch((1, 1), (2, 2), (3, 2), (4, 1)), 182 AssertOnQuery { _.stateNodes.size === 1 }, 183 AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 4 }, 184 AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 3 }, 185 AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 4 } 186 ) 187 } 188 189 test("multiple keys") { 190 val inputData = MemoryStream[Int] 191 192 val aggregated = 193 inputData.toDF() 194 .groupBy($"value", $"value" + 1) 195 .agg(count("*")) 196 .as[(Int, Int, Long)] 197 198 testStream(aggregated, Update)( 199 AddData(inputData, 1, 2), 200 CheckLastBatch((1, 2, 1), (2, 3, 1)), 201 AddData(inputData, 1, 2), 202 CheckLastBatch((1, 2, 2), (2, 3, 2)) 203 ) 204 } 205 206 testQuietly("midbatch failure") { 207 val inputData = MemoryStream[Int] 208 FailureSinglton.firstTime = true 209 val aggregated = 210 inputData.toDS() 211 .map { i => 212 if (i == 4 && FailureSinglton.firstTime) { 213 FailureSinglton.firstTime = false 214 sys.error("injected failure") 215 } 216 217 i 218 } 219 .groupBy($"value") 220 .agg(count("*")) 221 .as[(Int, Long)] 222 223 testStream(aggregated, Update)( 224 StartStream(), 225 AddData(inputData, 1, 2, 3, 4), 226 ExpectFailure[SparkException](), 227 StartStream(), 228 CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1)) 229 ) 230 } 231 232 test("typed aggregators") { 233 val inputData = MemoryStream[(String, Int)] 234 val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) 235 236 testStream(aggregated, Update)( 237 AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)), 238 CheckLastBatch(("a", 30), ("b", 3), ("c", 1)) 239 ) 240 } 241 242 test("prune results by current_time, complete mode") { 243 import testImplicits._ 244 val clock = new StreamManualClock 245 val inputData = MemoryStream[Long] 246 val aggregated = 247 inputData.toDF() 248 .groupBy($"value") 249 .agg(count("*")) 250 .where('value >= current_timestamp().cast("long") - 10L) 251 252 testStream(aggregated, Complete)( 253 StartStream(ProcessingTime("10 seconds"), triggerClock = clock), 254 255 // advance clock to 10 seconds, all keys retained 256 AddData(inputData, 0L, 5L, 5L, 10L), 257 AdvanceManualClock(10 * 1000), 258 CheckLastBatch((0L, 1), (5L, 2), (10L, 1)), 259 260 // advance clock to 20 seconds, should retain keys >= 10 261 AddData(inputData, 15L, 15L, 20L), 262 AdvanceManualClock(10 * 1000), 263 CheckLastBatch((10L, 1), (15L, 2), (20L, 1)), 264 265 // advance clock to 30 seconds, should retain keys >= 20 266 AddData(inputData, 0L, 85L), 267 AdvanceManualClock(10 * 1000), 268 CheckLastBatch((20L, 1), (85L, 1)), 269 270 // bounce stream and ensure correct batch timestamp is used 271 // i.e., we don't take it from the clock, which is at 90 seconds. 272 StopStream, 273 AssertOnQuery { q => // clear the sink 274 q.sink.asInstanceOf[MemorySink].clear() 275 // advance by a minute i.e., 90 seconds total 276 clock.advance(60 * 1000L) 277 true 278 }, 279 StartStream(ProcessingTime("10 seconds"), triggerClock = clock), 280 CheckLastBatch((20L, 1), (85L, 1)), 281 AssertOnQuery { q => 282 clock.getTimeMillis() == 90000L 283 }, 284 285 // advance clock to 100 seconds, should retain keys >= 90 286 AddData(inputData, 85L, 90L, 100L, 105L), 287 AdvanceManualClock(10 * 1000), 288 CheckLastBatch((90L, 1), (100L, 1), (105L, 1)) 289 ) 290 } 291 292 test("prune results by current_date, complete mode") { 293 import testImplicits._ 294 val clock = new StreamManualClock 295 val tz = TimeZone.getDefault.getID 296 val inputData = MemoryStream[Long] 297 val aggregated = 298 inputData.toDF() 299 .select(to_utc_timestamp(from_unixtime('value * DateTimeUtils.SECONDS_PER_DAY), tz)) 300 .toDF("value") 301 .groupBy($"value") 302 .agg(count("*")) 303 .where($"value".cast("date") >= date_sub(current_date(), 10)) 304 .select(($"value".cast("long") / DateTimeUtils.SECONDS_PER_DAY).cast("long"), $"count(1)") 305 testStream(aggregated, Complete)( 306 StartStream(ProcessingTime("10 day"), triggerClock = clock), 307 // advance clock to 10 days, should retain all keys 308 AddData(inputData, 0L, 5L, 5L, 10L), 309 AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10), 310 CheckLastBatch((0L, 1), (5L, 2), (10L, 1)), 311 // advance clock to 20 days, should retain keys >= 10 312 AddData(inputData, 15L, 15L, 20L), 313 AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10), 314 CheckLastBatch((10L, 1), (15L, 2), (20L, 1)), 315 // advance clock to 30 days, should retain keys >= 20 316 AddData(inputData, 85L), 317 AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10), 318 CheckLastBatch((20L, 1), (85L, 1)), 319 320 // bounce stream and ensure correct batch timestamp is used 321 // i.e., we don't take it from the clock, which is at 90 days. 322 StopStream, 323 AssertOnQuery { q => // clear the sink 324 q.sink.asInstanceOf[MemorySink].clear() 325 // advance by 60 days i.e., 90 days total 326 clock.advance(DateTimeUtils.MILLIS_PER_DAY * 60) 327 true 328 }, 329 StartStream(ProcessingTime("10 day"), triggerClock = clock), 330 CheckLastBatch((20L, 1), (85L, 1)), 331 332 // advance clock to 100 days, should retain keys >= 90 333 AddData(inputData, 85L, 90L, 100L, 105L), 334 AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10), 335 CheckLastBatch((90L, 1), (100L, 1), (105L, 1)) 336 ) 337 } 338} 339