1// Copyright 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package testutil
16
17import (
18	"bytes"
19	"context"
20	"errors"
21	"fmt"
22	"log"
23	"os"
24	"strings"
25
26	"google.golang.org/api/option"
27	"google.golang.org/grpc"
28	"google.golang.org/grpc/metadata"
29)
30
31// HeaderChecker defines header checking and validation rules for any outgoing metadata.
32type HeaderChecker struct {
33	// Key is the header name to be checked against e.g. "x-goog-api-client".
34	Key string
35
36	// ValuesValidator validates the header values retrieved from mapping against
37	// Key in the Headers.
38	ValuesValidator func(values ...string) error
39}
40
41// HeadersEnforcer asserts that outgoing RPC headers
42// are present and match expectations. If the expected headers
43// are not present or don't match expectations, it'll invoke OnFailure
44// with the validation error, or instead log.Fatal if OnFailure is nil.
45//
46// It expects that every declared key will be present in the outgoing
47// RPC header and each value will be validated by the validation function.
48type HeadersEnforcer struct {
49	// Checkers maps header keys that are expected to be sent in the metadata
50	// of outgoing gRPC requests, against the values passed into the custom
51	// validation functions.
52	//
53	// If Checkers is nil or empty, only the default header "x-goog-api-client"
54	// will be checked for.
55	// Otherwise, if you supply Matchers, those keys and their respective
56	// validation functions will be checked.
57	Checkers []*HeaderChecker
58
59	// OnFailure is the function that will be invoked after all validation
60	// failures have been composed. If OnFailure is nil, log.Fatal will be
61	// invoked instead.
62	OnFailure func(fmt_ string, args ...interface{})
63}
64
65// StreamInterceptors returns a list of StreamClientInterceptor functions which
66// enforce the presence and validity of expected headers during streaming RPCs.
67//
68// For client implementations which provide their own StreamClientInterceptor(s)
69// these interceptors should be specified as the final elements to
70// WithChainStreamInterceptor.
71//
72// Alternatively, users may apply gPRC options produced from DialOptions to
73// apply all applicable gRPC interceptors.
74func (h *HeadersEnforcer) StreamInterceptors() []grpc.StreamClientInterceptor {
75	return []grpc.StreamClientInterceptor{h.interceptStream}
76}
77
78// UnaryInterceptors returns a list of UnaryClientInterceptor functions which
79// enforce the presence and validity of expected headers during unary RPCs.
80//
81// For client implementations which provide their own UnaryClientInterceptor(s)
82// these interceptors should be specified as the final elements to
83// WithChainUnaryInterceptor.
84//
85// Alternatively, users may apply gPRC options produced from DialOptions to
86// apply all applicable gRPC interceptors.
87func (h *HeadersEnforcer) UnaryInterceptors() []grpc.UnaryClientInterceptor {
88	return []grpc.UnaryClientInterceptor{h.interceptUnary}
89}
90
91// DialOptions returns gRPC DialOptions consisting of unary and stream interceptors
92// to enforce the presence and validity of expected headers.
93func (h *HeadersEnforcer) DialOptions() []grpc.DialOption {
94	return []grpc.DialOption{
95		grpc.WithChainStreamInterceptor(h.interceptStream),
96		grpc.WithChainUnaryInterceptor(h.interceptUnary),
97	}
98}
99
100// CallOptions returns ClientOptions consisting of unary and stream interceptors
101// to enforce the presence and validity of expected headers.
102func (h *HeadersEnforcer) CallOptions() (copts []option.ClientOption) {
103	dopts := h.DialOptions()
104	for _, dopt := range dopts {
105		copts = append(copts, option.WithGRPCDialOption(dopt))
106	}
107	return
108}
109
110func (h *HeadersEnforcer) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
111	h.checkMetadata(ctx, method)
112	return invoker(ctx, method, req, res, cc, opts...)
113}
114
115func (h *HeadersEnforcer) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
116	h.checkMetadata(ctx, method)
117	return streamer(ctx, desc, cc, method, opts...)
118}
119
120// XGoogClientHeaderChecker is a HeaderChecker that ensures that the "x-goog-api-client"
121// header is present on outgoing metadata.
122var XGoogClientHeaderChecker = &HeaderChecker{
123	Key: "x-goog-api-client",
124	ValuesValidator: func(values ...string) error {
125		if len(values) == 0 {
126			return errors.New("expecting values")
127		}
128		for _, value := range values {
129			switch {
130			case strings.Contains(value, "gl-go/"):
131				// TODO: check for exact version strings.
132				return nil
133
134			default: // Add others here.
135			}
136		}
137		return errors.New("unmatched values")
138	},
139}
140
141// DefaultHeadersEnforcer returns a HeadersEnforcer that at bare minimum checks that
142// the "x-goog-api-client" key is present in the outgoing metadata headers. On any
143// validation failure, it will invoke log.Fatalf with the error message.
144func DefaultHeadersEnforcer() *HeadersEnforcer {
145	return &HeadersEnforcer{
146		Checkers: []*HeaderChecker{XGoogClientHeaderChecker},
147	}
148}
149
150func (h *HeadersEnforcer) checkMetadata(ctx context.Context, method string) {
151	onFailure := h.OnFailure
152	if onFailure == nil {
153		lgr := log.New(os.Stderr, "", 0) // Do not log the time prefix, it is noisy in test failure logs.
154		onFailure = func(fmt_ string, args ...interface{}) {
155			lgr.Fatalf(fmt_, args...)
156		}
157	}
158
159	md, ok := metadata.FromOutgoingContext(ctx)
160	if !ok {
161		onFailure("Missing metadata for method %q", method)
162		return
163	}
164	checkers := h.Checkers
165	if len(checkers) == 0 {
166		// Instead use the default HeaderChecker.
167		checkers = append(checkers, XGoogClientHeaderChecker)
168	}
169
170	errBuf := new(bytes.Buffer)
171	for _, checker := range checkers {
172		hdrKey := checker.Key
173		outHdrValues, ok := md[hdrKey]
174		if !ok {
175			fmt.Fprintf(errBuf, "missing header %q\n", hdrKey)
176			continue
177		}
178		if err := checker.ValuesValidator(outHdrValues...); err != nil {
179			fmt.Fprintf(errBuf, "header %q: %v\n", hdrKey, err)
180		}
181	}
182
183	if errBuf.Len() != 0 {
184		onFailure("For method %q, errors:\n%s", method, errBuf)
185		return
186	}
187}
188