1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.spark.network.client; 19 20 import java.io.Closeable; 21 import java.io.IOException; 22 import java.net.InetSocketAddress; 23 import java.net.SocketAddress; 24 import java.util.List; 25 import java.util.Random; 26 import java.util.concurrent.ConcurrentHashMap; 27 import java.util.concurrent.atomic.AtomicReference; 28 29 import com.google.common.base.Preconditions; 30 import com.google.common.base.Throwables; 31 import com.google.common.collect.Lists; 32 import io.netty.bootstrap.Bootstrap; 33 import io.netty.buffer.PooledByteBufAllocator; 34 import io.netty.channel.Channel; 35 import io.netty.channel.ChannelFuture; 36 import io.netty.channel.ChannelInitializer; 37 import io.netty.channel.ChannelOption; 38 import io.netty.channel.EventLoopGroup; 39 import io.netty.channel.socket.SocketChannel; 40 import org.slf4j.Logger; 41 import org.slf4j.LoggerFactory; 42 43 import org.apache.spark.network.TransportContext; 44 import org.apache.spark.network.server.TransportChannelHandler; 45 import org.apache.spark.network.util.IOMode; 46 import org.apache.spark.network.util.JavaUtils; 47 import org.apache.spark.network.util.NettyUtils; 48 import org.apache.spark.network.util.TransportConf; 49 50 /** 51 * Factory for creating {@link TransportClient}s by using createClient. 52 * 53 * The factory maintains a connection pool to other hosts and should return the same 54 * TransportClient for the same remote host. It also shares a single worker thread pool for 55 * all TransportClients. 56 * 57 * TransportClients will be reused whenever possible. Prior to completing the creation of a new 58 * TransportClient, all given {@link TransportClientBootstrap}s will be run. 59 */ 60 public class TransportClientFactory implements Closeable { 61 62 /** A simple data structure to track the pool of clients between two peer nodes. */ 63 private static class ClientPool { 64 TransportClient[] clients; 65 Object[] locks; 66 ClientPool(int size)67 ClientPool(int size) { 68 clients = new TransportClient[size]; 69 locks = new Object[size]; 70 for (int i = 0; i < size; i++) { 71 locks[i] = new Object(); 72 } 73 } 74 } 75 76 private static final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); 77 78 private final TransportContext context; 79 private final TransportConf conf; 80 private final List<TransportClientBootstrap> clientBootstraps; 81 private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool; 82 83 /** Random number generator for picking connections between peers. */ 84 private final Random rand; 85 private final int numConnectionsPerPeer; 86 87 private final Class<? extends Channel> socketChannelClass; 88 private EventLoopGroup workerGroup; 89 private PooledByteBufAllocator pooledAllocator; 90 TransportClientFactory( TransportContext context, List<TransportClientBootstrap> clientBootstraps)91 public TransportClientFactory( 92 TransportContext context, 93 List<TransportClientBootstrap> clientBootstraps) { 94 this.context = Preconditions.checkNotNull(context); 95 this.conf = context.getConf(); 96 this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); 97 this.connectionPool = new ConcurrentHashMap<>(); 98 this.numConnectionsPerPeer = conf.numConnectionsPerPeer(); 99 this.rand = new Random(); 100 101 IOMode ioMode = IOMode.valueOf(conf.ioMode()); 102 this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); 103 this.workerGroup = NettyUtils.createEventLoop( 104 ioMode, 105 conf.clientThreads(), 106 conf.getModuleName() + "-client"); 107 this.pooledAllocator = NettyUtils.createPooledByteBufAllocator( 108 conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads()); 109 } 110 111 /** 112 * Create a {@link TransportClient} connecting to the given remote host / port. 113 * 114 * We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer) 115 * and randomly picks one to use. If no client was previously created in the randomly selected 116 * spot, this function creates a new client and places it there. 117 * 118 * Prior to the creation of a new TransportClient, we will execute all 119 * {@link TransportClientBootstrap}s that are registered with this factory. 120 * 121 * This blocks until a connection is successfully established and fully bootstrapped. 122 * 123 * Concurrency: This method is safe to call from multiple threads. 124 */ createClient(String remoteHost, int remotePort)125 public TransportClient createClient(String remoteHost, int remotePort) 126 throws IOException, InterruptedException { 127 // Get connection from the connection pool first. 128 // If it is not found or not active, create a new one. 129 // Use unresolved address here to avoid DNS resolution each time we creates a client. 130 final InetSocketAddress unresolvedAddress = 131 InetSocketAddress.createUnresolved(remoteHost, remotePort); 132 133 // Create the ClientPool if we don't have it yet. 134 ClientPool clientPool = connectionPool.get(unresolvedAddress); 135 if (clientPool == null) { 136 connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer)); 137 clientPool = connectionPool.get(unresolvedAddress); 138 } 139 140 int clientIndex = rand.nextInt(numConnectionsPerPeer); 141 TransportClient cachedClient = clientPool.clients[clientIndex]; 142 143 if (cachedClient != null && cachedClient.isActive()) { 144 // Make sure that the channel will not timeout by updating the last use time of the 145 // handler. Then check that the client is still alive, in case it timed out before 146 // this code was able to update things. 147 TransportChannelHandler handler = cachedClient.getChannel().pipeline() 148 .get(TransportChannelHandler.class); 149 synchronized (handler) { 150 handler.getResponseHandler().updateTimeOfLastRequest(); 151 } 152 153 if (cachedClient.isActive()) { 154 logger.trace("Returning cached connection to {}: {}", 155 cachedClient.getSocketAddress(), cachedClient); 156 return cachedClient; 157 } 158 } 159 160 // If we reach here, we don't have an existing connection open. Let's create a new one. 161 // Multiple threads might race here to create new connections. Keep only one of them active. 162 final long preResolveHost = System.nanoTime(); 163 final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort); 164 final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000; 165 if (hostResolveTimeMs > 2000) { 166 logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs); 167 } else { 168 logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs); 169 } 170 171 synchronized (clientPool.locks[clientIndex]) { 172 cachedClient = clientPool.clients[clientIndex]; 173 174 if (cachedClient != null) { 175 if (cachedClient.isActive()) { 176 logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient); 177 return cachedClient; 178 } else { 179 logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress); 180 } 181 } 182 clientPool.clients[clientIndex] = createClient(resolvedAddress); 183 return clientPool.clients[clientIndex]; 184 } 185 } 186 187 /** 188 * Create a completely new {@link TransportClient} to the given remote host / port. 189 * This connection is not pooled. 190 * 191 * As with {@link #createClient(String, int)}, this method is blocking. 192 */ createUnmanagedClient(String remoteHost, int remotePort)193 public TransportClient createUnmanagedClient(String remoteHost, int remotePort) 194 throws IOException, InterruptedException { 195 final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); 196 return createClient(address); 197 } 198 199 /** Create a completely new {@link TransportClient} to the remote address. */ createClient(InetSocketAddress address)200 private TransportClient createClient(InetSocketAddress address) 201 throws IOException, InterruptedException { 202 logger.debug("Creating new connection to {}", address); 203 204 Bootstrap bootstrap = new Bootstrap(); 205 bootstrap.group(workerGroup) 206 .channel(socketChannelClass) 207 // Disable Nagle's Algorithm since we don't want packets to wait 208 .option(ChannelOption.TCP_NODELAY, true) 209 .option(ChannelOption.SO_KEEPALIVE, true) 210 .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()) 211 .option(ChannelOption.ALLOCATOR, pooledAllocator); 212 213 final AtomicReference<TransportClient> clientRef = new AtomicReference<>(); 214 final AtomicReference<Channel> channelRef = new AtomicReference<>(); 215 216 bootstrap.handler(new ChannelInitializer<SocketChannel>() { 217 @Override 218 public void initChannel(SocketChannel ch) { 219 TransportChannelHandler clientHandler = context.initializePipeline(ch); 220 clientRef.set(clientHandler.getClient()); 221 channelRef.set(ch); 222 } 223 }); 224 225 // Connect to the remote server 226 long preConnect = System.nanoTime(); 227 ChannelFuture cf = bootstrap.connect(address); 228 if (!cf.await(conf.connectionTimeoutMs())) { 229 throw new IOException( 230 String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); 231 } else if (cf.cause() != null) { 232 throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); 233 } 234 235 TransportClient client = clientRef.get(); 236 Channel channel = channelRef.get(); 237 assert client != null : "Channel future completed successfully with null client"; 238 239 // Execute any client bootstraps synchronously before marking the Client as successful. 240 long preBootstrap = System.nanoTime(); 241 logger.debug("Connection to {} successful, running bootstraps...", address); 242 try { 243 for (TransportClientBootstrap clientBootstrap : clientBootstraps) { 244 clientBootstrap.doBootstrap(client, channel); 245 } 246 } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala 247 long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000; 248 logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e); 249 client.close(); 250 throw Throwables.propagate(e); 251 } 252 long postBootstrap = System.nanoTime(); 253 254 logger.info("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", 255 address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000); 256 257 return client; 258 } 259 260 /** Close all connections in the connection pool, and shutdown the worker thread pool. */ 261 @Override close()262 public void close() { 263 // Go through all clients and close them if they are active. 264 for (ClientPool clientPool : connectionPool.values()) { 265 for (int i = 0; i < clientPool.clients.length; i++) { 266 TransportClient client = clientPool.clients[i]; 267 if (client != null) { 268 clientPool.clients[i] = null; 269 JavaUtils.closeQuietly(client); 270 } 271 } 272 } 273 connectionPool.clear(); 274 275 if (workerGroup != null) { 276 workerGroup.shutdownGracefully(); 277 workerGroup = null; 278 } 279 } 280 } 281