1 /**
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 package org.apache.hadoop.ipc;
20 
21 import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION;
22 import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_RPC_PROTECTION;
23 import static org.apache.hadoop.security.SaslRpcServer.AuthMethod.KERBEROS;
24 import static org.apache.hadoop.security.SaslRpcServer.AuthMethod.SIMPLE;
25 import static org.apache.hadoop.security.SaslRpcServer.AuthMethod.TOKEN;
26 import static org.junit.Assert.assertEquals;
27 import static org.junit.Assert.assertFalse;
28 import static org.junit.Assert.assertNotNull;
29 import static org.junit.Assert.assertNotSame;
30 import static org.junit.Assert.assertNull;
31 import static org.junit.Assert.assertTrue;
32 
33 import java.io.DataInput;
34 import java.io.DataOutput;
35 import java.io.IOException;
36 import java.lang.annotation.Annotation;
37 import java.net.InetAddress;
38 import java.net.InetSocketAddress;
39 import java.security.PrivilegedExceptionAction;
40 import java.security.Security;
41 import java.util.ArrayList;
42 import java.util.Collection;
43 import java.util.HashMap;
44 import java.util.Map;
45 import java.util.Set;
46 import java.util.regex.Pattern;
47 
48 import javax.security.auth.callback.Callback;
49 import javax.security.auth.callback.CallbackHandler;
50 import javax.security.auth.callback.NameCallback;
51 import javax.security.auth.callback.PasswordCallback;
52 import javax.security.auth.callback.UnsupportedCallbackException;
53 import javax.security.sasl.AuthorizeCallback;
54 import javax.security.sasl.Sasl;
55 import javax.security.sasl.SaslClient;
56 import javax.security.sasl.SaslException;
57 import javax.security.sasl.SaslServer;
58 
59 import org.apache.commons.lang.StringUtils;
60 import org.apache.commons.logging.Log;
61 import org.apache.commons.logging.LogFactory;
62 import org.apache.commons.logging.impl.Log4JLogger;
63 import org.apache.hadoop.conf.Configuration;
64 import org.apache.hadoop.fs.CommonConfigurationKeys;
65 import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
66 import org.apache.hadoop.io.Text;
67 import org.apache.hadoop.ipc.Client.ConnectionId;
68 import org.apache.hadoop.net.NetUtils;
69 import org.apache.hadoop.security.KerberosInfo;
70 import org.apache.hadoop.security.SaslInputStream;
71 import org.apache.hadoop.security.SaslPlainServer;
72 import org.apache.hadoop.security.SaslPropertiesResolver;
73 import org.apache.hadoop.security.SaslRpcClient;
74 import org.apache.hadoop.security.SaslRpcServer;
75 import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
76 import org.apache.hadoop.security.SaslRpcServer.QualityOfProtection;
77 import org.apache.hadoop.security.SecurityInfo;
78 import org.apache.hadoop.security.SecurityUtil;
79 import org.apache.hadoop.security.TestUserGroupInformation;
80 import org.apache.hadoop.security.UserGroupInformation;
81 import org.apache.hadoop.security.token.SecretManager;
82 import org.apache.hadoop.security.token.SecretManager.InvalidToken;
83 import org.apache.hadoop.security.token.Token;
84 import org.apache.hadoop.security.token.TokenIdentifier;
85 import org.apache.hadoop.security.token.TokenInfo;
86 import org.apache.hadoop.security.token.TokenSelector;
87 import org.apache.log4j.Level;
88 import org.junit.Before;
89 import org.junit.BeforeClass;
90 import org.junit.Test;
91 import org.junit.runner.RunWith;
92 import org.junit.runners.Parameterized;
93 import org.junit.runners.Parameterized.Parameters;
94 
95 /** Unit tests for using Sasl over RPC. */
96 @RunWith(Parameterized.class)
97 public class TestSaslRPC {
98   @Parameters
data()99   public static Collection<Object[]> data() {
100     Collection<Object[]> params = new ArrayList<Object[]>();
101     for (QualityOfProtection qop : QualityOfProtection.values()) {
102       params.add(new Object[]{ new QualityOfProtection[]{qop},qop, null });
103     }
104     params.add(new Object[]{ new QualityOfProtection[]{
105         QualityOfProtection.PRIVACY,QualityOfProtection.AUTHENTICATION },
106         QualityOfProtection.PRIVACY, null});
107     params.add(new Object[]{ new QualityOfProtection[]{
108         QualityOfProtection.PRIVACY,QualityOfProtection.AUTHENTICATION },
109         QualityOfProtection.AUTHENTICATION ,
110         "org.apache.hadoop.ipc.TestSaslRPC$AuthSaslPropertiesResolver" });
111 
112     return params;
113   }
114 
115   QualityOfProtection[] qop;
116   QualityOfProtection expectedQop;
117   String saslPropertiesResolver ;
118 
TestSaslRPC(QualityOfProtection[] qop, QualityOfProtection expectedQop, String saslPropertiesResolver)119   public TestSaslRPC(QualityOfProtection[] qop,
120       QualityOfProtection expectedQop,
121       String saslPropertiesResolver) {
122     this.qop=qop;
123     this.expectedQop = expectedQop;
124     this.saslPropertiesResolver = saslPropertiesResolver;
125   }
126 
127   private static final String ADDRESS = "0.0.0.0";
128 
129   public static final Log LOG =
130     LogFactory.getLog(TestSaslRPC.class);
131 
132   static final String ERROR_MESSAGE = "Token is invalid";
133   static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal";
134   static final String SERVER_KEYTAB_KEY = "test.ipc.server.keytab";
135   static final String SERVER_PRINCIPAL_1 = "p1/foo@BAR";
136   static final String SERVER_PRINCIPAL_2 = "p2/foo@BAR";
137   private static Configuration conf;
138   // If this is set to true AND the auth-method is not simple, secretManager
139   // will be enabled.
140   static Boolean enableSecretManager = null;
141   // If this is set to true, secretManager will be forecefully enabled
142   // irrespective of auth-method.
143   static Boolean forceSecretManager = null;
144   static Boolean clientFallBackToSimpleAllowed = true;
145 
146   static enum UseToken {
147     NONE(),
148     VALID(),
149     INVALID(),
150     OTHER();
151   }
152 
153   @BeforeClass
setupKerb()154   public static void setupKerb() {
155     System.setProperty("java.security.krb5.kdc", "");
156     System.setProperty("java.security.krb5.realm", "NONE");
157     Security.addProvider(new SaslPlainServer.SecurityProvider());
158   }
159 
160   @Before
setup()161   public void setup() {
162     LOG.info("---------------------------------");
163     LOG.info("Testing QOP:"+ getQOPNames(qop));
164     LOG.info("---------------------------------");
165     conf = new Configuration();
166     // the specific tests for kerberos will enable kerberos.  forcing it
167     // for all tests will cause tests to fail if the user has a TGT
168     conf.set(HADOOP_SECURITY_AUTHENTICATION, SIMPLE.toString());
169     conf.set(HADOOP_RPC_PROTECTION, getQOPNames(qop));
170     if (saslPropertiesResolver != null){
171       conf.set(CommonConfigurationKeys.HADOOP_SECURITY_SASL_PROPS_RESOLVER_CLASS,
172         saslPropertiesResolver);
173     }
174     UserGroupInformation.setConfiguration(conf);
175     enableSecretManager = null;
176     forceSecretManager = null;
177     clientFallBackToSimpleAllowed = true;
178   }
179 
getQOPNames(QualityOfProtection[] qops)180   static String getQOPNames (QualityOfProtection[] qops){
181     StringBuilder sb = new StringBuilder();
182     int i = 0;
183     for (QualityOfProtection qop:qops){
184      sb.append(org.apache.hadoop.util.StringUtils.toLowerCase(qop.name()));
185      if (++i < qops.length){
186        sb.append(",");
187      }
188     }
189     return sb.toString();
190   }
191 
192   static {
193     ((Log4JLogger) Client.LOG).getLogger().setLevel(Level.ALL);
194     ((Log4JLogger) Server.LOG).getLogger().setLevel(Level.ALL);
195     ((Log4JLogger) SaslRpcClient.LOG).getLogger().setLevel(Level.ALL);
196     ((Log4JLogger) SaslRpcServer.LOG).getLogger().setLevel(Level.ALL);
197     ((Log4JLogger) SaslInputStream.LOG).getLogger().setLevel(Level.ALL);
198     ((Log4JLogger) SecurityUtil.LOG).getLogger().setLevel(Level.ALL);
199   }
200 
201   public static class TestTokenIdentifier extends TokenIdentifier {
202     private Text tokenid;
203     private Text realUser;
204     final static Text KIND_NAME = new Text("test.token");
205 
TestTokenIdentifier()206     public TestTokenIdentifier() {
207       this(new Text(), new Text());
208     }
TestTokenIdentifier(Text tokenid)209     public TestTokenIdentifier(Text tokenid) {
210       this(tokenid, new Text());
211     }
TestTokenIdentifier(Text tokenid, Text realUser)212     public TestTokenIdentifier(Text tokenid, Text realUser) {
213       this.tokenid = tokenid == null ? new Text() : tokenid;
214       this.realUser = realUser == null ? new Text() : realUser;
215     }
216     @Override
getKind()217     public Text getKind() {
218       return KIND_NAME;
219     }
220     @Override
getUser()221     public UserGroupInformation getUser() {
222       if (realUser.toString().isEmpty()) {
223         return UserGroupInformation.createRemoteUser(tokenid.toString());
224       } else {
225         UserGroupInformation realUgi = UserGroupInformation
226             .createRemoteUser(realUser.toString());
227         return UserGroupInformation
228             .createProxyUser(tokenid.toString(), realUgi);
229       }
230     }
231 
232     @Override
readFields(DataInput in)233     public void readFields(DataInput in) throws IOException {
234       tokenid.readFields(in);
235       realUser.readFields(in);
236     }
237     @Override
write(DataOutput out)238     public void write(DataOutput out) throws IOException {
239       tokenid.write(out);
240       realUser.write(out);
241     }
242   }
243 
244   public static class TestTokenSecretManager extends
245       SecretManager<TestTokenIdentifier> {
246     @Override
createPassword(TestTokenIdentifier id)247     public byte[] createPassword(TestTokenIdentifier id) {
248       return id.getBytes();
249     }
250 
251     @Override
retrievePassword(TestTokenIdentifier id)252     public byte[] retrievePassword(TestTokenIdentifier id)
253         throws InvalidToken {
254       return id.getBytes();
255     }
256 
257     @Override
createIdentifier()258     public TestTokenIdentifier createIdentifier() {
259       return new TestTokenIdentifier();
260     }
261   }
262 
263   public static class BadTokenSecretManager extends TestTokenSecretManager {
264 
265     @Override
retrievePassword(TestTokenIdentifier id)266     public byte[] retrievePassword(TestTokenIdentifier id)
267         throws InvalidToken {
268       throw new InvalidToken(ERROR_MESSAGE);
269     }
270   }
271 
272   public static class TestTokenSelector implements
273       TokenSelector<TestTokenIdentifier> {
274     @SuppressWarnings("unchecked")
275     @Override
selectToken(Text service, Collection<Token<? extends TokenIdentifier>> tokens)276     public Token<TestTokenIdentifier> selectToken(Text service,
277         Collection<Token<? extends TokenIdentifier>> tokens) {
278       if (service == null) {
279         return null;
280       }
281       for (Token<? extends TokenIdentifier> token : tokens) {
282         if (TestTokenIdentifier.KIND_NAME.equals(token.getKind())
283             && service.equals(token.getService())) {
284           return (Token<TestTokenIdentifier>) token;
285         }
286       }
287       return null;
288     }
289   }
290 
291   @KerberosInfo(
292       serverPrincipal = SERVER_PRINCIPAL_KEY)
293   @TokenInfo(TestTokenSelector.class)
294   public interface TestSaslProtocol extends TestRPC.TestProtocol {
getAuthMethod()295     public AuthMethod getAuthMethod() throws IOException;
getAuthUser()296     public String getAuthUser() throws IOException;
297   }
298 
299   public static class TestSaslImpl extends TestRPC.TestImpl implements
300       TestSaslProtocol {
301     @Override
getAuthMethod()302     public AuthMethod getAuthMethod() throws IOException {
303       return UserGroupInformation.getCurrentUser()
304           .getAuthenticationMethod().getAuthMethod();
305     }
306     @Override
getAuthUser()307     public String getAuthUser() throws IOException {
308       return UserGroupInformation.getCurrentUser().getUserName();
309     }
310   }
311 
312   public static class CustomSecurityInfo extends SecurityInfo {
313 
314     @Override
getKerberosInfo(Class<?> protocol, Configuration conf)315     public KerberosInfo getKerberosInfo(Class<?> protocol, Configuration conf) {
316       return new KerberosInfo() {
317         @Override
318         public Class<? extends Annotation> annotationType() {
319           return null;
320         }
321         @Override
322         public String serverPrincipal() {
323           return SERVER_PRINCIPAL_KEY;
324         }
325         @Override
326         public String clientPrincipal() {
327           return null;
328         }
329       };
330     }
331 
332     @Override
getTokenInfo(Class<?> protocol, Configuration conf)333     public TokenInfo getTokenInfo(Class<?> protocol, Configuration conf) {
334       return new TokenInfo() {
335         @Override
336         public Class<? extends TokenSelector<? extends
337             TokenIdentifier>> value() {
338           return TestTokenSelector.class;
339         }
340         @Override
341         public Class<? extends Annotation> annotationType() {
342           return null;
343         }
344       };
345     }
346   }
347 
348   @Test
349   public void testDigestRpc() throws Exception {
350     TestTokenSecretManager sm = new TestTokenSecretManager();
351     final Server server = new RPC.Builder(conf)
352         .setProtocol(TestSaslProtocol.class).setInstance(new TestSaslImpl())
353         .setBindAddress(ADDRESS).setPort(0).setNumHandlers(5).setVerbose(true)
354         .setSecretManager(sm).build();
355 
356     doDigestRpc(server, sm);
357   }
358 
359   @Test
360   public void testDigestRpcWithoutAnnotation() throws Exception {
361     TestTokenSecretManager sm = new TestTokenSecretManager();
362     try {
363       SecurityUtil.setSecurityInfoProviders(new CustomSecurityInfo());
364       final Server server = new RPC.Builder(conf)
365           .setProtocol(TestSaslProtocol.class).setInstance(new TestSaslImpl())
366           .setBindAddress(ADDRESS).setPort(0).setNumHandlers(5)
367           .setVerbose(true).setSecretManager(sm).build();
368       doDigestRpc(server, sm);
369     } finally {
370       SecurityUtil.setSecurityInfoProviders(new SecurityInfo[0]);
371     }
372   }
373 
374   @Test
375   public void testErrorMessage() throws Exception {
376     BadTokenSecretManager sm = new BadTokenSecretManager();
377     final Server server = new RPC.Builder(conf)
378         .setProtocol(TestSaslProtocol.class).setInstance(new TestSaslImpl())
379         .setBindAddress(ADDRESS).setPort(0).setNumHandlers(5).setVerbose(true)
380         .setSecretManager(sm).build();
381 
382     boolean succeeded = false;
383     try {
384       doDigestRpc(server, sm);
385     } catch (RemoteException e) {
386       LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage());
387       assertEquals(ERROR_MESSAGE, e.getLocalizedMessage());
388       assertTrue(e.unwrapRemoteException() instanceof InvalidToken);
389       succeeded = true;
390     }
391     assertTrue(succeeded);
392   }
393 
394   private void doDigestRpc(Server server, TestTokenSecretManager sm
395                            ) throws Exception {
396     server.start();
397 
398     final UserGroupInformation current = UserGroupInformation.getCurrentUser();
399     final InetSocketAddress addr = NetUtils.getConnectAddress(server);
400     TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current
401         .getUserName()));
402     Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId,
403         sm);
404     SecurityUtil.setTokenService(token, addr);
405     current.addToken(token);
406 
407     TestSaslProtocol proxy = null;
408     try {
409       proxy = RPC.getProxy(TestSaslProtocol.class,
410           TestSaslProtocol.versionID, addr, conf);
411       AuthMethod authMethod = proxy.getAuthMethod();
412       assertEquals(TOKEN, authMethod);
413       //QOP must be auth
414       assertEquals(expectedQop.saslQop,
415                    RPC.getConnectionIdForProxy(proxy).getSaslQop());
416       proxy.ping();
417     } finally {
418       server.stop();
419       if (proxy != null) {
420         RPC.stopProxy(proxy);
421       }
422     }
423   }
424 
425   @Test
426   public void testPingInterval() throws Exception {
427     Configuration newConf = new Configuration(conf);
428     newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1);
429     conf.setInt(CommonConfigurationKeys.IPC_PING_INTERVAL_KEY,
430         CommonConfigurationKeys.IPC_PING_INTERVAL_DEFAULT);
431 
432     // set doPing to true
433     newConf.setBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, true);
434     ConnectionId remoteId = ConnectionId.getConnectionId(
435         new InetSocketAddress(0), TestSaslProtocol.class, null, 0, newConf);
436     assertEquals(CommonConfigurationKeys.IPC_PING_INTERVAL_DEFAULT,
437         remoteId.getPingInterval());
438     // set doPing to false
439     newConf.setBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, false);
440     remoteId = ConnectionId.getConnectionId(
441         new InetSocketAddress(0), TestSaslProtocol.class, null, 0, newConf);
442     assertEquals(0, remoteId.getPingInterval());
443   }
444 
445   @Test
446   public void testPerConnectionConf() throws Exception {
447     TestTokenSecretManager sm = new TestTokenSecretManager();
448     final Server server = new RPC.Builder(conf)
449         .setProtocol(TestSaslProtocol.class).setInstance(new TestSaslImpl())
450         .setBindAddress(ADDRESS).setPort(0).setNumHandlers(5).setVerbose(true)
451         .setSecretManager(sm).build();
452     server.start();
453     final UserGroupInformation current = UserGroupInformation.getCurrentUser();
454     final InetSocketAddress addr = NetUtils.getConnectAddress(server);
455     TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current
456         .getUserName()));
457     Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId,
458         sm);
459     SecurityUtil.setTokenService(token, addr);
460     current.addToken(token);
461 
462     Configuration newConf = new Configuration(conf);
463     newConf.set(CommonConfigurationKeysPublic.
464         HADOOP_RPC_SOCKET_FACTORY_CLASS_DEFAULT_KEY, "");
465 
466     Client client = null;
467     TestSaslProtocol proxy1 = null;
468     TestSaslProtocol proxy2 = null;
469     TestSaslProtocol proxy3 = null;
470     int timeouts[] = {111222, 3333333};
471     try {
472       newConf.setInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_KEY, timeouts[0]);
473       proxy1 = RPC.getProxy(TestSaslProtocol.class,
474           TestSaslProtocol.versionID, addr, newConf);
475       proxy1.getAuthMethod();
476       client = WritableRpcEngine.getClient(newConf);
477       Set<ConnectionId> conns = client.getConnectionIds();
478       assertEquals("number of connections in cache is wrong", 1, conns.size());
479       // same conf, connection should be re-used
480       proxy2 = RPC.getProxy(TestSaslProtocol.class,
481           TestSaslProtocol.versionID, addr, newConf);
482       proxy2.getAuthMethod();
483       assertEquals("number of connections in cache is wrong", 1, conns.size());
484       // different conf, new connection should be set up
485       newConf.setInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_KEY, timeouts[1]);
486       proxy3 = RPC.getProxy(TestSaslProtocol.class,
487           TestSaslProtocol.versionID, addr, newConf);
488       proxy3.getAuthMethod();
489       assertEquals("number of connections in cache is wrong", 2, conns.size());
490       // now verify the proxies have the correct connection ids and timeouts
491       ConnectionId[] connsArray = {
492           RPC.getConnectionIdForProxy(proxy1),
493           RPC.getConnectionIdForProxy(proxy2),
494           RPC.getConnectionIdForProxy(proxy3)
495       };
496       assertEquals(connsArray[0], connsArray[1]);
497       assertEquals(connsArray[0].getMaxIdleTime(), timeouts[0]);
498       assertFalse(connsArray[0].equals(connsArray[2]));
499       assertNotSame(connsArray[2].getMaxIdleTime(), timeouts[1]);
500     } finally {
501       server.stop();
502       // this is dirty, but clear out connection cache for next run
503       if (client != null) {
504         client.getConnectionIds().clear();
505       }
506       if (proxy1 != null) RPC.stopProxy(proxy1);
507       if (proxy2 != null) RPC.stopProxy(proxy2);
508       if (proxy3 != null) RPC.stopProxy(proxy3);
509     }
510   }
511 
512   static void testKerberosRpc(String principal, String keytab) throws Exception {
513     final Configuration newConf = new Configuration(conf);
514     newConf.set(SERVER_PRINCIPAL_KEY, principal);
515     newConf.set(SERVER_KEYTAB_KEY, keytab);
516     SecurityUtil.login(newConf, SERVER_KEYTAB_KEY, SERVER_PRINCIPAL_KEY);
517     TestUserGroupInformation.verifyLoginMetrics(1, 0);
518     UserGroupInformation current = UserGroupInformation.getCurrentUser();
519     System.out.println("UGI: " + current);
520 
521     Server server = new RPC.Builder(newConf)
522         .setProtocol(TestSaslProtocol.class).setInstance(new TestSaslImpl())
523         .setBindAddress(ADDRESS).setPort(0).setNumHandlers(5).setVerbose(true)
524         .build();
525     TestSaslProtocol proxy = null;
526 
527     server.start();
528 
529     InetSocketAddress addr = NetUtils.getConnectAddress(server);
530     try {
531       proxy = RPC.getProxy(TestSaslProtocol.class,
532           TestSaslProtocol.versionID, addr, newConf);
533       proxy.ping();
534     } finally {
535       server.stop();
536       if (proxy != null) {
537         RPC.stopProxy(proxy);
538       }
539     }
540     System.out.println("Test is successful.");
541   }
542 
543   @Test
544   public void testSaslPlainServer() throws IOException {
545     runNegotiation(
546         new TestPlainCallbacks.Client("user", "pass"),
547         new TestPlainCallbacks.Server("user", "pass"));
548   }
549 
550   @Test
551   public void testSaslPlainServerBadPassword() {
552     SaslException e = null;
553     try {
554       runNegotiation(
555           new TestPlainCallbacks.Client("user", "pass1"),
556           new TestPlainCallbacks.Server("user", "pass2"));
557     } catch (SaslException se) {
558       e = se;
559     }
560     assertNotNull(e);
561     assertEquals("PLAIN auth failed: wrong password", e.getMessage());
562   }
563 
564 
565   private void runNegotiation(CallbackHandler clientCbh,
566                               CallbackHandler serverCbh)
567                                   throws SaslException {
568     String mechanism = AuthMethod.PLAIN.getMechanismName();
569 
570     SaslClient saslClient = Sasl.createSaslClient(
571         new String[]{ mechanism }, null, null, null, null, clientCbh);
572     assertNotNull(saslClient);
573 
574     SaslServer saslServer = Sasl.createSaslServer(
575         mechanism, null, "localhost", null, serverCbh);
576     assertNotNull("failed to find PLAIN server", saslServer);
577 
578     byte[] response = saslClient.evaluateChallenge(new byte[0]);
579     assertNotNull(response);
580     assertTrue(saslClient.isComplete());
581 
582     response = saslServer.evaluateResponse(response);
583     assertNull(response);
584     assertTrue(saslServer.isComplete());
585     assertNotNull(saslServer.getAuthorizationID());
586   }
587 
588   static class TestPlainCallbacks {
589     public static class Client implements CallbackHandler {
590       String user = null;
591       String password = null;
592 
593       Client(String user, String password) {
594         this.user = user;
595         this.password = password;
596       }
597 
598       @Override
599       public void handle(Callback[] callbacks)
600           throws UnsupportedCallbackException {
601         for (Callback callback : callbacks) {
602           if (callback instanceof NameCallback) {
603             ((NameCallback) callback).setName(user);
604           } else if (callback instanceof PasswordCallback) {
605             ((PasswordCallback) callback).setPassword(password.toCharArray());
606           } else {
607             throw new UnsupportedCallbackException(callback,
608                 "Unrecognized SASL PLAIN Callback");
609           }
610         }
611       }
612     }
613 
614     public static class Server implements CallbackHandler {
615       String user = null;
616       String password = null;
617 
618       Server(String user, String password) {
619         this.user = user;
620         this.password = password;
621       }
622 
623       @Override
624       public void handle(Callback[] callbacks)
625           throws UnsupportedCallbackException, SaslException {
626         NameCallback nc = null;
627         PasswordCallback pc = null;
628         AuthorizeCallback ac = null;
629 
630         for (Callback callback : callbacks) {
631           if (callback instanceof NameCallback) {
632             nc = (NameCallback)callback;
633             assertEquals(user, nc.getName());
634           } else if (callback instanceof PasswordCallback) {
635             pc = (PasswordCallback)callback;
636             if (!password.equals(new String(pc.getPassword()))) {
637               throw new IllegalArgumentException("wrong password");
638             }
639           } else if (callback instanceof AuthorizeCallback) {
640             ac = (AuthorizeCallback)callback;
641             assertEquals(user, ac.getAuthorizationID());
642             assertEquals(user, ac.getAuthenticationID());
643             ac.setAuthorized(true);
644             ac.setAuthorizedID(ac.getAuthenticationID());
645           } else {
646             throw new UnsupportedCallbackException(callback,
647                 "Unsupported SASL PLAIN Callback");
648           }
649         }
650         assertNotNull(nc);
651         assertNotNull(pc);
652         assertNotNull(ac);
653       }
654     }
655   }
656 
657   private static Pattern BadToken =
658       Pattern.compile(".*DIGEST-MD5: digest response format violation.*");
659   private static Pattern KrbFailed =
660       Pattern.compile(".*Failed on local exception:.* " +
661                       "Failed to specify server's Kerberos principal name.*");
662   private static Pattern Denied(AuthMethod method) {
663       return Pattern.compile(".*RemoteException.*AccessControlException.*: "
664           + method + " authentication is not enabled.*");
665   }
666   private static Pattern No(AuthMethod ... method) {
667     String methods = StringUtils.join(method, ",\\s*");
668     return Pattern.compile(".*Failed on local exception:.* " +
669         "Client cannot authenticate via:\\[" + methods + "\\].*");
670   }
671   private static Pattern NoTokenAuth =
672       Pattern.compile(".*IllegalArgumentException: " +
673                       "TOKEN authentication requires a secret manager");
674   private static Pattern NoFallback =
675       Pattern.compile(".*Failed on local exception:.* " +
676           "Server asks us to fall back to SIMPLE auth, " +
677           "but this client is configured to only allow secure connections.*");
678 
679   /*
680    *  simple server
681    */
682   @Test
683   public void testSimpleServer() throws Exception {
684     assertAuthEquals(SIMPLE,    getAuthMethod(SIMPLE,   SIMPLE));
685     assertAuthEquals(SIMPLE,    getAuthMethod(SIMPLE,   SIMPLE, UseToken.OTHER));
686     // SASL methods are normally reverted to SIMPLE
687     assertAuthEquals(SIMPLE,    getAuthMethod(KERBEROS, SIMPLE));
688     assertAuthEquals(SIMPLE,    getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER));
689   }
690 
691   @Test
692   public void testNoClientFallbackToSimple()
693       throws Exception {
694     clientFallBackToSimpleAllowed = false;
695     // tokens are irrelevant w/o secret manager enabled
696     assertAuthEquals(SIMPLE,     getAuthMethod(SIMPLE, SIMPLE));
697     assertAuthEquals(SIMPLE,     getAuthMethod(SIMPLE, SIMPLE, UseToken.OTHER));
698     assertAuthEquals(SIMPLE,     getAuthMethod(SIMPLE, SIMPLE, UseToken.VALID));
699     assertAuthEquals(SIMPLE,     getAuthMethod(SIMPLE, SIMPLE, UseToken.INVALID));
700 
701     // A secure client must not fallback
702     assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE));
703     assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER));
704     assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID));
705     assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID));
706 
707     // Now set server to simple and also force the secret-manager. Now server
708     // should have both simple and token enabled.
709     forceSecretManager = true;
710     assertAuthEquals(SIMPLE,     getAuthMethod(SIMPLE, SIMPLE));
711     assertAuthEquals(SIMPLE,     getAuthMethod(SIMPLE, SIMPLE, UseToken.OTHER));
712     assertAuthEquals(TOKEN,      getAuthMethod(SIMPLE, SIMPLE, UseToken.VALID));
713     assertAuthEquals(BadToken,   getAuthMethod(SIMPLE, SIMPLE, UseToken.INVALID));
714 
715     // A secure client must not fallback
716     assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE));
717     assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER));
718     assertAuthEquals(TOKEN,      getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID));
719     assertAuthEquals(BadToken,   getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID));
720 
721     // doesn't try SASL
722     assertAuthEquals(Denied(SIMPLE), getAuthMethod(SIMPLE, TOKEN));
723     // does try SASL
724     assertAuthEquals(No(TOKEN),      getAuthMethod(SIMPLE, TOKEN, UseToken.OTHER));
725     assertAuthEquals(TOKEN,          getAuthMethod(SIMPLE, TOKEN, UseToken.VALID));
726     assertAuthEquals(BadToken,       getAuthMethod(SIMPLE, TOKEN, UseToken.INVALID));
727 
728     assertAuthEquals(No(TOKEN),      getAuthMethod(KERBEROS, TOKEN));
729     assertAuthEquals(No(TOKEN),      getAuthMethod(KERBEROS, TOKEN, UseToken.OTHER));
730     assertAuthEquals(TOKEN,          getAuthMethod(KERBEROS, TOKEN, UseToken.VALID));
731     assertAuthEquals(BadToken,       getAuthMethod(KERBEROS, TOKEN, UseToken.INVALID));
732   }
733 
734   @Test
735   public void testSimpleServerWithTokens() throws Exception {
736     // Client not using tokens
737     assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE));
738     // SASL methods are reverted to SIMPLE
739     assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE));
740 
741     // Use tokens. But tokens are ignored because client is reverted to simple
742     // due to server not using tokens
743     assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID));
744     assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER));
745 
746     // server isn't really advertising tokens
747     enableSecretManager = true;
748     assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE, UseToken.VALID));
749     assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE, UseToken.OTHER));
750 
751     assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID));
752     assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER));
753 
754     // now the simple server takes tokens
755     forceSecretManager = true;
756     assertAuthEquals(TOKEN,  getAuthMethod(SIMPLE,   SIMPLE, UseToken.VALID));
757     assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE, UseToken.OTHER));
758 
759     assertAuthEquals(TOKEN,  getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID));
760     assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER));
761   }
762 
763   @Test
764   public void testSimpleServerWithInvalidTokens() throws Exception {
765     // Tokens are ignored because client is reverted to simple
766     assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE, UseToken.INVALID));
767     assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID));
768     enableSecretManager = true;
769     assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE, UseToken.INVALID));
770     assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID));
771     forceSecretManager = true;
772     assertAuthEquals(BadToken, getAuthMethod(SIMPLE,   SIMPLE, UseToken.INVALID));
773     assertAuthEquals(BadToken, getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID));
774   }
775 
776   /*
777    *  token server
778    */
779   @Test
780   public void testTokenOnlyServer() throws Exception {
781     // simple client w/o tokens won't try SASL, so server denies
782     assertAuthEquals(Denied(SIMPLE), getAuthMethod(SIMPLE,   TOKEN));
783     assertAuthEquals(No(TOKEN),      getAuthMethod(SIMPLE,   TOKEN, UseToken.OTHER));
784     assertAuthEquals(No(TOKEN),      getAuthMethod(KERBEROS, TOKEN));
785     assertAuthEquals(No(TOKEN),      getAuthMethod(KERBEROS, TOKEN, UseToken.OTHER));
786   }
787 
788   @Test
789   public void testTokenOnlyServerWithTokens() throws Exception {
790     assertAuthEquals(TOKEN,       getAuthMethod(SIMPLE,   TOKEN, UseToken.VALID));
791     assertAuthEquals(TOKEN,       getAuthMethod(KERBEROS, TOKEN, UseToken.VALID));
792     enableSecretManager = false;
793     assertAuthEquals(NoTokenAuth, getAuthMethod(SIMPLE,   TOKEN, UseToken.VALID));
794     assertAuthEquals(NoTokenAuth, getAuthMethod(KERBEROS, TOKEN, UseToken.VALID));
795   }
796 
797   @Test
798   public void testTokenOnlyServerWithInvalidTokens() throws Exception {
799     assertAuthEquals(BadToken,    getAuthMethod(SIMPLE,   TOKEN, UseToken.INVALID));
800     assertAuthEquals(BadToken,    getAuthMethod(KERBEROS, TOKEN, UseToken.INVALID));
801     enableSecretManager = false;
802     assertAuthEquals(NoTokenAuth, getAuthMethod(SIMPLE,   TOKEN, UseToken.INVALID));
803     assertAuthEquals(NoTokenAuth, getAuthMethod(KERBEROS, TOKEN, UseToken.INVALID));
804   }
805 
806   /*
807    * kerberos server
808    */
809   @Test
810   public void testKerberosServer() throws Exception {
811     // doesn't try SASL
812     assertAuthEquals(Denied(SIMPLE),     getAuthMethod(SIMPLE,   KERBEROS));
813     // does try SASL
814     assertAuthEquals(No(TOKEN,KERBEROS), getAuthMethod(SIMPLE,   KERBEROS, UseToken.OTHER));
815     // no tgt
816     assertAuthEquals(KrbFailed,          getAuthMethod(KERBEROS, KERBEROS));
817     assertAuthEquals(KrbFailed,          getAuthMethod(KERBEROS, KERBEROS, UseToken.OTHER));
818   }
819 
820   @Test
821   public void testKerberosServerWithTokens() throws Exception {
822     // can use tokens regardless of auth
823     assertAuthEquals(TOKEN,        getAuthMethod(SIMPLE,   KERBEROS, UseToken.VALID));
824     assertAuthEquals(TOKEN,        getAuthMethod(KERBEROS, KERBEROS, UseToken.VALID));
825     enableSecretManager = false;
826     // shouldn't even try token because server didn't tell us to
827     assertAuthEquals(No(KERBEROS), getAuthMethod(SIMPLE,   KERBEROS, UseToken.VALID));
828     assertAuthEquals(KrbFailed,    getAuthMethod(KERBEROS, KERBEROS, UseToken.VALID));
829   }
830 
831   @Test
832   public void testKerberosServerWithInvalidTokens() throws Exception {
833     assertAuthEquals(BadToken,     getAuthMethod(SIMPLE,   KERBEROS, UseToken.INVALID));
834     assertAuthEquals(BadToken,     getAuthMethod(KERBEROS, KERBEROS, UseToken.INVALID));
835     enableSecretManager = false;
836     assertAuthEquals(No(KERBEROS), getAuthMethod(SIMPLE,   KERBEROS, UseToken.INVALID));
837     assertAuthEquals(KrbFailed,    getAuthMethod(KERBEROS, KERBEROS, UseToken.INVALID));
838   }
839 
840 
841   // test helpers
842 
843   private String getAuthMethod(
844       final AuthMethod clientAuth,
845       final AuthMethod serverAuth) throws Exception {
846     try {
847       return internalGetAuthMethod(clientAuth, serverAuth, UseToken.NONE);
848     } catch (Exception e) {
849       LOG.warn("Auth method failure", e);
850       return e.toString();
851     }
852   }
853 
854   private String getAuthMethod(
855       final AuthMethod clientAuth,
856       final AuthMethod serverAuth,
857       final UseToken tokenType) throws Exception {
858     try {
859       return internalGetAuthMethod(clientAuth, serverAuth, tokenType);
860     } catch (Exception e) {
861       LOG.warn("Auth method failure", e);
862       return e.toString();
863     }
864   }
865 
866   private String internalGetAuthMethod(
867       final AuthMethod clientAuth,
868       final AuthMethod serverAuth,
869       final UseToken tokenType) throws Exception {
870 
871     final Configuration serverConf = new Configuration(conf);
872     serverConf.set(HADOOP_SECURITY_AUTHENTICATION, serverAuth.toString());
873     UserGroupInformation.setConfiguration(serverConf);
874 
875     final UserGroupInformation serverUgi = (serverAuth == KERBEROS)
876         ? UserGroupInformation.createRemoteUser("server/localhost@NONE")
877         : UserGroupInformation.createRemoteUser("server");
878     serverUgi.setAuthenticationMethod(serverAuth);
879 
880     final TestTokenSecretManager sm = new TestTokenSecretManager();
881     boolean useSecretManager = (serverAuth != SIMPLE);
882     if (enableSecretManager != null) {
883       useSecretManager &= enableSecretManager.booleanValue();
884     }
885     if (forceSecretManager != null) {
886       useSecretManager |= forceSecretManager.booleanValue();
887     }
888     final SecretManager<?> serverSm = useSecretManager ? sm : null;
889 
890     Server server = serverUgi.doAs(new PrivilegedExceptionAction<Server>() {
891       @Override
892       public Server run() throws IOException {
893         Server server = new RPC.Builder(serverConf)
894         .setProtocol(TestSaslProtocol.class)
895         .setInstance(new TestSaslImpl()).setBindAddress(ADDRESS).setPort(0)
896         .setNumHandlers(5).setVerbose(true)
897         .setSecretManager(serverSm)
898         .build();
899         server.start();
900         return server;
901       }
902     });
903 
904     final Configuration clientConf = new Configuration(conf);
905     clientConf.set(HADOOP_SECURITY_AUTHENTICATION, clientAuth.toString());
906     clientConf.setBoolean(
907         CommonConfigurationKeys.IPC_CLIENT_FALLBACK_TO_SIMPLE_AUTH_ALLOWED_KEY,
908         clientFallBackToSimpleAllowed);
909     UserGroupInformation.setConfiguration(clientConf);
910 
911     final UserGroupInformation clientUgi =
912         UserGroupInformation.createRemoteUser("client");
913     clientUgi.setAuthenticationMethod(clientAuth);
914 
915     final InetSocketAddress addr = NetUtils.getConnectAddress(server);
916     if (tokenType != UseToken.NONE) {
917       TestTokenIdentifier tokenId = new TestTokenIdentifier(
918           new Text(clientUgi.getUserName()));
919       Token<TestTokenIdentifier> token = null;
920       switch (tokenType) {
921         case VALID:
922           token = new Token<TestTokenIdentifier>(tokenId, sm);
923           SecurityUtil.setTokenService(token, addr);
924           break;
925         case INVALID:
926           token = new Token<TestTokenIdentifier>(
927               tokenId.getBytes(), "bad-password!".getBytes(),
928               tokenId.getKind(), null);
929           SecurityUtil.setTokenService(token, addr);
930           break;
931         case OTHER:
932           token = new Token<TestTokenIdentifier>();
933           break;
934         case NONE: // won't get here
935       }
936       clientUgi.addToken(token);
937     }
938 
939     try {
940       LOG.info("trying ugi:"+clientUgi+" tokens:"+clientUgi.getTokens());
941       return clientUgi.doAs(new PrivilegedExceptionAction<String>() {
942         @Override
943         public String run() throws IOException {
944           TestSaslProtocol proxy = null;
945           try {
946             proxy = RPC.getProxy(TestSaslProtocol.class,
947                 TestSaslProtocol.versionID, addr, clientConf);
948 
949             proxy.ping();
950             // make sure the other side thinks we are who we said we are!!!
951             assertEquals(clientUgi.getUserName(), proxy.getAuthUser());
952             AuthMethod authMethod = proxy.getAuthMethod();
953             // verify sasl completed with correct QOP
954             assertEquals((authMethod != SIMPLE) ? expectedQop.saslQop : null,
955                          RPC.getConnectionIdForProxy(proxy).getSaslQop());
956             return authMethod.toString();
957           } finally {
958             if (proxy != null) {
959               RPC.stopProxy(proxy);
960             }
961           }
962         }
963       });
964     } finally {
965       server.stop();
966     }
967   }
968 
969   private static void assertAuthEquals(AuthMethod expect,
970       String actual) {
971     assertEquals(expect.toString(), actual);
972   }
973 
974   private static void assertAuthEquals(Pattern expect,
975       String actual) {
976     // this allows us to see the regexp and the value it didn't match
977     if (!expect.matcher(actual).matches()) {
978       assertEquals(expect, actual); // it failed
979     } else {
980       assertTrue(true); // it matched
981     }
982   }
983 
984   /*
985    * Class used to test overriding QOP values using SaslPropertiesResolver
986    */
987   static class AuthSaslPropertiesResolver extends SaslPropertiesResolver{
988 
989     @Override
990     public Map<String, String> getServerProperties(InetAddress address) {
991       Map<String, String> newPropertes = new HashMap<String, String>(getDefaultProperties());
992       newPropertes.put(Sasl.QOP, QualityOfProtection.AUTHENTICATION.getSaslQop());
993       return newPropertes;
994     }
995   }
996 
997   public static void main(String[] args) throws Exception {
998     System.out.println("Testing Kerberos authentication over RPC");
999     if (args.length != 2) {
1000       System.err
1001           .println("Usage: java <options> org.apache.hadoop.ipc.TestSaslRPC "
1002               + " <serverPrincipal> <keytabFile>");
1003       System.exit(-1);
1004     }
1005     String principal = args[0];
1006     String keytab = args[1];
1007     testKerberosRpc(principal, keytab);
1008   }
1009 }
1010