1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4 
5 using System.Collections;
6 using System.Globalization;
7 using System.Runtime.InteropServices;
8 
9 namespace System.DirectoryServices.ActiveDirectory
10 {
11     internal sealed class Locator
12     {
13         // To disable public/protected constructors for this class
Locator()14         private Locator() { }
15 
GetDomainControllerInfo(string computerName, string domainName, string siteName, long flags)16         internal static DomainControllerInfo GetDomainControllerInfo(string computerName, string domainName, string siteName, long flags)
17         {
18             int errorCode = 0;
19             DomainControllerInfo domainControllerInfo;
20 
21             errorCode = DsGetDcNameWrapper(computerName, domainName, siteName, flags, out domainControllerInfo);
22 
23             if (errorCode != 0)
24             {
25                 throw ExceptionHelper.GetExceptionFromErrorCode(errorCode, domainName);
26             }
27 
28             return domainControllerInfo;
29         }
30 
DsGetDcNameWrapper(string computerName, string domainName, string siteName, long flags, out DomainControllerInfo domainControllerInfo)31         internal static int DsGetDcNameWrapper(string computerName, string domainName, string siteName, long flags, out DomainControllerInfo domainControllerInfo)
32         {
33             IntPtr pDomainControllerInfo = IntPtr.Zero;
34             int result = 0;
35 
36             // empty siteName/computerName should be treated as null
37             if ((computerName != null) && (computerName.Length == 0))
38             {
39                 computerName = null;
40             }
41             if ((siteName != null) && (siteName.Length == 0))
42             {
43                 siteName = null;
44             }
45 
46             result = NativeMethods.DsGetDcName(computerName, domainName, IntPtr.Zero, siteName, (int)(flags | (long)PrivateLocatorFlags.ReturnDNSName), out pDomainControllerInfo);
47             if (result == 0)
48             {
49                 try
50                 {
51                     // success case
52                     domainControllerInfo = new DomainControllerInfo();
53                     Marshal.PtrToStructure(pDomainControllerInfo, domainControllerInfo);
54                 }
55                 finally
56                 {
57                     // free the buffer
58                     // what to do with error code??
59                     if (pDomainControllerInfo != IntPtr.Zero)
60                     {
61                         result = NativeMethods.NetApiBufferFree(pDomainControllerInfo);
62                     }
63                 }
64             }
65             else
66             {
67                 domainControllerInfo = new DomainControllerInfo();
68             }
69 
70             return result;
71         }
72 
EnumerateDomainControllers(DirectoryContext context, string domainName, string siteName, long dcFlags)73         internal static ArrayList EnumerateDomainControllers(DirectoryContext context, string domainName, string siteName, long dcFlags)
74         {
75             Hashtable allDCs = null;
76             ArrayList dcs = new ArrayList();
77 
78             //
79             // this api obtains the list of DCs/GCs based on dns records. The DCs/GCs that have registered
80             // non site specific records for the domain/forest are returned. Additonally DCs/GCs that have registered site specific records
81             // (site is either specified or defaulted to the site of the local machine) are also returned in this list.
82             //
83 
84             if (siteName == null)
85             {
86                 //
87                 // if the site name is not specified then we get the site specific records for the local machine's site (in the context of the domain/forest/application partition that is specified)
88                 // (sitename could still be null if the machine is not in any site for the specified domain/forest, in that case we don't look for any site specific records)
89                 //
90                 DomainControllerInfo domainControllerInfo;
91 
92                 int errorCode = DsGetDcNameWrapper(null, domainName, null, dcFlags & (long)(PrivateLocatorFlags.GCRequired | PrivateLocatorFlags.DSWriteableRequired | PrivateLocatorFlags.OnlyLDAPNeeded), out domainControllerInfo);
93                 if (errorCode == 0)
94                 {
95                     siteName = domainControllerInfo.ClientSiteName;
96                 }
97                 else if (errorCode == NativeMethods.ERROR_NO_SUCH_DOMAIN)
98                 {
99                     // return an empty collection
100                     return dcs;
101                 }
102                 else
103                 {
104                     throw ExceptionHelper.GetExceptionFromErrorCode(errorCode);
105                 }
106             }
107 
108             // this will get both the non site specific and the site specific records
109             allDCs = DnsGetDcWrapper(domainName, siteName, dcFlags);
110 
111             foreach (string dcName in allDCs.Keys)
112             {
113                 DirectoryContext dcContext = Utils.GetNewDirectoryContext(dcName, DirectoryContextType.DirectoryServer, context);
114 
115                 if ((dcFlags & (long)PrivateLocatorFlags.GCRequired) != 0)
116                 {
117                     // add a GlobalCatalog object
118                     dcs.Add(new GlobalCatalog(dcContext, dcName));
119                 }
120                 else
121                 {
122                     // add a domain controller object
123                     dcs.Add(new DomainController(dcContext, dcName));
124                 }
125             }
126 
127             return dcs;
128         }
129 
DnsGetDcWrapper(string domainName, string siteName, long dcFlags)130         private static Hashtable DnsGetDcWrapper(string domainName, string siteName, long dcFlags)
131         {
132             Hashtable domainControllers = new Hashtable();
133 
134             int optionFlags = 0;
135             IntPtr retGetDcContext = IntPtr.Zero;
136             IntPtr dcDnsHostNamePtr = IntPtr.Zero;
137             int sockAddressCount = 0;
138             IntPtr sockAddressCountPtr = new IntPtr(sockAddressCount);
139             IntPtr sockAddressList = IntPtr.Zero;
140             string dcDnsHostName = null;
141             int result = 0;
142 
143             result = NativeMethods.DsGetDcOpen(domainName, (int)optionFlags, siteName, IntPtr.Zero, null, (int)dcFlags, out retGetDcContext);
144             if (result == 0)
145             {
146                 try
147                 {
148                     result = NativeMethods.DsGetDcNext(retGetDcContext, ref sockAddressCountPtr, out sockAddressList, out dcDnsHostNamePtr);
149 
150                     if (result != 0 && result != NativeMethods.ERROR_FILE_MARK_DETECTED && result != NativeMethods.DNS_ERROR_RCODE_NAME_ERROR && result != NativeMethods.ERROR_NO_MORE_ITEMS)
151                     {
152                         throw ExceptionHelper.GetExceptionFromErrorCode(result);
153                     }
154 
155                     while (result != NativeMethods.ERROR_NO_MORE_ITEMS)
156                     {
157                         if (result != NativeMethods.ERROR_FILE_MARK_DETECTED && result != NativeMethods.DNS_ERROR_RCODE_NAME_ERROR)
158                         {
159                             try
160                             {
161                                 dcDnsHostName = Marshal.PtrToStringUni(dcDnsHostNamePtr);
162                                 string key = dcDnsHostName.ToLower(CultureInfo.InvariantCulture);
163 
164                                 if (!domainControllers.Contains(key))
165                                 {
166                                     domainControllers.Add(key, null);
167                                 }
168                             }
169                             finally
170                             {
171                                 // what to do with the error?
172                                 if (dcDnsHostNamePtr != IntPtr.Zero)
173                                 {
174                                     result = NativeMethods.NetApiBufferFree(dcDnsHostNamePtr);
175                                 }
176                             }
177                         }
178 
179                         result = NativeMethods.DsGetDcNext(retGetDcContext, ref sockAddressCountPtr, out sockAddressList, out dcDnsHostNamePtr);
180                         if (result != 0 && result != NativeMethods.ERROR_FILE_MARK_DETECTED && result != NativeMethods.DNS_ERROR_RCODE_NAME_ERROR && result != NativeMethods.ERROR_NO_MORE_ITEMS)
181                         {
182                             throw ExceptionHelper.GetExceptionFromErrorCode(result);
183                         }
184                     }
185                 }
186                 finally
187                 {
188                     NativeMethods.DsGetDcClose(retGetDcContext);
189                 }
190             }
191             else if (result != 0)
192             {
193                 throw ExceptionHelper.GetExceptionFromErrorCode(result);
194             }
195 
196             return domainControllers;
197         }
198 
DnsQueryWrapper(string domainName, string siteName, long dcFlags)199         private static Hashtable DnsQueryWrapper(string domainName, string siteName, long dcFlags)
200         {
201             Hashtable domainControllers = new Hashtable();
202             string recordName = "_ldap._tcp.";
203             int result = 0;
204             int options = 0;
205             IntPtr dnsResults = IntPtr.Zero;
206 
207             // construct the record name
208             if ((siteName != null) && (!(siteName.Length == 0)))
209             {
210                 // only looking for domain controllers / global catalogs within a
211                 // particular site
212                 recordName = recordName + siteName + "._sites.";
213             }
214 
215             // check if gc or dc
216             if (((long)dcFlags & (long)(PrivateLocatorFlags.GCRequired)) != 0)
217             {
218                 // global catalog
219                 recordName += "gc._msdcs.";
220             }
221             else if (((long)dcFlags & (long)(PrivateLocatorFlags.DSWriteableRequired)) != 0)
222             {
223                 // domain controller
224                 recordName += "dc._msdcs.";
225             }
226 
227             // now add the domainName
228             recordName = recordName + domainName;
229 
230             // set the BYPASS CACHE option is specified
231             if (((long)dcFlags & (long)LocatorOptions.ForceRediscovery) != 0)
232             {
233                 options |= NativeMethods.DnsQueryBypassCache;
234             }
235 
236             // Call DnsQuery
237             result = NativeMethods.DnsQuery(recordName, NativeMethods.DnsSrvData, options, IntPtr.Zero, out dnsResults, IntPtr.Zero);
238             if (result == 0)
239             {
240                 try
241                 {
242                     IntPtr currentDnsRecord = dnsResults;
243 
244                     while (currentDnsRecord != IntPtr.Zero)
245                     {
246                         // partial marshalling of dns record data
247                         PartialDnsRecord partialDnsRecord = new PartialDnsRecord();
248                         Marshal.PtrToStructure(currentDnsRecord, partialDnsRecord);
249 
250                         //check if the record is of type DNS_SRV_DATA
251                         if (partialDnsRecord.type == NativeMethods.DnsSrvData)
252                         {
253                             // remarshal to get the srv record data
254                             DnsRecord dnsRecord = new DnsRecord();
255                             Marshal.PtrToStructure(currentDnsRecord, dnsRecord);
256                             string targetName = dnsRecord.data.targetName;
257                             string key = targetName.ToLower(CultureInfo.InvariantCulture);
258 
259                             if (!domainControllers.Contains(key))
260                             {
261                                 domainControllers.Add(key, null);
262                             }
263                         }
264                         // move to next record
265                         currentDnsRecord = partialDnsRecord.next;
266                     }
267                 }
268                 finally
269                 {
270                     // release the dns results buffer
271                     if (dnsResults != IntPtr.Zero)
272                     {
273                         NativeMethods.DnsRecordListFree(dnsResults, true);
274                     }
275                 }
276             }
277             else if (result != 0)
278             {
279                 throw ExceptionHelper.GetExceptionFromErrorCode(result);
280             }
281 
282             return domainControllers;
283         }
284     }
285 }
286