1package ldap 2 3import ( 4 "bytes" 5 "errors" 6 "io" 7 "net" 8 "net/http" 9 "net/http/httptest" 10 "runtime" 11 "sync" 12 "testing" 13 "time" 14 15 "gopkg.in/asn1-ber.v1" 16) 17 18func TestUnresponsiveConnection(t *testing.T) { 19 // The do-nothing server that accepts requests and does nothing 20 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 21 })) 22 defer ts.Close() 23 c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String()) 24 if err != nil { 25 t.Fatalf("error connecting to localhost tcp: %v", err) 26 } 27 28 // Create an Ldap connection 29 conn := NewConn(c, false) 30 conn.SetTimeout(time.Millisecond) 31 conn.Start() 32 defer conn.Close() 33 34 // Mock a packet 35 packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") 36 packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, conn.nextMessageID(), "MessageID")) 37 bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") 38 bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) 39 packet.AppendChild(bindRequest) 40 41 // Send packet and test response 42 msgCtx, err := conn.sendMessage(packet) 43 if err != nil { 44 t.Fatalf("error sending message: %v", err) 45 } 46 defer conn.finishMessage(msgCtx) 47 48 packetResponse, ok := <-msgCtx.responses 49 if !ok { 50 t.Fatalf("no PacketResponse in response channel") 51 } 52 packet, err = packetResponse.ReadPacket() 53 if err == nil { 54 t.Fatalf("expected timeout error") 55 } 56 if err.Error() != "ldap: connection timed out" { 57 t.Fatalf("unexpected error: %v", err) 58 } 59} 60 61// TestFinishMessage tests that we do not enter deadlock when a goroutine makes 62// a request but does not handle all responses from the server. 63func TestFinishMessage(t *testing.T) { 64 ptc := newPacketTranslatorConn() 65 defer ptc.Close() 66 67 conn := NewConn(ptc, false) 68 conn.Start() 69 70 // Test sending 5 different requests in series. Ensure that we can 71 // get a response packet from the underlying connection and also 72 // ensure that we can gracefully ignore unhandled responses. 73 for i := 0; i < 5; i++ { 74 t.Logf("serial request %d", i) 75 // Create a message and make sure we can receive responses. 76 msgCtx := testSendRequest(t, ptc, conn) 77 testReceiveResponse(t, ptc, msgCtx) 78 79 // Send a few unhandled responses and finish the message. 80 testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5) 81 t.Logf("serial request %d done", i) 82 } 83 84 // Test sending 5 different requests in parallel. 85 var wg sync.WaitGroup 86 for i := 0; i < 5; i++ { 87 wg.Add(1) 88 go func(i int) { 89 defer wg.Done() 90 t.Logf("parallel request %d", i) 91 // Create a message and make sure we can receive responses. 92 msgCtx := testSendRequest(t, ptc, conn) 93 testReceiveResponse(t, ptc, msgCtx) 94 95 // Send a few unhandled responses and finish the message. 96 testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5) 97 t.Logf("parallel request %d done", i) 98 }(i) 99 } 100 wg.Wait() 101 102 // We cannot run Close() in a defer because t.FailNow() will run it and 103 // it will block if the processMessage Loop is in a deadlock. 104 conn.Close() 105} 106 107func testSendRequest(t *testing.T, ptc *packetTranslatorConn, conn *Conn) (msgCtx *messageContext) { 108 var msgID int64 109 runWithTimeout(t, time.Second, func() { 110 msgID = conn.nextMessageID() 111 }) 112 113 requestPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") 114 requestPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgID, "MessageID")) 115 116 var err error 117 118 runWithTimeout(t, time.Second, func() { 119 msgCtx, err = conn.sendMessage(requestPacket) 120 if err != nil { 121 t.Fatalf("unable to send request message: %s", err) 122 } 123 }) 124 125 // We should now be able to get this request packet out from the other 126 // side. 127 runWithTimeout(t, time.Second, func() { 128 if _, err = ptc.ReceiveRequest(); err != nil { 129 t.Fatalf("unable to receive request packet: %s", err) 130 } 131 }) 132 133 return msgCtx 134} 135 136func testReceiveResponse(t *testing.T, ptc *packetTranslatorConn, msgCtx *messageContext) { 137 // Send a mock response packet. 138 responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") 139 responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID")) 140 141 runWithTimeout(t, time.Second, func() { 142 if err := ptc.SendResponse(responsePacket); err != nil { 143 t.Fatalf("unable to send response packet: %s", err) 144 } 145 }) 146 147 // We should be able to receive the packet from the connection. 148 runWithTimeout(t, time.Second, func() { 149 if _, ok := <-msgCtx.responses; !ok { 150 t.Fatal("response channel closed") 151 } 152 }) 153} 154 155func testSendUnhandledResponsesAndFinish(t *testing.T, ptc *packetTranslatorConn, conn *Conn, msgCtx *messageContext, numResponses int) { 156 // Send a mock response packet. 157 responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") 158 responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID")) 159 160 // Send extra responses but do not attempt to receive them on the 161 // client side. 162 for i := 0; i < numResponses; i++ { 163 runWithTimeout(t, time.Second, func() { 164 if err := ptc.SendResponse(responsePacket); err != nil { 165 t.Fatalf("unable to send response packet: %s", err) 166 } 167 }) 168 } 169 170 // Finally, attempt to finish this message. 171 runWithTimeout(t, time.Second, func() { 172 conn.finishMessage(msgCtx) 173 }) 174} 175 176func runWithTimeout(t *testing.T, timeout time.Duration, f func()) { 177 done := make(chan struct{}) 178 go func() { 179 f() 180 close(done) 181 }() 182 183 select { 184 case <-done: // Success! 185 case <-time.After(timeout): 186 _, file, line, _ := runtime.Caller(1) 187 t.Fatalf("%s:%d timed out", file, line) 188 } 189} 190 191// packetTranslatorConn is a helpful type which can be used with various tests 192// in this package. It implements the net.Conn interface to be used as an 193// underlying connection for a *ldap.Conn. Most methods are no-ops but the 194// Read() and Write() methods are able to translate ber-encoded packets for 195// testing LDAP requests and responses. 196// 197// Test cases can simulate an LDAP server sending a response by calling the 198// SendResponse() method with a ber-encoded LDAP response packet. Test cases 199// can simulate an LDAP server receiving a request from a client by calling the 200// ReceiveRequest() method which returns a ber-encoded LDAP request packet. 201type packetTranslatorConn struct { 202 lock sync.Mutex 203 isClosed bool 204 205 responseCond sync.Cond 206 requestCond sync.Cond 207 208 responseBuf bytes.Buffer 209 requestBuf bytes.Buffer 210} 211 212var errPacketTranslatorConnClosed = errors.New("connection closed") 213 214func newPacketTranslatorConn() *packetTranslatorConn { 215 conn := &packetTranslatorConn{} 216 conn.responseCond = sync.Cond{L: &conn.lock} 217 conn.requestCond = sync.Cond{L: &conn.lock} 218 219 return conn 220} 221 222// Read is called by the reader() loop to receive response packets. It will 223// block until there are more packet bytes available or this connection is 224// closed. 225func (c *packetTranslatorConn) Read(b []byte) (n int, err error) { 226 c.lock.Lock() 227 defer c.lock.Unlock() 228 229 for !c.isClosed { 230 // Attempt to read data from the response buffer. If it fails 231 // with an EOF, wait and try again. 232 n, err = c.responseBuf.Read(b) 233 if err != io.EOF { 234 return n, err 235 } 236 237 c.responseCond.Wait() 238 } 239 240 return 0, errPacketTranslatorConnClosed 241} 242 243// SendResponse writes the given response packet to the response buffer for 244// this connection, signalling any goroutine waiting to read a response. 245func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error { 246 c.lock.Lock() 247 defer c.lock.Unlock() 248 249 if c.isClosed { 250 return errPacketTranslatorConnClosed 251 } 252 253 // Signal any goroutine waiting to read a response. 254 defer c.responseCond.Broadcast() 255 256 // Writes to the buffer should always succeed. 257 c.responseBuf.Write(packet.Bytes()) 258 259 return nil 260} 261 262// Write is called by the processMessages() loop to send request packets. 263func (c *packetTranslatorConn) Write(b []byte) (n int, err error) { 264 c.lock.Lock() 265 defer c.lock.Unlock() 266 267 if c.isClosed { 268 return 0, errPacketTranslatorConnClosed 269 } 270 271 // Signal any goroutine waiting to read a request. 272 defer c.requestCond.Broadcast() 273 274 // Writes to the buffer should always succeed. 275 return c.requestBuf.Write(b) 276} 277 278// ReceiveRequest attempts to read a request packet from this connection. It 279// will block until it is able to read a full request packet or until this 280// connection is closed. 281func (c *packetTranslatorConn) ReceiveRequest() (*ber.Packet, error) { 282 c.lock.Lock() 283 defer c.lock.Unlock() 284 285 for !c.isClosed { 286 // Attempt to parse a request packet from the request buffer. 287 // If it fails with an unexpected EOF, wait and try again. 288 requestReader := bytes.NewReader(c.requestBuf.Bytes()) 289 packet, err := ber.ReadPacket(requestReader) 290 switch err { 291 case io.EOF, io.ErrUnexpectedEOF: 292 c.requestCond.Wait() 293 case nil: 294 // Advance the request buffer by the number of bytes 295 // read to decode the request packet. 296 c.requestBuf.Next(c.requestBuf.Len() - requestReader.Len()) 297 return packet, nil 298 default: 299 return nil, err 300 } 301 } 302 303 return nil, errPacketTranslatorConnClosed 304} 305 306// Close closes this connection causing Read() and Write() calls to fail. 307func (c *packetTranslatorConn) Close() error { 308 c.lock.Lock() 309 defer c.lock.Unlock() 310 311 c.isClosed = true 312 c.responseCond.Broadcast() 313 c.requestCond.Broadcast() 314 315 return nil 316} 317 318func (c *packetTranslatorConn) LocalAddr() net.Addr { 319 return (*net.TCPAddr)(nil) 320} 321 322func (c *packetTranslatorConn) RemoteAddr() net.Addr { 323 return (*net.TCPAddr)(nil) 324} 325 326func (c *packetTranslatorConn) SetDeadline(t time.Time) error { 327 return nil 328} 329 330func (c *packetTranslatorConn) SetReadDeadline(t time.Time) error { 331 return nil 332} 333 334func (c *packetTranslatorConn) SetWriteDeadline(t time.Time) error { 335 return nil 336} 337