1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 package org.apache.thrift.transport;
21 
22 import java.io.IOException;
23 import java.nio.charset.StandardCharsets;
24 import java.util.HashMap;
25 import java.util.Map;
26 
27 import javax.security.auth.callback.Callback;
28 import javax.security.auth.callback.CallbackHandler;
29 import javax.security.auth.callback.NameCallback;
30 import javax.security.auth.callback.PasswordCallback;
31 import javax.security.auth.callback.UnsupportedCallbackException;
32 import javax.security.sasl.AuthorizeCallback;
33 import javax.security.sasl.RealmCallback;
34 import javax.security.sasl.Sasl;
35 import javax.security.sasl.SaslClient;
36 import javax.security.sasl.SaslClientFactory;
37 import javax.security.sasl.SaslException;
38 import javax.security.sasl.SaslServer;
39 import javax.security.sasl.SaslServerFactory;
40 
41 import junit.framework.TestCase;
42 
43 import org.apache.thrift.TConfiguration;
44 import org.apache.thrift.TProcessor;
45 import org.apache.thrift.protocol.TProtocolFactory;
46 import org.apache.thrift.server.ServerTestBase;
47 import org.apache.thrift.server.TServer;
48 import org.apache.thrift.server.TSimpleServer;
49 import org.apache.thrift.server.TServer.Args;
50 import org.slf4j.Logger;
51 import org.slf4j.LoggerFactory;
52 
53 public class TestTSaslTransports extends TestCase {
54 
55   private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class);
56 
57   public static final String HOST = "localhost";
58   public static final String SERVICE = "thrift-test";
59   public static final String PRINCIPAL = "thrift-test-principal";
60   public static final String PASSWORD = "super secret password";
61   public static final String REALM = "thrift-test-realm";
62 
63   public static final String UNWRAPPED_MECHANISM = "CRAM-MD5";
64   public static final Map<String, String> UNWRAPPED_PROPS = null;
65 
66   public static final String WRAPPED_MECHANISM = "DIGEST-MD5";
67   public static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>();
68 
69   static {
WRAPPED_PROPS.put(Sasl.QOP, R)70     WRAPPED_PROPS.put(Sasl.QOP, "auth-int");
71     WRAPPED_PROPS.put("com.sun.security.sasl.digest.realm", REALM);
72   }
73 
74   private static final String testMessage1 = "Hello, world! Also, four "
75       + "score and seven years ago our fathers brought forth on this "
76       + "continent a new nation, conceived in liberty, and dedicated to the "
77       + "proposition that all men are created equal.";
78 
79   private static final String testMessage2 = "I have a dream that one day "
80       + "this nation will rise up and live out the true meaning of its creed: "
81       + "'We hold these truths to be self-evident, that all men are created equal.'";
82 
83 
84   public static class TestSaslCallbackHandler implements CallbackHandler {
85     private final String password;
86 
TestSaslCallbackHandler(String password)87     public TestSaslCallbackHandler(String password) {
88       this.password = password;
89     }
90 
91     @Override
handle(Callback[] callbacks)92     public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
93       for (Callback c : callbacks) {
94         if (c instanceof NameCallback) {
95           ((NameCallback) c).setName(PRINCIPAL);
96         } else if (c instanceof PasswordCallback) {
97           ((PasswordCallback) c).setPassword(password.toCharArray());
98         } else if (c instanceof AuthorizeCallback) {
99           ((AuthorizeCallback) c).setAuthorized(true);
100         } else if (c instanceof RealmCallback) {
101           ((RealmCallback) c).setText(REALM);
102         } else {
103           throw new UnsupportedCallbackException(c);
104         }
105       }
106     }
107   }
108 
109   private class ServerThread extends Thread {
110     final String mechanism;
111     final Map<String, String> props;
112     volatile Throwable thrown;
113 
ServerThread(String mechanism, Map<String, String> props)114     public ServerThread(String mechanism, Map<String, String> props) {
115       this.mechanism = mechanism;
116       this.props = props;
117     }
118 
run()119     public void run() {
120       try {
121         internalRun();
122       } catch (Throwable t) {
123         thrown = t;
124       }
125     }
126 
internalRun()127     private void internalRun() throws Exception {
128       TServerSocket serverSocket = new TServerSocket(
129         new TServerSocket.ServerSocketTransportArgs().
130           port(ServerTestBase.PORT));
131       try {
132         acceptAndWrite(serverSocket);
133       } finally {
134         serverSocket.close();
135       }
136     }
137 
acceptAndWrite(TServerSocket serverSocket)138     private void acceptAndWrite(TServerSocket serverSocket)
139       throws Exception {
140       TTransport serverTransport = serverSocket.accept();
141       TTransport saslServerTransport = new TSaslServerTransport(
142         mechanism, SERVICE, HOST,
143         props, new TestSaslCallbackHandler(PASSWORD), serverTransport);
144 
145       saslServerTransport.open();
146 
147       byte[] inBuf = new byte[testMessage1.getBytes().length];
148       // Deliberately read less than the full buffer to ensure
149       // that TSaslTransport is correctly buffering reads. This
150       // will fail for the WRAPPED test, if it doesn't work.
151       saslServerTransport.readAll(inBuf, 0, 5);
152       saslServerTransport.readAll(inBuf, 5, 10);
153       saslServerTransport.readAll(inBuf, 15, inBuf.length - 15);
154       LOGGER.debug("server got: {}", new String(inBuf));
155       assertEquals(new String(inBuf), testMessage1);
156 
157       LOGGER.debug("server writing: {}", testMessage2);
158       saslServerTransport.write(testMessage2.getBytes());
159       saslServerTransport.flush();
160 
161       saslServerTransport.close();
162     }
163   }
164 
testSaslOpen(final String mechanism, final Map<String, String> props)165   private void testSaslOpen(final String mechanism, final Map<String, String> props)
166       throws Exception {
167     ServerThread serverThread = new ServerThread(mechanism, props);
168     serverThread.start();
169 
170     try {
171       Thread.sleep(1000);
172     } catch (InterruptedException e) {
173       // Ah well.
174     }
175 
176     try {
177       TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
178       TTransport saslClientTransport = new TSaslClientTransport(mechanism,
179                                                                 PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(PASSWORD), clientSocket);
180       saslClientTransport.open();
181       LOGGER.debug("client writing: {}", testMessage1);
182       saslClientTransport.write(testMessage1.getBytes());
183       saslClientTransport.flush();
184 
185       byte[] inBuf = new byte[testMessage2.getBytes().length];
186       saslClientTransport.readAll(inBuf, 0, inBuf.length);
187       LOGGER.debug("client got: {}", new String(inBuf));
188       assertEquals(new String(inBuf), testMessage2);
189 
190       TTransportException expectedException = null;
191       try {
192         saslClientTransport.open();
193       } catch (TTransportException e) {
194         expectedException = e;
195       }
196       assertNotNull(expectedException);
197 
198       saslClientTransport.close();
199     } catch (Exception e) {
200       LOGGER.warn("Exception caught", e);
201       throw e;
202     } finally {
203       serverThread.interrupt();
204       try {
205         serverThread.join();
206       } catch (InterruptedException e) {
207         // Ah well.
208       }
209       assertNull(serverThread.thrown);
210     }
211   }
212 
testUnwrappedOpen()213   public void testUnwrappedOpen() throws Exception {
214     testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
215   }
216 
testWrappedOpen()217   public void testWrappedOpen() throws Exception {
218     testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS);
219   }
220 
testAnonymousOpen()221   public void testAnonymousOpen() throws Exception {
222     testSaslOpen("ANONYMOUS", null);
223   }
224 
225   /**
226    * Test that we get the proper exceptions thrown back the server when
227    * the client provides invalid password.
228    */
testBadPassword()229   public void testBadPassword() throws Exception {
230     ServerThread serverThread = new ServerThread(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
231     serverThread.start();
232 
233     try {
234       Thread.sleep(1000);
235     } catch (InterruptedException e) {
236       // Ah well.
237     }
238 
239     boolean clientSidePassed = true;
240 
241     try {
242       TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
243       TTransport saslClientTransport = new TSaslClientTransport(
244         UNWRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, UNWRAPPED_PROPS,
245         new TestSaslCallbackHandler("NOT THE PASSWORD"), clientSocket);
246       saslClientTransport.open();
247       clientSidePassed = false;
248       fail("Was able to open transport with bad password");
249     } catch (TTransportException tte) {
250       LOGGER.error("Exception for bad password", tte);
251       assertNotNull(tte.getMessage());
252       assertTrue(tte.getMessage().contains("Invalid response"));
253 
254     } finally {
255       serverThread.interrupt();
256       serverThread.join();
257 
258       if (clientSidePassed) {
259         assertNotNull(serverThread.thrown);
260         assertTrue(serverThread.thrown.getMessage().contains("Invalid response"));
261       }
262     }
263   }
264 
testWithServer()265   public void testWithServer() throws Exception {
266     new TestTSaslTransportsWithServer().testIt();
267   }
268 
269   public static class TestTSaslTransportsWithServer extends ServerTestBase {
270 
271     private Thread serverThread;
272     private TServer server;
273 
274     @Override
getClientTransport(TTransport underlyingTransport)275     public TTransport getClientTransport(TTransport underlyingTransport) throws Exception {
276       return new TSaslClientTransport(
277         WRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS,
278         new TestSaslCallbackHandler(PASSWORD), underlyingTransport);
279     }
280 
281     @Override
startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory)282     public void startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory) throws Exception {
283       serverThread = new Thread() {
284         public void run() {
285           try {
286             // Transport
287             TServerSocket socket = new TServerSocket(new TServerSocket.ServerSocketTransportArgs().port(PORT));
288 
289             TTransportFactory factory = new TSaslServerTransport.Factory(
290               WRAPPED_MECHANISM, SERVICE, HOST, WRAPPED_PROPS,
291               new TestSaslCallbackHandler(PASSWORD));
292             server = new TSimpleServer(new Args(socket).processor(processor).transportFactory(factory).protocolFactory(protoFactory));
293 
294             // Run it
295             LOGGER.debug("Starting the server on port {}", PORT);
296             server.serve();
297           } catch (Exception e) {
298             e.printStackTrace();
299             fail();
300           }
301         }
302       };
303       serverThread.start();
304       Thread.sleep(1000);
305     }
306 
307     @Override
stopServer()308     public void stopServer() throws Exception {
309       server.stop();
310       try {
311         serverThread.join();
312       } catch (InterruptedException e) {}
313     }
314 
315   }
316 
317 
318   /**
319    * Implementation of SASL ANONYMOUS, used for testing client-side
320    * initial responses.
321    */
322   private static class AnonymousClient implements SaslClient {
323     private final String username;
324     private boolean hasProvidedInitialResponse;
325 
AnonymousClient(String username)326     public AnonymousClient(String username) {
327       this.username = username;
328     }
329 
getMechanismName()330     public String getMechanismName() { return "ANONYMOUS"; }
hasInitialResponse()331     public boolean hasInitialResponse() { return true; }
evaluateChallenge(byte[] challenge)332     public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
333       if (hasProvidedInitialResponse) {
334         throw new SaslException("Already complete!");
335       }
336 
337       hasProvidedInitialResponse = true;
338       return username.getBytes(StandardCharsets.UTF_8);
339     }
isComplete()340     public boolean isComplete() { return hasProvidedInitialResponse; }
unwrap(byte[] incoming, int offset, int len)341     public byte[] unwrap(byte[] incoming, int offset, int len) {
342       throw new UnsupportedOperationException();
343     }
wrap(byte[] outgoing, int offset, int len)344     public byte[] wrap(byte[] outgoing, int offset, int len) {
345       throw new UnsupportedOperationException();
346     }
getNegotiatedProperty(String propName)347     public Object getNegotiatedProperty(String propName) { return null; }
dispose()348     public void dispose() {}
349   }
350 
351   private static class AnonymousServer implements SaslServer {
352     private String user;
getMechanismName()353     public String getMechanismName() { return "ANONYMOUS"; }
evaluateResponse(byte[] response)354     public byte[] evaluateResponse(byte[] response) throws SaslException {
355       this.user = new String(response, StandardCharsets.UTF_8);
356       return null;
357     }
isComplete()358     public boolean isComplete() { return user != null; }
getAuthorizationID()359     public String getAuthorizationID() { return user; }
unwrap(byte[] incoming, int offset, int len)360     public byte[] unwrap(byte[] incoming, int offset, int len) {
361       throw new UnsupportedOperationException();
362     }
wrap(byte[] outgoing, int offset, int len)363     public byte[] wrap(byte[] outgoing, int offset, int len) {
364       throw new UnsupportedOperationException();
365     }
getNegotiatedProperty(String propName)366     public Object getNegotiatedProperty(String propName) { return null; }
dispose()367     public void dispose() {}
368 
369   }
370 
371   public static class SaslAnonymousFactory
372     implements SaslClientFactory, SaslServerFactory {
373 
createSaslClient( String[] mechanisms, String authorizationId, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)374     public SaslClient createSaslClient(
375       String[] mechanisms, String authorizationId, String protocol,
376       String serverName, Map<String,?> props, CallbackHandler cbh)
377     {
378       for (String mech : mechanisms) {
379         if ("ANONYMOUS".equals(mech)) {
380           return new AnonymousClient(authorizationId);
381         }
382       }
383       return null;
384     }
385 
createSaslServer( String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)386     public SaslServer createSaslServer(
387       String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)
388     {
389       if ("ANONYMOUS".equals(mechanism)) {
390         return new AnonymousServer();
391       }
392       return null;
393     }
getMechanismNames(Map<String, ?> props)394     public String[] getMechanismNames(Map<String, ?> props) {
395       return new String[] { "ANONYMOUS" };
396     }
397   }
398 
399   static {
java.security.Security.addProvider(new SaslAnonymousProvider())400     java.security.Security.addProvider(new SaslAnonymousProvider());
401   }
402   public static class SaslAnonymousProvider extends java.security.Provider {
SaslAnonymousProvider()403     public SaslAnonymousProvider() {
404       super("ThriftSaslAnonymous", 1.0, "Thrift Anonymous SASL provider");
405       put("SaslClientFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
406       put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
407     }
408   }
409 
410   private static class MockTTransport extends TTransport {
411 
412     byte[] badHeader = null;
413     private TMemoryInputTransport readBuffer;
414 
MockTTransport(int mode)415     public MockTTransport(int mode) throws TTransportException {
416       readBuffer = new TMemoryInputTransport();
417       if (mode==1) {
418         // Invalid status byte
419         badHeader = new byte[] { (byte)0xFF, (byte)0x00, (byte)0x00, (byte)0x00, (byte)0x05 };
420       } else if (mode == 2) {
421         // Valid status byte, negative payload length
422         badHeader = new byte[] { (byte)0x01, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF };
423       } else if (mode == 3) {
424         // Valid status byte, excessively large, bogus payload length
425         badHeader = new byte[] { (byte)0x01, (byte)0x64, (byte)0x00, (byte)0x00, (byte)0x00 };
426       }
427       readBuffer.reset(badHeader);
428     }
429 
430     @Override
isOpen()431     public boolean isOpen() {
432       return true;
433     }
434 
435     @Override
open()436     public void open() throws TTransportException {}
437 
438     @Override
close()439     public void close() {}
440 
441     @Override
read(byte[] buf, int off, int len)442     public int read(byte[] buf, int off, int len) throws TTransportException {
443       return readBuffer.read(buf, off, len);
444     }
445 
446     @Override
write(byte[] buf, int off, int len)447     public void write(byte[] buf, int off, int len) throws TTransportException {}
448 
449     @Override
getConfiguration()450     public TConfiguration getConfiguration() {
451       return readBuffer.getConfiguration();
452     }
453 
454     @Override
updateKnownMessageSize(long size)455     public void updateKnownMessageSize(long size) throws TTransportException {
456       readBuffer.updateKnownMessageSize(size);
457     }
458 
459     @Override
checkReadBytesAvailable(long numBytes)460     public void checkReadBytesAvailable(long numBytes) throws TTransportException {
461       readBuffer.checkReadBytesAvailable(numBytes);
462     }
463   }
464 
testBadHeader()465   public void testBadHeader() {
466     TSaslTransport saslTransport;
467     try {
468       saslTransport = new TSaslServerTransport(new MockTTransport(1));
469       saslTransport.receiveSaslMessage();
470       fail("Should have gotten an error due to incorrect status byte value.");
471     } catch (TTransportException e) {
472       assertEquals(e.getMessage(), "Invalid status -1");
473     }
474     try {
475       saslTransport = new TSaslServerTransport(new MockTTransport(2));
476       saslTransport.receiveSaslMessage();
477       fail("Should have gotten an error due to negative payload length.");
478     } catch (TTransportException e) {
479       assertEquals(e.getMessage(), "Invalid payload header length: -1");
480     }
481     try {
482       saslTransport = new TSaslServerTransport(new MockTTransport(3));
483       saslTransport.receiveSaslMessage();
484       fail("Should have gotten an error due to bogus (large) payload length.");
485     } catch (TTransportException e) {
486       assertEquals(e.getMessage(), "Invalid payload header length: 1677721600");
487     }
488   }
489 }
490