1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 package org.apache.spark.network.client;
19 
20 import java.io.Closeable;
21 import java.io.IOException;
22 import java.net.InetSocketAddress;
23 import java.net.SocketAddress;
24 import java.util.List;
25 import java.util.Random;
26 import java.util.concurrent.ConcurrentHashMap;
27 import java.util.concurrent.atomic.AtomicReference;
28 
29 import com.google.common.base.Preconditions;
30 import com.google.common.base.Throwables;
31 import com.google.common.collect.Lists;
32 import io.netty.bootstrap.Bootstrap;
33 import io.netty.buffer.PooledByteBufAllocator;
34 import io.netty.channel.Channel;
35 import io.netty.channel.ChannelFuture;
36 import io.netty.channel.ChannelInitializer;
37 import io.netty.channel.ChannelOption;
38 import io.netty.channel.EventLoopGroup;
39 import io.netty.channel.socket.SocketChannel;
40 import org.slf4j.Logger;
41 import org.slf4j.LoggerFactory;
42 
43 import org.apache.spark.network.TransportContext;
44 import org.apache.spark.network.server.TransportChannelHandler;
45 import org.apache.spark.network.util.IOMode;
46 import org.apache.spark.network.util.JavaUtils;
47 import org.apache.spark.network.util.NettyUtils;
48 import org.apache.spark.network.util.TransportConf;
49 
50 /**
51  * Factory for creating {@link TransportClient}s by using createClient.
52  *
53  * The factory maintains a connection pool to other hosts and should return the same
54  * TransportClient for the same remote host. It also shares a single worker thread pool for
55  * all TransportClients.
56  *
57  * TransportClients will be reused whenever possible. Prior to completing the creation of a new
58  * TransportClient, all given {@link TransportClientBootstrap}s will be run.
59  */
60 public class TransportClientFactory implements Closeable {
61 
62   /** A simple data structure to track the pool of clients between two peer nodes. */
63   private static class ClientPool {
64     TransportClient[] clients;
65     Object[] locks;
66 
ClientPool(int size)67     ClientPool(int size) {
68       clients = new TransportClient[size];
69       locks = new Object[size];
70       for (int i = 0; i < size; i++) {
71         locks[i] = new Object();
72       }
73     }
74   }
75 
76   private static final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
77 
78   private final TransportContext context;
79   private final TransportConf conf;
80   private final List<TransportClientBootstrap> clientBootstraps;
81   private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
82 
83   /** Random number generator for picking connections between peers. */
84   private final Random rand;
85   private final int numConnectionsPerPeer;
86 
87   private final Class<? extends Channel> socketChannelClass;
88   private EventLoopGroup workerGroup;
89   private PooledByteBufAllocator pooledAllocator;
90 
TransportClientFactory( TransportContext context, List<TransportClientBootstrap> clientBootstraps)91   public TransportClientFactory(
92       TransportContext context,
93       List<TransportClientBootstrap> clientBootstraps) {
94     this.context = Preconditions.checkNotNull(context);
95     this.conf = context.getConf();
96     this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
97     this.connectionPool = new ConcurrentHashMap<>();
98     this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
99     this.rand = new Random();
100 
101     IOMode ioMode = IOMode.valueOf(conf.ioMode());
102     this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
103     this.workerGroup = NettyUtils.createEventLoop(
104         ioMode,
105         conf.clientThreads(),
106         conf.getModuleName() + "-client");
107     this.pooledAllocator = NettyUtils.createPooledByteBufAllocator(
108       conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads());
109   }
110 
111   /**
112    * Create a {@link TransportClient} connecting to the given remote host / port.
113    *
114    * We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer)
115    * and randomly picks one to use. If no client was previously created in the randomly selected
116    * spot, this function creates a new client and places it there.
117    *
118    * Prior to the creation of a new TransportClient, we will execute all
119    * {@link TransportClientBootstrap}s that are registered with this factory.
120    *
121    * This blocks until a connection is successfully established and fully bootstrapped.
122    *
123    * Concurrency: This method is safe to call from multiple threads.
124    */
createClient(String remoteHost, int remotePort)125   public TransportClient createClient(String remoteHost, int remotePort)
126       throws IOException, InterruptedException {
127     // Get connection from the connection pool first.
128     // If it is not found or not active, create a new one.
129     // Use unresolved address here to avoid DNS resolution each time we creates a client.
130     final InetSocketAddress unresolvedAddress =
131       InetSocketAddress.createUnresolved(remoteHost, remotePort);
132 
133     // Create the ClientPool if we don't have it yet.
134     ClientPool clientPool = connectionPool.get(unresolvedAddress);
135     if (clientPool == null) {
136       connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer));
137       clientPool = connectionPool.get(unresolvedAddress);
138     }
139 
140     int clientIndex = rand.nextInt(numConnectionsPerPeer);
141     TransportClient cachedClient = clientPool.clients[clientIndex];
142 
143     if (cachedClient != null && cachedClient.isActive()) {
144       // Make sure that the channel will not timeout by updating the last use time of the
145       // handler. Then check that the client is still alive, in case it timed out before
146       // this code was able to update things.
147       TransportChannelHandler handler = cachedClient.getChannel().pipeline()
148         .get(TransportChannelHandler.class);
149       synchronized (handler) {
150         handler.getResponseHandler().updateTimeOfLastRequest();
151       }
152 
153       if (cachedClient.isActive()) {
154         logger.trace("Returning cached connection to {}: {}",
155           cachedClient.getSocketAddress(), cachedClient);
156         return cachedClient;
157       }
158     }
159 
160     // If we reach here, we don't have an existing connection open. Let's create a new one.
161     // Multiple threads might race here to create new connections. Keep only one of them active.
162     final long preResolveHost = System.nanoTime();
163     final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort);
164     final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000;
165     if (hostResolveTimeMs > 2000) {
166       logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
167     } else {
168       logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
169     }
170 
171     synchronized (clientPool.locks[clientIndex]) {
172       cachedClient = clientPool.clients[clientIndex];
173 
174       if (cachedClient != null) {
175         if (cachedClient.isActive()) {
176           logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient);
177           return cachedClient;
178         } else {
179           logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress);
180         }
181       }
182       clientPool.clients[clientIndex] = createClient(resolvedAddress);
183       return clientPool.clients[clientIndex];
184     }
185   }
186 
187   /**
188    * Create a completely new {@link TransportClient} to the given remote host / port.
189    * This connection is not pooled.
190    *
191    * As with {@link #createClient(String, int)}, this method is blocking.
192    */
createUnmanagedClient(String remoteHost, int remotePort)193   public TransportClient createUnmanagedClient(String remoteHost, int remotePort)
194       throws IOException, InterruptedException {
195     final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
196     return createClient(address);
197   }
198 
199   /** Create a completely new {@link TransportClient} to the remote address. */
createClient(InetSocketAddress address)200   private TransportClient createClient(InetSocketAddress address)
201       throws IOException, InterruptedException {
202     logger.debug("Creating new connection to {}", address);
203 
204     Bootstrap bootstrap = new Bootstrap();
205     bootstrap.group(workerGroup)
206       .channel(socketChannelClass)
207       // Disable Nagle's Algorithm since we don't want packets to wait
208       .option(ChannelOption.TCP_NODELAY, true)
209       .option(ChannelOption.SO_KEEPALIVE, true)
210       .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
211       .option(ChannelOption.ALLOCATOR, pooledAllocator);
212 
213     final AtomicReference<TransportClient> clientRef = new AtomicReference<>();
214     final AtomicReference<Channel> channelRef = new AtomicReference<>();
215 
216     bootstrap.handler(new ChannelInitializer<SocketChannel>() {
217       @Override
218       public void initChannel(SocketChannel ch) {
219         TransportChannelHandler clientHandler = context.initializePipeline(ch);
220         clientRef.set(clientHandler.getClient());
221         channelRef.set(ch);
222       }
223     });
224 
225     // Connect to the remote server
226     long preConnect = System.nanoTime();
227     ChannelFuture cf = bootstrap.connect(address);
228     if (!cf.await(conf.connectionTimeoutMs())) {
229       throw new IOException(
230         String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
231     } else if (cf.cause() != null) {
232       throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
233     }
234 
235     TransportClient client = clientRef.get();
236     Channel channel = channelRef.get();
237     assert client != null : "Channel future completed successfully with null client";
238 
239     // Execute any client bootstraps synchronously before marking the Client as successful.
240     long preBootstrap = System.nanoTime();
241     logger.debug("Connection to {} successful, running bootstraps...", address);
242     try {
243       for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
244         clientBootstrap.doBootstrap(client, channel);
245       }
246     } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
247       long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
248       logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
249       client.close();
250       throw Throwables.propagate(e);
251     }
252     long postBootstrap = System.nanoTime();
253 
254     logger.info("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
255       address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);
256 
257     return client;
258   }
259 
260   /** Close all connections in the connection pool, and shutdown the worker thread pool. */
261   @Override
close()262   public void close() {
263     // Go through all clients and close them if they are active.
264     for (ClientPool clientPool : connectionPool.values()) {
265       for (int i = 0; i < clientPool.clients.length; i++) {
266         TransportClient client = clientPool.clients[i];
267         if (client != null) {
268           clientPool.clients[i] = null;
269           JavaUtils.closeQuietly(client);
270         }
271       }
272     }
273     connectionPool.clear();
274 
275     if (workerGroup != null) {
276       workerGroup.shutdownGracefully();
277       workerGroup = null;
278     }
279   }
280 }
281