1// Copyright 2018 Istio Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package request
16
17import (
18	"fmt"
19	"io/ioutil"
20	"net/http"
21	"net/http/httptest"
22	"net/url"
23	"sync"
24	"testing"
25
26	"istio.io/istio/tests/util"
27)
28
29type pilotStubHandler struct {
30	sync.Mutex
31	States []pilotStubState
32}
33
34type pilotStubState struct {
35	wantMethod string
36	wantPath   string
37	wantBody   []byte
38	StatusCode int
39	Response   string
40}
41
42func (p *pilotStubHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
43	p.Lock()
44	if r.Method == p.States[0].wantMethod {
45		if r.URL.Path == p.States[0].wantPath {
46			defer r.Body.Close()
47			body, _ := ioutil.ReadAll(r.Body)
48			if err := util.Compare(body, p.States[0].wantBody); err == nil {
49				w.WriteHeader(p.States[0].StatusCode)
50				w.Write([]byte(p.States[0].Response))
51			} else {
52				w.WriteHeader(http.StatusBadRequest)
53				w.Write([]byte(fmt.Sprintf("wanted body %q got %q", string(p.States[0].wantBody), string(body))))
54			}
55		} else {
56			w.WriteHeader(http.StatusBadRequest)
57			w.Write([]byte(fmt.Sprintf("wanted path %q got %q", p.States[0].wantPath, r.URL.Path)))
58		}
59	} else {
60		w.WriteHeader(http.StatusBadRequest)
61		w.Write([]byte(fmt.Sprintf("wanted method %q got %q", p.States[0].wantMethod, r.Method)))
62	}
63	p.States = p.States[1:]
64	p.Unlock()
65}
66
67func Test_command_do(t *testing.T) {
68	tests := []struct {
69		name              string
70		method            string
71		path              string
72		body              string
73		pilotStates       []pilotStubState
74		pilotNotReachable bool
75		wantError         bool
76	}{
77		{
78			name:   "makes a request using passed method, url and body",
79			method: "POST",
80			path:   "/want/path",
81			body:   "body",
82			pilotStates: []pilotStubState{
83				{StatusCode: 200, Response: "fine", wantMethod: "POST", wantPath: "/want/path", wantBody: []byte("body")},
84			},
85		},
86		{
87			name:   "adds / prefix to path if required",
88			method: "POST",
89			path:   "want/path",
90			body:   "body",
91			pilotStates: []pilotStubState{
92				{StatusCode: 200, Response: "fine", wantMethod: "POST", wantPath: "/want/path", wantBody: []byte("body")},
93			},
94		},
95		{
96			name:   "handles empty string body in args",
97			method: "GET",
98			path:   "/want/path",
99			body:   "",
100			pilotStates: []pilotStubState{
101				{StatusCode: 200, Response: "fine", wantMethod: "GET", wantPath: "/want/path", wantBody: nil},
102			},
103		},
104		{
105			name:   "doesn't error on 404",
106			method: "GET",
107			path:   "/want/path",
108			body:   "",
109			pilotStates: []pilotStubState{
110				{StatusCode: 404, Response: "not-found", wantMethod: "GET", wantPath: "/want/path", wantBody: nil},
111			},
112		},
113		{
114			name:              "errors if Pilot is unreachable",
115			method:            "GET",
116			path:              "/want/path",
117			pilotNotReachable: true,
118			wantError:         true,
119		},
120		{
121			name:   "errors if Pilot responds with a non success status",
122			method: "GET",
123			path:   "/not/wanted/path",
124			body:   "",
125			pilotStates: []pilotStubState{
126				{StatusCode: 200, Response: "fine", wantMethod: "GET", wantPath: "/want/path", wantBody: nil},
127			},
128			wantError: true,
129		},
130	}
131	for _, tt := range tests {
132		t.Run(tt.name, func(t *testing.T) {
133			pilotStub := httptest.NewServer(
134				&pilotStubHandler{States: tt.pilotStates},
135			)
136			stubURL, _ := url.Parse(pilotStub.URL)
137			if tt.pilotNotReachable {
138				stubURL, _ = url.Parse("http://notpilot")
139			}
140			c := &Command{
141				Address: stubURL.Host,
142				Client:  &http.Client{},
143			}
144			err := c.Do(tt.method, tt.path, tt.body)
145			if (err == nil) && tt.wantError {
146				t.Errorf("Expected an error but received none")
147			} else if (err != nil) && !tt.wantError {
148				t.Errorf("Unexpected err: %v", err)
149			}
150		})
151	}
152}
153