1 /*
2  * Copyright (c) 2002, 2018, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.  Oracle designates this
8  * particular file as subject to the "Classpath" exception as provided
9  * by Oracle in the LICENSE file that accompanied this code.
10  *
11  * This code is distributed in the hope that it will be useful, but WITHOUT
12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14  * version 2 for more details (a copy is included in the LICENSE file that
15  * accompanied this code).
16  *
17  * You should have received a copy of the GNU General Public License version
18  * 2 along with this work; if not, write to the Free Software Foundation,
19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20  *
21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22  * or visit www.oracle.com if you need additional information or have any
23  * questions.
24  */
25 
26 package com.sun.jndi.ldap;
27 
28 import java.util.*;
29 
30 import javax.naming.*;
31 import javax.naming.directory.*;
32 import javax.naming.spi.NamingManager;
33 import javax.naming.ldap.LdapName;
34 import javax.naming.ldap.Rdn;
35 
36 /**
37  * This class discovers the location of LDAP services by querying DNS.
38  * See http://www.ietf.org/internet-drafts/draft-ietf-ldapext-locate-07.txt
39  */
40 
41 class ServiceLocator {
42 
43     private static final String SRV_RR = "SRV";
44 
45     private static final String[] SRV_RR_ATTR = new String[]{SRV_RR};
46 
47     private static final Random random = new Random();
48 
ServiceLocator()49     private ServiceLocator() {
50     }
51 
52     /**
53      * Maps a distinguished name (RFC 2253) to a fully qualified domain name.
54      * Processes a sequence of RDNs having a DC attribute.
55      * The special RDN "DC=." denotes the root of the domain tree.
56      * Multi-valued RDNs, non-DC attributes, binary-valued attributes and the
57      * RDN "DC=." all reset the domain name and processing continues.
58      *
59      * @param dn A string distinguished name (RFC 2253).
60      * @return A domain name or null if none can be derived.
61      * @throw InvalidNameException If the distinugished name is invalid.
62      */
mapDnToDomainName(String dn)63     static String mapDnToDomainName(String dn) throws InvalidNameException {
64         if (dn == null) {
65             return null;
66         }
67         StringBuffer domain = new StringBuffer();
68         LdapName ldapName = new LdapName(dn);
69 
70         // process RDNs left-to-right
71         //List<Rdn> rdnList = ldapName.getRdns();
72 
73         List<Rdn> rdnList = ldapName.getRdns();
74         for (int i = rdnList.size() - 1; i >= 0; i--) {
75             //Rdn rdn = rdnList.get(i);
76             Rdn rdn = rdnList.get(i);
77 
78             // single-valued RDN with a DC attribute
79             if ((rdn.size() == 1) &&
80                 ("dc".equalsIgnoreCase(rdn.getType()) )) {
81                 Object attrval = rdn.getValue();
82                 if (attrval instanceof String) {
83                     if (attrval.equals(".") ||
84                         (domain.length() == 1 && domain.charAt(0) == '.')) {
85                         domain.setLength(0); // reset (when current or previous
86                                              //        RDN value is "DC=.")
87                     }
88                     if (domain.length() > 0) {
89                         domain.append('.');
90                     }
91                     domain.append(attrval);
92                 } else {
93                     domain.setLength(0); // reset (when binary-valued attribute)
94                 }
95             } else {
96                 domain.setLength(0); // reset (when multi-valued RDN or non-DC)
97             }
98         }
99         return (domain.length() != 0) ? domain.toString() : null;
100     }
101 
102     /**
103      * Locates the LDAP service for a given domain.
104      * Queries DNS for a list of LDAP Service Location Records (SRV) for a
105      * given domain name.
106      *
107      * @param domainName A string domain name.
108      * @param environment The possibly null environment of the context.
109      * @return An ordered list of hostports for the LDAP service or null if
110      *         the service has not been located.
111      */
getLdapService(String domainName, Map<?,?> environment)112     static String[] getLdapService(String domainName, Map<?,?> environment) {
113         if (environment instanceof Hashtable) {
114             return getLdapService(domainName, (Hashtable)environment);
115         }
116         return getLdapService(domainName, new Hashtable<>(environment));
117     }
118 
119     /**
120      * Locates the LDAP service for a given domain.
121      * Queries DNS for a list of LDAP Service Location Records (SRV) for a
122      * given domain name.
123      *
124      * @param domainName A string domain name.
125      * @param environment The possibly null environment of the context.
126      * @return An ordered list of hostports for the LDAP service or null if
127      *         the service has not been located.
128      */
getLdapService(String domainName, Hashtable<?,?> environment)129     static String[] getLdapService(String domainName, Hashtable<?,?> environment) {
130 
131         if (domainName == null || domainName.length() == 0) {
132             return null;
133         }
134 
135         String dnsUrl = "dns:///_ldap._tcp." + domainName;
136         String[] hostports = null;
137 
138         try {
139             // Create the DNS context using NamingManager rather than using
140             // the initial context constructor. This avoids having the initial
141             // context constructor call itself (when processing the URL
142             // argument in the getAttributes call).
143             Context ctx = NamingManager.getURLContext("dns", environment);
144             if (!(ctx instanceof DirContext)) {
145                 return null; // cannot create a DNS context
146             }
147             Attributes attrs =
148                 ((DirContext)ctx).getAttributes(dnsUrl, SRV_RR_ATTR);
149             Attribute attr;
150 
151             if (attrs != null && ((attr = attrs.get(SRV_RR)) != null)) {
152                 int numValues = attr.size();
153                 int numRecords = 0;
154                 SrvRecord[] srvRecords = new SrvRecord[numValues];
155 
156                 // create the service records
157                 int i = 0;
158                 int j = 0;
159                 while (i < numValues) {
160                     try {
161                         srvRecords[j] = new SrvRecord((String) attr.get(i));
162                         j++;
163                     } catch (Exception e) {
164                         // ignore bad value
165                     }
166                     i++;
167                 }
168                 numRecords = j;
169 
170                 // trim
171                 if (numRecords < numValues) {
172                     SrvRecord[] trimmed = new SrvRecord[numRecords];
173                     System.arraycopy(srvRecords, 0, trimmed, 0, numRecords);
174                     srvRecords = trimmed;
175                 }
176 
177                 // Sort the service records in ascending order of their
178                 // priority value. For records with equal priority, move
179                 // those with weight 0 to the top of the list.
180                 if (numRecords > 1) {
181                     Arrays.sort(srvRecords);
182                 }
183 
184                 // extract the host and port number from each service record
185                 hostports = extractHostports(srvRecords);
186             }
187         } catch (NamingException e) {
188             // ignore
189         }
190         return hostports;
191     }
192 
193     /**
194      * Extract hosts and port numbers from a list of SRV records.
195      * An array of hostports is returned or null if none were found.
196      */
extractHostports(SrvRecord[] srvRecords)197     private static String[] extractHostports(SrvRecord[] srvRecords) {
198         String[] hostports = null;
199 
200         int head = 0;
201         int tail = 0;
202         int sublistLength = 0;
203         int k = 0;
204         for (int i = 0; i < srvRecords.length; i++) {
205             if (hostports == null) {
206                 hostports = new String[srvRecords.length];
207             }
208             // find the head and tail of the list of records having the same
209             // priority value.
210             head = i;
211             while (i < srvRecords.length - 1 &&
212                 srvRecords[i].priority == srvRecords[i + 1].priority) {
213                 i++;
214             }
215             tail = i;
216 
217             // select hostports from the sublist
218             sublistLength = (tail - head) + 1;
219             for (int j = 0; j < sublistLength; j++) {
220                 hostports[k++] = selectHostport(srvRecords, head, tail);
221             }
222         }
223         return hostports;
224     }
225 
226     /*
227      * Randomly select a service record in the range [head, tail] and return
228      * its hostport value. Follows the algorithm in RFC 2782.
229      */
selectHostport(SrvRecord[] srvRecords, int head, int tail)230     private static String selectHostport(SrvRecord[] srvRecords, int head,
231             int tail) {
232         if (head == tail) {
233             return srvRecords[head].hostport;
234         }
235 
236         // compute the running sum for records between head and tail
237         int sum = 0;
238         for (int i = head; i <= tail; i++) {
239             if (srvRecords[i] != null) {
240                 sum += srvRecords[i].weight;
241                 srvRecords[i].sum = sum;
242             }
243         }
244         String hostport = null;
245 
246         // If all records have zero weight, select first available one;
247         // otherwise, randomly select a record according to its weight
248         int target = (sum == 0 ? 0 : random.nextInt(sum + 1));
249         for (int i = head; i <= tail; i++) {
250             if (srvRecords[i] != null && srvRecords[i].sum >= target) {
251                 hostport = srvRecords[i].hostport;
252                 srvRecords[i] = null; // make this record unavailable
253                 break;
254             }
255         }
256         return hostport;
257     }
258 
259 /**
260  * This class holds a DNS service (SRV) record.
261  * See http://www.ietf.org/rfc/rfc2782.txt
262  */
263 
264 static class SrvRecord implements Comparable<SrvRecord> {
265 
266     int priority;
267     int weight;
268     int sum;
269     String hostport;
270 
271     /**
272      * Creates a service record object from a string record.
273      * DNS supplies the string record in the following format:
274      * <pre>
275      *     <Priority> " " <Weight> " " <Port> " " <Host>
276      * </pre>
277      */
SrvRecord(String srvRecord)278     SrvRecord(String srvRecord) throws Exception {
279         StringTokenizer tokenizer = new StringTokenizer(srvRecord, " ");
280         String port;
281 
282         if (tokenizer.countTokens() == 4) {
283             priority = Integer.parseInt(tokenizer.nextToken());
284             weight = Integer.parseInt(tokenizer.nextToken());
285             port = tokenizer.nextToken();
286             hostport = tokenizer.nextToken() + ":" + port;
287         } else {
288             throw new IllegalArgumentException();
289         }
290     }
291 
292     /*
293      * Sort records in ascending order of priority value. For records with
294      * equal priority move those with weight 0 to the top of the list.
295      */
compareTo(SrvRecord that)296     public int compareTo(SrvRecord that) {
297         if (priority > that.priority) {
298             return 1; // this > that
299         } else if (priority < that.priority) {
300             return -1; // this < that
301         } else if (weight == 0 && that.weight != 0) {
302             return -1; // this < that
303         } else if (weight != 0 && that.weight == 0) {
304             return 1; // this > that
305         } else {
306             return 0; // this == that
307         }
308     }
309 }
310 }
311