1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark
19
20import scala.concurrent.duration._
21import scala.language.implicitConversions
22import scala.language.postfixOps
23
24import org.scalatest.Matchers
25import org.scalatest.concurrent.Eventually._
26
27import org.apache.spark.JobExecutionStatus._
28
29class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkContext {
30
31  test("basic status API usage") {
32    sc = new SparkContext("local", "test", new SparkConf(false))
33    val jobFuture = sc.parallelize(1 to 10000, 2).map(identity).groupBy(identity).collectAsync()
34    val jobId: Int = eventually(timeout(10 seconds)) {
35      val jobIds = jobFuture.jobIds
36      jobIds.size should be(1)
37      jobIds.head
38    }
39    val jobInfo = eventually(timeout(10 seconds)) {
40      sc.statusTracker.getJobInfo(jobId).get
41    }
42    jobInfo.status() should not be FAILED
43    val stageIds = jobInfo.stageIds()
44    stageIds.size should be(2)
45
46    val firstStageInfo = eventually(timeout(10 seconds)) {
47      sc.statusTracker.getStageInfo(stageIds(0)).get
48    }
49    firstStageInfo.stageId() should be(stageIds(0))
50    firstStageInfo.currentAttemptId() should be(0)
51    firstStageInfo.numTasks() should be(2)
52    eventually(timeout(10 seconds)) {
53      val updatedFirstStageInfo = sc.statusTracker.getStageInfo(stageIds(0)).get
54      updatedFirstStageInfo.numCompletedTasks() should be(2)
55      updatedFirstStageInfo.numActiveTasks() should be(0)
56      updatedFirstStageInfo.numFailedTasks() should be(0)
57    }
58  }
59
60  test("getJobIdsForGroup()") {
61    sc = new SparkContext("local", "test", new SparkConf(false))
62    // Passing `null` should return jobs that were not run in a job group:
63    val defaultJobGroupFuture = sc.parallelize(1 to 1000).countAsync()
64    val defaultJobGroupJobId = eventually(timeout(10 seconds)) {
65      defaultJobGroupFuture.jobIds.head
66    }
67    eventually(timeout(10 seconds)) {
68      sc.statusTracker.getJobIdsForGroup(null).toSet should be (Set(defaultJobGroupJobId))
69    }
70    // Test jobs submitted in job groups:
71    sc.setJobGroup("my-job-group", "description")
72    sc.statusTracker.getJobIdsForGroup("my-job-group") should be (Seq.empty)
73    val firstJobFuture = sc.parallelize(1 to 1000).countAsync()
74    val firstJobId = eventually(timeout(10 seconds)) {
75      firstJobFuture.jobIds.head
76    }
77    eventually(timeout(10 seconds)) {
78      sc.statusTracker.getJobIdsForGroup("my-job-group") should be (Seq(firstJobId))
79    }
80    val secondJobFuture = sc.parallelize(1 to 1000).countAsync()
81    val secondJobId = eventually(timeout(10 seconds)) {
82      secondJobFuture.jobIds.head
83    }
84    eventually(timeout(10 seconds)) {
85      sc.statusTracker.getJobIdsForGroup("my-job-group").toSet should be (
86        Set(firstJobId, secondJobId))
87    }
88  }
89
90  test("getJobIdsForGroup() with takeAsync()") {
91    sc = new SparkContext("local", "test", new SparkConf(false))
92    sc.setJobGroup("my-job-group2", "description")
93    sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty
94    val firstJobFuture = sc.parallelize(1 to 1000, 1).takeAsync(1)
95    val firstJobId = eventually(timeout(10 seconds)) {
96      firstJobFuture.jobIds.head
97    }
98    eventually(timeout(10 seconds)) {
99      sc.statusTracker.getJobIdsForGroup("my-job-group2") should be (Seq(firstJobId))
100    }
101  }
102
103  test("getJobIdsForGroup() with takeAsync() across multiple partitions") {
104    sc = new SparkContext("local", "test", new SparkConf(false))
105    sc.setJobGroup("my-job-group2", "description")
106    sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty
107    val firstJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999)
108    val firstJobId = eventually(timeout(10 seconds)) {
109      firstJobFuture.jobIds.head
110    }
111    eventually(timeout(10 seconds)) {
112      sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2
113    }
114  }
115}
116