1package ldap 2 3import ( 4 "crypto/tls" 5 "errors" 6 "fmt" 7 "log" 8 "net" 9 "net/url" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "gopkg.in/asn1-ber.v1" 15) 16 17const ( 18 // MessageQuit causes the processMessages loop to exit 19 MessageQuit = 0 20 // MessageRequest sends a request to the server 21 MessageRequest = 1 22 // MessageResponse receives a response from the server 23 MessageResponse = 2 24 // MessageFinish indicates the client considers a particular message ID to be finished 25 MessageFinish = 3 26 // MessageTimeout indicates the client-specified timeout for a particular message ID has been reached 27 MessageTimeout = 4 28) 29 30const ( 31 // DefaultLdapPort default ldap port for pure TCP connection 32 DefaultLdapPort = "389" 33 // DefaultLdapsPort default ldap port for SSL connection 34 DefaultLdapsPort = "636" 35) 36 37// PacketResponse contains the packet or error encountered reading a response 38type PacketResponse struct { 39 // Packet is the packet read from the server 40 Packet *ber.Packet 41 // Error is an error encountered while reading 42 Error error 43} 44 45// ReadPacket returns the packet or an error 46func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) { 47 if (pr == nil) || (pr.Packet == nil && pr.Error == nil) { 48 return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response")) 49 } 50 return pr.Packet, pr.Error 51} 52 53type messageContext struct { 54 id int64 55 // close(done) should only be called from finishMessage() 56 done chan struct{} 57 // close(responses) should only be called from processMessages(), and only sent to from sendResponse() 58 responses chan *PacketResponse 59} 60 61// sendResponse should only be called within the processMessages() loop which 62// is also responsible for closing the responses channel. 63func (msgCtx *messageContext) sendResponse(packet *PacketResponse) { 64 select { 65 case msgCtx.responses <- packet: 66 // Successfully sent packet to message handler. 67 case <-msgCtx.done: 68 // The request handler is done and will not receive more 69 // packets. 70 } 71} 72 73type messagePacket struct { 74 Op int 75 MessageID int64 76 Packet *ber.Packet 77 Context *messageContext 78} 79 80type sendMessageFlags uint 81 82const ( 83 startTLS sendMessageFlags = 1 << iota 84) 85 86// Conn represents an LDAP Connection 87type Conn struct { 88 // requestTimeout is loaded atomically 89 // so we need to ensure 64-bit alignment on 32-bit platforms. 90 requestTimeout int64 91 conn net.Conn 92 isTLS bool 93 closing uint32 94 closeErr atomic.Value 95 isStartingTLS bool 96 Debug debugging 97 chanConfirm chan struct{} 98 messageContexts map[int64]*messageContext 99 chanMessage chan *messagePacket 100 chanMessageID chan int64 101 wgClose sync.WaitGroup 102 outstandingRequests uint 103 messageMutex sync.Mutex 104} 105 106var _ Client = &Conn{} 107 108// DefaultTimeout is a package-level variable that sets the timeout value 109// used for the Dial and DialTLS methods. 110// 111// WARNING: since this is a package-level variable, setting this value from 112// multiple places will probably result in undesired behaviour. 113var DefaultTimeout = 60 * time.Second 114 115// Dial connects to the given address on the given network using net.Dial 116// and then returns a new Conn for the connection. 117func Dial(network, addr string) (*Conn, error) { 118 c, err := net.DialTimeout(network, addr, DefaultTimeout) 119 if err != nil { 120 return nil, NewError(ErrorNetwork, err) 121 } 122 conn := NewConn(c, false) 123 conn.Start() 124 return conn, nil 125} 126 127// DialTLS connects to the given address on the given network using tls.Dial 128// and then returns a new Conn for the connection. 129func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { 130 c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config) 131 if err != nil { 132 return nil, NewError(ErrorNetwork, err) 133 } 134 conn := NewConn(c, true) 135 conn.Start() 136 return conn, nil 137} 138 139// DialURL connects to the given ldap URL vie TCP using tls.Dial or net.Dial if ldaps:// 140// or ldap:// specified as protocol. On success a new Conn for the connection 141// is returned. 142func DialURL(addr string) (*Conn, error) { 143 144 lurl, err := url.Parse(addr) 145 if err != nil { 146 return nil, NewError(ErrorNetwork, err) 147 } 148 149 host, port, err := net.SplitHostPort(lurl.Host) 150 if err != nil { 151 // we asume that error is due to missing port 152 host = lurl.Host 153 port = "" 154 } 155 156 switch lurl.Scheme { 157 case "ldap": 158 if port == "" { 159 port = DefaultLdapPort 160 } 161 return Dial("tcp", net.JoinHostPort(host, port)) 162 case "ldaps": 163 if port == "" { 164 port = DefaultLdapsPort 165 } 166 tlsConf := &tls.Config{ 167 ServerName: host, 168 } 169 return DialTLS("tcp", net.JoinHostPort(host, port), tlsConf) 170 } 171 172 return nil, NewError(ErrorNetwork, fmt.Errorf("Unknown scheme '%s'", lurl.Scheme)) 173} 174 175// NewConn returns a new Conn using conn for network I/O. 176func NewConn(conn net.Conn, isTLS bool) *Conn { 177 return &Conn{ 178 conn: conn, 179 chanConfirm: make(chan struct{}), 180 chanMessageID: make(chan int64), 181 chanMessage: make(chan *messagePacket, 10), 182 messageContexts: map[int64]*messageContext{}, 183 requestTimeout: 0, 184 isTLS: isTLS, 185 } 186} 187 188// Start initializes goroutines to read responses and process messages 189func (l *Conn) Start() { 190 go l.reader() 191 go l.processMessages() 192 l.wgClose.Add(1) 193} 194 195// IsClosing returns whether or not we're currently closing. 196func (l *Conn) IsClosing() bool { 197 return atomic.LoadUint32(&l.closing) == 1 198} 199 200// setClosing sets the closing value to true 201func (l *Conn) setClosing() bool { 202 return atomic.CompareAndSwapUint32(&l.closing, 0, 1) 203} 204 205// Close closes the connection. 206func (l *Conn) Close() { 207 l.messageMutex.Lock() 208 defer l.messageMutex.Unlock() 209 210 if l.setClosing() { 211 l.Debug.Printf("Sending quit message and waiting for confirmation") 212 l.chanMessage <- &messagePacket{Op: MessageQuit} 213 <-l.chanConfirm 214 close(l.chanMessage) 215 216 l.Debug.Printf("Closing network connection") 217 if err := l.conn.Close(); err != nil { 218 log.Println(err) 219 } 220 221 l.wgClose.Done() 222 } 223 l.wgClose.Wait() 224} 225 226// SetTimeout sets the time after a request is sent that a MessageTimeout triggers 227func (l *Conn) SetTimeout(timeout time.Duration) { 228 if timeout > 0 { 229 atomic.StoreInt64(&l.requestTimeout, int64(timeout)) 230 } 231} 232 233// Returns the next available messageID 234func (l *Conn) nextMessageID() int64 { 235 if messageID, ok := <-l.chanMessageID; ok { 236 return messageID 237 } 238 return 0 239} 240 241// StartTLS sends the command to start a TLS session and then creates a new TLS Client 242func (l *Conn) StartTLS(config *tls.Config) error { 243 if l.isTLS { 244 return NewError(ErrorNetwork, errors.New("ldap: already encrypted")) 245 } 246 247 packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") 248 packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) 249 request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS") 250 request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command")) 251 packet.AppendChild(request) 252 l.Debug.PrintPacket(packet) 253 254 msgCtx, err := l.sendMessageWithFlags(packet, startTLS) 255 if err != nil { 256 return err 257 } 258 defer l.finishMessage(msgCtx) 259 260 l.Debug.Printf("%d: waiting for response", msgCtx.id) 261 262 packetResponse, ok := <-msgCtx.responses 263 if !ok { 264 return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) 265 } 266 packet, err = packetResponse.ReadPacket() 267 l.Debug.Printf("%d: got response %p", msgCtx.id, packet) 268 if err != nil { 269 return err 270 } 271 272 if l.Debug { 273 if err := addLDAPDescriptions(packet); err != nil { 274 l.Close() 275 return err 276 } 277 ber.PrintPacket(packet) 278 } 279 280 if err := GetLDAPError(packet); err == nil { 281 conn := tls.Client(l.conn, config) 282 283 if connErr := conn.Handshake(); connErr != nil { 284 l.Close() 285 return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr)) 286 } 287 288 l.isTLS = true 289 l.conn = conn 290 } else { 291 return err 292 } 293 go l.reader() 294 295 return nil 296} 297 298// TLSConnectionState returns the client's TLS connection state. 299// The return values are their zero values if StartTLS did 300// not succeed. 301func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) { 302 tc, ok := l.conn.(*tls.Conn) 303 if !ok { 304 return 305 } 306 return tc.ConnectionState(), true 307} 308 309func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) { 310 return l.sendMessageWithFlags(packet, 0) 311} 312 313func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) { 314 if l.IsClosing() { 315 return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) 316 } 317 l.messageMutex.Lock() 318 l.Debug.Printf("flags&startTLS = %d", flags&startTLS) 319 if l.isStartingTLS { 320 l.messageMutex.Unlock() 321 return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase")) 322 } 323 if flags&startTLS != 0 { 324 if l.outstandingRequests != 0 { 325 l.messageMutex.Unlock() 326 return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests")) 327 } 328 l.isStartingTLS = true 329 } 330 l.outstandingRequests++ 331 332 l.messageMutex.Unlock() 333 334 responses := make(chan *PacketResponse) 335 messageID := packet.Children[0].Value.(int64) 336 message := &messagePacket{ 337 Op: MessageRequest, 338 MessageID: messageID, 339 Packet: packet, 340 Context: &messageContext{ 341 id: messageID, 342 done: make(chan struct{}), 343 responses: responses, 344 }, 345 } 346 l.sendProcessMessage(message) 347 return message.Context, nil 348} 349 350func (l *Conn) finishMessage(msgCtx *messageContext) { 351 close(msgCtx.done) 352 353 if l.IsClosing() { 354 return 355 } 356 357 l.messageMutex.Lock() 358 l.outstandingRequests-- 359 if l.isStartingTLS { 360 l.isStartingTLS = false 361 } 362 l.messageMutex.Unlock() 363 364 message := &messagePacket{ 365 Op: MessageFinish, 366 MessageID: msgCtx.id, 367 } 368 l.sendProcessMessage(message) 369} 370 371func (l *Conn) sendProcessMessage(message *messagePacket) bool { 372 l.messageMutex.Lock() 373 defer l.messageMutex.Unlock() 374 if l.IsClosing() { 375 return false 376 } 377 l.chanMessage <- message 378 return true 379} 380 381func (l *Conn) processMessages() { 382 defer func() { 383 if err := recover(); err != nil { 384 log.Printf("ldap: recovered panic in processMessages: %v", err) 385 } 386 for messageID, msgCtx := range l.messageContexts { 387 // If we are closing due to an error, inform anyone who 388 // is waiting about the error. 389 if l.IsClosing() && l.closeErr.Load() != nil { 390 msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}) 391 } 392 l.Debug.Printf("Closing channel for MessageID %d", messageID) 393 close(msgCtx.responses) 394 delete(l.messageContexts, messageID) 395 } 396 close(l.chanMessageID) 397 close(l.chanConfirm) 398 }() 399 400 var messageID int64 = 1 401 for { 402 select { 403 case l.chanMessageID <- messageID: 404 messageID++ 405 case message := <-l.chanMessage: 406 switch message.Op { 407 case MessageQuit: 408 l.Debug.Printf("Shutting down - quit message received") 409 return 410 case MessageRequest: 411 // Add to message list and write to network 412 l.Debug.Printf("Sending message %d", message.MessageID) 413 414 buf := message.Packet.Bytes() 415 _, err := l.conn.Write(buf) 416 if err != nil { 417 l.Debug.Printf("Error Sending Message: %s", err.Error()) 418 message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}) 419 close(message.Context.responses) 420 break 421 } 422 423 // Only add to messageContexts if we were able to 424 // successfully write the message. 425 l.messageContexts[message.MessageID] = message.Context 426 427 // Add timeout if defined 428 requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout)) 429 if requestTimeout > 0 { 430 go func() { 431 defer func() { 432 if err := recover(); err != nil { 433 log.Printf("ldap: recovered panic in RequestTimeout: %v", err) 434 } 435 }() 436 time.Sleep(requestTimeout) 437 timeoutMessage := &messagePacket{ 438 Op: MessageTimeout, 439 MessageID: message.MessageID, 440 } 441 l.sendProcessMessage(timeoutMessage) 442 }() 443 } 444 case MessageResponse: 445 l.Debug.Printf("Receiving message %d", message.MessageID) 446 if msgCtx, ok := l.messageContexts[message.MessageID]; ok { 447 msgCtx.sendResponse(&PacketResponse{message.Packet, nil}) 448 } else { 449 log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing()) 450 ber.PrintPacket(message.Packet) 451 } 452 case MessageTimeout: 453 // Handle the timeout by closing the channel 454 // All reads will return immediately 455 if msgCtx, ok := l.messageContexts[message.MessageID]; ok { 456 l.Debug.Printf("Receiving message timeout for %d", message.MessageID) 457 msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")}) 458 delete(l.messageContexts, message.MessageID) 459 close(msgCtx.responses) 460 } 461 case MessageFinish: 462 l.Debug.Printf("Finished message %d", message.MessageID) 463 if msgCtx, ok := l.messageContexts[message.MessageID]; ok { 464 delete(l.messageContexts, message.MessageID) 465 close(msgCtx.responses) 466 } 467 } 468 } 469 } 470} 471 472func (l *Conn) reader() { 473 cleanstop := false 474 defer func() { 475 if err := recover(); err != nil { 476 log.Printf("ldap: recovered panic in reader: %v", err) 477 } 478 if !cleanstop { 479 l.Close() 480 } 481 }() 482 483 for { 484 if cleanstop { 485 l.Debug.Printf("reader clean stopping (without closing the connection)") 486 return 487 } 488 packet, err := ber.ReadPacket(l.conn) 489 if err != nil { 490 // A read error is expected here if we are closing the connection... 491 if !l.IsClosing() { 492 l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err)) 493 l.Debug.Printf("reader error: %s", err.Error()) 494 } 495 return 496 } 497 addLDAPDescriptions(packet) 498 if len(packet.Children) == 0 { 499 l.Debug.Printf("Received bad ldap packet") 500 continue 501 } 502 l.messageMutex.Lock() 503 if l.isStartingTLS { 504 cleanstop = true 505 } 506 l.messageMutex.Unlock() 507 message := &messagePacket{ 508 Op: MessageResponse, 509 MessageID: packet.Children[0].Value.(int64), 510 Packet: packet, 511 } 512 if !l.sendProcessMessage(message) { 513 return 514 } 515 } 516} 517