1package xmpp 2 3import ( 4 "encoding/hex" 5 "fmt" 6 "io" 7 "net" 8 "sync" 9 "time" 10 11 "gopkg.in/check.v1" 12) 13 14type mockConn struct { 15 calledClose int 16 net.TCPConn 17} 18 19func (c *mockConn) Close() error { 20 c.calledClose++ 21 return nil 22} 23 24type mockConnIOReaderWriter struct { 25 read []byte 26 readIndex int 27 write []byte 28 errCount int 29 err error 30 31 calledClose int 32 33 lock sync.Mutex 34} 35 36func (iom *mockConnIOReaderWriter) CalledClose() bool { 37 iom.lock.Lock() 38 defer iom.lock.Unlock() 39 return iom.calledClose > 0 40} 41 42func (iom *mockConnIOReaderWriter) Written() []byte { 43 iom.lock.Lock() 44 defer iom.lock.Unlock() 45 46 var res []byte 47 l := len(iom.write) 48 res = make([]byte, l, l) 49 copy(res, iom.write) 50 return res 51} 52 53func (iom *mockConnIOReaderWriter) Read(p []byte) (n int, err error) { 54 iom.lock.Lock() 55 defer iom.lock.Unlock() 56 57 if iom.readIndex >= len(iom.read) { 58 return 0, io.EOF 59 } 60 i := copy(p, iom.read[iom.readIndex:]) 61 iom.readIndex += i 62 var e error 63 if iom.errCount == 0 { 64 e = iom.err 65 } 66 iom.errCount-- 67 return i, e 68} 69 70func (iom *mockConnIOReaderWriter) Write(p []byte) (n int, err error) { 71 iom.lock.Lock() 72 defer iom.lock.Unlock() 73 74 iom.write = append(iom.write, p...) 75 var e error 76 if iom.errCount == 0 { 77 e = iom.err 78 } 79 iom.errCount-- 80 return len(p), e 81} 82 83func (iom *mockConnIOReaderWriter) Close() error { 84 iom.lock.Lock() 85 defer iom.lock.Unlock() 86 87 iom.calledClose++ 88 return nil 89} 90 91type mockMultiConnIOReaderWriter struct { 92 read [][]byte 93 readIndex int 94 write []byte 95} 96 97func (iom *mockMultiConnIOReaderWriter) Read(p []byte) (n int, err error) { 98 if iom.readIndex >= len(iom.read) { 99 return 0, io.EOF 100 } 101 i := copy(p, iom.read[iom.readIndex]) 102 iom.readIndex++ 103 return i, nil 104} 105 106func (iom *mockMultiConnIOReaderWriter) Write(p []byte) (n int, err error) { 107 iom.write = append(iom.write, p...) 108 return len(p), nil 109} 110 111type fullMockedConn struct { 112 rw io.ReadWriter 113} 114 115func (c *fullMockedConn) Read(b []byte) (n int, err error) { 116 return c.rw.Read(b) 117} 118 119func (c *fullMockedConn) Write(b []byte) (n int, err error) { 120 return c.rw.Write(b) 121} 122 123func (c *fullMockedConn) Close() error { 124 return nil 125} 126 127func (c *fullMockedConn) LocalAddr() net.Addr { 128 return nil 129} 130 131func (c *fullMockedConn) RemoteAddr() net.Addr { 132 return nil 133} 134 135func (c *fullMockedConn) SetDeadline(t time.Time) error { 136 return nil 137} 138 139func (c *fullMockedConn) SetReadDeadline(t time.Time) error { 140 return nil 141} 142 143func (c *fullMockedConn) SetWriteDeadline(t time.Time) error { 144 return nil 145} 146 147type fixedRandReader struct { 148 data []string 149 at int 150} 151 152func fixedRand(data []string) io.Reader { 153 return &fixedRandReader{data, 0} 154} 155 156func bytesFromHex(s string) []byte { 157 val, _ := hex.DecodeString(s) 158 return val 159} 160 161func byteStringFromHex(s string) string { 162 val, _ := hex.DecodeString(s) 163 return string(val) 164} 165 166func (frr *fixedRandReader) Read(p []byte) (n int, err error) { 167 if frr.at < len(frr.data) { 168 plainBytes := bytesFromHex(frr.data[frr.at]) 169 frr.at++ 170 n = copy(p, plainBytes) 171 return 172 } 173 return 0, io.EOF 174} 175 176func createTeeConn(c net.Conn, w io.Writer) net.Conn { 177 return &teeConn{c, w} 178} 179 180type teeConn struct { 181 c net.Conn 182 w io.Writer 183} 184 185func (c *teeConn) Read(b []byte) (n int, err error) { 186 n, err = c.c.Read(b) 187 if n > 0 { 188 fmt.Fprintf(c.w, "READ: %x\n", b[:n]) 189 } 190 return 191} 192 193func (c *teeConn) Write(b []byte) (n int, err error) { 194 n, err = c.c.Write(b) 195 if n > 0 { 196 fmt.Fprintf(c.w, "WRITE: %x\n", b[:n]) 197 } 198 return n, err 199} 200 201func (c *teeConn) Close() error { 202 return c.c.Close() 203} 204 205func (c *teeConn) LocalAddr() net.Addr { 206 return c.c.LocalAddr() 207} 208 209func (c *teeConn) RemoteAddr() net.Addr { 210 return c.c.RemoteAddr() 211} 212 213func (c *teeConn) SetDeadline(t time.Time) error { 214 return c.c.SetDeadline(t) 215} 216 217func (c *teeConn) SetReadDeadline(t time.Time) error { 218 return c.c.SetReadDeadline(t) 219} 220 221func (c *teeConn) SetWriteDeadline(t time.Time) error { 222 return c.c.SetWriteDeadline(t) 223} 224 225type dialCall func(string, string) (c net.Conn, e error) 226type dialCallExp struct { 227 f dialCall 228 called bool 229} 230 231type mockProxy struct { 232 called int 233 calls []dialCallExp 234 sync.Mutex 235} 236 237func (p *mockProxy) Dial(network, addr string) (net.Conn, error) { 238 if len(p.calls)-1 < p.called { 239 return nil, fmt.Errorf("unexpected call to Dial: %s, %s", network, addr) 240 } 241 242 p.Lock() 243 defer p.Unlock() 244 245 fn := p.calls[p.called] 246 p.called = p.called + 1 247 248 fn.called = true 249 return fn.f(network, addr) 250} 251 252func (p *mockProxy) Expects(f dialCall) { 253 p.Lock() 254 defer p.Unlock() 255 256 if p.calls == nil { 257 p.calls = []dialCallExp{} 258 } 259 260 p.calls = append(p.calls, dialCallExp{f: f}) 261} 262 263var MatchesExpectations check.Checker = &allExpectations{ 264 &check.CheckerInfo{Name: "IsNil", Params: []string{"value"}}, 265} 266 267type allExpectations struct { 268 *check.CheckerInfo 269} 270 271func (checker *allExpectations) Check(params []interface{}, names []string) (result bool, error string) { 272 p := params[0].(*mockProxy) 273 274 if p.called != len(p.calls) { 275 return false, fmt.Sprintf("expected: %d calls, got: %d", len(p.calls), p.called) 276 } 277 278 return true, "" 279} 280