1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.util 19 20import org.apache.spark.SparkConf 21import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout} 22 23private[spark] object RpcUtils { 24 25 /** 26 * Retrieve a `RpcEndpointRef` which is located in the driver via its name. 27 */ 28 def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { 29 val driverHost: String = conf.get("spark.driver.host", "localhost") 30 val driverPort: Int = conf.getInt("spark.driver.port", 7077) 31 Utils.checkHost(driverHost, "Expected hostname") 32 rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name) 33 } 34 35 /** Returns the configured number of times to retry connecting */ 36 def numRetries(conf: SparkConf): Int = { 37 conf.getInt("spark.rpc.numRetries", 3) 38 } 39 40 /** Returns the configured number of milliseconds to wait on each retry */ 41 def retryWaitMs(conf: SparkConf): Long = { 42 conf.getTimeAsMs("spark.rpc.retry.wait", "3s") 43 } 44 45 /** Returns the default Spark timeout to use for RPC ask operations. */ 46 def askRpcTimeout(conf: SparkConf): RpcTimeout = { 47 RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s") 48 } 49 50 /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ 51 def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { 52 RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s") 53 } 54 55 private val MAX_MESSAGE_SIZE_IN_MB = Int.MaxValue / 1024 / 1024 56 57 /** Returns the configured max message size for messages in bytes. */ 58 def maxMessageSizeBytes(conf: SparkConf): Int = { 59 val maxSizeInMB = conf.getInt("spark.rpc.message.maxSize", 128) 60 if (maxSizeInMB > MAX_MESSAGE_SIZE_IN_MB) { 61 throw new IllegalArgumentException( 62 s"spark.rpc.message.maxSize should not be greater than $MAX_MESSAGE_SIZE_IN_MB MB") 63 } 64 maxSizeInMB * 1024 * 1024 65 } 66} 67