1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package etcdhttp
16
17import (
18	"context"
19	"encoding/json"
20	"fmt"
21	"io/ioutil"
22	"net/http"
23	"net/http/httptest"
24	"path"
25	"sort"
26	"strings"
27	"testing"
28
29	"go.uber.org/zap"
30
31	"github.com/coreos/go-semver/semver"
32	"go.etcd.io/etcd/etcdserver/api"
33	"go.etcd.io/etcd/etcdserver/api/membership"
34	"go.etcd.io/etcd/etcdserver/api/rafthttp"
35	pb "go.etcd.io/etcd/etcdserver/etcdserverpb"
36	"go.etcd.io/etcd/pkg/testutil"
37	"go.etcd.io/etcd/pkg/types"
38)
39
40type fakeCluster struct {
41	id         uint64
42	clientURLs []string
43	members    map[uint64]*membership.Member
44}
45
46func (c *fakeCluster) ID() types.ID         { return types.ID(c.id) }
47func (c *fakeCluster) ClientURLs() []string { return c.clientURLs }
48func (c *fakeCluster) Members() []*membership.Member {
49	var ms membership.MembersByID
50	for _, m := range c.members {
51		ms = append(ms, m)
52	}
53	sort.Sort(ms)
54	return []*membership.Member(ms)
55}
56func (c *fakeCluster) Member(id types.ID) *membership.Member { return c.members[uint64(id)] }
57func (c *fakeCluster) Version() *semver.Version              { return nil }
58
59type fakeServer struct {
60	cluster api.Cluster
61}
62
63func (s *fakeServer) AddMember(ctx context.Context, memb membership.Member) ([]*membership.Member, error) {
64	return nil, fmt.Errorf("AddMember not implemented in fakeServer")
65}
66func (s *fakeServer) RemoveMember(ctx context.Context, id uint64) ([]*membership.Member, error) {
67	return nil, fmt.Errorf("RemoveMember not implemented in fakeServer")
68}
69func (s *fakeServer) UpdateMember(ctx context.Context, updateMemb membership.Member) ([]*membership.Member, error) {
70	return nil, fmt.Errorf("UpdateMember not implemented in fakeServer")
71}
72func (s *fakeServer) PromoteMember(ctx context.Context, id uint64) ([]*membership.Member, error) {
73	return nil, fmt.Errorf("PromoteMember not implemented in fakeServer")
74}
75func (s *fakeServer) ClusterVersion() *semver.Version { return nil }
76func (s *fakeServer) Cluster() api.Cluster            { return s.cluster }
77func (s *fakeServer) Alarms() []*pb.AlarmMember       { return nil }
78
79var fakeRaftHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80	w.Write([]byte("test data"))
81})
82
83// TestNewPeerHandlerOnRaftPrefix tests that NewPeerHandler returns a handler that
84// handles raft-prefix requests well.
85func TestNewPeerHandlerOnRaftPrefix(t *testing.T) {
86	ph := newPeerHandler(zap.NewExample(), &fakeServer{cluster: &fakeCluster{}}, fakeRaftHandler, nil, nil)
87	srv := httptest.NewServer(ph)
88	defer srv.Close()
89
90	tests := []string{
91		rafthttp.RaftPrefix,
92		rafthttp.RaftPrefix + "/hello",
93	}
94	for i, tt := range tests {
95		resp, err := http.Get(srv.URL + tt)
96		if err != nil {
97			t.Fatalf("unexpected http.Get error: %v", err)
98		}
99		body, err := ioutil.ReadAll(resp.Body)
100		if err != nil {
101			t.Fatalf("unexpected ioutil.ReadAll error: %v", err)
102		}
103		if w := "test data"; string(body) != w {
104			t.Errorf("#%d: body = %s, want %s", i, body, w)
105		}
106	}
107}
108
109// TestServeMembersFails ensures peerMembersHandler only accepts GET request
110func TestServeMembersFails(t *testing.T) {
111	tests := []struct {
112		method string
113		wcode  int
114	}{
115		{
116			"POST",
117			http.StatusMethodNotAllowed,
118		},
119		{
120			"PUT",
121			http.StatusMethodNotAllowed,
122		},
123		{
124			"DELETE",
125			http.StatusMethodNotAllowed,
126		},
127		{
128			"BAD",
129			http.StatusMethodNotAllowed,
130		},
131	}
132	for i, tt := range tests {
133		rw := httptest.NewRecorder()
134		h := newPeerMembersHandler(nil, &fakeCluster{})
135		req, err := http.NewRequest(tt.method, "", nil)
136		if err != nil {
137			t.Fatalf("#%d: failed to create http request: %v", i, err)
138		}
139		h.ServeHTTP(rw, req)
140		if rw.Code != tt.wcode {
141			t.Errorf("#%d: code=%d, want %d", i, rw.Code, tt.wcode)
142		}
143	}
144}
145
146func TestServeMembersGet(t *testing.T) {
147	memb1 := membership.Member{ID: 1, Attributes: membership.Attributes{ClientURLs: []string{"http://localhost:8080"}}}
148	memb2 := membership.Member{ID: 2, Attributes: membership.Attributes{ClientURLs: []string{"http://localhost:8081"}}}
149	cluster := &fakeCluster{
150		id:      1,
151		members: map[uint64]*membership.Member{1: &memb1, 2: &memb2},
152	}
153	h := newPeerMembersHandler(nil, cluster)
154	msb, err := json.Marshal([]membership.Member{memb1, memb2})
155	if err != nil {
156		t.Fatal(err)
157	}
158	wms := string(msb) + "\n"
159
160	tests := []struct {
161		path  string
162		wcode int
163		wct   string
164		wbody string
165	}{
166		{peerMembersPath, http.StatusOK, "application/json", wms},
167		{path.Join(peerMembersPath, "bad"), http.StatusBadRequest, "text/plain; charset=utf-8", "bad path\n"},
168	}
169
170	for i, tt := range tests {
171		req, err := http.NewRequest("GET", testutil.MustNewURL(t, tt.path).String(), nil)
172		if err != nil {
173			t.Fatal(err)
174		}
175		rw := httptest.NewRecorder()
176		h.ServeHTTP(rw, req)
177
178		if rw.Code != tt.wcode {
179			t.Errorf("#%d: code=%d, want %d", i, rw.Code, tt.wcode)
180		}
181		if gct := rw.Header().Get("Content-Type"); gct != tt.wct {
182			t.Errorf("#%d: content-type = %s, want %s", i, gct, tt.wct)
183		}
184		if rw.Body.String() != tt.wbody {
185			t.Errorf("#%d: body = %s, want %s", i, rw.Body.String(), tt.wbody)
186		}
187		gcid := rw.Header().Get("X-Etcd-Cluster-ID")
188		wcid := cluster.ID().String()
189		if gcid != wcid {
190			t.Errorf("#%d: cid = %s, want %s", i, gcid, wcid)
191		}
192	}
193}
194
195// TestServeMemberPromoteFails ensures peerMemberPromoteHandler only accepts POST request
196func TestServeMemberPromoteFails(t *testing.T) {
197	tests := []struct {
198		method string
199		wcode  int
200	}{
201		{
202			"GET",
203			http.StatusMethodNotAllowed,
204		},
205		{
206			"PUT",
207			http.StatusMethodNotAllowed,
208		},
209		{
210			"DELETE",
211			http.StatusMethodNotAllowed,
212		},
213		{
214			"BAD",
215			http.StatusMethodNotAllowed,
216		},
217	}
218	for i, tt := range tests {
219		rw := httptest.NewRecorder()
220		h := newPeerMemberPromoteHandler(nil, &fakeServer{cluster: &fakeCluster{}})
221		req, err := http.NewRequest(tt.method, "", nil)
222		if err != nil {
223			t.Fatalf("#%d: failed to create http request: %v", i, err)
224		}
225		h.ServeHTTP(rw, req)
226		if rw.Code != tt.wcode {
227			t.Errorf("#%d: code=%d, want %d", i, rw.Code, tt.wcode)
228		}
229	}
230}
231
232// TestNewPeerHandlerOnMembersPromotePrefix verifies the request with members promote prefix is routed correctly
233func TestNewPeerHandlerOnMembersPromotePrefix(t *testing.T) {
234	ph := newPeerHandler(zap.NewExample(), &fakeServer{cluster: &fakeCluster{}}, fakeRaftHandler, nil, nil)
235	srv := httptest.NewServer(ph)
236	defer srv.Close()
237
238	tests := []struct {
239		path      string
240		wcode     int
241		checkBody bool
242		wKeyWords string
243	}{
244		{
245			// does not contain member id in path
246			peerMemberPromotePrefix,
247			http.StatusNotFound,
248			false,
249			"",
250		},
251		{
252			// try to promote member id = 1
253			peerMemberPromotePrefix + "1",
254			http.StatusInternalServerError,
255			true,
256			"PromoteMember not implemented in fakeServer",
257		},
258	}
259	for i, tt := range tests {
260		req, err := http.NewRequest("POST", srv.URL+tt.path, nil)
261		if err != nil {
262			t.Fatalf("failed to create request: %v", err)
263		}
264		resp, err := http.DefaultClient.Do(req)
265		if err != nil {
266			t.Fatalf("failed to get http response: %v", err)
267		}
268		body, err := ioutil.ReadAll(resp.Body)
269		resp.Body.Close()
270		if err != nil {
271			t.Fatalf("unexpected ioutil.ReadAll error: %v", err)
272		}
273		if resp.StatusCode != tt.wcode {
274			t.Fatalf("#%d: code = %d, want %d", i, resp.StatusCode, tt.wcode)
275		}
276		if tt.checkBody && strings.Contains(string(body), tt.wKeyWords) {
277			t.Errorf("#%d: body: %s, want body to contain keywords: %s", i, string(body), tt.wKeyWords)
278		}
279	}
280}
281