1/*
2 *
3 * Copyright 2021 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package fault implements the Envoy Fault Injection HTTP filter.
20package fault
21
22import (
23	"context"
24	"errors"
25	"fmt"
26	"io"
27	"strconv"
28	"sync/atomic"
29	"time"
30
31	"github.com/golang/protobuf/proto"
32	"github.com/golang/protobuf/ptypes"
33	"google.golang.org/grpc/codes"
34	"google.golang.org/grpc/internal/grpcrand"
35	iresolver "google.golang.org/grpc/internal/resolver"
36	"google.golang.org/grpc/metadata"
37	"google.golang.org/grpc/status"
38	"google.golang.org/grpc/xds/internal/httpfilter"
39	"google.golang.org/protobuf/types/known/anypb"
40
41	cpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/common/fault/v3"
42	fpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/fault/v3"
43	tpb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
44)
45
46const headerAbortHTTPStatus = "x-envoy-fault-abort-request"
47const headerAbortGRPCStatus = "x-envoy-fault-abort-grpc-request"
48const headerAbortPercentage = "x-envoy-fault-abort-request-percentage"
49
50const headerDelayPercentage = "x-envoy-fault-delay-request-percentage"
51const headerDelayDuration = "x-envoy-fault-delay-request"
52
53var statusMap = map[int]codes.Code{
54	400: codes.Internal,
55	401: codes.Unauthenticated,
56	403: codes.PermissionDenied,
57	404: codes.Unimplemented,
58	429: codes.Unavailable,
59	502: codes.Unavailable,
60	503: codes.Unavailable,
61	504: codes.Unavailable,
62}
63
64func init() {
65	httpfilter.Register(builder{})
66}
67
68type builder struct {
69}
70
71type config struct {
72	httpfilter.FilterConfig
73	config *fpb.HTTPFault
74}
75
76func (builder) TypeURLs() []string {
77	return []string{"type.googleapis.com/envoy.extensions.filters.http.fault.v3.HTTPFault"}
78}
79
80// Parsing is the same for the base config and the override config.
81func parseConfig(cfg proto.Message) (httpfilter.FilterConfig, error) {
82	if cfg == nil {
83		return nil, fmt.Errorf("fault: nil configuration message provided")
84	}
85	any, ok := cfg.(*anypb.Any)
86	if !ok {
87		return nil, fmt.Errorf("fault: error parsing config %v: unknown type %T", cfg, cfg)
88	}
89	msg := new(fpb.HTTPFault)
90	if err := ptypes.UnmarshalAny(any, msg); err != nil {
91		return nil, fmt.Errorf("fault: error parsing config %v: %v", cfg, err)
92	}
93	return config{config: msg}, nil
94}
95
96func (builder) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) {
97	return parseConfig(cfg)
98}
99
100func (builder) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) {
101	return parseConfig(override)
102}
103
104func (builder) IsTerminal() bool {
105	return false
106}
107
108var _ httpfilter.ClientInterceptorBuilder = builder{}
109
110func (builder) BuildClientInterceptor(cfg, override httpfilter.FilterConfig) (iresolver.ClientInterceptor, error) {
111	if cfg == nil {
112		return nil, fmt.Errorf("fault: nil config provided")
113	}
114
115	c, ok := cfg.(config)
116	if !ok {
117		return nil, fmt.Errorf("fault: incorrect config type provided (%T): %v", cfg, cfg)
118	}
119
120	if override != nil {
121		// override completely replaces the listener configuration; but we
122		// still validate the listener config type.
123		c, ok = override.(config)
124		if !ok {
125			return nil, fmt.Errorf("fault: incorrect override config type provided (%T): %v", override, override)
126		}
127	}
128
129	icfg := c.config
130	if (icfg.GetMaxActiveFaults() != nil && icfg.GetMaxActiveFaults().GetValue() == 0) ||
131		(icfg.GetDelay() == nil && icfg.GetAbort() == nil) {
132		return nil, nil
133	}
134	return &interceptor{config: icfg}, nil
135}
136
137type interceptor struct {
138	config *fpb.HTTPFault
139}
140
141var activeFaults uint32 // global active faults; accessed atomically
142
143func (i *interceptor) NewStream(ctx context.Context, ri iresolver.RPCInfo, done func(), newStream func(ctx context.Context, done func()) (iresolver.ClientStream, error)) (iresolver.ClientStream, error) {
144	if maxAF := i.config.GetMaxActiveFaults(); maxAF != nil {
145		defer atomic.AddUint32(&activeFaults, ^uint32(0)) // decrement counter
146		if af := atomic.AddUint32(&activeFaults, 1); af > maxAF.GetValue() {
147			// Would exceed maximum active fault limit.
148			return newStream(ctx, done)
149		}
150	}
151
152	if err := injectDelay(ctx, i.config.GetDelay()); err != nil {
153		return nil, err
154	}
155
156	if err := injectAbort(ctx, i.config.GetAbort()); err != nil {
157		if err == errOKStream {
158			return &okStream{ctx: ctx}, nil
159		}
160		return nil, err
161	}
162	return newStream(ctx, done)
163}
164
165// For overriding in tests
166var randIntn = grpcrand.Intn
167var newTimer = time.NewTimer
168
169func injectDelay(ctx context.Context, delayCfg *cpb.FaultDelay) error {
170	numerator, denominator := splitPct(delayCfg.GetPercentage())
171	var delay time.Duration
172	switch delayType := delayCfg.GetFaultDelaySecifier().(type) {
173	case *cpb.FaultDelay_FixedDelay:
174		delay = delayType.FixedDelay.AsDuration()
175	case *cpb.FaultDelay_HeaderDelay_:
176		md, _ := metadata.FromOutgoingContext(ctx)
177		v := md[headerDelayDuration]
178		if v == nil {
179			// No delay configured for this RPC.
180			return nil
181		}
182		ms, ok := parseIntFromMD(v)
183		if !ok {
184			// Malformed header; no delay.
185			return nil
186		}
187		delay = time.Duration(ms) * time.Millisecond
188		if v := md[headerDelayPercentage]; v != nil {
189			if num, ok := parseIntFromMD(v); ok && num < numerator {
190				numerator = num
191			}
192		}
193	}
194	if delay == 0 || randIntn(denominator) >= numerator {
195		return nil
196	}
197	t := newTimer(delay)
198	select {
199	case <-t.C:
200	case <-ctx.Done():
201		t.Stop()
202		return ctx.Err()
203	}
204	return nil
205}
206
207func injectAbort(ctx context.Context, abortCfg *fpb.FaultAbort) error {
208	numerator, denominator := splitPct(abortCfg.GetPercentage())
209	code := codes.OK
210	okCode := false
211	switch errType := abortCfg.GetErrorType().(type) {
212	case *fpb.FaultAbort_HttpStatus:
213		code, okCode = grpcFromHTTP(int(errType.HttpStatus))
214	case *fpb.FaultAbort_GrpcStatus:
215		code, okCode = sanitizeGRPCCode(codes.Code(errType.GrpcStatus)), true
216	case *fpb.FaultAbort_HeaderAbort_:
217		md, _ := metadata.FromOutgoingContext(ctx)
218		if v := md[headerAbortHTTPStatus]; v != nil {
219			// HTTP status has priority over gRPC status.
220			if httpStatus, ok := parseIntFromMD(v); ok {
221				code, okCode = grpcFromHTTP(httpStatus)
222			}
223		} else if v := md[headerAbortGRPCStatus]; v != nil {
224			if grpcStatus, ok := parseIntFromMD(v); ok {
225				code, okCode = sanitizeGRPCCode(codes.Code(grpcStatus)), true
226			}
227		}
228		if v := md[headerAbortPercentage]; v != nil {
229			if num, ok := parseIntFromMD(v); ok && num < numerator {
230				numerator = num
231			}
232		}
233	}
234	if !okCode || randIntn(denominator) >= numerator {
235		return nil
236	}
237	if code == codes.OK {
238		return errOKStream
239	}
240	return status.Errorf(code, "RPC terminated due to fault injection")
241}
242
243var errOKStream = errors.New("stream terminated early with OK status")
244
245// parseIntFromMD returns the integer in the last header or nil if parsing
246// failed.
247func parseIntFromMD(header []string) (int, bool) {
248	if len(header) == 0 {
249		return 0, false
250	}
251	v, err := strconv.Atoi(header[len(header)-1])
252	return v, err == nil
253}
254
255func splitPct(fp *tpb.FractionalPercent) (num int, den int) {
256	if fp == nil {
257		return 0, 100
258	}
259	num = int(fp.GetNumerator())
260	switch fp.GetDenominator() {
261	case tpb.FractionalPercent_HUNDRED:
262		return num, 100
263	case tpb.FractionalPercent_TEN_THOUSAND:
264		return num, 10 * 1000
265	case tpb.FractionalPercent_MILLION:
266		return num, 1000 * 1000
267	}
268	return num, 100
269}
270
271func grpcFromHTTP(httpStatus int) (codes.Code, bool) {
272	if httpStatus < 200 || httpStatus >= 600 {
273		// Malformed; ignore this fault type.
274		return codes.OK, false
275	}
276	if c := statusMap[httpStatus]; c != codes.OK {
277		// OK = 0/the default for the map.
278		return c, true
279	}
280	// All undefined HTTP status codes convert to Unknown. HTTP status of 200
281	// is "success", but gRPC converts to Unknown due to missing grpc status.
282	return codes.Unknown, true
283}
284
285func sanitizeGRPCCode(c codes.Code) codes.Code {
286	if c > codes.Code(16) {
287		return codes.Unknown
288	}
289	return c
290}
291
292type okStream struct {
293	ctx context.Context
294}
295
296func (*okStream) Header() (metadata.MD, error) { return nil, nil }
297func (*okStream) Trailer() metadata.MD         { return nil }
298func (*okStream) CloseSend() error             { return nil }
299func (o *okStream) Context() context.Context   { return o.ctx }
300func (*okStream) SendMsg(m interface{}) error  { return io.EOF }
301func (*okStream) RecvMsg(m interface{}) error  { return io.EOF }
302