1/* 2 * Copyright 2016 gRPC authors. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package test 18 19import ( 20 "bytes" 21 "errors" 22 "io" 23 "strings" 24 "testing" 25 "time" 26 27 "golang.org/x/net/http2" 28 "golang.org/x/net/http2/hpack" 29) 30 31// This is a subset of http2's serverTester type. 32// 33// serverTester wraps a io.ReadWriter (acting like the underlying 34// network connection) and provides utility methods to read and write 35// http2 frames. 36// 37// NOTE(bradfitz): this could eventually be exported somewhere. Others 38// have asked for it too. For now I'm still experimenting with the 39// API and don't feel like maintaining a stable testing API. 40 41type serverTester struct { 42 cc io.ReadWriteCloser // client conn 43 t testing.TB 44 fr *http2.Framer 45 46 // writing headers: 47 headerBuf bytes.Buffer 48 hpackEnc *hpack.Encoder 49 50 // reading frames: 51 frc chan http2.Frame 52 frErrc chan error 53} 54 55func newServerTesterFromConn(t testing.TB, cc io.ReadWriteCloser) *serverTester { 56 st := &serverTester{ 57 t: t, 58 cc: cc, 59 frc: make(chan http2.Frame, 1), 60 frErrc: make(chan error, 1), 61 } 62 st.hpackEnc = hpack.NewEncoder(&st.headerBuf) 63 st.fr = http2.NewFramer(cc, cc) 64 st.fr.ReadMetaHeaders = hpack.NewDecoder(4096 /*initialHeaderTableSize*/, nil) 65 66 return st 67} 68 69func (st *serverTester) readFrame() (http2.Frame, error) { 70 go func() { 71 fr, err := st.fr.ReadFrame() 72 if err != nil { 73 st.frErrc <- err 74 } else { 75 st.frc <- fr 76 } 77 }() 78 t := time.NewTimer(2 * time.Second) 79 defer t.Stop() 80 select { 81 case f := <-st.frc: 82 return f, nil 83 case err := <-st.frErrc: 84 return nil, err 85 case <-t.C: 86 return nil, errors.New("timeout waiting for frame") 87 } 88} 89 90// greet initiates the client's HTTP/2 connection into a state where 91// frames may be sent. 92func (st *serverTester) greet() { 93 st.writePreface() 94 st.writeInitialSettings() 95 st.wantSettings() 96 st.writeSettingsAck() 97 for { 98 f, err := st.readFrame() 99 if err != nil { 100 st.t.Fatal(err) 101 } 102 switch f := f.(type) { 103 case *http2.WindowUpdateFrame: 104 // grpc's transport/http2_server sends this 105 // before the settings ack. The Go http2 106 // server uses a setting instead. 107 case *http2.SettingsFrame: 108 if f.IsAck() { 109 return 110 } 111 st.t.Fatalf("during greet, got non-ACK settings frame") 112 default: 113 st.t.Fatalf("during greet, unexpected frame type %T", f) 114 } 115 } 116} 117 118func (st *serverTester) writePreface() { 119 n, err := st.cc.Write([]byte(http2.ClientPreface)) 120 if err != nil { 121 st.t.Fatalf("Error writing client preface: %v", err) 122 } 123 if n != len(http2.ClientPreface) { 124 st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(http2.ClientPreface)) 125 } 126} 127 128func (st *serverTester) writeInitialSettings() { 129 if err := st.fr.WriteSettings(); err != nil { 130 st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err) 131 } 132} 133 134func (st *serverTester) writeSettingsAck() { 135 if err := st.fr.WriteSettingsAck(); err != nil { 136 st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err) 137 } 138} 139 140func (st *serverTester) wantSettings() *http2.SettingsFrame { 141 f, err := st.readFrame() 142 if err != nil { 143 st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err) 144 } 145 sf, ok := f.(*http2.SettingsFrame) 146 if !ok { 147 st.t.Fatalf("got a %T; want *SettingsFrame", f) 148 } 149 return sf 150} 151 152// wait for any activity from the server 153func (st *serverTester) wantAnyFrame() http2.Frame { 154 f, err := st.fr.ReadFrame() 155 if err != nil { 156 st.t.Fatal(err) 157 } 158 return f 159} 160 161func (st *serverTester) encodeHeaderField(k, v string) { 162 err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) 163 if err != nil { 164 st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) 165 } 166} 167 168// encodeHeader encodes headers and returns their HPACK bytes. headers 169// must contain an even number of key/value pairs. There may be 170// multiple pairs for keys (e.g. "cookie"). The :method, :path, and 171// :scheme headers default to GET, / and https. 172func (st *serverTester) encodeHeader(headers ...string) []byte { 173 if len(headers)%2 == 1 { 174 panic("odd number of kv args") 175 } 176 177 st.headerBuf.Reset() 178 179 if len(headers) == 0 { 180 // Fast path, mostly for benchmarks, so test code doesn't pollute 181 // profiles when we're looking to improve server allocations. 182 st.encodeHeaderField(":method", "GET") 183 st.encodeHeaderField(":path", "/") 184 st.encodeHeaderField(":scheme", "https") 185 return st.headerBuf.Bytes() 186 } 187 188 if len(headers) == 2 && headers[0] == ":method" { 189 // Another fast path for benchmarks. 190 st.encodeHeaderField(":method", headers[1]) 191 st.encodeHeaderField(":path", "/") 192 st.encodeHeaderField(":scheme", "https") 193 return st.headerBuf.Bytes() 194 } 195 196 pseudoCount := map[string]int{} 197 keys := []string{":method", ":path", ":scheme"} 198 vals := map[string][]string{ 199 ":method": {"GET"}, 200 ":path": {"/"}, 201 ":scheme": {"https"}, 202 } 203 for len(headers) > 0 { 204 k, v := headers[0], headers[1] 205 headers = headers[2:] 206 if _, ok := vals[k]; !ok { 207 keys = append(keys, k) 208 } 209 if strings.HasPrefix(k, ":") { 210 pseudoCount[k]++ 211 if pseudoCount[k] == 1 { 212 vals[k] = []string{v} 213 } else { 214 // Allows testing of invalid headers w/ dup pseudo fields. 215 vals[k] = append(vals[k], v) 216 } 217 } else { 218 vals[k] = append(vals[k], v) 219 } 220 } 221 for _, k := range keys { 222 for _, v := range vals[k] { 223 st.encodeHeaderField(k, v) 224 } 225 } 226 return st.headerBuf.Bytes() 227} 228 229func (st *serverTester) writeHeadersGRPC(streamID uint32, path string) { 230 st.writeHeaders(http2.HeadersFrameParam{ 231 StreamID: streamID, 232 BlockFragment: st.encodeHeader( 233 ":method", "POST", 234 ":path", path, 235 "content-type", "application/grpc", 236 "te", "trailers", 237 ), 238 EndStream: false, 239 EndHeaders: true, 240 }) 241} 242 243func (st *serverTester) writeHeaders(p http2.HeadersFrameParam) { 244 if err := st.fr.WriteHeaders(p); err != nil { 245 st.t.Fatalf("Error writing HEADERS: %v", err) 246 } 247} 248 249func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) { 250 if err := st.fr.WriteData(streamID, endStream, data); err != nil { 251 st.t.Fatalf("Error writing DATA: %v", err) 252 } 253} 254 255func (st *serverTester) writeRSTStream(streamID uint32, code http2.ErrCode) { 256 if err := st.fr.WriteRSTStream(streamID, code); err != nil { 257 st.t.Fatalf("Error writing RST_STREAM: %v", err) 258 } 259} 260