1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 19package org.apache.spark.util 20 21import java.util.concurrent.{CountDownLatch, TimeUnit} 22 23import scala.concurrent.Future 24import scala.concurrent.duration._ 25import scala.util.Random 26 27import org.scalatest.concurrent.Eventually._ 28 29import org.apache.spark.SparkFunSuite 30 31class ThreadUtilsSuite extends SparkFunSuite { 32 33 test("newDaemonSingleThreadExecutor") { 34 val executor = ThreadUtils.newDaemonSingleThreadExecutor("this-is-a-thread-name") 35 @volatile var threadName = "" 36 executor.submit(new Runnable { 37 override def run(): Unit = { 38 threadName = Thread.currentThread().getName() 39 } 40 }) 41 executor.shutdown() 42 executor.awaitTermination(10, TimeUnit.SECONDS) 43 assert(threadName === "this-is-a-thread-name") 44 } 45 46 test("newDaemonSingleThreadScheduledExecutor") { 47 val executor = ThreadUtils.newDaemonSingleThreadScheduledExecutor("this-is-a-thread-name") 48 try { 49 val latch = new CountDownLatch(1) 50 @volatile var threadName = "" 51 executor.schedule(new Runnable { 52 override def run(): Unit = { 53 threadName = Thread.currentThread().getName() 54 latch.countDown() 55 } 56 }, 1, TimeUnit.MILLISECONDS) 57 latch.await(10, TimeUnit.SECONDS) 58 assert(threadName === "this-is-a-thread-name") 59 } finally { 60 executor.shutdownNow() 61 } 62 } 63 64 test("newDaemonCachedThreadPool") { 65 val maxThreadNumber = 10 66 val startThreadsLatch = new CountDownLatch(maxThreadNumber) 67 val latch = new CountDownLatch(1) 68 val cachedThreadPool = ThreadUtils.newDaemonCachedThreadPool( 69 "ThreadUtilsSuite-newDaemonCachedThreadPool", 70 maxThreadNumber, 71 keepAliveSeconds = 2) 72 try { 73 for (_ <- 1 to maxThreadNumber) { 74 cachedThreadPool.execute(new Runnable { 75 override def run(): Unit = { 76 startThreadsLatch.countDown() 77 latch.await(10, TimeUnit.SECONDS) 78 } 79 }) 80 } 81 startThreadsLatch.await(10, TimeUnit.SECONDS) 82 assert(cachedThreadPool.getActiveCount === maxThreadNumber) 83 assert(cachedThreadPool.getQueue.size === 0) 84 85 // Submit a new task and it should be put into the queue since the thread number reaches the 86 // limitation 87 cachedThreadPool.execute(new Runnable { 88 override def run(): Unit = { 89 latch.await(10, TimeUnit.SECONDS) 90 } 91 }) 92 93 assert(cachedThreadPool.getActiveCount === maxThreadNumber) 94 assert(cachedThreadPool.getQueue.size === 1) 95 96 latch.countDown() 97 eventually(timeout(10.seconds)) { 98 // All threads should be stopped after keepAliveSeconds 99 assert(cachedThreadPool.getActiveCount === 0) 100 assert(cachedThreadPool.getPoolSize === 0) 101 } 102 } finally { 103 cachedThreadPool.shutdownNow() 104 } 105 } 106 107 test("sameThread") { 108 val callerThreadName = Thread.currentThread().getName() 109 val f = Future { 110 Thread.currentThread().getName() 111 }(ThreadUtils.sameThread) 112 val futureThreadName = ThreadUtils.awaitResult(f, 10.seconds) 113 assert(futureThreadName === callerThreadName) 114 } 115 116 test("runInNewThread") { 117 import ThreadUtils._ 118 assert(runInNewThread("thread-name") { Thread.currentThread().getName } === "thread-name") 119 assert(runInNewThread("thread-name") { Thread.currentThread().isDaemon } === true) 120 assert( 121 runInNewThread("thread-name", isDaemon = false) { Thread.currentThread().isDaemon } === false 122 ) 123 val uniqueExceptionMessage = "test" + Random.nextInt() 124 val exception = intercept[IllegalArgumentException] { 125 runInNewThread("thread-name") { throw new IllegalArgumentException(uniqueExceptionMessage) } 126 } 127 assert(exception.asInstanceOf[IllegalArgumentException].getMessage === uniqueExceptionMessage) 128 assert(exception.getStackTrace.mkString("\n").contains( 129 "... run in separate thread using org.apache.spark.util.ThreadUtils ...") === true, 130 "stack trace does not contain expected place holder" 131 ) 132 assert(exception.getStackTrace.mkString("\n").contains("ThreadUtils.scala") === false, 133 "stack trace contains unexpected references to ThreadUtils" 134 ) 135 } 136} 137