1package ldap
2
3import (
4	"bytes"
5	"errors"
6	"io"
7	"net"
8	"net/http"
9	"net/http/httptest"
10	"runtime"
11	"sync"
12	"testing"
13	"time"
14
15	"gopkg.in/asn1-ber.v1"
16)
17
18func TestUnresponsiveConnection(t *testing.T) {
19	// The do-nothing server that accepts requests and does nothing
20	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
21	}))
22	defer ts.Close()
23	c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
24	if err != nil {
25		t.Fatalf("error connecting to localhost tcp: %v", err)
26	}
27
28	// Create an Ldap connection
29	conn := NewConn(c, false)
30	conn.SetTimeout(time.Millisecond)
31	conn.Start()
32	defer conn.Close()
33
34	// Mock a packet
35	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
36	packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, conn.nextMessageID(), "MessageID"))
37	bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
38	bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
39	packet.AppendChild(bindRequest)
40
41	// Send packet and test response
42	msgCtx, err := conn.sendMessage(packet)
43	if err != nil {
44		t.Fatalf("error sending message: %v", err)
45	}
46	defer conn.finishMessage(msgCtx)
47
48	packetResponse, ok := <-msgCtx.responses
49	if !ok {
50		t.Fatalf("no PacketResponse in response channel")
51	}
52	packet, err = packetResponse.ReadPacket()
53	if err == nil {
54		t.Fatalf("expected timeout error")
55	}
56	if err.Error() != "ldap: connection timed out" {
57		t.Fatalf("unexpected error: %v", err)
58	}
59}
60
61// TestFinishMessage tests that we do not enter deadlock when a goroutine makes
62// a request but does not handle all responses from the server.
63func TestFinishMessage(t *testing.T) {
64	ptc := newPacketTranslatorConn()
65	defer ptc.Close()
66
67	conn := NewConn(ptc, false)
68	conn.Start()
69
70	// Test sending 5 different requests in series. Ensure that we can
71	// get a response packet from the underlying connection and also
72	// ensure that we can gracefully ignore unhandled responses.
73	for i := 0; i < 5; i++ {
74		t.Logf("serial request %d", i)
75		// Create a message and make sure we can receive responses.
76		msgCtx := testSendRequest(t, ptc, conn)
77		testReceiveResponse(t, ptc, msgCtx)
78
79		// Send a few unhandled responses and finish the message.
80		testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5)
81		t.Logf("serial request %d done", i)
82	}
83
84	// Test sending 5 different requests in parallel.
85	var wg sync.WaitGroup
86	for i := 0; i < 5; i++ {
87		wg.Add(1)
88		go func(i int) {
89			defer wg.Done()
90			t.Logf("parallel request %d", i)
91			// Create a message and make sure we can receive responses.
92			msgCtx := testSendRequest(t, ptc, conn)
93			testReceiveResponse(t, ptc, msgCtx)
94
95			// Send a few unhandled responses and finish the message.
96			testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5)
97			t.Logf("parallel request %d done", i)
98		}(i)
99	}
100	wg.Wait()
101
102	// We cannot run Close() in a defer because t.FailNow() will run it and
103	// it will block if the processMessage Loop is in a deadlock.
104	conn.Close()
105}
106
107func testSendRequest(t *testing.T, ptc *packetTranslatorConn, conn *Conn) (msgCtx *messageContext) {
108	var msgID int64
109	runWithTimeout(t, time.Second, func() {
110		msgID = conn.nextMessageID()
111	})
112
113	requestPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
114	requestPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgID, "MessageID"))
115
116	var err error
117
118	runWithTimeout(t, time.Second, func() {
119		msgCtx, err = conn.sendMessage(requestPacket)
120		if err != nil {
121			t.Fatalf("unable to send request message: %s", err)
122		}
123	})
124
125	// We should now be able to get this request packet out from the other
126	// side.
127	runWithTimeout(t, time.Second, func() {
128		if _, err = ptc.ReceiveRequest(); err != nil {
129			t.Fatalf("unable to receive request packet: %s", err)
130		}
131	})
132
133	return msgCtx
134}
135
136func testReceiveResponse(t *testing.T, ptc *packetTranslatorConn, msgCtx *messageContext) {
137	// Send a mock response packet.
138	responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
139	responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID"))
140
141	runWithTimeout(t, time.Second, func() {
142		if err := ptc.SendResponse(responsePacket); err != nil {
143			t.Fatalf("unable to send response packet: %s", err)
144		}
145	})
146
147	// We should be able to receive the packet from the connection.
148	runWithTimeout(t, time.Second, func() {
149		if _, ok := <-msgCtx.responses; !ok {
150			t.Fatal("response channel closed")
151		}
152	})
153}
154
155func testSendUnhandledResponsesAndFinish(t *testing.T, ptc *packetTranslatorConn, conn *Conn, msgCtx *messageContext, numResponses int) {
156	// Send a mock response packet.
157	responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
158	responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID"))
159
160	// Send extra responses but do not attempt to receive them on the
161	// client side.
162	for i := 0; i < numResponses; i++ {
163		runWithTimeout(t, time.Second, func() {
164			if err := ptc.SendResponse(responsePacket); err != nil {
165				t.Fatalf("unable to send response packet: %s", err)
166			}
167		})
168	}
169
170	// Finally, attempt to finish this message.
171	runWithTimeout(t, time.Second, func() {
172		conn.finishMessage(msgCtx)
173	})
174}
175
176func runWithTimeout(t *testing.T, timeout time.Duration, f func()) {
177	done := make(chan struct{})
178	go func() {
179		f()
180		close(done)
181	}()
182
183	select {
184	case <-done: // Success!
185	case <-time.After(timeout):
186		_, file, line, _ := runtime.Caller(1)
187		t.Fatalf("%s:%d timed out", file, line)
188	}
189}
190
191// packetTranslatorConn is a helpful type which can be used with various tests
192// in this package. It implements the net.Conn interface to be used as an
193// underlying connection for a *ldap.Conn. Most methods are no-ops but the
194// Read() and Write() methods are able to translate ber-encoded packets for
195// testing LDAP requests and responses.
196//
197// Test cases can simulate an LDAP server sending a response by calling the
198// SendResponse() method with a ber-encoded LDAP response packet. Test cases
199// can simulate an LDAP server receiving a request from a client by calling the
200// ReceiveRequest() method which returns a ber-encoded LDAP request packet.
201type packetTranslatorConn struct {
202	lock     sync.Mutex
203	isClosed bool
204
205	responseCond sync.Cond
206	requestCond  sync.Cond
207
208	responseBuf bytes.Buffer
209	requestBuf  bytes.Buffer
210}
211
212var errPacketTranslatorConnClosed = errors.New("connection closed")
213
214func newPacketTranslatorConn() *packetTranslatorConn {
215	conn := &packetTranslatorConn{}
216	conn.responseCond = sync.Cond{L: &conn.lock}
217	conn.requestCond = sync.Cond{L: &conn.lock}
218
219	return conn
220}
221
222// Read is called by the reader() loop to receive response packets. It will
223// block until there are more packet bytes available or this connection is
224// closed.
225func (c *packetTranslatorConn) Read(b []byte) (n int, err error) {
226	c.lock.Lock()
227	defer c.lock.Unlock()
228
229	for !c.isClosed {
230		// Attempt to read data from the response buffer. If it fails
231		// with an EOF, wait and try again.
232		n, err = c.responseBuf.Read(b)
233		if err != io.EOF {
234			return n, err
235		}
236
237		c.responseCond.Wait()
238	}
239
240	return 0, errPacketTranslatorConnClosed
241}
242
243// SendResponse writes the given response packet to the response buffer for
244// this connection, signalling any goroutine waiting to read a response.
245func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error {
246	c.lock.Lock()
247	defer c.lock.Unlock()
248
249	if c.isClosed {
250		return errPacketTranslatorConnClosed
251	}
252
253	// Signal any goroutine waiting to read a response.
254	defer c.responseCond.Broadcast()
255
256	// Writes to the buffer should always succeed.
257	c.responseBuf.Write(packet.Bytes())
258
259	return nil
260}
261
262// Write is called by the processMessages() loop to send request packets.
263func (c *packetTranslatorConn) Write(b []byte) (n int, err error) {
264	c.lock.Lock()
265	defer c.lock.Unlock()
266
267	if c.isClosed {
268		return 0, errPacketTranslatorConnClosed
269	}
270
271	// Signal any goroutine waiting to read a request.
272	defer c.requestCond.Broadcast()
273
274	// Writes to the buffer should always succeed.
275	return c.requestBuf.Write(b)
276}
277
278// ReceiveRequest attempts to read a request packet from this connection. It
279// will block until it is able to read a full request packet or until this
280// connection is closed.
281func (c *packetTranslatorConn) ReceiveRequest() (*ber.Packet, error) {
282	c.lock.Lock()
283	defer c.lock.Unlock()
284
285	for !c.isClosed {
286		// Attempt to parse a request packet from the request buffer.
287		// If it fails with an unexpected EOF, wait and try again.
288		requestReader := bytes.NewReader(c.requestBuf.Bytes())
289		packet, err := ber.ReadPacket(requestReader)
290		switch err {
291		case io.EOF, io.ErrUnexpectedEOF:
292			c.requestCond.Wait()
293		case nil:
294			// Advance the request buffer by the number of bytes
295			// read to decode the request packet.
296			c.requestBuf.Next(c.requestBuf.Len() - requestReader.Len())
297			return packet, nil
298		default:
299			return nil, err
300		}
301	}
302
303	return nil, errPacketTranslatorConnClosed
304}
305
306// Close closes this connection causing Read() and Write() calls to fail.
307func (c *packetTranslatorConn) Close() error {
308	c.lock.Lock()
309	defer c.lock.Unlock()
310
311	c.isClosed = true
312	c.responseCond.Broadcast()
313	c.requestCond.Broadcast()
314
315	return nil
316}
317
318func (c *packetTranslatorConn) LocalAddr() net.Addr {
319	return (*net.TCPAddr)(nil)
320}
321
322func (c *packetTranslatorConn) RemoteAddr() net.Addr {
323	return (*net.TCPAddr)(nil)
324}
325
326func (c *packetTranslatorConn) SetDeadline(t time.Time) error {
327	return nil
328}
329
330func (c *packetTranslatorConn) SetReadDeadline(t time.Time) error {
331	return nil
332}
333
334func (c *packetTranslatorConn) SetWriteDeadline(t time.Time) error {
335	return nil
336}
337