1package rpc
2
3import (
4	"errors"
5	"io"
6	"net"
7	"time"
8
9	"golang.org/x/net/context"
10)
11
12type testConnectionHandler struct{}
13
14var _ ConnectionHandler = testConnectionHandler{}
15
16func (testConnectionHandler) OnConnect(context.Context, *Connection, GenericClient, *Server) error {
17	return nil
18}
19
20func (testConnectionHandler) OnConnectError(err error, reconnectThrottleDuration time.Duration) {
21}
22
23func (testConnectionHandler) OnDoCommandError(err error, nextTime time.Duration) {
24}
25
26func (testConnectionHandler) OnDisconnected(ctx context.Context, status DisconnectStatus) {
27}
28
29func (testConnectionHandler) ShouldRetry(name string, err error) bool {
30	return false
31}
32
33func (testConnectionHandler) ShouldRetryOnConnect(err error) bool {
34	return false
35}
36
37func (testConnectionHandler) HandlerName() string {
38	return "testConnectionHandler"
39}
40
41type singleTransport struct {
42	t Transporter
43}
44
45var _ ConnectionTransport = singleTransport{}
46
47// Dial is an implementation of the ConnectionTransport interface.
48func (st singleTransport) Dial(ctx context.Context) (Transporter, error) {
49	if !st.t.IsConnected() {
50		return nil, io.EOF
51	}
52	return st.t, nil
53}
54
55// IsConnected is an implementation of the ConnectionTransport interface.
56func (st singleTransport) IsConnected() bool {
57	return st.t.IsConnected()
58}
59
60// Finalize is an implementation of the ConnectionTransport interface.
61func (st singleTransport) Finalize() {}
62
63// Close is an implementation of the ConnectionTransport interface.
64func (st singleTransport) Close() {}
65
66type testStatus struct {
67	Code int
68}
69
70func testWrapError(err error) interface{} {
71	return &testStatus{}
72}
73
74func testLogTags(ctx context.Context) (map[interface{}]string, bool) {
75	return nil, false
76}
77
78type throttleError struct {
79	Err error
80}
81
82func (e throttleError) ToStatus() (s testStatus) {
83	s.Code = 15
84	return
85}
86
87func (e throttleError) Error() string {
88	return e.Err.Error()
89}
90
91type testErrorUnwrapper struct{}
92
93var _ ErrorUnwrapper = testErrorUnwrapper{}
94
95func (eu testErrorUnwrapper) Timeout() time.Duration {
96	return 0
97}
98
99func (eu testErrorUnwrapper) MakeArg() interface{} {
100	return &testStatus{}
101}
102
103func (eu testErrorUnwrapper) UnwrapError(arg interface{}) (appError error, dispatchError error) {
104	s, ok := arg.(*testStatus)
105	if !ok {
106		return nil, errors.New("Error converting arg to testStatus object")
107	}
108	if s == nil || s.Code == 0 {
109		return nil, nil
110	}
111
112	switch s.Code {
113	case 15:
114		appError = throttleError{errors.New("throttle")}
115	default:
116		panic("Unknown testing error")
117	}
118	return appError, nil
119}
120
121// TestLogger is an interface for things, like *testing.T, that have a
122// Logf and Helper function.
123type TestLogger interface {
124	Logf(format string, args ...interface{})
125	Helper()
126}
127
128const testMaxFrameLength = 1024
129
130// MakeConnectionForTest returns a Connection object, and a net.Conn
131// object representing the other end of that connection.
132func MakeConnectionForTest(t TestLogger) (net.Conn, *Connection) {
133	clientConn, serverConn := net.Pipe()
134	logOutput := testLogOutput{t}
135	logFactory := NewSimpleLogFactory(logOutput, nil)
136	instrumenterStorage := NewMemoryInstrumentationStorage()
137	transporter := NewTransport(clientConn, logFactory,
138		instrumenterStorage, testWrapError, testMaxFrameLength)
139	st := singleTransport{transporter}
140	opts := ConnectionOpts{
141		WrapErrorFunc: testWrapError,
142		TagsFunc:      testLogTags,
143	}
144	conn := NewConnectionWithTransport(testConnectionHandler{}, st,
145		testErrorUnwrapper{}, logOutput, opts)
146	return serverConn, conn
147}
148