1package ldaputil
2
3import (
4	"bytes"
5	"crypto/tls"
6	"crypto/x509"
7	"encoding/binary"
8	"fmt"
9	"math"
10	"net"
11	"net/url"
12	"strings"
13	"text/template"
14
15	"github.com/go-ldap/ldap"
16	"github.com/hashicorp/errwrap"
17	hclog "github.com/hashicorp/go-hclog"
18	multierror "github.com/hashicorp/go-multierror"
19	"github.com/hashicorp/vault/sdk/helper/tlsutil"
20)
21
22type Client struct {
23	Logger hclog.Logger
24	LDAP   LDAP
25}
26
27func (c *Client) DialLDAP(cfg *ConfigEntry) (Connection, error) {
28	var retErr *multierror.Error
29	var conn Connection
30	urls := strings.Split(cfg.Url, ",")
31	for _, uut := range urls {
32		u, err := url.Parse(uut)
33		if err != nil {
34			retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("error parsing url %q: {{err}}", uut), err))
35			continue
36		}
37		host, port, err := net.SplitHostPort(u.Host)
38		if err != nil {
39			host = u.Host
40		}
41
42		var tlsConfig *tls.Config
43		switch u.Scheme {
44		case "ldap":
45			if port == "" {
46				port = "389"
47			}
48			conn, err = c.LDAP.Dial("tcp", net.JoinHostPort(host, port))
49			if err != nil {
50				break
51			}
52			if conn == nil {
53				err = fmt.Errorf("empty connection after dialing")
54				break
55			}
56			if cfg.StartTLS {
57				tlsConfig, err = getTLSConfig(cfg, host)
58				if err != nil {
59					break
60				}
61				err = conn.StartTLS(tlsConfig)
62			}
63		case "ldaps":
64			if port == "" {
65				port = "636"
66			}
67			tlsConfig, err = getTLSConfig(cfg, host)
68			if err != nil {
69				break
70			}
71			conn, err = c.LDAP.DialTLS("tcp", net.JoinHostPort(host, port), tlsConfig)
72		default:
73			retErr = multierror.Append(retErr, fmt.Errorf("invalid LDAP scheme in url %q", net.JoinHostPort(host, port)))
74			continue
75		}
76		if err == nil {
77			if retErr != nil {
78				if c.Logger.IsDebug() {
79					c.Logger.Debug("errors connecting to some hosts: %s", retErr.Error())
80				}
81			}
82			retErr = nil
83			break
84		}
85		retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("error connecting to host %q: {{err}}", uut), err))
86	}
87
88	return conn, retErr.ErrorOrNil()
89}
90
91/*
92 * Discover and return the bind string for the user attempting to authenticate.
93 * This is handled in one of several ways:
94 *
95 * 1. If DiscoverDN is set, the user object will be searched for using userdn (base search path)
96 *    and userattr (the attribute that maps to the provided username).
97 *    The bind will either be anonymous or use binddn and bindpassword if they were provided.
98 * 2. If upndomain is set, the user dn is constructed as 'username@upndomain'. See https://msdn.microsoft.com/en-us/library/cc223499.aspx
99 *
100 */
101func (c *Client) GetUserBindDN(cfg *ConfigEntry, conn Connection, username string) (string, error) {
102	bindDN := ""
103	// Note: The logic below drives the logic in ConfigEntry.Validate().
104	// If updated, please update there as well.
105	if cfg.DiscoverDN || (cfg.BindDN != "" && cfg.BindPassword != "") {
106		var err error
107		if cfg.BindPassword != "" {
108			err = conn.Bind(cfg.BindDN, cfg.BindPassword)
109		} else {
110			err = conn.UnauthenticatedBind(cfg.BindDN)
111		}
112		if err != nil {
113			return bindDN, errwrap.Wrapf("LDAP bind (service) failed: {{err}}", err)
114		}
115
116		filter := fmt.Sprintf("(%s=%s)", cfg.UserAttr, ldap.EscapeFilter(username))
117		if c.Logger.IsDebug() {
118			c.Logger.Debug("discovering user", "userdn", cfg.UserDN, "filter", filter)
119		}
120		result, err := conn.Search(&ldap.SearchRequest{
121			BaseDN:    cfg.UserDN,
122			Scope:     ldap.ScopeWholeSubtree,
123			Filter:    filter,
124			SizeLimit: math.MaxInt32,
125		})
126		if err != nil {
127			return bindDN, errwrap.Wrapf("LDAP search for binddn failed: {{err}}", err)
128		}
129		if len(result.Entries) != 1 {
130			return bindDN, fmt.Errorf("LDAP search for binddn 0 or not unique")
131		}
132		bindDN = result.Entries[0].DN
133	} else {
134		if cfg.UPNDomain != "" {
135			bindDN = fmt.Sprintf("%s@%s", EscapeLDAPValue(username), cfg.UPNDomain)
136		} else {
137			bindDN = fmt.Sprintf("%s=%s,%s", cfg.UserAttr, EscapeLDAPValue(username), cfg.UserDN)
138		}
139	}
140
141	return bindDN, nil
142}
143
144/*
145 * Returns the DN of the object representing the authenticated user.
146 */
147func (c *Client) GetUserDN(cfg *ConfigEntry, conn Connection, bindDN string) (string, error) {
148	userDN := ""
149	if cfg.UPNDomain != "" {
150		// Find the distinguished name for the user if userPrincipalName used for login
151		filter := fmt.Sprintf("(userPrincipalName=%s)", ldap.EscapeFilter(bindDN))
152		if c.Logger.IsDebug() {
153			c.Logger.Debug("searching upn", "userdn", cfg.UserDN, "filter", filter)
154		}
155		result, err := conn.Search(&ldap.SearchRequest{
156			BaseDN:    cfg.UserDN,
157			Scope:     ldap.ScopeWholeSubtree,
158			Filter:    filter,
159			SizeLimit: math.MaxInt32,
160		})
161		if err != nil {
162			return userDN, errwrap.Wrapf("LDAP search failed for detecting user: {{err}}", err)
163		}
164		for _, e := range result.Entries {
165			userDN = e.DN
166		}
167	} else {
168		userDN = bindDN
169	}
170
171	return userDN, nil
172}
173
174func (c *Client) performLdapFilterGroupsSearch(cfg *ConfigEntry, conn Connection, userDN string, username string) ([]*ldap.Entry, error) {
175	if cfg.GroupFilter == "" {
176		c.Logger.Warn("groupfilter is empty, will not query server")
177		return make([]*ldap.Entry, 0), nil
178	}
179
180	if cfg.GroupDN == "" {
181		c.Logger.Warn("groupdn is empty, will not query server")
182		return make([]*ldap.Entry, 0), nil
183	}
184
185	// If groupfilter was defined, resolve it as a Go template and use the query for
186	// returning the user's groups
187	if c.Logger.IsDebug() {
188		c.Logger.Debug("compiling group filter", "group_filter", cfg.GroupFilter)
189	}
190
191	// Parse the configuration as a template.
192	// Example template "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))"
193	t, err := template.New("queryTemplate").Parse(cfg.GroupFilter)
194	if err != nil {
195		return nil, errwrap.Wrapf("LDAP search failed due to template compilation error: {{err}}", err)
196	}
197
198	// Build context to pass to template - we will be exposing UserDn and Username.
199	context := struct {
200		UserDN   string
201		Username string
202	}{
203		ldap.EscapeFilter(userDN),
204		ldap.EscapeFilter(username),
205	}
206
207	var renderedQuery bytes.Buffer
208	t.Execute(&renderedQuery, context)
209
210	if c.Logger.IsDebug() {
211		c.Logger.Debug("searching", "groupdn", cfg.GroupDN, "rendered_query", renderedQuery.String())
212	}
213
214	result, err := conn.Search(&ldap.SearchRequest{
215		BaseDN: cfg.GroupDN,
216		Scope:  ldap.ScopeWholeSubtree,
217		Filter: renderedQuery.String(),
218		Attributes: []string{
219			cfg.GroupAttr,
220		},
221		SizeLimit: math.MaxInt32,
222	})
223	if err != nil {
224		return nil, errwrap.Wrapf("LDAP search failed: {{err}}", err)
225	}
226
227	return result.Entries, nil
228}
229
230func sidBytesToString(b []byte) (string, error) {
231	reader := bytes.NewReader(b)
232
233	var revision, subAuthorityCount uint8
234	var identifierAuthorityParts [3]uint16
235
236	if err := binary.Read(reader, binary.LittleEndian, &revision); err != nil {
237		return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading Revision: {{err}}", b), err)
238	}
239
240	if err := binary.Read(reader, binary.LittleEndian, &subAuthorityCount); err != nil {
241		return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading SubAuthorityCount: {{err}}", b), err)
242	}
243
244	if err := binary.Read(reader, binary.BigEndian, &identifierAuthorityParts); err != nil {
245		return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading IdentifierAuthority: {{err}}", b), err)
246	}
247	identifierAuthority := (uint64(identifierAuthorityParts[0]) << 32) + (uint64(identifierAuthorityParts[1]) << 16) + uint64(identifierAuthorityParts[2])
248
249	subAuthority := make([]uint32, subAuthorityCount)
250	if err := binary.Read(reader, binary.LittleEndian, &subAuthority); err != nil {
251		return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading SubAuthority: {{err}}", b), err)
252	}
253
254	result := fmt.Sprintf("S-%d-%d", revision, identifierAuthority)
255	for _, subAuthorityPart := range subAuthority {
256		result += fmt.Sprintf("-%d", subAuthorityPart)
257	}
258
259	return result, nil
260}
261
262func (c *Client) performLdapTokenGroupsSearch(cfg *ConfigEntry, conn Connection, userDN string) ([]*ldap.Entry, error) {
263	result, err := conn.Search(&ldap.SearchRequest{
264		BaseDN: userDN,
265		Scope:  ldap.ScopeBaseObject,
266		Filter: "(objectClass=*)",
267		Attributes: []string{
268			"tokenGroups",
269		},
270		SizeLimit: 1,
271	})
272	if err != nil {
273		return nil, errwrap.Wrapf("LDAP search failed: {{err}}", err)
274	}
275	if len(result.Entries) == 0 {
276		c.Logger.Warn("unable to read object for group attributes", "userdn", userDN, "groupattr", cfg.GroupAttr)
277		return make([]*ldap.Entry, 0), nil
278	}
279
280	userEntry := result.Entries[0]
281	groupAttrValues := userEntry.GetRawAttributeValues("tokenGroups")
282
283	groupEntries := make([]*ldap.Entry, 0, len(groupAttrValues))
284	for _, sidBytes := range groupAttrValues {
285		sidString, err := sidBytesToString(sidBytes)
286		if err != nil {
287			c.Logger.Warn("unable to read sid", "err", err)
288			continue
289		}
290
291		groupResult, err := conn.Search(&ldap.SearchRequest{
292			BaseDN: fmt.Sprintf("<SID=%s>", sidString),
293			Scope:  ldap.ScopeBaseObject,
294			Filter: "(objectClass=*)",
295			Attributes: []string{
296				"1.1", // RFC no attributes
297			},
298			SizeLimit: 1,
299		})
300		if err != nil {
301			c.Logger.Warn("unable to read the group sid", "sid", sidString)
302			continue
303		}
304		if len(groupResult.Entries) == 0 {
305			c.Logger.Warn("unable to find the group", "sid", sidString)
306			continue
307		}
308
309		groupEntries = append(groupEntries, groupResult.Entries[0])
310	}
311
312	return groupEntries, nil
313}
314
315/*
316 * getLdapGroups queries LDAP and returns a slice describing the set of groups the authenticated user is a member of.
317 *
318 * If cfg.UseTokenGroups is true then the search is performed directly on the userDN.
319 * The values of those attributes are converted to string SIDs, and then looked up to get ldap.Entry objects.
320 * Otherwise, the search query is constructed according to cfg.GroupFilter, and run in context of cfg.GroupDN.
321 * Groups will be resolved from the query results by following the attribute defined in cfg.GroupAttr.
322 *
323 * cfg.GroupFilter is a go template and is compiled with the following context: [UserDN, Username]
324 *    UserDN - The DN of the authenticated user
325 *    Username - The Username of the authenticated user
326 *
327 * Example:
328 *   cfg.GroupFilter = "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))"
329 *   cfg.GroupDN     = "OU=Groups,DC=myorg,DC=com"
330 *   cfg.GroupAttr   = "cn"
331 *
332 * NOTE - If cfg.GroupFilter is empty, no query is performed and an empty result slice is returned.
333 *
334 */
335func (c *Client) GetLdapGroups(cfg *ConfigEntry, conn Connection, userDN string, username string) ([]string, error) {
336	var entries []*ldap.Entry
337	var err error
338	if cfg.UseTokenGroups {
339		entries, err = c.performLdapTokenGroupsSearch(cfg, conn, userDN)
340	} else {
341		entries, err = c.performLdapFilterGroupsSearch(cfg, conn, userDN, username)
342	}
343	if err != nil {
344		return nil, err
345	}
346
347	// retrieve the groups in a string/bool map as a structure to avoid duplicates inside
348	ldapMap := make(map[string]bool)
349
350	for _, e := range entries {
351		dn, err := ldap.ParseDN(e.DN)
352		if err != nil || len(dn.RDNs) == 0 {
353			continue
354		}
355
356		// Enumerate attributes of each result, parse out CN and add as group
357		values := e.GetAttributeValues(cfg.GroupAttr)
358		if len(values) > 0 {
359			for _, val := range values {
360				groupCN := getCN(val)
361				ldapMap[groupCN] = true
362			}
363		} else {
364			// If groupattr didn't resolve, use self (enumerating group objects)
365			groupCN := getCN(e.DN)
366			ldapMap[groupCN] = true
367		}
368	}
369
370	ldapGroups := make([]string, 0, len(ldapMap))
371	for key := range ldapMap {
372		ldapGroups = append(ldapGroups, key)
373	}
374
375	return ldapGroups, nil
376}
377
378// EscapeLDAPValue is exported because a plugin uses it outside this package.
379func EscapeLDAPValue(input string) string {
380	if input == "" {
381		return ""
382	}
383
384	// RFC4514 forbids un-escaped:
385	// - leading space or hash
386	// - trailing space
387	// - special characters '"', '+', ',', ';', '<', '>', '\\'
388	// - null
389	for i := 0; i < len(input); i++ {
390		escaped := false
391		if input[i] == '\\' {
392			i++
393			escaped = true
394		}
395		switch input[i] {
396		case '"', '+', ',', ';', '<', '>', '\\':
397			if !escaped {
398				input = input[0:i] + "\\" + input[i:]
399				i++
400			}
401			continue
402		}
403		if escaped {
404			input = input[0:i] + "\\" + input[i:]
405			i++
406		}
407	}
408	if input[0] == ' ' || input[0] == '#' {
409		input = "\\" + input
410	}
411	if input[len(input)-1] == ' ' {
412		input = input[0:len(input)-1] + "\\ "
413	}
414	return input
415}
416
417/*
418 * Parses a distinguished name and returns the CN portion.
419 * Given a non-conforming string (such as an already-extracted CN),
420 * it will be returned as-is.
421 */
422func getCN(dn string) string {
423	parsedDN, err := ldap.ParseDN(dn)
424	if err != nil || len(parsedDN.RDNs) == 0 {
425		// It was already a CN, return as-is
426		return dn
427	}
428
429	for _, rdn := range parsedDN.RDNs {
430		for _, rdnAttr := range rdn.Attributes {
431			if strings.EqualFold(rdnAttr.Type, "CN") {
432				return rdnAttr.Value
433			}
434		}
435	}
436
437	// Default, return self
438	return dn
439}
440
441func getTLSConfig(cfg *ConfigEntry, host string) (*tls.Config, error) {
442	tlsConfig := &tls.Config{
443		ServerName: host,
444	}
445
446	if cfg.TLSMinVersion != "" {
447		tlsMinVersion, ok := tlsutil.TLSLookup[cfg.TLSMinVersion]
448		if !ok {
449			return nil, fmt.Errorf("invalid 'tls_min_version' in config")
450		}
451		tlsConfig.MinVersion = tlsMinVersion
452	}
453
454	if cfg.TLSMaxVersion != "" {
455		tlsMaxVersion, ok := tlsutil.TLSLookup[cfg.TLSMaxVersion]
456		if !ok {
457			return nil, fmt.Errorf("invalid 'tls_max_version' in config")
458		}
459		tlsConfig.MaxVersion = tlsMaxVersion
460	}
461
462	if cfg.InsecureTLS {
463		tlsConfig.InsecureSkipVerify = true
464	}
465	if cfg.Certificate != "" {
466		caPool := x509.NewCertPool()
467		ok := caPool.AppendCertsFromPEM([]byte(cfg.Certificate))
468		if !ok {
469			return nil, fmt.Errorf("could not append CA certificate")
470		}
471		tlsConfig.RootCAs = caPool
472	}
473	return tlsConfig, nil
474}
475