1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package etcdhttp
16
17import (
18	"context"
19	"encoding/json"
20	"fmt"
21	"io/ioutil"
22	"net/http"
23	"net/http/httptest"
24	"path"
25	"sort"
26	"strings"
27	"testing"
28
29	"go.uber.org/zap"
30
31	"github.com/coreos/go-semver/semver"
32	"go.etcd.io/etcd/etcdserver/api"
33	"go.etcd.io/etcd/etcdserver/api/membership"
34	"go.etcd.io/etcd/etcdserver/api/rafthttp"
35	pb "go.etcd.io/etcd/etcdserver/etcdserverpb"
36	"go.etcd.io/etcd/pkg/testutil"
37	"go.etcd.io/etcd/pkg/types"
38)
39
40type fakeCluster struct {
41	id         uint64
42	clientURLs []string
43	members    map[uint64]*membership.Member
44}
45
46func (c *fakeCluster) ID() types.ID         { return types.ID(c.id) }
47func (c *fakeCluster) ClientURLs() []string { return c.clientURLs }
48func (c *fakeCluster) Members() []*membership.Member {
49	var ms membership.MembersByID
50	for _, m := range c.members {
51		ms = append(ms, m)
52	}
53	sort.Sort(ms)
54	return []*membership.Member(ms)
55}
56func (c *fakeCluster) Member(id types.ID) *membership.Member { return c.members[uint64(id)] }
57func (c *fakeCluster) Version() *semver.Version              { return nil }
58
59type fakeServer struct {
60	cluster api.Cluster
61	alarms  []*pb.AlarmMember
62}
63
64func (s *fakeServer) AddMember(ctx context.Context, memb membership.Member) ([]*membership.Member, error) {
65	return nil, fmt.Errorf("AddMember not implemented in fakeServer")
66}
67func (s *fakeServer) RemoveMember(ctx context.Context, id uint64) ([]*membership.Member, error) {
68	return nil, fmt.Errorf("RemoveMember not implemented in fakeServer")
69}
70func (s *fakeServer) UpdateMember(ctx context.Context, updateMemb membership.Member) ([]*membership.Member, error) {
71	return nil, fmt.Errorf("UpdateMember not implemented in fakeServer")
72}
73func (s *fakeServer) PromoteMember(ctx context.Context, id uint64) ([]*membership.Member, error) {
74	return nil, fmt.Errorf("PromoteMember not implemented in fakeServer")
75}
76func (s *fakeServer) ClusterVersion() *semver.Version { return nil }
77func (s *fakeServer) Cluster() api.Cluster            { return s.cluster }
78func (s *fakeServer) Alarms() []*pb.AlarmMember       { return s.alarms }
79
80var fakeRaftHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
81	w.Write([]byte("test data"))
82})
83
84// TestNewPeerHandlerOnRaftPrefix tests that NewPeerHandler returns a handler that
85// handles raft-prefix requests well.
86func TestNewPeerHandlerOnRaftPrefix(t *testing.T) {
87	ph := newPeerHandler(zap.NewExample(), &fakeServer{cluster: &fakeCluster{}}, fakeRaftHandler, nil, nil)
88	srv := httptest.NewServer(ph)
89	defer srv.Close()
90
91	tests := []string{
92		rafthttp.RaftPrefix,
93		rafthttp.RaftPrefix + "/hello",
94	}
95	for i, tt := range tests {
96		resp, err := http.Get(srv.URL + tt)
97		if err != nil {
98			t.Fatalf("unexpected http.Get error: %v", err)
99		}
100		body, err := ioutil.ReadAll(resp.Body)
101		if err != nil {
102			t.Fatalf("unexpected ioutil.ReadAll error: %v", err)
103		}
104		if w := "test data"; string(body) != w {
105			t.Errorf("#%d: body = %s, want %s", i, body, w)
106		}
107	}
108}
109
110// TestServeMembersFails ensures peerMembersHandler only accepts GET request
111func TestServeMembersFails(t *testing.T) {
112	tests := []struct {
113		method string
114		wcode  int
115	}{
116		{
117			"POST",
118			http.StatusMethodNotAllowed,
119		},
120		{
121			"PUT",
122			http.StatusMethodNotAllowed,
123		},
124		{
125			"DELETE",
126			http.StatusMethodNotAllowed,
127		},
128		{
129			"BAD",
130			http.StatusMethodNotAllowed,
131		},
132	}
133	for i, tt := range tests {
134		rw := httptest.NewRecorder()
135		h := newPeerMembersHandler(nil, &fakeCluster{})
136		req, err := http.NewRequest(tt.method, "", nil)
137		if err != nil {
138			t.Fatalf("#%d: failed to create http request: %v", i, err)
139		}
140		h.ServeHTTP(rw, req)
141		if rw.Code != tt.wcode {
142			t.Errorf("#%d: code=%d, want %d", i, rw.Code, tt.wcode)
143		}
144	}
145}
146
147func TestServeMembersGet(t *testing.T) {
148	memb1 := membership.Member{ID: 1, Attributes: membership.Attributes{ClientURLs: []string{"http://localhost:8080"}}}
149	memb2 := membership.Member{ID: 2, Attributes: membership.Attributes{ClientURLs: []string{"http://localhost:8081"}}}
150	cluster := &fakeCluster{
151		id:      1,
152		members: map[uint64]*membership.Member{1: &memb1, 2: &memb2},
153	}
154	h := newPeerMembersHandler(nil, cluster)
155	msb, err := json.Marshal([]membership.Member{memb1, memb2})
156	if err != nil {
157		t.Fatal(err)
158	}
159	wms := string(msb) + "\n"
160
161	tests := []struct {
162		path  string
163		wcode int
164		wct   string
165		wbody string
166	}{
167		{peerMembersPath, http.StatusOK, "application/json", wms},
168		{path.Join(peerMembersPath, "bad"), http.StatusBadRequest, "text/plain; charset=utf-8", "bad path\n"},
169	}
170
171	for i, tt := range tests {
172		req, err := http.NewRequest("GET", testutil.MustNewURL(t, tt.path).String(), nil)
173		if err != nil {
174			t.Fatal(err)
175		}
176		rw := httptest.NewRecorder()
177		h.ServeHTTP(rw, req)
178
179		if rw.Code != tt.wcode {
180			t.Errorf("#%d: code=%d, want %d", i, rw.Code, tt.wcode)
181		}
182		if gct := rw.Header().Get("Content-Type"); gct != tt.wct {
183			t.Errorf("#%d: content-type = %s, want %s", i, gct, tt.wct)
184		}
185		if rw.Body.String() != tt.wbody {
186			t.Errorf("#%d: body = %s, want %s", i, rw.Body.String(), tt.wbody)
187		}
188		gcid := rw.Header().Get("X-Etcd-Cluster-ID")
189		wcid := cluster.ID().String()
190		if gcid != wcid {
191			t.Errorf("#%d: cid = %s, want %s", i, gcid, wcid)
192		}
193	}
194}
195
196// TestServeMemberPromoteFails ensures peerMemberPromoteHandler only accepts POST request
197func TestServeMemberPromoteFails(t *testing.T) {
198	tests := []struct {
199		method string
200		wcode  int
201	}{
202		{
203			"GET",
204			http.StatusMethodNotAllowed,
205		},
206		{
207			"PUT",
208			http.StatusMethodNotAllowed,
209		},
210		{
211			"DELETE",
212			http.StatusMethodNotAllowed,
213		},
214		{
215			"BAD",
216			http.StatusMethodNotAllowed,
217		},
218	}
219	for i, tt := range tests {
220		rw := httptest.NewRecorder()
221		h := newPeerMemberPromoteHandler(nil, &fakeServer{cluster: &fakeCluster{}})
222		req, err := http.NewRequest(tt.method, "", nil)
223		if err != nil {
224			t.Fatalf("#%d: failed to create http request: %v", i, err)
225		}
226		h.ServeHTTP(rw, req)
227		if rw.Code != tt.wcode {
228			t.Errorf("#%d: code=%d, want %d", i, rw.Code, tt.wcode)
229		}
230	}
231}
232
233// TestNewPeerHandlerOnMembersPromotePrefix verifies the request with members promote prefix is routed correctly
234func TestNewPeerHandlerOnMembersPromotePrefix(t *testing.T) {
235	ph := newPeerHandler(zap.NewExample(), &fakeServer{cluster: &fakeCluster{}}, fakeRaftHandler, nil, nil)
236	srv := httptest.NewServer(ph)
237	defer srv.Close()
238
239	tests := []struct {
240		path      string
241		wcode     int
242		checkBody bool
243		wKeyWords string
244	}{
245		{
246			// does not contain member id in path
247			peerMemberPromotePrefix,
248			http.StatusNotFound,
249			false,
250			"",
251		},
252		{
253			// try to promote member id = 1
254			peerMemberPromotePrefix + "1",
255			http.StatusInternalServerError,
256			true,
257			"PromoteMember not implemented in fakeServer",
258		},
259	}
260	for i, tt := range tests {
261		req, err := http.NewRequest("POST", srv.URL+tt.path, nil)
262		if err != nil {
263			t.Fatalf("failed to create request: %v", err)
264		}
265		resp, err := http.DefaultClient.Do(req)
266		if err != nil {
267			t.Fatalf("failed to get http response: %v", err)
268		}
269		body, err := ioutil.ReadAll(resp.Body)
270		resp.Body.Close()
271		if err != nil {
272			t.Fatalf("unexpected ioutil.ReadAll error: %v", err)
273		}
274		if resp.StatusCode != tt.wcode {
275			t.Fatalf("#%d: code = %d, want %d", i, resp.StatusCode, tt.wcode)
276		}
277		if tt.checkBody && strings.Contains(string(body), tt.wKeyWords) {
278			t.Errorf("#%d: body: %s, want body to contain keywords: %s", i, string(body), tt.wKeyWords)
279		}
280	}
281}
282