1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 package org.apache.thrift.transport;
21 
22 import java.io.IOException;
23 import java.util.HashMap;
24 import java.util.Map;
25 
26 import javax.security.auth.callback.Callback;
27 import javax.security.auth.callback.CallbackHandler;
28 import javax.security.auth.callback.NameCallback;
29 import javax.security.auth.callback.PasswordCallback;
30 import javax.security.auth.callback.UnsupportedCallbackException;
31 import javax.security.sasl.AuthorizeCallback;
32 import javax.security.sasl.RealmCallback;
33 import javax.security.sasl.Sasl;
34 import javax.security.sasl.SaslClient;
35 import javax.security.sasl.SaslClientFactory;
36 import javax.security.sasl.SaslException;
37 import javax.security.sasl.SaslServer;
38 import javax.security.sasl.SaslServerFactory;
39 
40 import junit.framework.TestCase;
41 
42 import org.apache.thrift.TProcessor;
43 import org.apache.thrift.protocol.TProtocolFactory;
44 import org.apache.thrift.server.ServerTestBase;
45 import org.apache.thrift.server.TServer;
46 import org.apache.thrift.server.TSimpleServer;
47 import org.apache.thrift.server.TServer.Args;
48 import org.slf4j.Logger;
49 import org.slf4j.LoggerFactory;
50 
51 public class TestTSaslTransports extends TestCase {
52 
53   private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class);
54 
55   private static final String HOST = "localhost";
56   private static final String SERVICE = "thrift-test";
57   private static final String PRINCIPAL = "thrift-test-principal";
58   private static final String PASSWORD = "super secret password";
59   private static final String REALM = "thrift-test-realm";
60 
61   private static final String UNWRAPPED_MECHANISM = "CRAM-MD5";
62   private static final Map<String, String> UNWRAPPED_PROPS = null;
63 
64   private static final String WRAPPED_MECHANISM = "DIGEST-MD5";
65   private static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>();
66 
67   static {
WRAPPED_PROPS.put(Sasl.QOP, R)68     WRAPPED_PROPS.put(Sasl.QOP, "auth-int");
69     WRAPPED_PROPS.put("com.sun.security.sasl.digest.realm", REALM);
70   }
71 
72   private static final String testMessage1 = "Hello, world! Also, four "
73       + "score and seven years ago our fathers brought forth on this "
74       + "continent a new nation, conceived in liberty, and dedicated to the "
75       + "proposition that all men are created equal.";
76 
77   private static final String testMessage2 = "I have a dream that one day "
78       + "this nation will rise up and live out the true meaning of its creed: "
79       + "'We hold these truths to be self-evident, that all men are created equal.'";
80 
81 
82   private static class TestSaslCallbackHandler implements CallbackHandler {
83     private final String password;
84 
TestSaslCallbackHandler(String password)85     public TestSaslCallbackHandler(String password) {
86       this.password = password;
87     }
88 
89     @Override
handle(Callback[] callbacks)90     public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
91       for (Callback c : callbacks) {
92         if (c instanceof NameCallback) {
93           ((NameCallback) c).setName(PRINCIPAL);
94         } else if (c instanceof PasswordCallback) {
95           ((PasswordCallback) c).setPassword(password.toCharArray());
96         } else if (c instanceof AuthorizeCallback) {
97           ((AuthorizeCallback) c).setAuthorized(true);
98         } else if (c instanceof RealmCallback) {
99           ((RealmCallback) c).setText(REALM);
100         } else {
101           throw new UnsupportedCallbackException(c);
102         }
103       }
104     }
105   }
106 
107   private class ServerThread extends Thread {
108     final String mechanism;
109     final Map<String, String> props;
110     volatile Throwable thrown;
111 
ServerThread(String mechanism, Map<String, String> props)112     public ServerThread(String mechanism, Map<String, String> props) {
113       this.mechanism = mechanism;
114       this.props = props;
115     }
116 
run()117     public void run() {
118       try {
119         internalRun();
120       } catch (Throwable t) {
121         thrown = t;
122       }
123     }
124 
internalRun()125     private void internalRun() throws Exception {
126       TServerSocket serverSocket = new TServerSocket(
127         new TServerSocket.ServerSocketTransportArgs().
128           port(ServerTestBase.PORT));
129       try {
130         acceptAndWrite(serverSocket);
131       } finally {
132         serverSocket.close();
133       }
134     }
135 
acceptAndWrite(TServerSocket serverSocket)136     private void acceptAndWrite(TServerSocket serverSocket)
137       throws Exception {
138       TTransport serverTransport = serverSocket.accept();
139       TTransport saslServerTransport = new TSaslServerTransport(
140         mechanism, SERVICE, HOST,
141         props, new TestSaslCallbackHandler(PASSWORD), serverTransport);
142 
143       saslServerTransport.open();
144 
145       byte[] inBuf = new byte[testMessage1.getBytes().length];
146       // Deliberately read less than the full buffer to ensure
147       // that TSaslTransport is correctly buffering reads. This
148       // will fail for the WRAPPED test, if it doesn't work.
149       saslServerTransport.readAll(inBuf, 0, 5);
150       saslServerTransport.readAll(inBuf, 5, 10);
151       saslServerTransport.readAll(inBuf, 15, inBuf.length - 15);
152       LOGGER.debug("server got: {}", new String(inBuf));
153       assertEquals(new String(inBuf), testMessage1);
154 
155       LOGGER.debug("server writing: {}", testMessage2);
156       saslServerTransport.write(testMessage2.getBytes());
157       saslServerTransport.flush();
158 
159       saslServerTransport.close();
160     }
161   }
162 
testSaslOpen(final String mechanism, final Map<String, String> props)163   private void testSaslOpen(final String mechanism, final Map<String, String> props)
164       throws Exception {
165     ServerThread serverThread = new ServerThread(mechanism, props);
166     serverThread.start();
167 
168     try {
169       Thread.sleep(1000);
170     } catch (InterruptedException e) {
171       // Ah well.
172     }
173 
174     try {
175       TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
176       TTransport saslClientTransport = new TSaslClientTransport(mechanism,
177                                                                 PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(PASSWORD), clientSocket);
178       saslClientTransport.open();
179       LOGGER.debug("client writing: {}", testMessage1);
180       saslClientTransport.write(testMessage1.getBytes());
181       saslClientTransport.flush();
182 
183       byte[] inBuf = new byte[testMessage2.getBytes().length];
184       saslClientTransport.readAll(inBuf, 0, inBuf.length);
185       LOGGER.debug("client got: {}", new String(inBuf));
186       assertEquals(new String(inBuf), testMessage2);
187 
188       TTransportException expectedException = null;
189       try {
190         saslClientTransport.open();
191       } catch (TTransportException e) {
192         expectedException = e;
193       }
194       assertNotNull(expectedException);
195 
196       saslClientTransport.close();
197     } catch (Exception e) {
198       LOGGER.warn("Exception caught", e);
199       throw e;
200     } finally {
201       serverThread.interrupt();
202       try {
203         serverThread.join();
204       } catch (InterruptedException e) {
205         // Ah well.
206       }
207       assertNull(serverThread.thrown);
208     }
209   }
210 
testUnwrappedOpen()211   public void testUnwrappedOpen() throws Exception {
212     testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
213   }
214 
testWrappedOpen()215   public void testWrappedOpen() throws Exception {
216     testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS);
217   }
218 
testAnonymousOpen()219   public void testAnonymousOpen() throws Exception {
220     testSaslOpen("ANONYMOUS", null);
221   }
222 
223   /**
224    * Test that we get the proper exceptions thrown back the server when
225    * the client provides invalid password.
226    */
testBadPassword()227   public void testBadPassword() throws Exception {
228     ServerThread serverThread = new ServerThread(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
229     serverThread.start();
230 
231     try {
232       Thread.sleep(1000);
233     } catch (InterruptedException e) {
234       // Ah well.
235     }
236 
237     boolean clientSidePassed = true;
238 
239     try {
240       TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
241       TTransport saslClientTransport = new TSaslClientTransport(
242         UNWRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, UNWRAPPED_PROPS,
243         new TestSaslCallbackHandler("NOT THE PASSWORD"), clientSocket);
244       saslClientTransport.open();
245       clientSidePassed = false;
246       fail("Was able to open transport with bad password");
247     } catch (TTransportException tte) {
248       LOGGER.error("Exception for bad password", tte);
249       assertNotNull(tte.getMessage());
250       assertTrue(tte.getMessage().contains("Invalid response"));
251 
252     } finally {
253       serverThread.interrupt();
254       serverThread.join();
255 
256       if (clientSidePassed) {
257         assertNotNull(serverThread.thrown);
258         assertTrue(serverThread.thrown.getMessage().contains("Invalid response"));
259       }
260     }
261   }
262 
testWithServer()263   public void testWithServer() throws Exception {
264     new TestTSaslTransportsWithServer().testIt();
265   }
266 
267   private static class TestTSaslTransportsWithServer extends ServerTestBase {
268 
269     private Thread serverThread;
270     private TServer server;
271 
272     @Override
getClientTransport(TTransport underlyingTransport)273     public TTransport getClientTransport(TTransport underlyingTransport) throws Exception {
274       return new TSaslClientTransport(
275         WRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS,
276         new TestSaslCallbackHandler(PASSWORD), underlyingTransport);
277     }
278 
279     @Override
startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory)280     public void startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory) throws Exception {
281       serverThread = new Thread() {
282         public void run() {
283           try {
284             // Transport
285             TServerSocket socket = new TServerSocket(new TServerSocket.ServerSocketTransportArgs().port(PORT));
286 
287             TTransportFactory factory = new TSaslServerTransport.Factory(
288               WRAPPED_MECHANISM, SERVICE, HOST, WRAPPED_PROPS,
289               new TestSaslCallbackHandler(PASSWORD));
290             server = new TSimpleServer(new Args(socket).processor(processor).transportFactory(factory).protocolFactory(protoFactory));
291 
292             // Run it
293             LOGGER.debug("Starting the server on port {}", PORT);
294             server.serve();
295           } catch (Exception e) {
296             e.printStackTrace();
297             fail();
298           }
299         }
300       };
301       serverThread.start();
302       Thread.sleep(1000);
303     }
304 
305     @Override
stopServer()306     public void stopServer() throws Exception {
307       server.stop();
308       try {
309         serverThread.join();
310       } catch (InterruptedException e) {}
311     }
312 
313   }
314 
315 
316   /**
317    * Implementation of SASL ANONYMOUS, used for testing client-side
318    * initial responses.
319    */
320   private static class AnonymousClient implements SaslClient {
321     private final String username;
322     private boolean hasProvidedInitialResponse;
323 
AnonymousClient(String username)324     public AnonymousClient(String username) {
325       this.username = username;
326     }
327 
getMechanismName()328     public String getMechanismName() { return "ANONYMOUS"; }
hasInitialResponse()329     public boolean hasInitialResponse() { return true; }
evaluateChallenge(byte[] challenge)330     public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
331       if (hasProvidedInitialResponse) {
332         throw new SaslException("Already complete!");
333       }
334 
335       try {
336         hasProvidedInitialResponse = true;
337         return username.getBytes("UTF-8");
338       } catch (IOException e) {
339         throw new SaslException(e.toString());
340       }
341     }
isComplete()342     public boolean isComplete() { return hasProvidedInitialResponse; }
unwrap(byte[] incoming, int offset, int len)343     public byte[] unwrap(byte[] incoming, int offset, int len) {
344       throw new UnsupportedOperationException();
345     }
wrap(byte[] outgoing, int offset, int len)346     public byte[] wrap(byte[] outgoing, int offset, int len) {
347       throw new UnsupportedOperationException();
348     }
getNegotiatedProperty(String propName)349     public Object getNegotiatedProperty(String propName) { return null; }
dispose()350     public void dispose() {}
351   }
352 
353   private static class AnonymousServer implements SaslServer {
354     private String user;
getMechanismName()355     public String getMechanismName() { return "ANONYMOUS"; }
evaluateResponse(byte[] response)356     public byte[] evaluateResponse(byte[] response) throws SaslException {
357       try {
358         this.user = new String(response, "UTF-8");
359       } catch (IOException e) {
360         throw new SaslException(e.toString());
361       }
362       return null;
363     }
isComplete()364     public boolean isComplete() { return user != null; }
getAuthorizationID()365     public String getAuthorizationID() { return user; }
unwrap(byte[] incoming, int offset, int len)366     public byte[] unwrap(byte[] incoming, int offset, int len) {
367       throw new UnsupportedOperationException();
368     }
wrap(byte[] outgoing, int offset, int len)369     public byte[] wrap(byte[] outgoing, int offset, int len) {
370       throw new UnsupportedOperationException();
371     }
getNegotiatedProperty(String propName)372     public Object getNegotiatedProperty(String propName) { return null; }
dispose()373     public void dispose() {}
374 
375   }
376 
377   public static class SaslAnonymousFactory
378     implements SaslClientFactory, SaslServerFactory {
379 
createSaslClient( String[] mechanisms, String authorizationId, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)380     public SaslClient createSaslClient(
381       String[] mechanisms, String authorizationId, String protocol,
382       String serverName, Map<String,?> props, CallbackHandler cbh)
383     {
384       for (String mech : mechanisms) {
385         if ("ANONYMOUS".equals(mech)) {
386           return new AnonymousClient(authorizationId);
387         }
388       }
389       return null;
390     }
391 
createSaslServer( String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)392     public SaslServer createSaslServer(
393       String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)
394     {
395       if ("ANONYMOUS".equals(mechanism)) {
396         return new AnonymousServer();
397       }
398       return null;
399     }
getMechanismNames(Map<String, ?> props)400     public String[] getMechanismNames(Map<String, ?> props) {
401       return new String[] { "ANONYMOUS" };
402     }
403   }
404 
405   static {
java.security.Security.addProvider(new SaslAnonymousProvider())406     java.security.Security.addProvider(new SaslAnonymousProvider());
407   }
408   public static class SaslAnonymousProvider extends java.security.Provider {
SaslAnonymousProvider()409     public SaslAnonymousProvider() {
410       super("ThriftSaslAnonymous", 1.0, "Thrift Anonymous SASL provider");
411       put("SaslClientFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
412       put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
413     }
414   }
415 
416   private static class MockTTransport extends TTransport {
417 
418     byte[] badHeader = null;
419     private TMemoryInputTransport readBuffer = new TMemoryInputTransport();
420 
MockTTransport(int mode)421     public MockTTransport(int mode) {
422       if (mode==1) {
423         // Invalid status byte
424         badHeader = new byte[] { (byte)0xFF, (byte)0x00, (byte)0x00, (byte)0x00, (byte)0x05 };
425       } else if (mode == 2) {
426         // Valid status byte, negative payload length
427         badHeader = new byte[] { (byte)0x01, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF };
428       } else if (mode == 3) {
429         // Valid status byte, excessively large, bogus payload length
430         badHeader = new byte[] { (byte)0x01, (byte)0x64, (byte)0x00, (byte)0x00, (byte)0x00 };
431       }
432       readBuffer.reset(badHeader);
433     }
434 
435     @Override
isOpen()436     public boolean isOpen() {
437       return true;
438     }
439 
440     @Override
open()441     public void open() throws TTransportException {}
442 
443     @Override
close()444     public void close() {}
445 
446     @Override
read(byte[] buf, int off, int len)447     public int read(byte[] buf, int off, int len) throws TTransportException {
448       return readBuffer.read(buf, off, len);
449     }
450 
451     @Override
write(byte[] buf, int off, int len)452     public void write(byte[] buf, int off, int len) throws TTransportException {}
453   }
454 
testBadHeader()455   public void testBadHeader() {
456     TSaslTransport saslTransport = new TSaslServerTransport(new MockTTransport(1));
457     try {
458       saslTransport.receiveSaslMessage();
459       fail("Should have gotten an error due to incorrect status byte value.");
460     } catch (TTransportException e) {
461       assertEquals(e.getMessage(), "Invalid status -1");
462     }
463     saslTransport = new TSaslServerTransport(new MockTTransport(2));
464     try {
465       saslTransport.receiveSaslMessage();
466       fail("Should have gotten an error due to negative payload length.");
467     } catch (TTransportException e) {
468       assertEquals(e.getMessage(), "Invalid payload header length: -1");
469     }
470     saslTransport = new TSaslServerTransport(new MockTTransport(3));
471     try {
472       saslTransport.receiveSaslMessage();
473       fail("Should have gotten an error due to bogus (large) payload length.");
474     } catch (TTransportException e) {
475       assertEquals(e.getMessage(), "Invalid payload header length: 1677721600");
476     }
477   }
478 }
479