1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18
19package org.apache.spark.util
20
21import java.util.concurrent.{CountDownLatch, TimeUnit}
22
23import scala.concurrent.Future
24import scala.concurrent.duration._
25import scala.util.Random
26
27import org.scalatest.concurrent.Eventually._
28
29import org.apache.spark.SparkFunSuite
30
31class ThreadUtilsSuite extends SparkFunSuite {
32
33  test("newDaemonSingleThreadExecutor") {
34    val executor = ThreadUtils.newDaemonSingleThreadExecutor("this-is-a-thread-name")
35    @volatile var threadName = ""
36    executor.submit(new Runnable {
37      override def run(): Unit = {
38        threadName = Thread.currentThread().getName()
39      }
40    })
41    executor.shutdown()
42    executor.awaitTermination(10, TimeUnit.SECONDS)
43    assert(threadName === "this-is-a-thread-name")
44  }
45
46  test("newDaemonSingleThreadScheduledExecutor") {
47    val executor = ThreadUtils.newDaemonSingleThreadScheduledExecutor("this-is-a-thread-name")
48    try {
49      val latch = new CountDownLatch(1)
50      @volatile var threadName = ""
51      executor.schedule(new Runnable {
52        override def run(): Unit = {
53          threadName = Thread.currentThread().getName()
54          latch.countDown()
55        }
56      }, 1, TimeUnit.MILLISECONDS)
57      latch.await(10, TimeUnit.SECONDS)
58      assert(threadName === "this-is-a-thread-name")
59    } finally {
60      executor.shutdownNow()
61    }
62  }
63
64  test("newDaemonCachedThreadPool") {
65    val maxThreadNumber = 10
66    val startThreadsLatch = new CountDownLatch(maxThreadNumber)
67    val latch = new CountDownLatch(1)
68    val cachedThreadPool = ThreadUtils.newDaemonCachedThreadPool(
69      "ThreadUtilsSuite-newDaemonCachedThreadPool",
70      maxThreadNumber,
71      keepAliveSeconds = 2)
72    try {
73      for (_ <- 1 to maxThreadNumber) {
74        cachedThreadPool.execute(new Runnable {
75          override def run(): Unit = {
76            startThreadsLatch.countDown()
77            latch.await(10, TimeUnit.SECONDS)
78          }
79        })
80      }
81      startThreadsLatch.await(10, TimeUnit.SECONDS)
82      assert(cachedThreadPool.getActiveCount === maxThreadNumber)
83      assert(cachedThreadPool.getQueue.size === 0)
84
85      // Submit a new task and it should be put into the queue since the thread number reaches the
86      // limitation
87      cachedThreadPool.execute(new Runnable {
88        override def run(): Unit = {
89          latch.await(10, TimeUnit.SECONDS)
90        }
91      })
92
93      assert(cachedThreadPool.getActiveCount === maxThreadNumber)
94      assert(cachedThreadPool.getQueue.size === 1)
95
96      latch.countDown()
97      eventually(timeout(10.seconds)) {
98        // All threads should be stopped after keepAliveSeconds
99        assert(cachedThreadPool.getActiveCount === 0)
100        assert(cachedThreadPool.getPoolSize === 0)
101      }
102    } finally {
103      cachedThreadPool.shutdownNow()
104    }
105  }
106
107  test("sameThread") {
108    val callerThreadName = Thread.currentThread().getName()
109    val f = Future {
110      Thread.currentThread().getName()
111    }(ThreadUtils.sameThread)
112    val futureThreadName = ThreadUtils.awaitResult(f, 10.seconds)
113    assert(futureThreadName === callerThreadName)
114  }
115
116  test("runInNewThread") {
117    import ThreadUtils._
118    assert(runInNewThread("thread-name") { Thread.currentThread().getName } === "thread-name")
119    assert(runInNewThread("thread-name") { Thread.currentThread().isDaemon } === true)
120    assert(
121      runInNewThread("thread-name", isDaemon = false) { Thread.currentThread().isDaemon } === false
122    )
123    val uniqueExceptionMessage = "test" + Random.nextInt()
124    val exception = intercept[IllegalArgumentException] {
125      runInNewThread("thread-name") { throw new IllegalArgumentException(uniqueExceptionMessage) }
126    }
127    assert(exception.asInstanceOf[IllegalArgumentException].getMessage === uniqueExceptionMessage)
128    assert(exception.getStackTrace.mkString("\n").contains(
129      "... run in separate thread using org.apache.spark.util.ThreadUtils ...") === true,
130      "stack trace does not contain expected place holder"
131    )
132    assert(exception.getStackTrace.mkString("\n").contains("ThreadUtils.scala") === false,
133      "stack trace contains unexpected references to ThreadUtils"
134    )
135  }
136}
137