1package ldaputil 2 3import ( 4 "bytes" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/binary" 8 "fmt" 9 "math" 10 "net" 11 "net/url" 12 "strings" 13 "text/template" 14 15 "github.com/go-ldap/ldap" 16 "github.com/hashicorp/errwrap" 17 hclog "github.com/hashicorp/go-hclog" 18 multierror "github.com/hashicorp/go-multierror" 19 "github.com/hashicorp/vault/sdk/helper/tlsutil" 20) 21 22type Client struct { 23 Logger hclog.Logger 24 LDAP LDAP 25} 26 27func (c *Client) DialLDAP(cfg *ConfigEntry) (Connection, error) { 28 var retErr *multierror.Error 29 var conn Connection 30 urls := strings.Split(cfg.Url, ",") 31 for _, uut := range urls { 32 u, err := url.Parse(uut) 33 if err != nil { 34 retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("error parsing url %q: {{err}}", uut), err)) 35 continue 36 } 37 host, port, err := net.SplitHostPort(u.Host) 38 if err != nil { 39 host = u.Host 40 } 41 42 var tlsConfig *tls.Config 43 switch u.Scheme { 44 case "ldap": 45 if port == "" { 46 port = "389" 47 } 48 conn, err = c.LDAP.Dial("tcp", net.JoinHostPort(host, port)) 49 if err != nil { 50 break 51 } 52 if conn == nil { 53 err = fmt.Errorf("empty connection after dialing") 54 break 55 } 56 if cfg.StartTLS { 57 tlsConfig, err = getTLSConfig(cfg, host) 58 if err != nil { 59 break 60 } 61 err = conn.StartTLS(tlsConfig) 62 } 63 case "ldaps": 64 if port == "" { 65 port = "636" 66 } 67 tlsConfig, err = getTLSConfig(cfg, host) 68 if err != nil { 69 break 70 } 71 conn, err = c.LDAP.DialTLS("tcp", net.JoinHostPort(host, port), tlsConfig) 72 default: 73 retErr = multierror.Append(retErr, fmt.Errorf("invalid LDAP scheme in url %q", net.JoinHostPort(host, port))) 74 continue 75 } 76 if err == nil { 77 if retErr != nil { 78 if c.Logger.IsDebug() { 79 c.Logger.Debug("errors connecting to some hosts: %s", retErr.Error()) 80 } 81 } 82 retErr = nil 83 break 84 } 85 retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("error connecting to host %q: {{err}}", uut), err)) 86 } 87 88 return conn, retErr.ErrorOrNil() 89} 90 91/* 92 * Discover and return the bind string for the user attempting to authenticate. 93 * This is handled in one of several ways: 94 * 95 * 1. If DiscoverDN is set, the user object will be searched for using userdn (base search path) 96 * and userattr (the attribute that maps to the provided username). 97 * The bind will either be anonymous or use binddn and bindpassword if they were provided. 98 * 2. If upndomain is set, the user dn is constructed as 'username@upndomain'. See https://msdn.microsoft.com/en-us/library/cc223499.aspx 99 * 100 */ 101func (c *Client) GetUserBindDN(cfg *ConfigEntry, conn Connection, username string) (string, error) { 102 bindDN := "" 103 // Note: The logic below drives the logic in ConfigEntry.Validate(). 104 // If updated, please update there as well. 105 if cfg.DiscoverDN || (cfg.BindDN != "" && cfg.BindPassword != "") { 106 var err error 107 if cfg.BindPassword != "" { 108 err = conn.Bind(cfg.BindDN, cfg.BindPassword) 109 } else { 110 err = conn.UnauthenticatedBind(cfg.BindDN) 111 } 112 if err != nil { 113 return bindDN, errwrap.Wrapf("LDAP bind (service) failed: {{err}}", err) 114 } 115 116 filter := fmt.Sprintf("(%s=%s)", cfg.UserAttr, ldap.EscapeFilter(username)) 117 if c.Logger.IsDebug() { 118 c.Logger.Debug("discovering user", "userdn", cfg.UserDN, "filter", filter) 119 } 120 result, err := conn.Search(&ldap.SearchRequest{ 121 BaseDN: cfg.UserDN, 122 Scope: ldap.ScopeWholeSubtree, 123 Filter: filter, 124 SizeLimit: math.MaxInt32, 125 }) 126 if err != nil { 127 return bindDN, errwrap.Wrapf("LDAP search for binddn failed: {{err}}", err) 128 } 129 if len(result.Entries) != 1 { 130 return bindDN, fmt.Errorf("LDAP search for binddn 0 or not unique") 131 } 132 bindDN = result.Entries[0].DN 133 } else { 134 if cfg.UPNDomain != "" { 135 bindDN = fmt.Sprintf("%s@%s", EscapeLDAPValue(username), cfg.UPNDomain) 136 } else { 137 bindDN = fmt.Sprintf("%s=%s,%s", cfg.UserAttr, EscapeLDAPValue(username), cfg.UserDN) 138 } 139 } 140 141 return bindDN, nil 142} 143 144/* 145 * Returns the DN of the object representing the authenticated user. 146 */ 147func (c *Client) GetUserDN(cfg *ConfigEntry, conn Connection, bindDN string) (string, error) { 148 userDN := "" 149 if cfg.UPNDomain != "" { 150 // Find the distinguished name for the user if userPrincipalName used for login 151 filter := fmt.Sprintf("(userPrincipalName=%s)", ldap.EscapeFilter(bindDN)) 152 if c.Logger.IsDebug() { 153 c.Logger.Debug("searching upn", "userdn", cfg.UserDN, "filter", filter) 154 } 155 result, err := conn.Search(&ldap.SearchRequest{ 156 BaseDN: cfg.UserDN, 157 Scope: ldap.ScopeWholeSubtree, 158 Filter: filter, 159 SizeLimit: math.MaxInt32, 160 }) 161 if err != nil { 162 return userDN, errwrap.Wrapf("LDAP search failed for detecting user: {{err}}", err) 163 } 164 for _, e := range result.Entries { 165 userDN = e.DN 166 } 167 } else { 168 userDN = bindDN 169 } 170 171 return userDN, nil 172} 173 174func (c *Client) performLdapFilterGroupsSearch(cfg *ConfigEntry, conn Connection, userDN string, username string) ([]*ldap.Entry, error) { 175 if cfg.GroupFilter == "" { 176 c.Logger.Warn("groupfilter is empty, will not query server") 177 return make([]*ldap.Entry, 0), nil 178 } 179 180 if cfg.GroupDN == "" { 181 c.Logger.Warn("groupdn is empty, will not query server") 182 return make([]*ldap.Entry, 0), nil 183 } 184 185 // If groupfilter was defined, resolve it as a Go template and use the query for 186 // returning the user's groups 187 if c.Logger.IsDebug() { 188 c.Logger.Debug("compiling group filter", "group_filter", cfg.GroupFilter) 189 } 190 191 // Parse the configuration as a template. 192 // Example template "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))" 193 t, err := template.New("queryTemplate").Parse(cfg.GroupFilter) 194 if err != nil { 195 return nil, errwrap.Wrapf("LDAP search failed due to template compilation error: {{err}}", err) 196 } 197 198 // Build context to pass to template - we will be exposing UserDn and Username. 199 context := struct { 200 UserDN string 201 Username string 202 }{ 203 ldap.EscapeFilter(userDN), 204 ldap.EscapeFilter(username), 205 } 206 207 var renderedQuery bytes.Buffer 208 t.Execute(&renderedQuery, context) 209 210 if c.Logger.IsDebug() { 211 c.Logger.Debug("searching", "groupdn", cfg.GroupDN, "rendered_query", renderedQuery.String()) 212 } 213 214 result, err := conn.Search(&ldap.SearchRequest{ 215 BaseDN: cfg.GroupDN, 216 Scope: ldap.ScopeWholeSubtree, 217 Filter: renderedQuery.String(), 218 Attributes: []string{ 219 cfg.GroupAttr, 220 }, 221 SizeLimit: math.MaxInt32, 222 }) 223 if err != nil { 224 return nil, errwrap.Wrapf("LDAP search failed: {{err}}", err) 225 } 226 227 return result.Entries, nil 228} 229 230func sidBytesToString(b []byte) (string, error) { 231 reader := bytes.NewReader(b) 232 233 var revision, subAuthorityCount uint8 234 var identifierAuthorityParts [3]uint16 235 236 if err := binary.Read(reader, binary.LittleEndian, &revision); err != nil { 237 return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading Revision: {{err}}", b), err) 238 } 239 240 if err := binary.Read(reader, binary.LittleEndian, &subAuthorityCount); err != nil { 241 return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading SubAuthorityCount: {{err}}", b), err) 242 } 243 244 if err := binary.Read(reader, binary.BigEndian, &identifierAuthorityParts); err != nil { 245 return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading IdentifierAuthority: {{err}}", b), err) 246 } 247 identifierAuthority := (uint64(identifierAuthorityParts[0]) << 32) + (uint64(identifierAuthorityParts[1]) << 16) + uint64(identifierAuthorityParts[2]) 248 249 subAuthority := make([]uint32, subAuthorityCount) 250 if err := binary.Read(reader, binary.LittleEndian, &subAuthority); err != nil { 251 return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading SubAuthority: {{err}}", b), err) 252 } 253 254 result := fmt.Sprintf("S-%d-%d", revision, identifierAuthority) 255 for _, subAuthorityPart := range subAuthority { 256 result += fmt.Sprintf("-%d", subAuthorityPart) 257 } 258 259 return result, nil 260} 261 262func (c *Client) performLdapTokenGroupsSearch(cfg *ConfigEntry, conn Connection, userDN string) ([]*ldap.Entry, error) { 263 result, err := conn.Search(&ldap.SearchRequest{ 264 BaseDN: userDN, 265 Scope: ldap.ScopeBaseObject, 266 Filter: "(objectClass=*)", 267 Attributes: []string{ 268 "tokenGroups", 269 }, 270 SizeLimit: 1, 271 }) 272 if err != nil { 273 return nil, errwrap.Wrapf("LDAP search failed: {{err}}", err) 274 } 275 if len(result.Entries) == 0 { 276 c.Logger.Warn("unable to read object for group attributes", "userdn", userDN, "groupattr", cfg.GroupAttr) 277 return make([]*ldap.Entry, 0), nil 278 } 279 280 userEntry := result.Entries[0] 281 groupAttrValues := userEntry.GetRawAttributeValues("tokenGroups") 282 283 groupEntries := make([]*ldap.Entry, 0, len(groupAttrValues)) 284 for _, sidBytes := range groupAttrValues { 285 sidString, err := sidBytesToString(sidBytes) 286 if err != nil { 287 c.Logger.Warn("unable to read sid", "err", err) 288 continue 289 } 290 291 groupResult, err := conn.Search(&ldap.SearchRequest{ 292 BaseDN: fmt.Sprintf("<SID=%s>", sidString), 293 Scope: ldap.ScopeBaseObject, 294 Filter: "(objectClass=*)", 295 Attributes: []string{ 296 "1.1", // RFC no attributes 297 }, 298 SizeLimit: 1, 299 }) 300 if err != nil { 301 c.Logger.Warn("unable to read the group sid", "sid", sidString) 302 continue 303 } 304 if len(groupResult.Entries) == 0 { 305 c.Logger.Warn("unable to find the group", "sid", sidString) 306 continue 307 } 308 309 groupEntries = append(groupEntries, groupResult.Entries[0]) 310 } 311 312 return groupEntries, nil 313} 314 315/* 316 * getLdapGroups queries LDAP and returns a slice describing the set of groups the authenticated user is a member of. 317 * 318 * If cfg.UseTokenGroups is true then the search is performed directly on the userDN. 319 * The values of those attributes are converted to string SIDs, and then looked up to get ldap.Entry objects. 320 * Otherwise, the search query is constructed according to cfg.GroupFilter, and run in context of cfg.GroupDN. 321 * Groups will be resolved from the query results by following the attribute defined in cfg.GroupAttr. 322 * 323 * cfg.GroupFilter is a go template and is compiled with the following context: [UserDN, Username] 324 * UserDN - The DN of the authenticated user 325 * Username - The Username of the authenticated user 326 * 327 * Example: 328 * cfg.GroupFilter = "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))" 329 * cfg.GroupDN = "OU=Groups,DC=myorg,DC=com" 330 * cfg.GroupAttr = "cn" 331 * 332 * NOTE - If cfg.GroupFilter is empty, no query is performed and an empty result slice is returned. 333 * 334 */ 335func (c *Client) GetLdapGroups(cfg *ConfigEntry, conn Connection, userDN string, username string) ([]string, error) { 336 var entries []*ldap.Entry 337 var err error 338 if cfg.UseTokenGroups { 339 entries, err = c.performLdapTokenGroupsSearch(cfg, conn, userDN) 340 } else { 341 entries, err = c.performLdapFilterGroupsSearch(cfg, conn, userDN, username) 342 } 343 if err != nil { 344 return nil, err 345 } 346 347 // retrieve the groups in a string/bool map as a structure to avoid duplicates inside 348 ldapMap := make(map[string]bool) 349 350 for _, e := range entries { 351 dn, err := ldap.ParseDN(e.DN) 352 if err != nil || len(dn.RDNs) == 0 { 353 continue 354 } 355 356 // Enumerate attributes of each result, parse out CN and add as group 357 values := e.GetAttributeValues(cfg.GroupAttr) 358 if len(values) > 0 { 359 for _, val := range values { 360 groupCN := getCN(val) 361 ldapMap[groupCN] = true 362 } 363 } else { 364 // If groupattr didn't resolve, use self (enumerating group objects) 365 groupCN := getCN(e.DN) 366 ldapMap[groupCN] = true 367 } 368 } 369 370 ldapGroups := make([]string, 0, len(ldapMap)) 371 for key := range ldapMap { 372 ldapGroups = append(ldapGroups, key) 373 } 374 375 return ldapGroups, nil 376} 377 378// EscapeLDAPValue is exported because a plugin uses it outside this package. 379func EscapeLDAPValue(input string) string { 380 if input == "" { 381 return "" 382 } 383 384 // RFC4514 forbids un-escaped: 385 // - leading space or hash 386 // - trailing space 387 // - special characters '"', '+', ',', ';', '<', '>', '\\' 388 // - null 389 for i := 0; i < len(input); i++ { 390 escaped := false 391 if input[i] == '\\' { 392 i++ 393 escaped = true 394 } 395 switch input[i] { 396 case '"', '+', ',', ';', '<', '>', '\\': 397 if !escaped { 398 input = input[0:i] + "\\" + input[i:] 399 i++ 400 } 401 continue 402 } 403 if escaped { 404 input = input[0:i] + "\\" + input[i:] 405 i++ 406 } 407 } 408 if input[0] == ' ' || input[0] == '#' { 409 input = "\\" + input 410 } 411 if input[len(input)-1] == ' ' { 412 input = input[0:len(input)-1] + "\\ " 413 } 414 return input 415} 416 417/* 418 * Parses a distinguished name and returns the CN portion. 419 * Given a non-conforming string (such as an already-extracted CN), 420 * it will be returned as-is. 421 */ 422func getCN(dn string) string { 423 parsedDN, err := ldap.ParseDN(dn) 424 if err != nil || len(parsedDN.RDNs) == 0 { 425 // It was already a CN, return as-is 426 return dn 427 } 428 429 for _, rdn := range parsedDN.RDNs { 430 for _, rdnAttr := range rdn.Attributes { 431 if strings.EqualFold(rdnAttr.Type, "CN") { 432 return rdnAttr.Value 433 } 434 } 435 } 436 437 // Default, return self 438 return dn 439} 440 441func getTLSConfig(cfg *ConfigEntry, host string) (*tls.Config, error) { 442 tlsConfig := &tls.Config{ 443 ServerName: host, 444 } 445 446 if cfg.TLSMinVersion != "" { 447 tlsMinVersion, ok := tlsutil.TLSLookup[cfg.TLSMinVersion] 448 if !ok { 449 return nil, fmt.Errorf("invalid 'tls_min_version' in config") 450 } 451 tlsConfig.MinVersion = tlsMinVersion 452 } 453 454 if cfg.TLSMaxVersion != "" { 455 tlsMaxVersion, ok := tlsutil.TLSLookup[cfg.TLSMaxVersion] 456 if !ok { 457 return nil, fmt.Errorf("invalid 'tls_max_version' in config") 458 } 459 tlsConfig.MaxVersion = tlsMaxVersion 460 } 461 462 if cfg.InsecureTLS { 463 tlsConfig.InsecureSkipVerify = true 464 } 465 if cfg.Certificate != "" { 466 caPool := x509.NewCertPool() 467 ok := caPool.AppendCertsFromPEM([]byte(cfg.Certificate)) 468 if !ok { 469 return nil, fmt.Errorf("could not append CA certificate") 470 } 471 tlsConfig.RootCAs = caPool 472 } 473 return tlsConfig, nil 474} 475