1package ldap
2
3import (
4	"crypto/tls"
5
6	"gopkg.in/ldap.v3"
7)
8
9type searchFunc = func(request *ldap.SearchRequest) (*ldap.SearchResult, error)
10
11// MockConnection struct for testing
12type MockConnection struct {
13	SearchFunc       searchFunc
14	SearchCalled     bool
15	SearchAttributes []string
16
17	AddParams *ldap.AddRequest
18	AddCalled bool
19
20	DelParams *ldap.DelRequest
21	DelCalled bool
22
23	CloseCalled bool
24
25	UnauthenticatedBindCalled bool
26	BindCalled                bool
27
28	BindProvider                func(username, password string) error
29	UnauthenticatedBindProvider func() error
30}
31
32// Bind mocks Bind connection function
33func (c *MockConnection) Bind(username, password string) error {
34	c.BindCalled = true
35
36	if c.BindProvider != nil {
37		return c.BindProvider(username, password)
38	}
39
40	return nil
41}
42
43// UnauthenticatedBind mocks UnauthenticatedBind connection function
44func (c *MockConnection) UnauthenticatedBind(username string) error {
45	c.UnauthenticatedBindCalled = true
46
47	if c.UnauthenticatedBindProvider != nil {
48		return c.UnauthenticatedBindProvider()
49	}
50
51	return nil
52}
53
54// Close mocks Close connection function
55func (c *MockConnection) Close() {
56	c.CloseCalled = true
57}
58
59func (c *MockConnection) setSearchResult(result *ldap.SearchResult) {
60	c.SearchFunc = func(request *ldap.SearchRequest) (*ldap.SearchResult, error) {
61		return result, nil
62	}
63}
64
65func (c *MockConnection) setSearchError(err error) {
66	c.SearchFunc = func(request *ldap.SearchRequest) (*ldap.SearchResult, error) {
67		return nil, err
68	}
69}
70
71func (c *MockConnection) setSearchFunc(fn searchFunc) {
72	c.SearchFunc = fn
73}
74
75// Search mocks Search connection function
76func (c *MockConnection) Search(sr *ldap.SearchRequest) (*ldap.SearchResult, error) {
77	c.SearchCalled = true
78	c.SearchAttributes = sr.Attributes
79
80	return c.SearchFunc(sr)
81}
82
83// Add mocks Add connection function
84func (c *MockConnection) Add(request *ldap.AddRequest) error {
85	c.AddCalled = true
86	c.AddParams = request
87	return nil
88}
89
90// Del mocks Del connection function
91func (c *MockConnection) Del(request *ldap.DelRequest) error {
92	c.DelCalled = true
93	c.DelParams = request
94	return nil
95}
96
97// StartTLS mocks StartTLS connection function
98func (c *MockConnection) StartTLS(*tls.Config) error {
99	return nil
100}
101