1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package rafthttp
16
17import (
18	"errors"
19	"fmt"
20	"io"
21	"io/ioutil"
22	"net/http"
23	"sync"
24	"testing"
25	"time"
26
27	stats "go.etcd.io/etcd/etcdserver/api/v2stats"
28	"go.etcd.io/etcd/pkg/testutil"
29	"go.etcd.io/etcd/pkg/types"
30	"go.etcd.io/etcd/raft/raftpb"
31	"go.etcd.io/etcd/version"
32
33	"go.uber.org/zap"
34)
35
36// TestPipelineSend tests that pipeline could send data using roundtripper
37// and increase success count in stats.
38func TestPipelineSend(t *testing.T) {
39	tr := &roundTripperRecorder{rec: testutil.NewRecorderStream()}
40	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
41	tp := &Transport{pipelineRt: tr}
42	p := startTestPipeline(tp, picker)
43
44	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
45	tr.rec.Wait(1)
46	p.stop()
47	if p.followerStats.Counts.Success != 1 {
48		t.Errorf("success = %d, want 1", p.followerStats.Counts.Success)
49	}
50}
51
52// TestPipelineKeepSendingWhenPostError tests that pipeline can keep
53// sending messages if previous messages meet post error.
54func TestPipelineKeepSendingWhenPostError(t *testing.T) {
55	tr := &respRoundTripper{rec: testutil.NewRecorderStream(), err: fmt.Errorf("roundtrip error")}
56	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
57	tp := &Transport{pipelineRt: tr}
58	p := startTestPipeline(tp, picker)
59	defer p.stop()
60
61	for i := 0; i < 50; i++ {
62		p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
63	}
64
65	_, err := tr.rec.Wait(50)
66	if err != nil {
67		t.Errorf("unexpected wait error %v", err)
68	}
69}
70
71func TestPipelineExceedMaximumServing(t *testing.T) {
72	rt := newRoundTripperBlocker()
73	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
74	tp := &Transport{pipelineRt: rt}
75	p := startTestPipeline(tp, picker)
76	defer p.stop()
77
78	// keep the sender busy and make the buffer full
79	// nothing can go out as we block the sender
80	for i := 0; i < connPerPipeline+pipelineBufSize; i++ {
81		select {
82		case p.msgc <- raftpb.Message{}:
83		case <-time.After(time.Second):
84			t.Errorf("failed to send out message")
85		}
86	}
87
88	// try to send a data when we are sure the buffer is full
89	select {
90	case p.msgc <- raftpb.Message{}:
91		t.Errorf("unexpected message sendout")
92	default:
93	}
94
95	// unblock the senders and force them to send out the data
96	rt.unblock()
97
98	// It could send new data after previous ones succeed
99	select {
100	case p.msgc <- raftpb.Message{}:
101	case <-time.After(time.Second):
102		t.Errorf("failed to send out message")
103	}
104}
105
106// TestPipelineSendFailed tests that when send func meets the post error,
107// it increases fail count in stats.
108func TestPipelineSendFailed(t *testing.T) {
109	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
110	rt := newRespRoundTripper(0, errors.New("blah"))
111	rt.rec = testutil.NewRecorderStream()
112	tp := &Transport{pipelineRt: rt}
113	p := startTestPipeline(tp, picker)
114
115	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
116	if _, err := rt.rec.Wait(1); err != nil {
117		t.Fatal(err)
118	}
119
120	p.stop()
121
122	if p.followerStats.Counts.Fail != 1 {
123		t.Errorf("fail = %d, want 1", p.followerStats.Counts.Fail)
124	}
125}
126
127func TestPipelinePost(t *testing.T) {
128	tr := &roundTripperRecorder{rec: &testutil.RecorderBuffered{}}
129	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
130	tp := &Transport{ClusterID: types.ID(1), pipelineRt: tr}
131	p := startTestPipeline(tp, picker)
132	if err := p.post([]byte("some data")); err != nil {
133		t.Fatalf("unexpected post error: %v", err)
134	}
135	act, err := tr.rec.Wait(1)
136	if err != nil {
137		t.Fatal(err)
138	}
139	p.stop()
140
141	req := act[0].Params[0].(*http.Request)
142
143	if g := req.Method; g != "POST" {
144		t.Errorf("method = %s, want %s", g, "POST")
145	}
146	if g := req.URL.String(); g != "http://localhost:2380/raft" {
147		t.Errorf("url = %s, want %s", g, "http://localhost:2380/raft")
148	}
149	if g := req.Header.Get("Content-Type"); g != "application/protobuf" {
150		t.Errorf("content type = %s, want %s", g, "application/protobuf")
151	}
152	if g := req.Header.Get("X-Server-Version"); g != version.Version {
153		t.Errorf("version = %s, want %s", g, version.Version)
154	}
155	if g := req.Header.Get("X-Min-Cluster-Version"); g != version.MinClusterVersion {
156		t.Errorf("min version = %s, want %s", g, version.MinClusterVersion)
157	}
158	if g := req.Header.Get("X-Etcd-Cluster-ID"); g != "1" {
159		t.Errorf("cluster id = %s, want %s", g, "1")
160	}
161	b, err := ioutil.ReadAll(req.Body)
162	if err != nil {
163		t.Fatalf("unexpected ReadAll error: %v", err)
164	}
165	if string(b) != "some data" {
166		t.Errorf("body = %s, want %s", b, "some data")
167	}
168}
169
170func TestPipelinePostBad(t *testing.T) {
171	tests := []struct {
172		u    string
173		code int
174		err  error
175	}{
176		// RoundTrip returns error
177		{"http://localhost:2380", 0, errors.New("blah")},
178		// unexpected response status code
179		{"http://localhost:2380", http.StatusOK, nil},
180		{"http://localhost:2380", http.StatusCreated, nil},
181	}
182	for i, tt := range tests {
183		picker := mustNewURLPicker(t, []string{tt.u})
184		tp := &Transport{pipelineRt: newRespRoundTripper(tt.code, tt.err)}
185		p := startTestPipeline(tp, picker)
186		err := p.post([]byte("some data"))
187		p.stop()
188
189		if err == nil {
190			t.Errorf("#%d: err = nil, want not nil", i)
191		}
192	}
193}
194
195func TestPipelinePostErrorc(t *testing.T) {
196	tests := []struct {
197		u    string
198		code int
199		err  error
200	}{
201		{"http://localhost:2380", http.StatusForbidden, nil},
202	}
203	for i, tt := range tests {
204		picker := mustNewURLPicker(t, []string{tt.u})
205		tp := &Transport{pipelineRt: newRespRoundTripper(tt.code, tt.err)}
206		p := startTestPipeline(tp, picker)
207		p.post([]byte("some data"))
208		p.stop()
209		select {
210		case <-p.errorc:
211		default:
212			t.Fatalf("#%d: cannot receive from errorc", i)
213		}
214	}
215}
216
217func TestStopBlockedPipeline(t *testing.T) {
218	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
219	tp := &Transport{pipelineRt: newRoundTripperBlocker()}
220	p := startTestPipeline(tp, picker)
221	// send many messages that most of them will be blocked in buffer
222	for i := 0; i < connPerPipeline*10; i++ {
223		p.msgc <- raftpb.Message{}
224	}
225
226	done := make(chan struct{})
227	go func() {
228		p.stop()
229		done <- struct{}{}
230	}()
231	select {
232	case <-done:
233	case <-time.After(time.Second):
234		t.Fatalf("failed to stop pipeline in 1s")
235	}
236}
237
238type roundTripperBlocker struct {
239	unblockc chan struct{}
240	mu       sync.Mutex
241	cancel   map[*http.Request]chan struct{}
242}
243
244func newRoundTripperBlocker() *roundTripperBlocker {
245	return &roundTripperBlocker{
246		unblockc: make(chan struct{}),
247		cancel:   make(map[*http.Request]chan struct{}),
248	}
249}
250
251func (t *roundTripperBlocker) unblock() {
252	close(t.unblockc)
253}
254
255func (t *roundTripperBlocker) CancelRequest(req *http.Request) {
256	t.mu.Lock()
257	defer t.mu.Unlock()
258	if c, ok := t.cancel[req]; ok {
259		c <- struct{}{}
260		delete(t.cancel, req)
261	}
262}
263
264type respRoundTripper struct {
265	mu  sync.Mutex
266	rec testutil.Recorder
267
268	code   int
269	header http.Header
270	err    error
271}
272
273func newRespRoundTripper(code int, err error) *respRoundTripper {
274	return &respRoundTripper{code: code, err: err}
275}
276func (t *respRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
277	t.mu.Lock()
278	defer t.mu.Unlock()
279	if t.rec != nil {
280		t.rec.Record(testutil.Action{Name: "req", Params: []interface{}{req}})
281	}
282	return &http.Response{StatusCode: t.code, Header: t.header, Body: &nopReadCloser{}}, t.err
283}
284
285type roundTripperRecorder struct {
286	rec testutil.Recorder
287}
288
289func (t *roundTripperRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
290	if t.rec != nil {
291		t.rec.Record(testutil.Action{Name: "req", Params: []interface{}{req}})
292	}
293	return &http.Response{StatusCode: http.StatusNoContent, Body: &nopReadCloser{}}, nil
294}
295
296type nopReadCloser struct{}
297
298func (n *nopReadCloser) Read(p []byte) (int, error) { return 0, io.EOF }
299func (n *nopReadCloser) Close() error               { return nil }
300
301func startTestPipeline(tr *Transport, picker *urlPicker) *pipeline {
302	p := &pipeline{
303		peerID:        types.ID(1),
304		tr:            tr,
305		picker:        picker,
306		status:        newPeerStatus(zap.NewExample(), tr.ID, types.ID(1)),
307		raft:          &fakeRaft{},
308		followerStats: &stats.FollowerStats{},
309		errorc:        make(chan error, 1),
310	}
311	p.start()
312	return p
313}
314