1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 package org.apache.thrift.transport; 21 22 import java.io.IOException; 23 import java.nio.charset.StandardCharsets; 24 import java.util.HashMap; 25 import java.util.Map; 26 27 import javax.security.auth.callback.Callback; 28 import javax.security.auth.callback.CallbackHandler; 29 import javax.security.auth.callback.NameCallback; 30 import javax.security.auth.callback.PasswordCallback; 31 import javax.security.auth.callback.UnsupportedCallbackException; 32 import javax.security.sasl.AuthorizeCallback; 33 import javax.security.sasl.RealmCallback; 34 import javax.security.sasl.Sasl; 35 import javax.security.sasl.SaslClient; 36 import javax.security.sasl.SaslClientFactory; 37 import javax.security.sasl.SaslException; 38 import javax.security.sasl.SaslServer; 39 import javax.security.sasl.SaslServerFactory; 40 41 import junit.framework.TestCase; 42 43 import org.apache.thrift.TConfiguration; 44 import org.apache.thrift.TProcessor; 45 import org.apache.thrift.protocol.TProtocolFactory; 46 import org.apache.thrift.server.ServerTestBase; 47 import org.apache.thrift.server.TServer; 48 import org.apache.thrift.server.TSimpleServer; 49 import org.apache.thrift.server.TServer.Args; 50 import org.slf4j.Logger; 51 import org.slf4j.LoggerFactory; 52 53 public class TestTSaslTransports extends TestCase { 54 55 private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class); 56 57 public static final String HOST = "localhost"; 58 public static final String SERVICE = "thrift-test"; 59 public static final String PRINCIPAL = "thrift-test-principal"; 60 public static final String PASSWORD = "super secret password"; 61 public static final String REALM = "thrift-test-realm"; 62 63 public static final String UNWRAPPED_MECHANISM = "CRAM-MD5"; 64 public static final Map<String, String> UNWRAPPED_PROPS = null; 65 66 public static final String WRAPPED_MECHANISM = "DIGEST-MD5"; 67 public static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>(); 68 69 static { WRAPPED_PROPS.put(Sasl.QOP, R)70 WRAPPED_PROPS.put(Sasl.QOP, "auth-int"); 71 WRAPPED_PROPS.put("com.sun.security.sasl.digest.realm", REALM); 72 } 73 74 private static final String testMessage1 = "Hello, world! Also, four " 75 + "score and seven years ago our fathers brought forth on this " 76 + "continent a new nation, conceived in liberty, and dedicated to the " 77 + "proposition that all men are created equal."; 78 79 private static final String testMessage2 = "I have a dream that one day " 80 + "this nation will rise up and live out the true meaning of its creed: " 81 + "'We hold these truths to be self-evident, that all men are created equal.'"; 82 83 84 public static class TestSaslCallbackHandler implements CallbackHandler { 85 private final String password; 86 TestSaslCallbackHandler(String password)87 public TestSaslCallbackHandler(String password) { 88 this.password = password; 89 } 90 91 @Override handle(Callback[] callbacks)92 public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { 93 for (Callback c : callbacks) { 94 if (c instanceof NameCallback) { 95 ((NameCallback) c).setName(PRINCIPAL); 96 } else if (c instanceof PasswordCallback) { 97 ((PasswordCallback) c).setPassword(password.toCharArray()); 98 } else if (c instanceof AuthorizeCallback) { 99 ((AuthorizeCallback) c).setAuthorized(true); 100 } else if (c instanceof RealmCallback) { 101 ((RealmCallback) c).setText(REALM); 102 } else { 103 throw new UnsupportedCallbackException(c); 104 } 105 } 106 } 107 } 108 109 private class ServerThread extends Thread { 110 final String mechanism; 111 final Map<String, String> props; 112 volatile Throwable thrown; 113 ServerThread(String mechanism, Map<String, String> props)114 public ServerThread(String mechanism, Map<String, String> props) { 115 this.mechanism = mechanism; 116 this.props = props; 117 } 118 run()119 public void run() { 120 try { 121 internalRun(); 122 } catch (Throwable t) { 123 thrown = t; 124 } 125 } 126 internalRun()127 private void internalRun() throws Exception { 128 TServerSocket serverSocket = new TServerSocket( 129 new TServerSocket.ServerSocketTransportArgs(). 130 port(ServerTestBase.PORT)); 131 try { 132 acceptAndWrite(serverSocket); 133 } finally { 134 serverSocket.close(); 135 } 136 } 137 acceptAndWrite(TServerSocket serverSocket)138 private void acceptAndWrite(TServerSocket serverSocket) 139 throws Exception { 140 TTransport serverTransport = serverSocket.accept(); 141 TTransport saslServerTransport = new TSaslServerTransport( 142 mechanism, SERVICE, HOST, 143 props, new TestSaslCallbackHandler(PASSWORD), serverTransport); 144 145 saslServerTransport.open(); 146 147 byte[] inBuf = new byte[testMessage1.getBytes().length]; 148 // Deliberately read less than the full buffer to ensure 149 // that TSaslTransport is correctly buffering reads. This 150 // will fail for the WRAPPED test, if it doesn't work. 151 saslServerTransport.readAll(inBuf, 0, 5); 152 saslServerTransport.readAll(inBuf, 5, 10); 153 saslServerTransport.readAll(inBuf, 15, inBuf.length - 15); 154 LOGGER.debug("server got: {}", new String(inBuf)); 155 assertEquals(new String(inBuf), testMessage1); 156 157 LOGGER.debug("server writing: {}", testMessage2); 158 saslServerTransport.write(testMessage2.getBytes()); 159 saslServerTransport.flush(); 160 161 saslServerTransport.close(); 162 } 163 } 164 testSaslOpen(final String mechanism, final Map<String, String> props)165 private void testSaslOpen(final String mechanism, final Map<String, String> props) 166 throws Exception { 167 ServerThread serverThread = new ServerThread(mechanism, props); 168 serverThread.start(); 169 170 try { 171 Thread.sleep(1000); 172 } catch (InterruptedException e) { 173 // Ah well. 174 } 175 176 try { 177 TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT); 178 TTransport saslClientTransport = new TSaslClientTransport(mechanism, 179 PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(PASSWORD), clientSocket); 180 saslClientTransport.open(); 181 LOGGER.debug("client writing: {}", testMessage1); 182 saslClientTransport.write(testMessage1.getBytes()); 183 saslClientTransport.flush(); 184 185 byte[] inBuf = new byte[testMessage2.getBytes().length]; 186 saslClientTransport.readAll(inBuf, 0, inBuf.length); 187 LOGGER.debug("client got: {}", new String(inBuf)); 188 assertEquals(new String(inBuf), testMessage2); 189 190 TTransportException expectedException = null; 191 try { 192 saslClientTransport.open(); 193 } catch (TTransportException e) { 194 expectedException = e; 195 } 196 assertNotNull(expectedException); 197 198 saslClientTransport.close(); 199 } catch (Exception e) { 200 LOGGER.warn("Exception caught", e); 201 throw e; 202 } finally { 203 serverThread.interrupt(); 204 try { 205 serverThread.join(); 206 } catch (InterruptedException e) { 207 // Ah well. 208 } 209 assertNull(serverThread.thrown); 210 } 211 } 212 testUnwrappedOpen()213 public void testUnwrappedOpen() throws Exception { 214 testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS); 215 } 216 testWrappedOpen()217 public void testWrappedOpen() throws Exception { 218 testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS); 219 } 220 testAnonymousOpen()221 public void testAnonymousOpen() throws Exception { 222 testSaslOpen("ANONYMOUS", null); 223 } 224 225 /** 226 * Test that we get the proper exceptions thrown back the server when 227 * the client provides invalid password. 228 */ testBadPassword()229 public void testBadPassword() throws Exception { 230 ServerThread serverThread = new ServerThread(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS); 231 serverThread.start(); 232 233 try { 234 Thread.sleep(1000); 235 } catch (InterruptedException e) { 236 // Ah well. 237 } 238 239 boolean clientSidePassed = true; 240 241 try { 242 TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT); 243 TTransport saslClientTransport = new TSaslClientTransport( 244 UNWRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, UNWRAPPED_PROPS, 245 new TestSaslCallbackHandler("NOT THE PASSWORD"), clientSocket); 246 saslClientTransport.open(); 247 clientSidePassed = false; 248 fail("Was able to open transport with bad password"); 249 } catch (TTransportException tte) { 250 LOGGER.error("Exception for bad password", tte); 251 assertNotNull(tte.getMessage()); 252 assertTrue(tte.getMessage().contains("Invalid response")); 253 254 } finally { 255 serverThread.interrupt(); 256 serverThread.join(); 257 258 if (clientSidePassed) { 259 assertNotNull(serverThread.thrown); 260 assertTrue(serverThread.thrown.getMessage().contains("Invalid response")); 261 } 262 } 263 } 264 testWithServer()265 public void testWithServer() throws Exception { 266 new TestTSaslTransportsWithServer().testIt(); 267 } 268 269 public static class TestTSaslTransportsWithServer extends ServerTestBase { 270 271 private Thread serverThread; 272 private TServer server; 273 274 @Override getClientTransport(TTransport underlyingTransport)275 public TTransport getClientTransport(TTransport underlyingTransport) throws Exception { 276 return new TSaslClientTransport( 277 WRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS, 278 new TestSaslCallbackHandler(PASSWORD), underlyingTransport); 279 } 280 281 @Override startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory)282 public void startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory) throws Exception { 283 serverThread = new Thread() { 284 public void run() { 285 try { 286 // Transport 287 TServerSocket socket = new TServerSocket(new TServerSocket.ServerSocketTransportArgs().port(PORT)); 288 289 TTransportFactory factory = new TSaslServerTransport.Factory( 290 WRAPPED_MECHANISM, SERVICE, HOST, WRAPPED_PROPS, 291 new TestSaslCallbackHandler(PASSWORD)); 292 server = new TSimpleServer(new Args(socket).processor(processor).transportFactory(factory).protocolFactory(protoFactory)); 293 294 // Run it 295 LOGGER.debug("Starting the server on port {}", PORT); 296 server.serve(); 297 } catch (Exception e) { 298 e.printStackTrace(); 299 fail(); 300 } 301 } 302 }; 303 serverThread.start(); 304 Thread.sleep(1000); 305 } 306 307 @Override stopServer()308 public void stopServer() throws Exception { 309 server.stop(); 310 try { 311 serverThread.join(); 312 } catch (InterruptedException e) {} 313 } 314 315 } 316 317 318 /** 319 * Implementation of SASL ANONYMOUS, used for testing client-side 320 * initial responses. 321 */ 322 private static class AnonymousClient implements SaslClient { 323 private final String username; 324 private boolean hasProvidedInitialResponse; 325 AnonymousClient(String username)326 public AnonymousClient(String username) { 327 this.username = username; 328 } 329 getMechanismName()330 public String getMechanismName() { return "ANONYMOUS"; } hasInitialResponse()331 public boolean hasInitialResponse() { return true; } evaluateChallenge(byte[] challenge)332 public byte[] evaluateChallenge(byte[] challenge) throws SaslException { 333 if (hasProvidedInitialResponse) { 334 throw new SaslException("Already complete!"); 335 } 336 337 hasProvidedInitialResponse = true; 338 return username.getBytes(StandardCharsets.UTF_8); 339 } isComplete()340 public boolean isComplete() { return hasProvidedInitialResponse; } unwrap(byte[] incoming, int offset, int len)341 public byte[] unwrap(byte[] incoming, int offset, int len) { 342 throw new UnsupportedOperationException(); 343 } wrap(byte[] outgoing, int offset, int len)344 public byte[] wrap(byte[] outgoing, int offset, int len) { 345 throw new UnsupportedOperationException(); 346 } getNegotiatedProperty(String propName)347 public Object getNegotiatedProperty(String propName) { return null; } dispose()348 public void dispose() {} 349 } 350 351 private static class AnonymousServer implements SaslServer { 352 private String user; getMechanismName()353 public String getMechanismName() { return "ANONYMOUS"; } evaluateResponse(byte[] response)354 public byte[] evaluateResponse(byte[] response) throws SaslException { 355 this.user = new String(response, StandardCharsets.UTF_8); 356 return null; 357 } isComplete()358 public boolean isComplete() { return user != null; } getAuthorizationID()359 public String getAuthorizationID() { return user; } unwrap(byte[] incoming, int offset, int len)360 public byte[] unwrap(byte[] incoming, int offset, int len) { 361 throw new UnsupportedOperationException(); 362 } wrap(byte[] outgoing, int offset, int len)363 public byte[] wrap(byte[] outgoing, int offset, int len) { 364 throw new UnsupportedOperationException(); 365 } getNegotiatedProperty(String propName)366 public Object getNegotiatedProperty(String propName) { return null; } dispose()367 public void dispose() {} 368 369 } 370 371 public static class SaslAnonymousFactory 372 implements SaslClientFactory, SaslServerFactory { 373 createSaslClient( String[] mechanisms, String authorizationId, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)374 public SaslClient createSaslClient( 375 String[] mechanisms, String authorizationId, String protocol, 376 String serverName, Map<String,?> props, CallbackHandler cbh) 377 { 378 for (String mech : mechanisms) { 379 if ("ANONYMOUS".equals(mech)) { 380 return new AnonymousClient(authorizationId); 381 } 382 } 383 return null; 384 } 385 createSaslServer( String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)386 public SaslServer createSaslServer( 387 String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh) 388 { 389 if ("ANONYMOUS".equals(mechanism)) { 390 return new AnonymousServer(); 391 } 392 return null; 393 } getMechanismNames(Map<String, ?> props)394 public String[] getMechanismNames(Map<String, ?> props) { 395 return new String[] { "ANONYMOUS" }; 396 } 397 } 398 399 static { java.security.Security.addProvider(new SaslAnonymousProvider())400 java.security.Security.addProvider(new SaslAnonymousProvider()); 401 } 402 public static class SaslAnonymousProvider extends java.security.Provider { SaslAnonymousProvider()403 public SaslAnonymousProvider() { 404 super("ThriftSaslAnonymous", 1.0, "Thrift Anonymous SASL provider"); 405 put("SaslClientFactory.ANONYMOUS", SaslAnonymousFactory.class.getName()); 406 put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName()); 407 } 408 } 409 410 private static class MockTTransport extends TTransport { 411 412 byte[] badHeader = null; 413 private TMemoryInputTransport readBuffer; 414 MockTTransport(int mode)415 public MockTTransport(int mode) throws TTransportException { 416 readBuffer = new TMemoryInputTransport(); 417 if (mode==1) { 418 // Invalid status byte 419 badHeader = new byte[] { (byte)0xFF, (byte)0x00, (byte)0x00, (byte)0x00, (byte)0x05 }; 420 } else if (mode == 2) { 421 // Valid status byte, negative payload length 422 badHeader = new byte[] { (byte)0x01, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF }; 423 } else if (mode == 3) { 424 // Valid status byte, excessively large, bogus payload length 425 badHeader = new byte[] { (byte)0x01, (byte)0x64, (byte)0x00, (byte)0x00, (byte)0x00 }; 426 } 427 readBuffer.reset(badHeader); 428 } 429 430 @Override isOpen()431 public boolean isOpen() { 432 return true; 433 } 434 435 @Override open()436 public void open() throws TTransportException {} 437 438 @Override close()439 public void close() {} 440 441 @Override read(byte[] buf, int off, int len)442 public int read(byte[] buf, int off, int len) throws TTransportException { 443 return readBuffer.read(buf, off, len); 444 } 445 446 @Override write(byte[] buf, int off, int len)447 public void write(byte[] buf, int off, int len) throws TTransportException {} 448 449 @Override getConfiguration()450 public TConfiguration getConfiguration() { 451 return readBuffer.getConfiguration(); 452 } 453 454 @Override updateKnownMessageSize(long size)455 public void updateKnownMessageSize(long size) throws TTransportException { 456 readBuffer.updateKnownMessageSize(size); 457 } 458 459 @Override checkReadBytesAvailable(long numBytes)460 public void checkReadBytesAvailable(long numBytes) throws TTransportException { 461 readBuffer.checkReadBytesAvailable(numBytes); 462 } 463 } 464 testBadHeader()465 public void testBadHeader() { 466 TSaslTransport saslTransport; 467 try { 468 saslTransport = new TSaslServerTransport(new MockTTransport(1)); 469 saslTransport.receiveSaslMessage(); 470 fail("Should have gotten an error due to incorrect status byte value."); 471 } catch (TTransportException e) { 472 assertEquals(e.getMessage(), "Invalid status -1"); 473 } 474 try { 475 saslTransport = new TSaslServerTransport(new MockTTransport(2)); 476 saslTransport.receiveSaslMessage(); 477 fail("Should have gotten an error due to negative payload length."); 478 } catch (TTransportException e) { 479 assertEquals(e.getMessage(), "Invalid payload header length: -1"); 480 } 481 try { 482 saslTransport = new TSaslServerTransport(new MockTTransport(3)); 483 saslTransport.receiveSaslMessage(); 484 fail("Should have gotten an error due to bogus (large) payload length."); 485 } catch (TTransportException e) { 486 assertEquals(e.getMessage(), "Invalid payload header length: 1677721600"); 487 } 488 } 489 } 490