1 /*
2  * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 import java.io.Closeable;
25 import java.io.IOException;
26 import java.io.ObjectInputStream;
27 import java.io.ObjectOutputStream;
28 import java.io.Serializable;
29 import java.net.ServerSocket;
30 import java.net.Socket;
31 import java.net.UnknownHostException;
32 import java.util.ArrayList;
33 import java.util.Arrays;
34 import java.util.HashMap;
35 import java.util.Map;
36 import java.util.StringJoiner;
37 import javax.security.auth.callback.Callback;
38 import javax.security.auth.callback.CallbackHandler;
39 import javax.security.auth.callback.NameCallback;
40 import javax.security.auth.callback.PasswordCallback;
41 import javax.security.auth.callback.UnsupportedCallbackException;
42 import javax.security.sasl.AuthorizeCallback;
43 import javax.security.sasl.RealmCallback;
44 import javax.security.sasl.RealmChoiceCallback;
45 import javax.security.sasl.Sasl;
46 import javax.security.sasl.SaslClient;
47 import javax.security.sasl.SaslException;
48 import javax.security.sasl.SaslServer;
49 
50 /*
51  * @test
52  * @bug 8049814
53  * @summary JAVA SASL server and client tests with CRAM-MD5 and
54  *          DIGEST-MD5 mechanisms. The tests try different QOP values on
55  *          client and server side.
56  * @modules java.security.sasl/javax.security.sasl
57  */
58 public class ClientServerTest {
59 
60     private static final int DELAY = 100;
61     private static final String LOCALHOST = "localhost";
62     private static final String DIGEST_MD5 = "DIGEST-MD5";
63     private static final String CRAM_MD5 = "CRAM-MD5";
64     private static final String PROTOCOL = "saslservice";
65     private static final String USER_ID = "sasltester";
66     private static final String PASSWD = "password";
67     private static final String QOP_AUTH = "auth";
68     private static final String QOP_AUTH_CONF = "auth-conf";
69     private static final String QOP_AUTH_INT = "auth-int";
70     private static final String AUTHID_SASL_TESTER = "sasl_tester";
71     private static final ArrayList<String> SUPPORT_MECHS = new ArrayList<>();
72 
73     static {
74         SUPPORT_MECHS.add(DIGEST_MD5);
75         SUPPORT_MECHS.add(CRAM_MD5);
76     }
77 
main(String[] args)78     public static void main(String[] args) throws Exception {
79         String[] allQops = { QOP_AUTH_CONF, QOP_AUTH_INT, QOP_AUTH };
80         String[] twoQops = { QOP_AUTH_INT, QOP_AUTH };
81         String[] authQop = { QOP_AUTH };
82         String[] authIntQop = { QOP_AUTH_INT };
83         String[] authConfQop = { QOP_AUTH_CONF };
84         String[] emptyQop = {};
85 
86         boolean success = true;
87 
88         success &= runTest("", CRAM_MD5, new String[] { QOP_AUTH },
89                 new String[] { QOP_AUTH }, false);
90         success &= runTest("", DIGEST_MD5, new String[] { QOP_AUTH },
91                 new String[] { QOP_AUTH }, false);
92         success &= runTest(AUTHID_SASL_TESTER, DIGEST_MD5,
93                 new String[] { QOP_AUTH }, new String[] { QOP_AUTH }, false);
94         success &= runTest("", DIGEST_MD5, allQops, authQop, false);
95         success &= runTest("", DIGEST_MD5, allQops, authIntQop, false);
96         success &= runTest("", DIGEST_MD5, allQops, authConfQop, false);
97         success &= runTest("", DIGEST_MD5, twoQops, authQop, false);
98         success &= runTest("", DIGEST_MD5, twoQops, authIntQop, false);
99         success &= runTest("", DIGEST_MD5, twoQops, authConfQop, true);
100         success &= runTest("", DIGEST_MD5, authIntQop, authQop, true);
101         success &= runTest("", DIGEST_MD5, authConfQop, authQop, true);
102         success &= runTest("", DIGEST_MD5, authConfQop, emptyQop, true);
103         success &= runTest("", DIGEST_MD5, authIntQop, emptyQop, true);
104         success &= runTest("", DIGEST_MD5, authQop, emptyQop, true);
105 
106         if (!success) {
107             throw new RuntimeException("At least one test case failed");
108         }
109 
110         System.out.println("Test passed");
111     }
112 
runTest(String authId, String mech, String[] clientQops, String[] serverQops, boolean expectException)113     private static boolean runTest(String authId, String mech,
114             String[] clientQops, String[] serverQops, boolean expectException)
115             throws Exception {
116 
117         System.out.println("AuthId:" + authId
118                 + " mechanism:" + mech
119                 + " clientQops: " + Arrays.toString(clientQops)
120                 + " serverQops: " + Arrays.toString(serverQops)
121                 + " expect exception:" + expectException);
122 
123         try (Server server = Server.start(LOCALHOST, authId, serverQops)) {
124             new Client(LOCALHOST, server.getPort(), mech, authId, clientQops)
125                     .run();
126             if (expectException) {
127                 System.out.println("Expected exception not thrown");
128                 return false;
129             }
130         } catch (SaslException e) {
131             if (!expectException) {
132                 System.out.println("Unexpected exception: " + e);
133                 return false;
134             }
135             System.out.println("Expected exception: " + e);
136         }
137 
138         return true;
139     }
140 
141     static enum SaslStatus {
142         SUCCESS, FAILURE, CONTINUE
143     }
144 
145     static class Message implements Serializable {
146 
147         private final SaslStatus status;
148         private final byte[] data;
149 
Message(SaslStatus status, byte[] data)150         public Message(SaslStatus status, byte[] data) {
151             this.status = status;
152             this.data = data;
153         }
154 
getStatus()155         public SaslStatus getStatus() {
156             return status;
157         }
158 
getData()159         public byte[] getData() {
160             return data;
161         }
162     }
163 
164     static class SaslPeer {
165 
166         final String host;
167         final String mechanism;
168         final String qop;
169         final CallbackHandler callback;
170 
SaslPeer(String host, String authId, String... qops)171         SaslPeer(String host, String authId, String... qops) {
172             this(host, null, authId, qops);
173         }
174 
SaslPeer(String host, String mechanism, String authId, String... qops)175         SaslPeer(String host, String mechanism, String authId, String... qops) {
176             this.host = host;
177             this.mechanism = mechanism;
178 
179             StringJoiner sj = new StringJoiner(",");
180             for (String q : qops) {
181                 sj.add(q);
182             }
183             qop = sj.toString();
184 
185             callback = new TestCallbackHandler(USER_ID, PASSWD, host, authId);
186         }
187 
getMessage(Object ob)188         Message getMessage(Object ob) {
189             if (!(ob instanceof Message)) {
190                 throw new RuntimeException("Expected an instance of Message");
191             }
192             return (Message) ob;
193         }
194     }
195 
196     static class Server extends SaslPeer implements Runnable, Closeable {
197 
198         private volatile boolean ready = false;
199         private volatile ServerSocket ssocket;
200 
start(String host, String authId, String[] serverQops)201         static Server start(String host, String authId, String[] serverQops)
202                 throws UnknownHostException {
203             Server server = new Server(host, authId, serverQops);
204             Thread thread = new Thread(server);
205             thread.setDaemon(true);
206             thread.start();
207 
208             while (!server.ready) {
209                 try {
210                     Thread.sleep(DELAY);
211                 } catch (InterruptedException e) {
212                     throw new RuntimeException(e);
213                 }
214             }
215 
216             return server;
217         }
218 
Server(String host, String authId, String... qops)219         Server(String host, String authId, String... qops) {
220             super(host, authId, qops);
221         }
222 
getPort()223         int getPort() {
224             return ssocket.getLocalPort();
225         }
226 
processConnection(SaslEndpoint endpoint)227         private void processConnection(SaslEndpoint endpoint)
228                 throws SaslException, IOException, ClassNotFoundException {
229             System.out.println("process connection");
230             endpoint.send(SUPPORT_MECHS);
231             Object o = endpoint.receive();
232             if (!(o instanceof String)) {
233                 throw new RuntimeException("Received unexpected object: " + o);
234             }
235             String mech = (String) o;
236             SaslServer saslServer = createSaslServer(mech);
237             Message msg = getMessage(endpoint.receive());
238             while (!saslServer.isComplete()) {
239                 byte[] data = processData(msg.getData(), endpoint,
240                         saslServer);
241                 if (saslServer.isComplete()) {
242                     System.out.println("server is complete");
243                     endpoint.send(new Message(SaslStatus.SUCCESS, data));
244                 } else {
245                     System.out.println("server continues");
246                     endpoint.send(new Message(SaslStatus.CONTINUE, data));
247                     msg = getMessage(endpoint.receive());
248                 }
249             }
250         }
251 
processData(byte[] data, SaslEndpoint endpoint, SaslServer server)252         private byte[] processData(byte[] data, SaslEndpoint endpoint,
253                 SaslServer server) throws SaslException, IOException {
254             try {
255                 return server.evaluateResponse(data);
256             } catch (SaslException e) {
257                 endpoint.send(new Message(SaslStatus.FAILURE, null));
258                 System.out.println("Error while processing data");
259                 throw e;
260             }
261         }
262 
createSaslServer(String mechanism)263         private SaslServer createSaslServer(String mechanism)
264                 throws SaslException {
265             Map<String, String> props = new HashMap<>();
266             props.put(Sasl.QOP, qop);
267             return Sasl.createSaslServer(mechanism, PROTOCOL, host, props,
268                     callback);
269         }
270 
271         @Override
run()272         public void run() {
273             try (ServerSocket ss = new ServerSocket(0)) {
274                 ssocket = ss;
275                 System.out.println("server started on port " + getPort());
276                 ready = true;
277                 Socket socket = ss.accept();
278                 try (SaslEndpoint endpoint = new SaslEndpoint(socket)) {
279                     System.out.println("server accepted connection");
280                     processConnection(endpoint);
281                 }
282             } catch (Exception e) {
283                 // ignore it for now, client will throw an exception
284             }
285         }
286 
287         @Override
close()288         public void close() throws IOException {
289             if (!ssocket.isClosed()) {
290                 ssocket.close();
291             }
292         }
293     }
294 
295     static class Client extends SaslPeer {
296 
297         private final int port;
298 
Client(String host, int port, String mech, String authId, String... qops)299         Client(String host, int port, String mech, String authId,
300                 String... qops) {
301             super(host, mech, authId, qops);
302             this.port = port;
303         }
304 
run()305         public void run() throws Exception {
306             System.out.println("Host:" + host + " port: "
307                     + port);
308             try (SaslEndpoint endpoint = SaslEndpoint.create(host, port)) {
309                 negotiateMechanism(endpoint);
310                 SaslClient client = createSaslClient();
311                 byte[] data = new byte[0];
312                 if (client.hasInitialResponse()) {
313                     data = client.evaluateChallenge(data);
314                 }
315                 endpoint.send(new Message(SaslStatus.CONTINUE, data));
316                 Message msg = getMessage(endpoint.receive());
317                 while (!client.isComplete()
318                         && msg.getStatus() != SaslStatus.FAILURE) {
319                     switch (msg.getStatus()) {
320                         case CONTINUE:
321                             System.out.println("client continues");
322                             data = client.evaluateChallenge(msg.getData());
323                             endpoint.send(new Message(SaslStatus.CONTINUE,
324                                     data));
325                             msg = getMessage(endpoint.receive());
326                             break;
327                         case SUCCESS:
328                             System.out.println("client succeeded");
329                             data = client.evaluateChallenge(msg.getData());
330                             if (data != null) {
331                                 throw new SaslException("data should be null");
332                             }
333                             break;
334                         default:
335                             throw new RuntimeException("Wrong status:"
336                                     + msg.getStatus());
337                     }
338                 }
339 
340                 if (msg.getStatus() == SaslStatus.FAILURE) {
341                     throw new RuntimeException("Status is FAILURE");
342                 }
343             }
344 
345             System.out.println("Done");
346         }
347 
createSaslClient()348         private SaslClient createSaslClient() throws SaslException {
349             Map<String, String> props = new HashMap<>();
350             props.put(Sasl.QOP, qop);
351             return Sasl.createSaslClient(new String[] {mechanism}, USER_ID,
352                     PROTOCOL, host, props, callback);
353         }
354 
negotiateMechanism(SaslEndpoint endpoint)355         private void negotiateMechanism(SaslEndpoint endpoint)
356                 throws ClassNotFoundException, IOException {
357             Object o = endpoint.receive();
358             if (o instanceof ArrayList) {
359                 ArrayList list = (ArrayList) o;
360                 if (!list.contains(mechanism)) {
361                     throw new RuntimeException(
362                             "Server does not support specified mechanism:"
363                                     + mechanism);
364                 }
365             } else {
366                 throw new RuntimeException(
367                         "Expected an instance of ArrayList, but received " + o);
368             }
369 
370             endpoint.send(mechanism);
371         }
372 
373     }
374 
375     static class SaslEndpoint implements AutoCloseable {
376 
377         private final Socket socket;
378         private ObjectInputStream input;
379         private ObjectOutputStream output;
380 
create(String host, int port)381         static SaslEndpoint create(String host, int port) throws IOException {
382             return new SaslEndpoint(new Socket(host, port));
383         }
384 
SaslEndpoint(Socket socket)385         SaslEndpoint(Socket socket) throws IOException {
386             this.socket = socket;
387         }
388 
getInput()389         private ObjectInputStream getInput() throws IOException {
390             if (input == null && socket != null) {
391                 input = new ObjectInputStream(socket.getInputStream());
392             }
393             return input;
394         }
395 
getOutput()396         private ObjectOutputStream getOutput() throws IOException {
397             if (output == null && socket != null) {
398                 output = new ObjectOutputStream(socket.getOutputStream());
399             }
400             return output;
401         }
402 
receive()403         public Object receive() throws IOException, ClassNotFoundException {
404             return getInput().readObject();
405         }
406 
send(Object obj)407         public void send(Object obj) throws IOException {
408             getOutput().writeObject(obj);
409             getOutput().flush();
410         }
411 
412         @Override
close()413         public void close() throws IOException {
414             if (socket != null && !socket.isClosed()) {
415                 socket.close();
416             }
417         }
418 
419     }
420 
421     static class TestCallbackHandler implements CallbackHandler {
422 
423         private final String userId;
424         private final char[] passwd;
425         private final String realm;
426         private String authId;
427 
TestCallbackHandler(String userId, String passwd, String realm, String authId)428         TestCallbackHandler(String userId, String passwd, String realm,
429                 String authId) {
430             this.userId = userId;
431             this.passwd = passwd.toCharArray();
432             this.realm = realm;
433             this.authId = authId;
434         }
435 
436         @Override
handle(Callback[] callbacks)437         public void handle(Callback[] callbacks) throws IOException,
438                 UnsupportedCallbackException {
439             for (Callback callback : callbacks) {
440                 if (callback instanceof NameCallback) {
441                     System.out.println("NameCallback");
442                     ((NameCallback) callback).setName(userId);
443                 } else if (callback instanceof PasswordCallback) {
444                     System.out.println("PasswordCallback");
445                     ((PasswordCallback) callback).setPassword(passwd);
446                 } else if (callback instanceof RealmCallback) {
447                     System.out.println("RealmCallback");
448                     ((RealmCallback) callback).setText(realm);
449                 } else if (callback instanceof RealmChoiceCallback) {
450                     System.out.println("RealmChoiceCallback");
451                     RealmChoiceCallback choice = (RealmChoiceCallback) callback;
452                     if (realm == null) {
453                         choice.setSelectedIndex(choice.getDefaultChoice());
454                     } else {
455                         String[] choices = choice.getChoices();
456                         for (int j = 0; j < choices.length; j++) {
457                             if (realm.equals(choices[j])) {
458                                 choice.setSelectedIndex(j);
459                                 break;
460                             }
461                         }
462                     }
463                 } else if (callback instanceof AuthorizeCallback) {
464                     System.out.println("AuthorizeCallback");
465                     ((AuthorizeCallback) callback).setAuthorized(true);
466                     if (authId == null || authId.trim().length() == 0) {
467                         authId = userId;
468                     }
469                     ((AuthorizeCallback) callback).setAuthorizedID(authId);
470                 } else {
471                     throw new UnsupportedCallbackException(callback);
472                 }
473             }
474         }
475     }
476 
477 }
478