1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 package org.apache.thrift.transport.sasl;
21 
22 import java.nio.channels.SelectionKey;
23 import java.nio.charset.StandardCharsets;
24 
25 import javax.security.sasl.SaslServer;
26 
27 import org.apache.thrift.TByteArrayOutputStream;
28 import org.apache.thrift.TProcessor;
29 import org.apache.thrift.protocol.TProtocol;
30 import org.apache.thrift.protocol.TProtocolFactory;
31 import org.apache.thrift.server.ServerContext;
32 import org.apache.thrift.server.TServerEventHandler;
33 import org.apache.thrift.transport.TMemoryTransport;
34 import org.apache.thrift.transport.TNonblockingTransport;
35 import org.apache.thrift.transport.TTransportException;
36 import org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType;
37 import org.slf4j.Logger;
38 import org.slf4j.LoggerFactory;
39 
40 import static org.apache.thrift.transport.sasl.NegotiationStatus.COMPLETE;
41 import static org.apache.thrift.transport.sasl.NegotiationStatus.OK;
42 
43 /**
44  * State machine managing one sasl connection in a nonblocking way.
45  */
46 public class NonblockingSaslHandler {
47   private static final Logger LOGGER = LoggerFactory.getLogger(NonblockingSaslHandler.class);
48 
49   private static final int INTEREST_NONE = 0;
50   private static final int INTEREST_READ = SelectionKey.OP_READ;
51   private static final int INTEREST_WRITE = SelectionKey.OP_WRITE;
52 
53   // Tracking the current running phase
54   private Phase currentPhase = Phase.INITIIALIIZING;
55   // Tracking the next phase on the next invocation of the state machine.
56   // It should be the same as current phase if current phase is not yet finished.
57   // Otherwise, if it is different from current phase, the statemachine is in a transition state:
58   // current phase is done, and next phase is not yet started.
59   private Phase nextPhase = currentPhase;
60 
61   // Underlying nonblocking transport
62   private SelectionKey selectionKey;
63   private TNonblockingTransport underlyingTransport;
64 
65   // APIs for intercepting event / customizing behaviors:
66   // Factories (decorating the base implementations) & EventHandler (intercepting)
67   private TSaslServerFactory saslServerFactory;
68   private TSaslProcessorFactory processorFactory;
69   private TProtocolFactory inputProtocolFactory;
70   private TProtocolFactory outputProtocolFactory;
71   private TServerEventHandler eventHandler;
72   private ServerContext serverContext;
73   // It turns out the event handler implementation in hive sometimes creates a null ServerContext.
74   // In order to know whether TServerEventHandler#createContext is called we use such a flag.
75   private boolean serverContextCreated = false;
76 
77   // Wrapper around sasl server
78   private ServerSaslPeer saslPeer;
79 
80   // Sasl negotiation io
81   private SaslNegotiationFrameReader saslResponse;
82   private SaslNegotiationFrameWriter saslChallenge;
83   // IO for request from and response to the socket
84   private DataFrameReader requestReader;
85   private DataFrameWriter responseWriter;
86   // If sasl is negotiated for integrity/confidentiality protection
87   private boolean dataProtected;
88 
NonblockingSaslHandler(SelectionKey selectionKey, TNonblockingTransport underlyingTransport, TSaslServerFactory saslServerFactory, TSaslProcessorFactory processorFactory, TProtocolFactory inputProtocolFactory, TProtocolFactory outputProtocolFactory, TServerEventHandler eventHandler)89   public NonblockingSaslHandler(SelectionKey selectionKey, TNonblockingTransport underlyingTransport,
90                                 TSaslServerFactory saslServerFactory, TSaslProcessorFactory processorFactory,
91                                 TProtocolFactory inputProtocolFactory, TProtocolFactory outputProtocolFactory,
92                                 TServerEventHandler eventHandler) {
93     this.selectionKey = selectionKey;
94     this.underlyingTransport = underlyingTransport;
95     this.saslServerFactory = saslServerFactory;
96     this.processorFactory = processorFactory;
97     this.inputProtocolFactory = inputProtocolFactory;
98     this.outputProtocolFactory = outputProtocolFactory;
99     this.eventHandler = eventHandler;
100 
101     saslResponse = new SaslNegotiationFrameReader();
102     saslChallenge = new SaslNegotiationFrameWriter();
103     requestReader = new DataFrameReader();
104     responseWriter = new DataFrameWriter();
105   }
106 
107   /**
108    * Get current phase of the state machine.
109    *
110    * @return current phase.
111    */
getCurrentPhase()112   public Phase getCurrentPhase() {
113     return currentPhase;
114   }
115 
116   /**
117    * Get next phase of the state machine.
118    * It is different from current phase iff current phase is done (and next phase not yet started).
119    *
120    * @return next phase.
121    */
getNextPhase()122   public Phase getNextPhase() {
123     return nextPhase;
124   }
125 
126   /**
127    *
128    * @return underlying nonblocking socket
129    */
getUnderlyingTransport()130   public TNonblockingTransport getUnderlyingTransport() {
131     return underlyingTransport;
132   }
133 
134   /**
135    *
136    * @return SaslServer instance
137    */
getSaslServer()138   public SaslServer getSaslServer() {
139     return saslPeer.getSaslServer();
140   }
141 
142   /**
143    *
144    * @return true if current phase is done.
145    */
isCurrentPhaseDone()146   public boolean isCurrentPhaseDone() {
147     return currentPhase != nextPhase;
148   }
149 
150   /**
151    * Run state machine.
152    *
153    * @throws IllegalStateException if current state is already done.
154    */
runCurrentPhase()155   public void runCurrentPhase() {
156     currentPhase.runStateMachine(this);
157   }
158 
159   /**
160    * When current phase is intrested in read selection, calling this will run the current phase and
161    * its following phases if the following ones are interested to read, until there is nothing
162    * available in the underlying transport.
163    *
164    * @throws IllegalStateException if is called in an irrelevant phase.
165    */
handleRead()166   public void handleRead() {
167     handleOps(INTEREST_READ);
168   }
169 
170   /**
171    * Similiar to handleRead. But it is for write ops.
172    *
173    * @throws IllegalStateException if it is called in an irrelevant phase.
174    */
handleWrite()175   public void handleWrite() {
176     handleOps(INTEREST_WRITE);
177   }
178 
handleOps(int interestOps)179   private void handleOps(int interestOps) {
180     if (currentPhase.selectionInterest != interestOps) {
181       throw new IllegalStateException("Current phase " + currentPhase + " but got interest " +
182           interestOps);
183     }
184     runCurrentPhase();
185     if (isCurrentPhaseDone() && nextPhase.selectionInterest == interestOps) {
186       stepToNextPhase();
187       handleOps(interestOps);
188     }
189   }
190 
191   /**
192    * When current phase is finished, it's expected to call this method first before running the
193    * state machine again.
194    * By calling this, "next phase" is marked as started (and not done), thus is ready to run.
195    *
196    * @throws IllegalArgumentException if current phase is not yet done.
197    */
stepToNextPhase()198   public void stepToNextPhase() {
199     if (!isCurrentPhaseDone()) {
200       throw new IllegalArgumentException("Not yet done with current phase: " + currentPhase);
201     }
202     LOGGER.debug("Switch phase {} to {}", currentPhase, nextPhase);
203     switch (nextPhase) {
204       case INITIIALIIZING:
205         throw new IllegalStateException("INITIALIZING cannot be the next phase of " + currentPhase);
206       default:
207     }
208     // If next phase's interest is not the same as current,  nor the same as the selection key,
209     // we need to change interest on the selector.
210     if (!(nextPhase.selectionInterest == currentPhase.selectionInterest ||
211         nextPhase.selectionInterest == selectionKey.interestOps())) {
212       changeSelectionInterest(nextPhase.selectionInterest);
213     }
214     currentPhase = nextPhase;
215   }
216 
changeSelectionInterest(int selectionInterest)217   private void changeSelectionInterest(int selectionInterest) {
218     selectionKey.interestOps(selectionInterest);
219   }
220 
221   // sasl negotiaion failure handling
failSaslNegotiation(TSaslNegotiationException e)222   private void failSaslNegotiation(TSaslNegotiationException e) {
223     LOGGER.error("Sasl negotiation failed", e);
224     String errorMsg = e.getDetails();
225     saslChallenge.withHeaderAndPayload(new byte[]{e.getErrorType().code.getValue()},
226         errorMsg.getBytes(StandardCharsets.UTF_8));
227     nextPhase = Phase.WRITING_FAILURE_MESSAGE;
228   }
229 
fail(Exception e)230   private void fail(Exception e) {
231     LOGGER.error("Failed io in " + currentPhase, e);
232     nextPhase = Phase.CLOSING;
233   }
234 
failIO(TTransportException e)235   private void failIO(TTransportException e) {
236     StringBuilder errorMsg = new StringBuilder("IO failure ")
237         .append(e.getType())
238         .append(" in ")
239         .append(currentPhase);
240     if (e.getMessage() != null) {
241       errorMsg.append(": ").append(e.getMessage());
242     }
243     LOGGER.error(errorMsg.toString(), e);
244     nextPhase = Phase.CLOSING;
245   }
246 
247   // Read handlings
248 
handleInitializing()249   private void handleInitializing() {
250     try {
251       saslResponse.read(underlyingTransport);
252       if (saslResponse.isComplete()) {
253         SaslNegotiationHeaderReader startHeader = saslResponse.getHeader();
254         if (startHeader.getStatus() != NegotiationStatus.START) {
255           throw new TInvalidSaslFrameException("Expecting START status but got " + startHeader.getStatus());
256         }
257         String mechanism = new String(saslResponse.getPayload(), StandardCharsets.UTF_8);
258         saslPeer = saslServerFactory.getSaslPeer(mechanism);
259         saslResponse.clear();
260         nextPhase = Phase.READING_SASL_RESPONSE;
261       }
262     } catch (TSaslNegotiationException e) {
263       failSaslNegotiation(e);
264     } catch (TTransportException e) {
265       failIO(e);
266     }
267   }
268 
handleReadingSaslResponse()269   private void handleReadingSaslResponse() {
270     try {
271       saslResponse.read(underlyingTransport);
272       if (saslResponse.isComplete()) {
273         nextPhase = Phase.EVALUATING_SASL_RESPONSE;
274       }
275     } catch (TSaslNegotiationException e) {
276       failSaslNegotiation(e);
277     } catch (TTransportException e) {
278       failIO(e);
279     }
280   }
281 
handleReadingRequest()282   private void handleReadingRequest() {
283     try {
284       requestReader.read(underlyingTransport);
285       if (requestReader.isComplete()) {
286         nextPhase = Phase.PROCESSING;
287       }
288     } catch (TTransportException e) {
289       failIO(e);
290     }
291   }
292 
293   // Computation executions
294 
executeEvaluatingSaslResponse()295   private void executeEvaluatingSaslResponse() {
296     if (!(saslResponse.getHeader().getStatus() == OK || saslResponse.getHeader().getStatus() == COMPLETE)) {
297       String error = "Expect status OK or COMPLETE, but got " + saslResponse.getHeader().getStatus();
298       failSaslNegotiation(new TSaslNegotiationException(ErrorType.PROTOCOL_ERROR, error));
299       return;
300     }
301     try {
302       byte[] response = saslResponse.getPayload();
303       saslResponse.clear();
304       byte[] newChallenge = saslPeer.evaluate(response);
305       if (saslPeer.isAuthenticated()) {
306         dataProtected = saslPeer.isDataProtected();
307         saslChallenge.withHeaderAndPayload(new byte[]{COMPLETE.getValue()}, newChallenge);
308         nextPhase = Phase.WRITING_SUCCESS_MESSAGE;
309       } else {
310         saslChallenge.withHeaderAndPayload(new byte[]{OK.getValue()}, newChallenge);
311         nextPhase = Phase.WRITING_SASL_CHALLENGE;
312       }
313     } catch (TSaslNegotiationException e) {
314       failSaslNegotiation(e);
315     }
316   }
317 
executeProcessing()318   private void executeProcessing() {
319     try {
320       byte[] inputPayload = requestReader.getPayload();
321       requestReader.clear();
322       byte[] rawInput = dataProtected ? saslPeer.unwrap(inputPayload) : inputPayload;
323       TMemoryTransport memoryTransport = new TMemoryTransport(rawInput);
324       TProtocol requestProtocol = inputProtocolFactory.getProtocol(memoryTransport);
325       TProtocol responseProtocol = outputProtocolFactory.getProtocol(memoryTransport);
326 
327       if (eventHandler != null) {
328         if (!serverContextCreated) {
329           serverContext = eventHandler.createContext(requestProtocol, responseProtocol);
330           serverContextCreated = true;
331         }
332         eventHandler.processContext(serverContext, memoryTransport, memoryTransport);
333       }
334 
335       TProcessor processor = processorFactory.getProcessor(this);
336       processor.process(requestProtocol, responseProtocol);
337       TByteArrayOutputStream rawOutput = memoryTransport.getOutput();
338       if (rawOutput.len() == 0) {
339         // This is a oneway request, no response to send back. Waiting for next incoming request.
340         nextPhase = Phase.READING_REQUEST;
341         return;
342       }
343       if (dataProtected) {
344         byte[] outputPayload = saslPeer.wrap(rawOutput.get(), 0, rawOutput.len());
345         responseWriter.withOnlyPayload(outputPayload);
346       } else {
347         responseWriter.withOnlyPayload(rawOutput.get(), 0 ,rawOutput.len());
348       }
349       nextPhase = Phase.WRITING_RESPONSE;
350     } catch (TTransportException e) {
351       failIO(e);
352     } catch (Exception e) {
353       fail(e);
354     }
355   }
356 
357   // Write handlings
358 
handleWritingSaslChallenge()359   private void handleWritingSaslChallenge() {
360     try {
361       saslChallenge.write(underlyingTransport);
362       if (saslChallenge.isComplete()) {
363         saslChallenge.clear();
364         nextPhase = Phase.READING_SASL_RESPONSE;
365       }
366     } catch (TTransportException e) {
367       fail(e);
368     }
369   }
370 
handleWritingSuccessMessage()371   private void handleWritingSuccessMessage() {
372     try {
373       saslChallenge.write(underlyingTransport);
374       if (saslChallenge.isComplete()) {
375         LOGGER.debug("Authentication is done.");
376         saslChallenge = null;
377         saslResponse = null;
378         nextPhase = Phase.READING_REQUEST;
379       }
380     } catch (TTransportException e) {
381       fail(e);
382     }
383   }
384 
handleWritingFailureMessage()385   private void handleWritingFailureMessage() {
386     try {
387       saslChallenge.write(underlyingTransport);
388       if (saslChallenge.isComplete()) {
389         nextPhase = Phase.CLOSING;
390       }
391     } catch (TTransportException e) {
392       fail(e);
393     }
394   }
395 
handleWritingResponse()396   private void handleWritingResponse() {
397     try {
398       responseWriter.write(underlyingTransport);
399       if (responseWriter.isComplete()) {
400         responseWriter.clear();
401         nextPhase = Phase.READING_REQUEST;
402       }
403     } catch (TTransportException e) {
404       fail(e);
405     }
406   }
407 
408   /**
409    * Release all the resources managed by this state machine (connection, selection and sasl server).
410    * To avoid being blocked, this should be invoked in the network thread that manages the selector.
411    */
close()412   public void close() {
413     underlyingTransport.close();
414     selectionKey.cancel();
415     if (saslPeer != null) {
416       saslPeer.dispose();
417     }
418     if (serverContextCreated) {
419       eventHandler.deleteContext(serverContext,
420           inputProtocolFactory.getProtocol(underlyingTransport),
421           outputProtocolFactory.getProtocol(underlyingTransport));
422     }
423     nextPhase = Phase.CLOSED;
424     currentPhase = Phase.CLOSED;
425     LOGGER.trace("Connection closed: {}", underlyingTransport);
426   }
427 
428   public enum Phase {
INITIIALIIZING(INTEREST_READ)429     INITIIALIIZING(INTEREST_READ) {
430       @Override
431       void unsafeRun(NonblockingSaslHandler statemachine) {
432         statemachine.handleInitializing();
433       }
434     },
READING_SASL_RESPONSE(INTEREST_READ)435     READING_SASL_RESPONSE(INTEREST_READ) {
436       @Override
437       void unsafeRun(NonblockingSaslHandler statemachine) {
438         statemachine.handleReadingSaslResponse();
439       }
440     },
EVALUATING_SASL_RESPONSE(INTEREST_NONE)441     EVALUATING_SASL_RESPONSE(INTEREST_NONE) {
442       @Override
443       void unsafeRun(NonblockingSaslHandler statemachine) {
444         statemachine.executeEvaluatingSaslResponse();
445       }
446     },
WRITING_SASL_CHALLENGE(INTEREST_WRITE)447     WRITING_SASL_CHALLENGE(INTEREST_WRITE) {
448       @Override
449       void unsafeRun(NonblockingSaslHandler statemachine) {
450         statemachine.handleWritingSaslChallenge();
451       }
452     },
WRITING_SUCCESS_MESSAGE(INTEREST_WRITE)453     WRITING_SUCCESS_MESSAGE(INTEREST_WRITE) {
454       @Override
455       void unsafeRun(NonblockingSaslHandler statemachine) {
456         statemachine.handleWritingSuccessMessage();
457       }
458     },
WRITING_FAILURE_MESSAGE(INTEREST_WRITE)459     WRITING_FAILURE_MESSAGE(INTEREST_WRITE) {
460       @Override
461       void unsafeRun(NonblockingSaslHandler statemachine) {
462         statemachine.handleWritingFailureMessage();
463       }
464     },
READING_REQUEST(INTEREST_READ)465     READING_REQUEST(INTEREST_READ) {
466       @Override
467       void unsafeRun(NonblockingSaslHandler statemachine) {
468         statemachine.handleReadingRequest();
469       }
470     },
PROCESSING(INTEREST_NONE)471     PROCESSING(INTEREST_NONE) {
472       @Override
473       void unsafeRun(NonblockingSaslHandler statemachine) {
474         statemachine.executeProcessing();
475       }
476     },
WRITING_RESPONSE(INTEREST_WRITE)477     WRITING_RESPONSE(INTEREST_WRITE) {
478       @Override
479       void unsafeRun(NonblockingSaslHandler statemachine) {
480         statemachine.handleWritingResponse();
481       }
482     },
CLOSING(INTEREST_NONE)483     CLOSING(INTEREST_NONE) {
484       @Override
485       void unsafeRun(NonblockingSaslHandler statemachine) {
486         statemachine.close();
487       }
488     },
CLOSED(INTEREST_NONE)489     CLOSED(INTEREST_NONE) {
490       @Override
491       void unsafeRun(NonblockingSaslHandler statemachine) {
492         // Do nothing.
493       }
494     }
495     ;
496 
497     // The interest on the selection key during the phase
498     private int selectionInterest;
499 
Phase(int selectionInterest)500     Phase(int selectionInterest) {
501       this.selectionInterest = selectionInterest;
502     }
503 
504     /**
505      * Provide the execution to run for the state machine in current phase. The execution should
506      * return the next phase after running on the state machine.
507      *
508      * @param statemachine The state machine to run.
509      * @throws IllegalArgumentException if the state machine's current phase is different.
510      * @throws IllegalStateException if the state machine' current phase is already done.
511      */
runStateMachine(NonblockingSaslHandler statemachine)512     void runStateMachine(NonblockingSaslHandler statemachine) {
513       if (statemachine.currentPhase != this) {
514         throw new IllegalArgumentException("State machine is " + statemachine.currentPhase +
515             " but is expected to be " + this);
516       }
517       if (statemachine.isCurrentPhaseDone()) {
518         throw new IllegalStateException("State machine should step into " + statemachine.nextPhase);
519       }
520       unsafeRun(statemachine);
521     }
522 
523     // Run the state machine without checkiing its own phase
524     // It should not be called direcly by users.
unsafeRun(NonblockingSaslHandler statemachine)525     abstract void unsafeRun(NonblockingSaslHandler statemachine);
526   }
527 }
528