1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 package org.apache.thrift.transport.sasl; 21 22 import java.nio.channels.SelectionKey; 23 import java.nio.charset.StandardCharsets; 24 25 import javax.security.sasl.SaslServer; 26 27 import org.apache.thrift.TByteArrayOutputStream; 28 import org.apache.thrift.TProcessor; 29 import org.apache.thrift.protocol.TProtocol; 30 import org.apache.thrift.protocol.TProtocolFactory; 31 import org.apache.thrift.server.ServerContext; 32 import org.apache.thrift.server.TServerEventHandler; 33 import org.apache.thrift.transport.TMemoryTransport; 34 import org.apache.thrift.transport.TNonblockingTransport; 35 import org.apache.thrift.transport.TTransportException; 36 import org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType; 37 import org.slf4j.Logger; 38 import org.slf4j.LoggerFactory; 39 40 import static org.apache.thrift.transport.sasl.NegotiationStatus.COMPLETE; 41 import static org.apache.thrift.transport.sasl.NegotiationStatus.OK; 42 43 /** 44 * State machine managing one sasl connection in a nonblocking way. 45 */ 46 public class NonblockingSaslHandler { 47 private static final Logger LOGGER = LoggerFactory.getLogger(NonblockingSaslHandler.class); 48 49 private static final int INTEREST_NONE = 0; 50 private static final int INTEREST_READ = SelectionKey.OP_READ; 51 private static final int INTEREST_WRITE = SelectionKey.OP_WRITE; 52 53 // Tracking the current running phase 54 private Phase currentPhase = Phase.INITIIALIIZING; 55 // Tracking the next phase on the next invocation of the state machine. 56 // It should be the same as current phase if current phase is not yet finished. 57 // Otherwise, if it is different from current phase, the statemachine is in a transition state: 58 // current phase is done, and next phase is not yet started. 59 private Phase nextPhase = currentPhase; 60 61 // Underlying nonblocking transport 62 private SelectionKey selectionKey; 63 private TNonblockingTransport underlyingTransport; 64 65 // APIs for intercepting event / customizing behaviors: 66 // Factories (decorating the base implementations) & EventHandler (intercepting) 67 private TSaslServerFactory saslServerFactory; 68 private TSaslProcessorFactory processorFactory; 69 private TProtocolFactory inputProtocolFactory; 70 private TProtocolFactory outputProtocolFactory; 71 private TServerEventHandler eventHandler; 72 private ServerContext serverContext; 73 // It turns out the event handler implementation in hive sometimes creates a null ServerContext. 74 // In order to know whether TServerEventHandler#createContext is called we use such a flag. 75 private boolean serverContextCreated = false; 76 77 // Wrapper around sasl server 78 private ServerSaslPeer saslPeer; 79 80 // Sasl negotiation io 81 private SaslNegotiationFrameReader saslResponse; 82 private SaslNegotiationFrameWriter saslChallenge; 83 // IO for request from and response to the socket 84 private DataFrameReader requestReader; 85 private DataFrameWriter responseWriter; 86 // If sasl is negotiated for integrity/confidentiality protection 87 private boolean dataProtected; 88 NonblockingSaslHandler(SelectionKey selectionKey, TNonblockingTransport underlyingTransport, TSaslServerFactory saslServerFactory, TSaslProcessorFactory processorFactory, TProtocolFactory inputProtocolFactory, TProtocolFactory outputProtocolFactory, TServerEventHandler eventHandler)89 public NonblockingSaslHandler(SelectionKey selectionKey, TNonblockingTransport underlyingTransport, 90 TSaslServerFactory saslServerFactory, TSaslProcessorFactory processorFactory, 91 TProtocolFactory inputProtocolFactory, TProtocolFactory outputProtocolFactory, 92 TServerEventHandler eventHandler) { 93 this.selectionKey = selectionKey; 94 this.underlyingTransport = underlyingTransport; 95 this.saslServerFactory = saslServerFactory; 96 this.processorFactory = processorFactory; 97 this.inputProtocolFactory = inputProtocolFactory; 98 this.outputProtocolFactory = outputProtocolFactory; 99 this.eventHandler = eventHandler; 100 101 saslResponse = new SaslNegotiationFrameReader(); 102 saslChallenge = new SaslNegotiationFrameWriter(); 103 requestReader = new DataFrameReader(); 104 responseWriter = new DataFrameWriter(); 105 } 106 107 /** 108 * Get current phase of the state machine. 109 * 110 * @return current phase. 111 */ getCurrentPhase()112 public Phase getCurrentPhase() { 113 return currentPhase; 114 } 115 116 /** 117 * Get next phase of the state machine. 118 * It is different from current phase iff current phase is done (and next phase not yet started). 119 * 120 * @return next phase. 121 */ getNextPhase()122 public Phase getNextPhase() { 123 return nextPhase; 124 } 125 126 /** 127 * 128 * @return underlying nonblocking socket 129 */ getUnderlyingTransport()130 public TNonblockingTransport getUnderlyingTransport() { 131 return underlyingTransport; 132 } 133 134 /** 135 * 136 * @return SaslServer instance 137 */ getSaslServer()138 public SaslServer getSaslServer() { 139 return saslPeer.getSaslServer(); 140 } 141 142 /** 143 * 144 * @return true if current phase is done. 145 */ isCurrentPhaseDone()146 public boolean isCurrentPhaseDone() { 147 return currentPhase != nextPhase; 148 } 149 150 /** 151 * Run state machine. 152 * 153 * @throws IllegalStateException if current state is already done. 154 */ runCurrentPhase()155 public void runCurrentPhase() { 156 currentPhase.runStateMachine(this); 157 } 158 159 /** 160 * When current phase is intrested in read selection, calling this will run the current phase and 161 * its following phases if the following ones are interested to read, until there is nothing 162 * available in the underlying transport. 163 * 164 * @throws IllegalStateException if is called in an irrelevant phase. 165 */ handleRead()166 public void handleRead() { 167 handleOps(INTEREST_READ); 168 } 169 170 /** 171 * Similiar to handleRead. But it is for write ops. 172 * 173 * @throws IllegalStateException if it is called in an irrelevant phase. 174 */ handleWrite()175 public void handleWrite() { 176 handleOps(INTEREST_WRITE); 177 } 178 handleOps(int interestOps)179 private void handleOps(int interestOps) { 180 if (currentPhase.selectionInterest != interestOps) { 181 throw new IllegalStateException("Current phase " + currentPhase + " but got interest " + 182 interestOps); 183 } 184 runCurrentPhase(); 185 if (isCurrentPhaseDone() && nextPhase.selectionInterest == interestOps) { 186 stepToNextPhase(); 187 handleOps(interestOps); 188 } 189 } 190 191 /** 192 * When current phase is finished, it's expected to call this method first before running the 193 * state machine again. 194 * By calling this, "next phase" is marked as started (and not done), thus is ready to run. 195 * 196 * @throws IllegalArgumentException if current phase is not yet done. 197 */ stepToNextPhase()198 public void stepToNextPhase() { 199 if (!isCurrentPhaseDone()) { 200 throw new IllegalArgumentException("Not yet done with current phase: " + currentPhase); 201 } 202 LOGGER.debug("Switch phase {} to {}", currentPhase, nextPhase); 203 switch (nextPhase) { 204 case INITIIALIIZING: 205 throw new IllegalStateException("INITIALIZING cannot be the next phase of " + currentPhase); 206 default: 207 } 208 // If next phase's interest is not the same as current, nor the same as the selection key, 209 // we need to change interest on the selector. 210 if (!(nextPhase.selectionInterest == currentPhase.selectionInterest || 211 nextPhase.selectionInterest == selectionKey.interestOps())) { 212 changeSelectionInterest(nextPhase.selectionInterest); 213 } 214 currentPhase = nextPhase; 215 } 216 changeSelectionInterest(int selectionInterest)217 private void changeSelectionInterest(int selectionInterest) { 218 selectionKey.interestOps(selectionInterest); 219 } 220 221 // sasl negotiaion failure handling failSaslNegotiation(TSaslNegotiationException e)222 private void failSaslNegotiation(TSaslNegotiationException e) { 223 LOGGER.error("Sasl negotiation failed", e); 224 String errorMsg = e.getDetails(); 225 saslChallenge.withHeaderAndPayload(new byte[]{e.getErrorType().code.getValue()}, 226 errorMsg.getBytes(StandardCharsets.UTF_8)); 227 nextPhase = Phase.WRITING_FAILURE_MESSAGE; 228 } 229 fail(Exception e)230 private void fail(Exception e) { 231 LOGGER.error("Failed io in " + currentPhase, e); 232 nextPhase = Phase.CLOSING; 233 } 234 failIO(TTransportException e)235 private void failIO(TTransportException e) { 236 StringBuilder errorMsg = new StringBuilder("IO failure ") 237 .append(e.getType()) 238 .append(" in ") 239 .append(currentPhase); 240 if (e.getMessage() != null) { 241 errorMsg.append(": ").append(e.getMessage()); 242 } 243 LOGGER.error(errorMsg.toString(), e); 244 nextPhase = Phase.CLOSING; 245 } 246 247 // Read handlings 248 handleInitializing()249 private void handleInitializing() { 250 try { 251 saslResponse.read(underlyingTransport); 252 if (saslResponse.isComplete()) { 253 SaslNegotiationHeaderReader startHeader = saslResponse.getHeader(); 254 if (startHeader.getStatus() != NegotiationStatus.START) { 255 throw new TInvalidSaslFrameException("Expecting START status but got " + startHeader.getStatus()); 256 } 257 String mechanism = new String(saslResponse.getPayload(), StandardCharsets.UTF_8); 258 saslPeer = saslServerFactory.getSaslPeer(mechanism); 259 saslResponse.clear(); 260 nextPhase = Phase.READING_SASL_RESPONSE; 261 } 262 } catch (TSaslNegotiationException e) { 263 failSaslNegotiation(e); 264 } catch (TTransportException e) { 265 failIO(e); 266 } 267 } 268 handleReadingSaslResponse()269 private void handleReadingSaslResponse() { 270 try { 271 saslResponse.read(underlyingTransport); 272 if (saslResponse.isComplete()) { 273 nextPhase = Phase.EVALUATING_SASL_RESPONSE; 274 } 275 } catch (TSaslNegotiationException e) { 276 failSaslNegotiation(e); 277 } catch (TTransportException e) { 278 failIO(e); 279 } 280 } 281 handleReadingRequest()282 private void handleReadingRequest() { 283 try { 284 requestReader.read(underlyingTransport); 285 if (requestReader.isComplete()) { 286 nextPhase = Phase.PROCESSING; 287 } 288 } catch (TTransportException e) { 289 failIO(e); 290 } 291 } 292 293 // Computation executions 294 executeEvaluatingSaslResponse()295 private void executeEvaluatingSaslResponse() { 296 if (!(saslResponse.getHeader().getStatus() == OK || saslResponse.getHeader().getStatus() == COMPLETE)) { 297 String error = "Expect status OK or COMPLETE, but got " + saslResponse.getHeader().getStatus(); 298 failSaslNegotiation(new TSaslNegotiationException(ErrorType.PROTOCOL_ERROR, error)); 299 return; 300 } 301 try { 302 byte[] response = saslResponse.getPayload(); 303 saslResponse.clear(); 304 byte[] newChallenge = saslPeer.evaluate(response); 305 if (saslPeer.isAuthenticated()) { 306 dataProtected = saslPeer.isDataProtected(); 307 saslChallenge.withHeaderAndPayload(new byte[]{COMPLETE.getValue()}, newChallenge); 308 nextPhase = Phase.WRITING_SUCCESS_MESSAGE; 309 } else { 310 saslChallenge.withHeaderAndPayload(new byte[]{OK.getValue()}, newChallenge); 311 nextPhase = Phase.WRITING_SASL_CHALLENGE; 312 } 313 } catch (TSaslNegotiationException e) { 314 failSaslNegotiation(e); 315 } 316 } 317 executeProcessing()318 private void executeProcessing() { 319 try { 320 byte[] inputPayload = requestReader.getPayload(); 321 requestReader.clear(); 322 byte[] rawInput = dataProtected ? saslPeer.unwrap(inputPayload) : inputPayload; 323 TMemoryTransport memoryTransport = new TMemoryTransport(rawInput); 324 TProtocol requestProtocol = inputProtocolFactory.getProtocol(memoryTransport); 325 TProtocol responseProtocol = outputProtocolFactory.getProtocol(memoryTransport); 326 327 if (eventHandler != null) { 328 if (!serverContextCreated) { 329 serverContext = eventHandler.createContext(requestProtocol, responseProtocol); 330 serverContextCreated = true; 331 } 332 eventHandler.processContext(serverContext, memoryTransport, memoryTransport); 333 } 334 335 TProcessor processor = processorFactory.getProcessor(this); 336 processor.process(requestProtocol, responseProtocol); 337 TByteArrayOutputStream rawOutput = memoryTransport.getOutput(); 338 if (rawOutput.len() == 0) { 339 // This is a oneway request, no response to send back. Waiting for next incoming request. 340 nextPhase = Phase.READING_REQUEST; 341 return; 342 } 343 if (dataProtected) { 344 byte[] outputPayload = saslPeer.wrap(rawOutput.get(), 0, rawOutput.len()); 345 responseWriter.withOnlyPayload(outputPayload); 346 } else { 347 responseWriter.withOnlyPayload(rawOutput.get(), 0 ,rawOutput.len()); 348 } 349 nextPhase = Phase.WRITING_RESPONSE; 350 } catch (TTransportException e) { 351 failIO(e); 352 } catch (Exception e) { 353 fail(e); 354 } 355 } 356 357 // Write handlings 358 handleWritingSaslChallenge()359 private void handleWritingSaslChallenge() { 360 try { 361 saslChallenge.write(underlyingTransport); 362 if (saslChallenge.isComplete()) { 363 saslChallenge.clear(); 364 nextPhase = Phase.READING_SASL_RESPONSE; 365 } 366 } catch (TTransportException e) { 367 fail(e); 368 } 369 } 370 handleWritingSuccessMessage()371 private void handleWritingSuccessMessage() { 372 try { 373 saslChallenge.write(underlyingTransport); 374 if (saslChallenge.isComplete()) { 375 LOGGER.debug("Authentication is done."); 376 saslChallenge = null; 377 saslResponse = null; 378 nextPhase = Phase.READING_REQUEST; 379 } 380 } catch (TTransportException e) { 381 fail(e); 382 } 383 } 384 handleWritingFailureMessage()385 private void handleWritingFailureMessage() { 386 try { 387 saslChallenge.write(underlyingTransport); 388 if (saslChallenge.isComplete()) { 389 nextPhase = Phase.CLOSING; 390 } 391 } catch (TTransportException e) { 392 fail(e); 393 } 394 } 395 handleWritingResponse()396 private void handleWritingResponse() { 397 try { 398 responseWriter.write(underlyingTransport); 399 if (responseWriter.isComplete()) { 400 responseWriter.clear(); 401 nextPhase = Phase.READING_REQUEST; 402 } 403 } catch (TTransportException e) { 404 fail(e); 405 } 406 } 407 408 /** 409 * Release all the resources managed by this state machine (connection, selection and sasl server). 410 * To avoid being blocked, this should be invoked in the network thread that manages the selector. 411 */ close()412 public void close() { 413 underlyingTransport.close(); 414 selectionKey.cancel(); 415 if (saslPeer != null) { 416 saslPeer.dispose(); 417 } 418 if (serverContextCreated) { 419 eventHandler.deleteContext(serverContext, 420 inputProtocolFactory.getProtocol(underlyingTransport), 421 outputProtocolFactory.getProtocol(underlyingTransport)); 422 } 423 nextPhase = Phase.CLOSED; 424 currentPhase = Phase.CLOSED; 425 LOGGER.trace("Connection closed: {}", underlyingTransport); 426 } 427 428 public enum Phase { INITIIALIIZING(INTEREST_READ)429 INITIIALIIZING(INTEREST_READ) { 430 @Override 431 void unsafeRun(NonblockingSaslHandler statemachine) { 432 statemachine.handleInitializing(); 433 } 434 }, READING_SASL_RESPONSE(INTEREST_READ)435 READING_SASL_RESPONSE(INTEREST_READ) { 436 @Override 437 void unsafeRun(NonblockingSaslHandler statemachine) { 438 statemachine.handleReadingSaslResponse(); 439 } 440 }, EVALUATING_SASL_RESPONSE(INTEREST_NONE)441 EVALUATING_SASL_RESPONSE(INTEREST_NONE) { 442 @Override 443 void unsafeRun(NonblockingSaslHandler statemachine) { 444 statemachine.executeEvaluatingSaslResponse(); 445 } 446 }, WRITING_SASL_CHALLENGE(INTEREST_WRITE)447 WRITING_SASL_CHALLENGE(INTEREST_WRITE) { 448 @Override 449 void unsafeRun(NonblockingSaslHandler statemachine) { 450 statemachine.handleWritingSaslChallenge(); 451 } 452 }, WRITING_SUCCESS_MESSAGE(INTEREST_WRITE)453 WRITING_SUCCESS_MESSAGE(INTEREST_WRITE) { 454 @Override 455 void unsafeRun(NonblockingSaslHandler statemachine) { 456 statemachine.handleWritingSuccessMessage(); 457 } 458 }, WRITING_FAILURE_MESSAGE(INTEREST_WRITE)459 WRITING_FAILURE_MESSAGE(INTEREST_WRITE) { 460 @Override 461 void unsafeRun(NonblockingSaslHandler statemachine) { 462 statemachine.handleWritingFailureMessage(); 463 } 464 }, READING_REQUEST(INTEREST_READ)465 READING_REQUEST(INTEREST_READ) { 466 @Override 467 void unsafeRun(NonblockingSaslHandler statemachine) { 468 statemachine.handleReadingRequest(); 469 } 470 }, PROCESSING(INTEREST_NONE)471 PROCESSING(INTEREST_NONE) { 472 @Override 473 void unsafeRun(NonblockingSaslHandler statemachine) { 474 statemachine.executeProcessing(); 475 } 476 }, WRITING_RESPONSE(INTEREST_WRITE)477 WRITING_RESPONSE(INTEREST_WRITE) { 478 @Override 479 void unsafeRun(NonblockingSaslHandler statemachine) { 480 statemachine.handleWritingResponse(); 481 } 482 }, CLOSING(INTEREST_NONE)483 CLOSING(INTEREST_NONE) { 484 @Override 485 void unsafeRun(NonblockingSaslHandler statemachine) { 486 statemachine.close(); 487 } 488 }, CLOSED(INTEREST_NONE)489 CLOSED(INTEREST_NONE) { 490 @Override 491 void unsafeRun(NonblockingSaslHandler statemachine) { 492 // Do nothing. 493 } 494 } 495 ; 496 497 // The interest on the selection key during the phase 498 private int selectionInterest; 499 Phase(int selectionInterest)500 Phase(int selectionInterest) { 501 this.selectionInterest = selectionInterest; 502 } 503 504 /** 505 * Provide the execution to run for the state machine in current phase. The execution should 506 * return the next phase after running on the state machine. 507 * 508 * @param statemachine The state machine to run. 509 * @throws IllegalArgumentException if the state machine's current phase is different. 510 * @throws IllegalStateException if the state machine' current phase is already done. 511 */ runStateMachine(NonblockingSaslHandler statemachine)512 void runStateMachine(NonblockingSaslHandler statemachine) { 513 if (statemachine.currentPhase != this) { 514 throw new IllegalArgumentException("State machine is " + statemachine.currentPhase + 515 " but is expected to be " + this); 516 } 517 if (statemachine.isCurrentPhaseDone()) { 518 throw new IllegalStateException("State machine should step into " + statemachine.nextPhase); 519 } 520 unsafeRun(statemachine); 521 } 522 523 // Run the state machine without checkiing its own phase 524 // It should not be called direcly by users. unsafeRun(NonblockingSaslHandler statemachine)525 abstract void unsafeRun(NonblockingSaslHandler statemachine); 526 } 527 } 528