1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 package org.apache.thrift.transport; 21 22 import java.io.IOException; 23 import java.util.HashMap; 24 import java.util.Map; 25 26 import javax.security.auth.callback.Callback; 27 import javax.security.auth.callback.CallbackHandler; 28 import javax.security.auth.callback.NameCallback; 29 import javax.security.auth.callback.PasswordCallback; 30 import javax.security.auth.callback.UnsupportedCallbackException; 31 import javax.security.sasl.AuthorizeCallback; 32 import javax.security.sasl.RealmCallback; 33 import javax.security.sasl.Sasl; 34 import javax.security.sasl.SaslClient; 35 import javax.security.sasl.SaslClientFactory; 36 import javax.security.sasl.SaslException; 37 import javax.security.sasl.SaslServer; 38 import javax.security.sasl.SaslServerFactory; 39 40 import junit.framework.TestCase; 41 42 import org.apache.thrift.TProcessor; 43 import org.apache.thrift.protocol.TProtocolFactory; 44 import org.apache.thrift.server.ServerTestBase; 45 import org.apache.thrift.server.TServer; 46 import org.apache.thrift.server.TSimpleServer; 47 import org.apache.thrift.server.TServer.Args; 48 import org.slf4j.Logger; 49 import org.slf4j.LoggerFactory; 50 51 public class TestTSaslTransports extends TestCase { 52 53 private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class); 54 55 private static final String HOST = "localhost"; 56 private static final String SERVICE = "thrift-test"; 57 private static final String PRINCIPAL = "thrift-test-principal"; 58 private static final String PASSWORD = "super secret password"; 59 private static final String REALM = "thrift-test-realm"; 60 61 private static final String UNWRAPPED_MECHANISM = "CRAM-MD5"; 62 private static final Map<String, String> UNWRAPPED_PROPS = null; 63 64 private static final String WRAPPED_MECHANISM = "DIGEST-MD5"; 65 private static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>(); 66 67 static { WRAPPED_PROPS.put(Sasl.QOP, R)68 WRAPPED_PROPS.put(Sasl.QOP, "auth-int"); 69 WRAPPED_PROPS.put("com.sun.security.sasl.digest.realm", REALM); 70 } 71 72 private static final String testMessage1 = "Hello, world! Also, four " 73 + "score and seven years ago our fathers brought forth on this " 74 + "continent a new nation, conceived in liberty, and dedicated to the " 75 + "proposition that all men are created equal."; 76 77 private static final String testMessage2 = "I have a dream that one day " 78 + "this nation will rise up and live out the true meaning of its creed: " 79 + "'We hold these truths to be self-evident, that all men are created equal.'"; 80 81 82 private static class TestSaslCallbackHandler implements CallbackHandler { 83 private final String password; 84 TestSaslCallbackHandler(String password)85 public TestSaslCallbackHandler(String password) { 86 this.password = password; 87 } 88 89 @Override handle(Callback[] callbacks)90 public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { 91 for (Callback c : callbacks) { 92 if (c instanceof NameCallback) { 93 ((NameCallback) c).setName(PRINCIPAL); 94 } else if (c instanceof PasswordCallback) { 95 ((PasswordCallback) c).setPassword(password.toCharArray()); 96 } else if (c instanceof AuthorizeCallback) { 97 ((AuthorizeCallback) c).setAuthorized(true); 98 } else if (c instanceof RealmCallback) { 99 ((RealmCallback) c).setText(REALM); 100 } else { 101 throw new UnsupportedCallbackException(c); 102 } 103 } 104 } 105 } 106 107 private class ServerThread extends Thread { 108 final String mechanism; 109 final Map<String, String> props; 110 volatile Throwable thrown; 111 ServerThread(String mechanism, Map<String, String> props)112 public ServerThread(String mechanism, Map<String, String> props) { 113 this.mechanism = mechanism; 114 this.props = props; 115 } 116 run()117 public void run() { 118 try { 119 internalRun(); 120 } catch (Throwable t) { 121 thrown = t; 122 } 123 } 124 internalRun()125 private void internalRun() throws Exception { 126 TServerSocket serverSocket = new TServerSocket( 127 new TServerSocket.ServerSocketTransportArgs(). 128 port(ServerTestBase.PORT)); 129 try { 130 acceptAndWrite(serverSocket); 131 } finally { 132 serverSocket.close(); 133 } 134 } 135 acceptAndWrite(TServerSocket serverSocket)136 private void acceptAndWrite(TServerSocket serverSocket) 137 throws Exception { 138 TTransport serverTransport = serverSocket.accept(); 139 TTransport saslServerTransport = new TSaslServerTransport( 140 mechanism, SERVICE, HOST, 141 props, new TestSaslCallbackHandler(PASSWORD), serverTransport); 142 143 saslServerTransport.open(); 144 145 byte[] inBuf = new byte[testMessage1.getBytes().length]; 146 // Deliberately read less than the full buffer to ensure 147 // that TSaslTransport is correctly buffering reads. This 148 // will fail for the WRAPPED test, if it doesn't work. 149 saslServerTransport.readAll(inBuf, 0, 5); 150 saslServerTransport.readAll(inBuf, 5, 10); 151 saslServerTransport.readAll(inBuf, 15, inBuf.length - 15); 152 LOGGER.debug("server got: {}", new String(inBuf)); 153 assertEquals(new String(inBuf), testMessage1); 154 155 LOGGER.debug("server writing: {}", testMessage2); 156 saslServerTransport.write(testMessage2.getBytes()); 157 saslServerTransport.flush(); 158 159 saslServerTransport.close(); 160 } 161 } 162 testSaslOpen(final String mechanism, final Map<String, String> props)163 private void testSaslOpen(final String mechanism, final Map<String, String> props) 164 throws Exception { 165 ServerThread serverThread = new ServerThread(mechanism, props); 166 serverThread.start(); 167 168 try { 169 Thread.sleep(1000); 170 } catch (InterruptedException e) { 171 // Ah well. 172 } 173 174 try { 175 TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT); 176 TTransport saslClientTransport = new TSaslClientTransport(mechanism, 177 PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(PASSWORD), clientSocket); 178 saslClientTransport.open(); 179 LOGGER.debug("client writing: {}", testMessage1); 180 saslClientTransport.write(testMessage1.getBytes()); 181 saslClientTransport.flush(); 182 183 byte[] inBuf = new byte[testMessage2.getBytes().length]; 184 saslClientTransport.readAll(inBuf, 0, inBuf.length); 185 LOGGER.debug("client got: {}", new String(inBuf)); 186 assertEquals(new String(inBuf), testMessage2); 187 188 TTransportException expectedException = null; 189 try { 190 saslClientTransport.open(); 191 } catch (TTransportException e) { 192 expectedException = e; 193 } 194 assertNotNull(expectedException); 195 196 saslClientTransport.close(); 197 } catch (Exception e) { 198 LOGGER.warn("Exception caught", e); 199 throw e; 200 } finally { 201 serverThread.interrupt(); 202 try { 203 serverThread.join(); 204 } catch (InterruptedException e) { 205 // Ah well. 206 } 207 assertNull(serverThread.thrown); 208 } 209 } 210 testUnwrappedOpen()211 public void testUnwrappedOpen() throws Exception { 212 testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS); 213 } 214 testWrappedOpen()215 public void testWrappedOpen() throws Exception { 216 testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS); 217 } 218 testAnonymousOpen()219 public void testAnonymousOpen() throws Exception { 220 testSaslOpen("ANONYMOUS", null); 221 } 222 223 /** 224 * Test that we get the proper exceptions thrown back the server when 225 * the client provides invalid password. 226 */ testBadPassword()227 public void testBadPassword() throws Exception { 228 ServerThread serverThread = new ServerThread(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS); 229 serverThread.start(); 230 231 try { 232 Thread.sleep(1000); 233 } catch (InterruptedException e) { 234 // Ah well. 235 } 236 237 boolean clientSidePassed = true; 238 239 try { 240 TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT); 241 TTransport saslClientTransport = new TSaslClientTransport( 242 UNWRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, UNWRAPPED_PROPS, 243 new TestSaslCallbackHandler("NOT THE PASSWORD"), clientSocket); 244 saslClientTransport.open(); 245 clientSidePassed = false; 246 fail("Was able to open transport with bad password"); 247 } catch (TTransportException tte) { 248 LOGGER.error("Exception for bad password", tte); 249 assertNotNull(tte.getMessage()); 250 assertTrue(tte.getMessage().contains("Invalid response")); 251 252 } finally { 253 serverThread.interrupt(); 254 serverThread.join(); 255 256 if (clientSidePassed) { 257 assertNotNull(serverThread.thrown); 258 assertTrue(serverThread.thrown.getMessage().contains("Invalid response")); 259 } 260 } 261 } 262 testWithServer()263 public void testWithServer() throws Exception { 264 new TestTSaslTransportsWithServer().testIt(); 265 } 266 267 private static class TestTSaslTransportsWithServer extends ServerTestBase { 268 269 private Thread serverThread; 270 private TServer server; 271 272 @Override getClientTransport(TTransport underlyingTransport)273 public TTransport getClientTransport(TTransport underlyingTransport) throws Exception { 274 return new TSaslClientTransport( 275 WRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS, 276 new TestSaslCallbackHandler(PASSWORD), underlyingTransport); 277 } 278 279 @Override startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory)280 public void startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory) throws Exception { 281 serverThread = new Thread() { 282 public void run() { 283 try { 284 // Transport 285 TServerSocket socket = new TServerSocket(new TServerSocket.ServerSocketTransportArgs().port(PORT)); 286 287 TTransportFactory factory = new TSaslServerTransport.Factory( 288 WRAPPED_MECHANISM, SERVICE, HOST, WRAPPED_PROPS, 289 new TestSaslCallbackHandler(PASSWORD)); 290 server = new TSimpleServer(new Args(socket).processor(processor).transportFactory(factory).protocolFactory(protoFactory)); 291 292 // Run it 293 LOGGER.debug("Starting the server on port {}", PORT); 294 server.serve(); 295 } catch (Exception e) { 296 e.printStackTrace(); 297 fail(); 298 } 299 } 300 }; 301 serverThread.start(); 302 Thread.sleep(1000); 303 } 304 305 @Override stopServer()306 public void stopServer() throws Exception { 307 server.stop(); 308 try { 309 serverThread.join(); 310 } catch (InterruptedException e) {} 311 } 312 313 } 314 315 316 /** 317 * Implementation of SASL ANONYMOUS, used for testing client-side 318 * initial responses. 319 */ 320 private static class AnonymousClient implements SaslClient { 321 private final String username; 322 private boolean hasProvidedInitialResponse; 323 AnonymousClient(String username)324 public AnonymousClient(String username) { 325 this.username = username; 326 } 327 getMechanismName()328 public String getMechanismName() { return "ANONYMOUS"; } hasInitialResponse()329 public boolean hasInitialResponse() { return true; } evaluateChallenge(byte[] challenge)330 public byte[] evaluateChallenge(byte[] challenge) throws SaslException { 331 if (hasProvidedInitialResponse) { 332 throw new SaslException("Already complete!"); 333 } 334 335 try { 336 hasProvidedInitialResponse = true; 337 return username.getBytes("UTF-8"); 338 } catch (IOException e) { 339 throw new SaslException(e.toString()); 340 } 341 } isComplete()342 public boolean isComplete() { return hasProvidedInitialResponse; } unwrap(byte[] incoming, int offset, int len)343 public byte[] unwrap(byte[] incoming, int offset, int len) { 344 throw new UnsupportedOperationException(); 345 } wrap(byte[] outgoing, int offset, int len)346 public byte[] wrap(byte[] outgoing, int offset, int len) { 347 throw new UnsupportedOperationException(); 348 } getNegotiatedProperty(String propName)349 public Object getNegotiatedProperty(String propName) { return null; } dispose()350 public void dispose() {} 351 } 352 353 private static class AnonymousServer implements SaslServer { 354 private String user; getMechanismName()355 public String getMechanismName() { return "ANONYMOUS"; } evaluateResponse(byte[] response)356 public byte[] evaluateResponse(byte[] response) throws SaslException { 357 try { 358 this.user = new String(response, "UTF-8"); 359 } catch (IOException e) { 360 throw new SaslException(e.toString()); 361 } 362 return null; 363 } isComplete()364 public boolean isComplete() { return user != null; } getAuthorizationID()365 public String getAuthorizationID() { return user; } unwrap(byte[] incoming, int offset, int len)366 public byte[] unwrap(byte[] incoming, int offset, int len) { 367 throw new UnsupportedOperationException(); 368 } wrap(byte[] outgoing, int offset, int len)369 public byte[] wrap(byte[] outgoing, int offset, int len) { 370 throw new UnsupportedOperationException(); 371 } getNegotiatedProperty(String propName)372 public Object getNegotiatedProperty(String propName) { return null; } dispose()373 public void dispose() {} 374 375 } 376 377 public static class SaslAnonymousFactory 378 implements SaslClientFactory, SaslServerFactory { 379 createSaslClient( String[] mechanisms, String authorizationId, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)380 public SaslClient createSaslClient( 381 String[] mechanisms, String authorizationId, String protocol, 382 String serverName, Map<String,?> props, CallbackHandler cbh) 383 { 384 for (String mech : mechanisms) { 385 if ("ANONYMOUS".equals(mech)) { 386 return new AnonymousClient(authorizationId); 387 } 388 } 389 return null; 390 } 391 createSaslServer( String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)392 public SaslServer createSaslServer( 393 String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh) 394 { 395 if ("ANONYMOUS".equals(mechanism)) { 396 return new AnonymousServer(); 397 } 398 return null; 399 } getMechanismNames(Map<String, ?> props)400 public String[] getMechanismNames(Map<String, ?> props) { 401 return new String[] { "ANONYMOUS" }; 402 } 403 } 404 405 static { java.security.Security.addProvider(new SaslAnonymousProvider())406 java.security.Security.addProvider(new SaslAnonymousProvider()); 407 } 408 public static class SaslAnonymousProvider extends java.security.Provider { SaslAnonymousProvider()409 public SaslAnonymousProvider() { 410 super("ThriftSaslAnonymous", 1.0, "Thrift Anonymous SASL provider"); 411 put("SaslClientFactory.ANONYMOUS", SaslAnonymousFactory.class.getName()); 412 put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName()); 413 } 414 } 415 416 private static class MockTTransport extends TTransport { 417 418 byte[] badHeader = null; 419 private TMemoryInputTransport readBuffer = new TMemoryInputTransport(); 420 MockTTransport(int mode)421 public MockTTransport(int mode) { 422 if (mode==1) { 423 // Invalid status byte 424 badHeader = new byte[] { (byte)0xFF, (byte)0x00, (byte)0x00, (byte)0x00, (byte)0x05 }; 425 } else if (mode == 2) { 426 // Valid status byte, negative payload length 427 badHeader = new byte[] { (byte)0x01, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF }; 428 } else if (mode == 3) { 429 // Valid status byte, excessively large, bogus payload length 430 badHeader = new byte[] { (byte)0x01, (byte)0x64, (byte)0x00, (byte)0x00, (byte)0x00 }; 431 } 432 readBuffer.reset(badHeader); 433 } 434 435 @Override isOpen()436 public boolean isOpen() { 437 return true; 438 } 439 440 @Override open()441 public void open() throws TTransportException {} 442 443 @Override close()444 public void close() {} 445 446 @Override read(byte[] buf, int off, int len)447 public int read(byte[] buf, int off, int len) throws TTransportException { 448 return readBuffer.read(buf, off, len); 449 } 450 451 @Override write(byte[] buf, int off, int len)452 public void write(byte[] buf, int off, int len) throws TTransportException {} 453 } 454 testBadHeader()455 public void testBadHeader() { 456 TSaslTransport saslTransport = new TSaslServerTransport(new MockTTransport(1)); 457 try { 458 saslTransport.receiveSaslMessage(); 459 fail("Should have gotten an error due to incorrect status byte value."); 460 } catch (TTransportException e) { 461 assertEquals(e.getMessage(), "Invalid status -1"); 462 } 463 saslTransport = new TSaslServerTransport(new MockTTransport(2)); 464 try { 465 saslTransport.receiveSaslMessage(); 466 fail("Should have gotten an error due to negative payload length."); 467 } catch (TTransportException e) { 468 assertEquals(e.getMessage(), "Invalid payload header length: -1"); 469 } 470 saslTransport = new TSaslServerTransport(new MockTTransport(3)); 471 try { 472 saslTransport.receiveSaslMessage(); 473 fail("Should have gotten an error due to bogus (large) payload length."); 474 } catch (TTransportException e) { 475 assertEquals(e.getMessage(), "Invalid payload header length: 1677721600"); 476 } 477 } 478 } 479