1package xmpp
2
3import (
4	"encoding/hex"
5	"fmt"
6	"io"
7	"net"
8	"sync"
9	"time"
10
11	"gopkg.in/check.v1"
12)
13
14type mockConn struct {
15	calledClose int
16	net.TCPConn
17}
18
19func (c *mockConn) Close() error {
20	c.calledClose++
21	return nil
22}
23
24type mockConnIOReaderWriter struct {
25	read      []byte
26	readIndex int
27	write     []byte
28	errCount  int
29	err       error
30
31	calledClose int
32
33	lock sync.Mutex
34}
35
36func (iom *mockConnIOReaderWriter) CalledClose() bool {
37	iom.lock.Lock()
38	defer iom.lock.Unlock()
39	return iom.calledClose > 0
40}
41
42func (iom *mockConnIOReaderWriter) Written() []byte {
43	iom.lock.Lock()
44	defer iom.lock.Unlock()
45
46	var res []byte
47	l := len(iom.write)
48	res = make([]byte, l, l)
49	copy(res, iom.write)
50	return res
51}
52
53func (iom *mockConnIOReaderWriter) Read(p []byte) (n int, err error) {
54	iom.lock.Lock()
55	defer iom.lock.Unlock()
56
57	if iom.readIndex >= len(iom.read) {
58		return 0, io.EOF
59	}
60	i := copy(p, iom.read[iom.readIndex:])
61	iom.readIndex += i
62	var e error
63	if iom.errCount == 0 {
64		e = iom.err
65	}
66	iom.errCount--
67	return i, e
68}
69
70func (iom *mockConnIOReaderWriter) Write(p []byte) (n int, err error) {
71	iom.lock.Lock()
72	defer iom.lock.Unlock()
73
74	iom.write = append(iom.write, p...)
75	var e error
76	if iom.errCount == 0 {
77		e = iom.err
78	}
79	iom.errCount--
80	return len(p), e
81}
82
83func (iom *mockConnIOReaderWriter) Close() error {
84	iom.lock.Lock()
85	defer iom.lock.Unlock()
86
87	iom.calledClose++
88	return nil
89}
90
91type mockMultiConnIOReaderWriter struct {
92	read      [][]byte
93	readIndex int
94	write     []byte
95}
96
97func (iom *mockMultiConnIOReaderWriter) Read(p []byte) (n int, err error) {
98	if iom.readIndex >= len(iom.read) {
99		return 0, io.EOF
100	}
101	i := copy(p, iom.read[iom.readIndex])
102	iom.readIndex++
103	return i, nil
104}
105
106func (iom *mockMultiConnIOReaderWriter) Write(p []byte) (n int, err error) {
107	iom.write = append(iom.write, p...)
108	return len(p), nil
109}
110
111type fullMockedConn struct {
112	rw io.ReadWriter
113}
114
115func (c *fullMockedConn) Read(b []byte) (n int, err error) {
116	return c.rw.Read(b)
117}
118
119func (c *fullMockedConn) Write(b []byte) (n int, err error) {
120	return c.rw.Write(b)
121}
122
123func (c *fullMockedConn) Close() error {
124	return nil
125}
126
127func (c *fullMockedConn) LocalAddr() net.Addr {
128	return nil
129}
130
131func (c *fullMockedConn) RemoteAddr() net.Addr {
132	return nil
133}
134
135func (c *fullMockedConn) SetDeadline(t time.Time) error {
136	return nil
137}
138
139func (c *fullMockedConn) SetReadDeadline(t time.Time) error {
140	return nil
141}
142
143func (c *fullMockedConn) SetWriteDeadline(t time.Time) error {
144	return nil
145}
146
147type fixedRandReader struct {
148	data []string
149	at   int
150}
151
152func fixedRand(data []string) io.Reader {
153	return &fixedRandReader{data, 0}
154}
155
156func bytesFromHex(s string) []byte {
157	val, _ := hex.DecodeString(s)
158	return val
159}
160
161func byteStringFromHex(s string) string {
162	val, _ := hex.DecodeString(s)
163	return string(val)
164}
165
166func (frr *fixedRandReader) Read(p []byte) (n int, err error) {
167	if frr.at < len(frr.data) {
168		plainBytes := bytesFromHex(frr.data[frr.at])
169		frr.at++
170		n = copy(p, plainBytes)
171		return
172	}
173	return 0, io.EOF
174}
175
176func createTeeConn(c net.Conn, w io.Writer) net.Conn {
177	return &teeConn{c, w}
178}
179
180type teeConn struct {
181	c net.Conn
182	w io.Writer
183}
184
185func (c *teeConn) Read(b []byte) (n int, err error) {
186	n, err = c.c.Read(b)
187	if n > 0 {
188		fmt.Fprintf(c.w, "READ: %x\n", b[:n])
189	}
190	return
191}
192
193func (c *teeConn) Write(b []byte) (n int, err error) {
194	n, err = c.c.Write(b)
195	if n > 0 {
196		fmt.Fprintf(c.w, "WRITE: %x\n", b[:n])
197	}
198	return n, err
199}
200
201func (c *teeConn) Close() error {
202	return c.c.Close()
203}
204
205func (c *teeConn) LocalAddr() net.Addr {
206	return c.c.LocalAddr()
207}
208
209func (c *teeConn) RemoteAddr() net.Addr {
210	return c.c.RemoteAddr()
211}
212
213func (c *teeConn) SetDeadline(t time.Time) error {
214	return c.c.SetDeadline(t)
215}
216
217func (c *teeConn) SetReadDeadline(t time.Time) error {
218	return c.c.SetReadDeadline(t)
219}
220
221func (c *teeConn) SetWriteDeadline(t time.Time) error {
222	return c.c.SetWriteDeadline(t)
223}
224
225type dialCall func(string, string) (c net.Conn, e error)
226type dialCallExp struct {
227	f      dialCall
228	called bool
229}
230
231type mockProxy struct {
232	called int
233	calls  []dialCallExp
234	sync.Mutex
235}
236
237func (p *mockProxy) Dial(network, addr string) (net.Conn, error) {
238	if len(p.calls)-1 < p.called {
239		return nil, fmt.Errorf("unexpected call to Dial: %s, %s", network, addr)
240	}
241
242	p.Lock()
243	defer p.Unlock()
244
245	fn := p.calls[p.called]
246	p.called = p.called + 1
247
248	fn.called = true
249	return fn.f(network, addr)
250}
251
252func (p *mockProxy) Expects(f dialCall) {
253	p.Lock()
254	defer p.Unlock()
255
256	if p.calls == nil {
257		p.calls = []dialCallExp{}
258	}
259
260	p.calls = append(p.calls, dialCallExp{f: f})
261}
262
263var MatchesExpectations check.Checker = &allExpectations{
264	&check.CheckerInfo{Name: "IsNil", Params: []string{"value"}},
265}
266
267type allExpectations struct {
268	*check.CheckerInfo
269}
270
271func (checker *allExpectations) Check(params []interface{}, names []string) (result bool, error string) {
272	p := params[0].(*mockProxy)
273
274	if p.called != len(p.calls) {
275		return false, fmt.Sprintf("expected: %d calls, got: %d", len(p.calls), p.called)
276	}
277
278	return true, ""
279}
280