1package ldap
2
3import (
4	"crypto/tls"
5	"errors"
6	"fmt"
7	"log"
8	"net"
9	"net/url"
10	"sync"
11	"sync/atomic"
12	"time"
13
14	"gopkg.in/asn1-ber.v1"
15)
16
17const (
18	// MessageQuit causes the processMessages loop to exit
19	MessageQuit = 0
20	// MessageRequest sends a request to the server
21	MessageRequest = 1
22	// MessageResponse receives a response from the server
23	MessageResponse = 2
24	// MessageFinish indicates the client considers a particular message ID to be finished
25	MessageFinish = 3
26	// MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
27	MessageTimeout = 4
28)
29
30const (
31	// DefaultLdapPort default ldap port for pure TCP connection
32	DefaultLdapPort = "389"
33	// DefaultLdapsPort default ldap port for SSL connection
34	DefaultLdapsPort = "636"
35)
36
37// PacketResponse contains the packet or error encountered reading a response
38type PacketResponse struct {
39	// Packet is the packet read from the server
40	Packet *ber.Packet
41	// Error is an error encountered while reading
42	Error error
43}
44
45// ReadPacket returns the packet or an error
46func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
47	if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
48		return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
49	}
50	return pr.Packet, pr.Error
51}
52
53type messageContext struct {
54	id int64
55	// close(done) should only be called from finishMessage()
56	done chan struct{}
57	// close(responses) should only be called from processMessages(), and only sent to from sendResponse()
58	responses chan *PacketResponse
59}
60
61// sendResponse should only be called within the processMessages() loop which
62// is also responsible for closing the responses channel.
63func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
64	select {
65	case msgCtx.responses <- packet:
66		// Successfully sent packet to message handler.
67	case <-msgCtx.done:
68		// The request handler is done and will not receive more
69		// packets.
70	}
71}
72
73type messagePacket struct {
74	Op        int
75	MessageID int64
76	Packet    *ber.Packet
77	Context   *messageContext
78}
79
80type sendMessageFlags uint
81
82const (
83	startTLS sendMessageFlags = 1 << iota
84)
85
86// Conn represents an LDAP Connection
87type Conn struct {
88	// requestTimeout is loaded atomically
89	// so we need to ensure 64-bit alignment on 32-bit platforms.
90	requestTimeout      int64
91	conn                net.Conn
92	isTLS               bool
93	closing             uint32
94	closeErr            atomic.Value
95	isStartingTLS       bool
96	Debug               debugging
97	chanConfirm         chan struct{}
98	messageContexts     map[int64]*messageContext
99	chanMessage         chan *messagePacket
100	chanMessageID       chan int64
101	wgClose             sync.WaitGroup
102	outstandingRequests uint
103	messageMutex        sync.Mutex
104}
105
106var _ Client = &Conn{}
107
108// DefaultTimeout is a package-level variable that sets the timeout value
109// used for the Dial and DialTLS methods.
110//
111// WARNING: since this is a package-level variable, setting this value from
112// multiple places will probably result in undesired behaviour.
113var DefaultTimeout = 60 * time.Second
114
115// Dial connects to the given address on the given network using net.Dial
116// and then returns a new Conn for the connection.
117func Dial(network, addr string) (*Conn, error) {
118	c, err := net.DialTimeout(network, addr, DefaultTimeout)
119	if err != nil {
120		return nil, NewError(ErrorNetwork, err)
121	}
122	conn := NewConn(c, false)
123	conn.Start()
124	return conn, nil
125}
126
127// DialTLS connects to the given address on the given network using tls.Dial
128// and then returns a new Conn for the connection.
129func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
130	c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config)
131	if err != nil {
132		return nil, NewError(ErrorNetwork, err)
133	}
134	conn := NewConn(c, true)
135	conn.Start()
136	return conn, nil
137}
138
139// DialURL connects to the given ldap URL vie TCP using tls.Dial or net.Dial if ldaps://
140// or ldap:// specified as protocol. On success a new Conn for the connection
141// is returned.
142func DialURL(addr string) (*Conn, error) {
143
144	lurl, err := url.Parse(addr)
145	if err != nil {
146		return nil, NewError(ErrorNetwork, err)
147	}
148
149	host, port, err := net.SplitHostPort(lurl.Host)
150	if err != nil {
151		// we asume that error is due to missing port
152		host = lurl.Host
153		port = ""
154	}
155
156	switch lurl.Scheme {
157	case "ldap":
158		if port == "" {
159			port = DefaultLdapPort
160		}
161		return Dial("tcp", net.JoinHostPort(host, port))
162	case "ldaps":
163		if port == "" {
164			port = DefaultLdapsPort
165		}
166		tlsConf := &tls.Config{
167			ServerName: host,
168		}
169		return DialTLS("tcp", net.JoinHostPort(host, port), tlsConf)
170	}
171
172	return nil, NewError(ErrorNetwork, fmt.Errorf("Unknown scheme '%s'", lurl.Scheme))
173}
174
175// NewConn returns a new Conn using conn for network I/O.
176func NewConn(conn net.Conn, isTLS bool) *Conn {
177	return &Conn{
178		conn:            conn,
179		chanConfirm:     make(chan struct{}),
180		chanMessageID:   make(chan int64),
181		chanMessage:     make(chan *messagePacket, 10),
182		messageContexts: map[int64]*messageContext{},
183		requestTimeout:  0,
184		isTLS:           isTLS,
185	}
186}
187
188// Start initializes goroutines to read responses and process messages
189func (l *Conn) Start() {
190	go l.reader()
191	go l.processMessages()
192	l.wgClose.Add(1)
193}
194
195// IsClosing returns whether or not we're currently closing.
196func (l *Conn) IsClosing() bool {
197	return atomic.LoadUint32(&l.closing) == 1
198}
199
200// setClosing sets the closing value to true
201func (l *Conn) setClosing() bool {
202	return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
203}
204
205// Close closes the connection.
206func (l *Conn) Close() {
207	l.messageMutex.Lock()
208	defer l.messageMutex.Unlock()
209
210	if l.setClosing() {
211		l.Debug.Printf("Sending quit message and waiting for confirmation")
212		l.chanMessage <- &messagePacket{Op: MessageQuit}
213		<-l.chanConfirm
214		close(l.chanMessage)
215
216		l.Debug.Printf("Closing network connection")
217		if err := l.conn.Close(); err != nil {
218			log.Println(err)
219		}
220
221		l.wgClose.Done()
222	}
223	l.wgClose.Wait()
224}
225
226// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
227func (l *Conn) SetTimeout(timeout time.Duration) {
228	if timeout > 0 {
229		atomic.StoreInt64(&l.requestTimeout, int64(timeout))
230	}
231}
232
233// Returns the next available messageID
234func (l *Conn) nextMessageID() int64 {
235	if messageID, ok := <-l.chanMessageID; ok {
236		return messageID
237	}
238	return 0
239}
240
241// StartTLS sends the command to start a TLS session and then creates a new TLS Client
242func (l *Conn) StartTLS(config *tls.Config) error {
243	if l.isTLS {
244		return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
245	}
246
247	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
248	packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
249	request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
250	request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
251	packet.AppendChild(request)
252	l.Debug.PrintPacket(packet)
253
254	msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
255	if err != nil {
256		return err
257	}
258	defer l.finishMessage(msgCtx)
259
260	l.Debug.Printf("%d: waiting for response", msgCtx.id)
261
262	packetResponse, ok := <-msgCtx.responses
263	if !ok {
264		return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
265	}
266	packet, err = packetResponse.ReadPacket()
267	l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
268	if err != nil {
269		return err
270	}
271
272	if l.Debug {
273		if err := addLDAPDescriptions(packet); err != nil {
274			l.Close()
275			return err
276		}
277		ber.PrintPacket(packet)
278	}
279
280	if err := GetLDAPError(packet); err == nil {
281		conn := tls.Client(l.conn, config)
282
283		if connErr := conn.Handshake(); connErr != nil {
284			l.Close()
285			return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr))
286		}
287
288		l.isTLS = true
289		l.conn = conn
290	} else {
291		return err
292	}
293	go l.reader()
294
295	return nil
296}
297
298// TLSConnectionState returns the client's TLS connection state.
299// The return values are their zero values if StartTLS did
300// not succeed.
301func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) {
302	tc, ok := l.conn.(*tls.Conn)
303	if !ok {
304		return
305	}
306	return tc.ConnectionState(), true
307}
308
309func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
310	return l.sendMessageWithFlags(packet, 0)
311}
312
313func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
314	if l.IsClosing() {
315		return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
316	}
317	l.messageMutex.Lock()
318	l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
319	if l.isStartingTLS {
320		l.messageMutex.Unlock()
321		return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
322	}
323	if flags&startTLS != 0 {
324		if l.outstandingRequests != 0 {
325			l.messageMutex.Unlock()
326			return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
327		}
328		l.isStartingTLS = true
329	}
330	l.outstandingRequests++
331
332	l.messageMutex.Unlock()
333
334	responses := make(chan *PacketResponse)
335	messageID := packet.Children[0].Value.(int64)
336	message := &messagePacket{
337		Op:        MessageRequest,
338		MessageID: messageID,
339		Packet:    packet,
340		Context: &messageContext{
341			id:        messageID,
342			done:      make(chan struct{}),
343			responses: responses,
344		},
345	}
346	l.sendProcessMessage(message)
347	return message.Context, nil
348}
349
350func (l *Conn) finishMessage(msgCtx *messageContext) {
351	close(msgCtx.done)
352
353	if l.IsClosing() {
354		return
355	}
356
357	l.messageMutex.Lock()
358	l.outstandingRequests--
359	if l.isStartingTLS {
360		l.isStartingTLS = false
361	}
362	l.messageMutex.Unlock()
363
364	message := &messagePacket{
365		Op:        MessageFinish,
366		MessageID: msgCtx.id,
367	}
368	l.sendProcessMessage(message)
369}
370
371func (l *Conn) sendProcessMessage(message *messagePacket) bool {
372	l.messageMutex.Lock()
373	defer l.messageMutex.Unlock()
374	if l.IsClosing() {
375		return false
376	}
377	l.chanMessage <- message
378	return true
379}
380
381func (l *Conn) processMessages() {
382	defer func() {
383		if err := recover(); err != nil {
384			log.Printf("ldap: recovered panic in processMessages: %v", err)
385		}
386		for messageID, msgCtx := range l.messageContexts {
387			// If we are closing due to an error, inform anyone who
388			// is waiting about the error.
389			if l.IsClosing() && l.closeErr.Load() != nil {
390				msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
391			}
392			l.Debug.Printf("Closing channel for MessageID %d", messageID)
393			close(msgCtx.responses)
394			delete(l.messageContexts, messageID)
395		}
396		close(l.chanMessageID)
397		close(l.chanConfirm)
398	}()
399
400	var messageID int64 = 1
401	for {
402		select {
403		case l.chanMessageID <- messageID:
404			messageID++
405		case message := <-l.chanMessage:
406			switch message.Op {
407			case MessageQuit:
408				l.Debug.Printf("Shutting down - quit message received")
409				return
410			case MessageRequest:
411				// Add to message list and write to network
412				l.Debug.Printf("Sending message %d", message.MessageID)
413
414				buf := message.Packet.Bytes()
415				_, err := l.conn.Write(buf)
416				if err != nil {
417					l.Debug.Printf("Error Sending Message: %s", err.Error())
418					message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
419					close(message.Context.responses)
420					break
421				}
422
423				// Only add to messageContexts if we were able to
424				// successfully write the message.
425				l.messageContexts[message.MessageID] = message.Context
426
427				// Add timeout if defined
428				requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
429				if requestTimeout > 0 {
430					go func() {
431						defer func() {
432							if err := recover(); err != nil {
433								log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
434							}
435						}()
436						time.Sleep(requestTimeout)
437						timeoutMessage := &messagePacket{
438							Op:        MessageTimeout,
439							MessageID: message.MessageID,
440						}
441						l.sendProcessMessage(timeoutMessage)
442					}()
443				}
444			case MessageResponse:
445				l.Debug.Printf("Receiving message %d", message.MessageID)
446				if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
447					msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
448				} else {
449					log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing())
450					ber.PrintPacket(message.Packet)
451				}
452			case MessageTimeout:
453				// Handle the timeout by closing the channel
454				// All reads will return immediately
455				if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
456					l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
457					msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")})
458					delete(l.messageContexts, message.MessageID)
459					close(msgCtx.responses)
460				}
461			case MessageFinish:
462				l.Debug.Printf("Finished message %d", message.MessageID)
463				if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
464					delete(l.messageContexts, message.MessageID)
465					close(msgCtx.responses)
466				}
467			}
468		}
469	}
470}
471
472func (l *Conn) reader() {
473	cleanstop := false
474	defer func() {
475		if err := recover(); err != nil {
476			log.Printf("ldap: recovered panic in reader: %v", err)
477		}
478		if !cleanstop {
479			l.Close()
480		}
481	}()
482
483	for {
484		if cleanstop {
485			l.Debug.Printf("reader clean stopping (without closing the connection)")
486			return
487		}
488		packet, err := ber.ReadPacket(l.conn)
489		if err != nil {
490			// A read error is expected here if we are closing the connection...
491			if !l.IsClosing() {
492				l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
493				l.Debug.Printf("reader error: %s", err.Error())
494			}
495			return
496		}
497		addLDAPDescriptions(packet)
498		if len(packet.Children) == 0 {
499			l.Debug.Printf("Received bad ldap packet")
500			continue
501		}
502		l.messageMutex.Lock()
503		if l.isStartingTLS {
504			cleanstop = true
505		}
506		l.messageMutex.Unlock()
507		message := &messagePacket{
508			Op:        MessageResponse,
509			MessageID: packet.Children[0].Value.(int64),
510			Packet:    packet,
511		}
512		if !l.sendProcessMessage(message) {
513			return
514		}
515	}
516}
517