1// Copyright 2019 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package testutil 16 17import ( 18 "bytes" 19 "context" 20 "errors" 21 "fmt" 22 "log" 23 "os" 24 "strings" 25 26 "google.golang.org/api/option" 27 "google.golang.org/grpc" 28 "google.golang.org/grpc/metadata" 29) 30 31// HeaderChecker defines header checking and validation rules for any outgoing metadata. 32type HeaderChecker struct { 33 // Key is the header name to be checked against e.g. "x-goog-api-client". 34 Key string 35 36 // ValuesValidator validates the header values retrieved from mapping against 37 // Key in the Headers. 38 ValuesValidator func(values ...string) error 39} 40 41// HeadersEnforcer asserts that outgoing RPC headers 42// are present and match expectations. If the expected headers 43// are not present or don't match expectations, it'll invoke OnFailure 44// with the validation error, or instead log.Fatal if OnFailure is nil. 45// 46// It expects that every declared key will be present in the outgoing 47// RPC header and each value will be validated by the validation function. 48type HeadersEnforcer struct { 49 // Checkers maps header keys that are expected to be sent in the metadata 50 // of outgoing gRPC requests, against the values passed into the custom 51 // validation functions. 52 // 53 // If Checkers is nil or empty, only the default header "x-goog-api-client" 54 // will be checked for. 55 // Otherwise, if you supply Matchers, those keys and their respective 56 // validation functions will be checked. 57 Checkers []*HeaderChecker 58 59 // OnFailure is the function that will be invoked after all validation 60 // failures have been composed. If OnFailure is nil, log.Fatal will be 61 // invoked instead. 62 OnFailure func(fmt_ string, args ...interface{}) 63} 64 65// StreamInterceptors returns a list of StreamClientInterceptor functions which 66// enforce the presence and validity of expected headers during streaming RPCs. 67// 68// For client implementations which provide their own StreamClientInterceptor(s) 69// these interceptors should be specified as the final elements to 70// WithChainStreamInterceptor. 71// 72// Alternatively, users may apply gPRC options produced from DialOptions to 73// apply all applicable gRPC interceptors. 74func (h *HeadersEnforcer) StreamInterceptors() []grpc.StreamClientInterceptor { 75 return []grpc.StreamClientInterceptor{h.interceptStream} 76} 77 78// UnaryInterceptors returns a list of UnaryClientInterceptor functions which 79// enforce the presence and validity of expected headers during unary RPCs. 80// 81// For client implementations which provide their own UnaryClientInterceptor(s) 82// these interceptors should be specified as the final elements to 83// WithChainUnaryInterceptor. 84// 85// Alternatively, users may apply gPRC options produced from DialOptions to 86// apply all applicable gRPC interceptors. 87func (h *HeadersEnforcer) UnaryInterceptors() []grpc.UnaryClientInterceptor { 88 return []grpc.UnaryClientInterceptor{h.interceptUnary} 89} 90 91// DialOptions returns gRPC DialOptions consisting of unary and stream interceptors 92// to enforce the presence and validity of expected headers. 93func (h *HeadersEnforcer) DialOptions() []grpc.DialOption { 94 return []grpc.DialOption{ 95 grpc.WithChainStreamInterceptor(h.interceptStream), 96 grpc.WithChainUnaryInterceptor(h.interceptUnary), 97 } 98} 99 100// CallOptions returns ClientOptions consisting of unary and stream interceptors 101// to enforce the presence and validity of expected headers. 102func (h *HeadersEnforcer) CallOptions() (copts []option.ClientOption) { 103 dopts := h.DialOptions() 104 for _, dopt := range dopts { 105 copts = append(copts, option.WithGRPCDialOption(dopt)) 106 } 107 return 108} 109 110func (h *HeadersEnforcer) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 111 h.checkMetadata(ctx, method) 112 return invoker(ctx, method, req, res, cc, opts...) 113} 114 115func (h *HeadersEnforcer) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 116 h.checkMetadata(ctx, method) 117 return streamer(ctx, desc, cc, method, opts...) 118} 119 120// XGoogClientHeaderChecker is a HeaderChecker that ensures that the "x-goog-api-client" 121// header is present on outgoing metadata. 122var XGoogClientHeaderChecker = &HeaderChecker{ 123 Key: "x-goog-api-client", 124 ValuesValidator: func(values ...string) error { 125 if len(values) == 0 { 126 return errors.New("expecting values") 127 } 128 for _, value := range values { 129 switch { 130 case strings.Contains(value, "gl-go/"): 131 // TODO: check for exact version strings. 132 return nil 133 134 default: // Add others here. 135 } 136 } 137 return errors.New("unmatched values") 138 }, 139} 140 141// DefaultHeadersEnforcer returns a HeadersEnforcer that at bare minimum checks that 142// the "x-goog-api-client" key is present in the outgoing metadata headers. On any 143// validation failure, it will invoke log.Fatalf with the error message. 144func DefaultHeadersEnforcer() *HeadersEnforcer { 145 return &HeadersEnforcer{ 146 Checkers: []*HeaderChecker{XGoogClientHeaderChecker}, 147 } 148} 149 150func (h *HeadersEnforcer) checkMetadata(ctx context.Context, method string) { 151 onFailure := h.OnFailure 152 if onFailure == nil { 153 lgr := log.New(os.Stderr, "", 0) // Do not log the time prefix, it is noisy in test failure logs. 154 onFailure = func(fmt_ string, args ...interface{}) { 155 lgr.Fatalf(fmt_, args...) 156 } 157 } 158 159 md, ok := metadata.FromOutgoingContext(ctx) 160 if !ok { 161 onFailure("Missing metadata for method %q", method) 162 return 163 } 164 checkers := h.Checkers 165 if len(checkers) == 0 { 166 // Instead use the default HeaderChecker. 167 checkers = append(checkers, XGoogClientHeaderChecker) 168 } 169 170 errBuf := new(bytes.Buffer) 171 for _, checker := range checkers { 172 hdrKey := checker.Key 173 outHdrValues, ok := md[hdrKey] 174 if !ok { 175 fmt.Fprintf(errBuf, "missing header %q\n", hdrKey) 176 continue 177 } 178 if err := checker.ValuesValidator(outHdrValues...); err != nil { 179 fmt.Fprintf(errBuf, "header %q: %v\n", hdrKey, err) 180 } 181 } 182 183 if errBuf.Len() != 0 { 184 onFailure("For method %q, errors:\n%s", method, errBuf) 185 return 186 } 187} 188