1// Copyright 2018 Istio Authors 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package request 16 17import ( 18 "fmt" 19 "io/ioutil" 20 "net/http" 21 "net/http/httptest" 22 "net/url" 23 "sync" 24 "testing" 25 26 "istio.io/istio/tests/util" 27) 28 29type pilotStubHandler struct { 30 sync.Mutex 31 States []pilotStubState 32} 33 34type pilotStubState struct { 35 wantMethod string 36 wantPath string 37 wantBody []byte 38 StatusCode int 39 Response string 40} 41 42func (p *pilotStubHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 43 p.Lock() 44 if r.Method == p.States[0].wantMethod { 45 if r.URL.Path == p.States[0].wantPath { 46 defer r.Body.Close() 47 body, _ := ioutil.ReadAll(r.Body) 48 if err := util.Compare(body, p.States[0].wantBody); err == nil { 49 w.WriteHeader(p.States[0].StatusCode) 50 w.Write([]byte(p.States[0].Response)) 51 } else { 52 w.WriteHeader(http.StatusBadRequest) 53 w.Write([]byte(fmt.Sprintf("wanted body %q got %q", string(p.States[0].wantBody), string(body)))) 54 } 55 } else { 56 w.WriteHeader(http.StatusBadRequest) 57 w.Write([]byte(fmt.Sprintf("wanted path %q got %q", p.States[0].wantPath, r.URL.Path))) 58 } 59 } else { 60 w.WriteHeader(http.StatusBadRequest) 61 w.Write([]byte(fmt.Sprintf("wanted method %q got %q", p.States[0].wantMethod, r.Method))) 62 } 63 p.States = p.States[1:] 64 p.Unlock() 65} 66 67func Test_command_do(t *testing.T) { 68 tests := []struct { 69 name string 70 method string 71 path string 72 body string 73 pilotStates []pilotStubState 74 pilotNotReachable bool 75 wantError bool 76 }{ 77 { 78 name: "makes a request using passed method, url and body", 79 method: "POST", 80 path: "/want/path", 81 body: "body", 82 pilotStates: []pilotStubState{ 83 {StatusCode: 200, Response: "fine", wantMethod: "POST", wantPath: "/want/path", wantBody: []byte("body")}, 84 }, 85 }, 86 { 87 name: "adds / prefix to path if required", 88 method: "POST", 89 path: "want/path", 90 body: "body", 91 pilotStates: []pilotStubState{ 92 {StatusCode: 200, Response: "fine", wantMethod: "POST", wantPath: "/want/path", wantBody: []byte("body")}, 93 }, 94 }, 95 { 96 name: "handles empty string body in args", 97 method: "GET", 98 path: "/want/path", 99 body: "", 100 pilotStates: []pilotStubState{ 101 {StatusCode: 200, Response: "fine", wantMethod: "GET", wantPath: "/want/path", wantBody: nil}, 102 }, 103 }, 104 { 105 name: "doesn't error on 404", 106 method: "GET", 107 path: "/want/path", 108 body: "", 109 pilotStates: []pilotStubState{ 110 {StatusCode: 404, Response: "not-found", wantMethod: "GET", wantPath: "/want/path", wantBody: nil}, 111 }, 112 }, 113 { 114 name: "errors if Pilot is unreachable", 115 method: "GET", 116 path: "/want/path", 117 pilotNotReachable: true, 118 wantError: true, 119 }, 120 { 121 name: "errors if Pilot responds with a non success status", 122 method: "GET", 123 path: "/not/wanted/path", 124 body: "", 125 pilotStates: []pilotStubState{ 126 {StatusCode: 200, Response: "fine", wantMethod: "GET", wantPath: "/want/path", wantBody: nil}, 127 }, 128 wantError: true, 129 }, 130 } 131 for _, tt := range tests { 132 t.Run(tt.name, func(t *testing.T) { 133 pilotStub := httptest.NewServer( 134 &pilotStubHandler{States: tt.pilotStates}, 135 ) 136 stubURL, _ := url.Parse(pilotStub.URL) 137 if tt.pilotNotReachable { 138 stubURL, _ = url.Parse("http://notpilot") 139 } 140 c := &Command{ 141 Address: stubURL.Host, 142 Client: &http.Client{}, 143 } 144 err := c.Do(tt.method, tt.path, tt.body) 145 if (err == nil) && tt.wantError { 146 t.Errorf("Expected an error but received none") 147 } else if (err != nil) && !tt.wantError { 148 t.Errorf("Unexpected err: %v", err) 149 } 150 }) 151 } 152} 153