1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 package org.apache.spark.launcher;
19 
20 import java.io.Closeable;
21 import java.io.IOException;
22 import java.net.InetAddress;
23 import java.net.InetSocketAddress;
24 import java.net.ServerSocket;
25 import java.net.Socket;
26 import java.security.SecureRandom;
27 import java.util.ArrayList;
28 import java.util.List;
29 import java.util.Timer;
30 import java.util.TimerTask;
31 import java.util.concurrent.ConcurrentHashMap;
32 import java.util.concurrent.ConcurrentMap;
33 import java.util.concurrent.ThreadFactory;
34 import java.util.concurrent.atomic.AtomicLong;
35 import java.util.logging.Level;
36 import java.util.logging.Logger;
37 
38 import static org.apache.spark.launcher.LauncherProtocol.*;
39 
40 /**
41  * A server that listens locally for connections from client launched by the library. Each client
42  * has a secret that it needs to send to the server to identify itself and establish the session.
43  *
44  * I/O is currently blocking (one thread per client). Clients have a limited time to connect back
45  * to the server, otherwise the server will ignore the connection.
46  *
47  * === Architecture Overview ===
48  *
49  * The launcher server is used when Spark apps are launched as separate processes than the calling
50  * app. It looks more or less like the following:
51  *
52  *         -----------------------                       -----------------------
53  *         |      User App       |     spark-submit      |      Spark App      |
54  *         |                     |  -------------------> |                     |
55  *         |         ------------|                       |-------------        |
56  *         |         |           |        hello          |            |        |
57  *         |         | L. Server |<----------------------| L. Backend |        |
58  *         |         |           |                       |            |        |
59  *         |         -------------                       -----------------------
60  *         |               |     |                              ^
61  *         |               v     |                              |
62  *         |        -------------|                              |
63  *         |        |            |      <per-app channel>       |
64  *         |        | App Handle |<------------------------------
65  *         |        |            |
66  *         -----------------------
67  *
68  * The server is started on demand and remains active while there are active or outstanding clients,
69  * to avoid opening too many ports when multiple clients are launched. Each client is given a unique
70  * secret, and have a limited amount of time to connect back
71  * ({@link SparkLauncher#CHILD_CONNECTION_TIMEOUT}), at which point the server will throw away
72  * that client's state. A client is only allowed to connect back to the server once.
73  *
74  * The launcher server listens on the localhost only, so it doesn't need access controls (aside from
75  * the per-app secret) nor encryption. It thus requires that the launched app has a local process
76  * that communicates with the server. In cluster mode, this means that the client that launches the
77  * application must remain alive for the duration of the application (or until the app handle is
78  * disconnected).
79  */
80 class LauncherServer implements Closeable {
81 
82   private static final Logger LOG = Logger.getLogger(LauncherServer.class.getName());
83   private static final String THREAD_NAME_FMT = "LauncherServer-%d";
84   private static final long DEFAULT_CONNECT_TIMEOUT = 10000L;
85 
86   /** For creating secrets used for communication with child processes. */
87   private static final SecureRandom RND = new SecureRandom();
88 
89   private static volatile LauncherServer serverInstance;
90 
91   /**
92    * Creates a handle for an app to be launched. This method will start a server if one hasn't been
93    * started yet. The server is shared for multiple handles, and once all handles are disposed of,
94    * the server is shut down.
95    */
newAppHandle()96   static synchronized ChildProcAppHandle newAppHandle() throws IOException {
97     LauncherServer server = serverInstance != null ? serverInstance : new LauncherServer();
98     server.ref();
99     serverInstance = server;
100 
101     String secret = server.createSecret();
102     while (server.pending.containsKey(secret)) {
103       secret = server.createSecret();
104     }
105 
106     return server.newAppHandle(secret);
107   }
108 
getServerInstance()109   static LauncherServer getServerInstance() {
110     return serverInstance;
111   }
112 
113   private final AtomicLong refCount;
114   private final AtomicLong threadIds;
115   private final ConcurrentMap<String, ChildProcAppHandle> pending;
116   private final List<ServerConnection> clients;
117   private final ServerSocket server;
118   private final Thread serverThread;
119   private final ThreadFactory factory;
120   private final Timer timeoutTimer;
121 
122   private volatile boolean running;
123 
LauncherServer()124   private LauncherServer() throws IOException {
125     this.refCount = new AtomicLong(0);
126 
127     ServerSocket server = new ServerSocket();
128     try {
129       server.setReuseAddress(true);
130       server.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
131 
132       this.clients = new ArrayList<>();
133       this.threadIds = new AtomicLong();
134       this.factory = new NamedThreadFactory(THREAD_NAME_FMT);
135       this.pending = new ConcurrentHashMap<>();
136       this.timeoutTimer = new Timer("LauncherServer-TimeoutTimer", true);
137       this.server = server;
138       this.running = true;
139 
140       this.serverThread = factory.newThread(new Runnable() {
141         @Override
142         public void run() {
143           acceptConnections();
144         }
145       });
146       serverThread.start();
147     } catch (IOException ioe) {
148       close();
149       throw ioe;
150     } catch (Exception e) {
151       close();
152       throw new IOException(e);
153     }
154   }
155 
156   /**
157    * Creates a new app handle. The handle will wait for an incoming connection for a configurable
158    * amount of time, and if one doesn't arrive, it will transition to an error state.
159    */
newAppHandle(String secret)160   ChildProcAppHandle newAppHandle(String secret) {
161     ChildProcAppHandle handle = new ChildProcAppHandle(secret, this);
162     ChildProcAppHandle existing = pending.putIfAbsent(secret, handle);
163     CommandBuilderUtils.checkState(existing == null, "Multiple handles with the same secret.");
164     return handle;
165   }
166 
167   @Override
close()168   public void close() throws IOException {
169     synchronized (this) {
170       if (running) {
171         running = false;
172         timeoutTimer.cancel();
173         server.close();
174         synchronized (clients) {
175           List<ServerConnection> copy = new ArrayList<>(clients);
176           clients.clear();
177           for (ServerConnection client : copy) {
178             client.close();
179           }
180         }
181       }
182     }
183     if (serverThread != null) {
184       try {
185         serverThread.join();
186       } catch (InterruptedException ie) {
187         // no-op
188       }
189     }
190   }
191 
ref()192   void ref() {
193     refCount.incrementAndGet();
194   }
195 
unref()196   void unref() {
197     synchronized(LauncherServer.class) {
198       if (refCount.decrementAndGet() == 0) {
199         try {
200           close();
201         } catch (IOException ioe) {
202           // no-op.
203         } finally {
204           serverInstance = null;
205         }
206       }
207     }
208   }
209 
getPort()210   int getPort() {
211     return server.getLocalPort();
212   }
213 
214   /**
215    * Removes the client handle from the pending list (in case it's still there), and unrefs
216    * the server.
217    */
unregister(ChildProcAppHandle handle)218   void unregister(ChildProcAppHandle handle) {
219     pending.remove(handle.getSecret());
220     unref();
221   }
222 
acceptConnections()223   private void acceptConnections() {
224     try {
225       while (running) {
226         final Socket client = server.accept();
227         TimerTask timeout = new TimerTask() {
228           @Override
229           public void run() {
230             LOG.warning("Timed out waiting for hello message from client.");
231             try {
232               client.close();
233             } catch (IOException ioe) {
234               // no-op.
235             }
236           }
237         };
238         ServerConnection clientConnection = new ServerConnection(client, timeout);
239         Thread clientThread = factory.newThread(clientConnection);
240         synchronized (timeout) {
241           clientThread.start();
242           synchronized (clients) {
243             clients.add(clientConnection);
244           }
245           long timeoutMs = getConnectionTimeout();
246           // 0 is used for testing to avoid issues with clock resolution / thread scheduling,
247           // and force an immediate timeout.
248           if (timeoutMs > 0) {
249             timeoutTimer.schedule(timeout, getConnectionTimeout());
250           } else {
251             timeout.run();
252           }
253         }
254       }
255     } catch (IOException ioe) {
256       if (running) {
257         LOG.log(Level.SEVERE, "Error in accept loop.", ioe);
258       }
259     }
260   }
261 
getConnectionTimeout()262   private long getConnectionTimeout() {
263     String value = SparkLauncher.launcherConfig.get(SparkLauncher.CHILD_CONNECTION_TIMEOUT);
264     return (value != null) ? Long.parseLong(value) : DEFAULT_CONNECT_TIMEOUT;
265   }
266 
createSecret()267   private String createSecret() {
268     byte[] secret = new byte[128];
269     RND.nextBytes(secret);
270 
271     StringBuilder sb = new StringBuilder();
272     for (byte b : secret) {
273       int ival = b >= 0 ? b : Byte.MAX_VALUE - b;
274       if (ival < 0x10) {
275         sb.append("0");
276       }
277       sb.append(Integer.toHexString(ival));
278     }
279     return sb.toString();
280   }
281 
282   private class ServerConnection extends LauncherConnection {
283 
284     private TimerTask timeout;
285     private ChildProcAppHandle handle;
286 
ServerConnection(Socket socket, TimerTask timeout)287     ServerConnection(Socket socket, TimerTask timeout) throws IOException {
288       super(socket);
289       this.timeout = timeout;
290     }
291 
292     @Override
handle(Message msg)293     protected void handle(Message msg) throws IOException {
294       try {
295         if (msg instanceof Hello) {
296           timeout.cancel();
297           timeout = null;
298           Hello hello = (Hello) msg;
299           ChildProcAppHandle handle = pending.remove(hello.secret);
300           if (handle != null) {
301             handle.setConnection(this);
302             handle.setState(SparkAppHandle.State.CONNECTED);
303             this.handle = handle;
304           } else {
305             throw new IllegalArgumentException("Received Hello for unknown client.");
306           }
307         } else {
308           if (handle == null) {
309             throw new IllegalArgumentException("Expected hello, got: " +
310             msg != null ? msg.getClass().getName() : null);
311           }
312           if (msg instanceof SetAppId) {
313             SetAppId set = (SetAppId) msg;
314             handle.setAppId(set.appId);
315           } else if (msg instanceof SetState) {
316             handle.setState(((SetState)msg).state);
317           } else {
318             throw new IllegalArgumentException("Invalid message: " +
319               msg != null ? msg.getClass().getName() : null);
320           }
321         }
322       } catch (Exception e) {
323         LOG.log(Level.INFO, "Error handling message from client.", e);
324         if (timeout != null) {
325           timeout.cancel();
326         }
327         close();
328       } finally {
329         timeoutTimer.purge();
330       }
331     }
332 
333     @Override
close()334     public void close() throws IOException {
335       synchronized (clients) {
336         clients.remove(this);
337       }
338       super.close();
339       if (handle != null) {
340         if (!handle.getState().isFinal()) {
341           LOG.log(Level.WARNING, "Lost connection to spark application.");
342           handle.setState(SparkAppHandle.State.LOST);
343         }
344         handle.disconnect();
345       }
346     }
347 
348   }
349 
350 }
351