1package azblob
2
3import (
4	"bytes"
5	"context"
6	"fmt"
7	"net/http"
8	"net/url"
9	"runtime"
10	"strings"
11	"time"
12
13	"github.com/Azure/azure-pipeline-go/pipeline"
14)
15
16// RequestLogOptions configures the retry policy's behavior.
17type RequestLogOptions struct {
18	// LogWarningIfTryOverThreshold logs a warning if a tried operation takes longer than the specified
19	// duration (-1=no logging; 0=default threshold).
20	LogWarningIfTryOverThreshold time.Duration
21}
22
23func (o RequestLogOptions) defaults() RequestLogOptions {
24	if o.LogWarningIfTryOverThreshold == 0 {
25		// It would be good to relate this to https://azure.microsoft.com/en-us/support/legal/sla/storage/v1_2/
26		// But this monitors the time to get the HTTP response; NOT the time to download the response body.
27		o.LogWarningIfTryOverThreshold = 3 * time.Second // Default to 3 seconds
28	}
29	return o
30}
31
32// NewRequestLogPolicyFactory creates a RequestLogPolicyFactory object configured using the specified options.
33func NewRequestLogPolicyFactory(o RequestLogOptions) pipeline.Factory {
34	o = o.defaults() // Force defaults to be calculated
35	return pipeline.FactoryFunc(func(next pipeline.Policy, po *pipeline.PolicyOptions) pipeline.PolicyFunc {
36		// These variables are per-policy; shared by multiple calls to Do
37		var try int32
38		operationStart := time.Now() // If this is the 1st try, record the operation state time
39		return func(ctx context.Context, request pipeline.Request) (response pipeline.Response, err error) {
40			try++ // The first try is #1 (not #0)
41
42			// Log the outgoing request as informational
43			if po.ShouldLog(pipeline.LogInfo) {
44				b := &bytes.Buffer{}
45				fmt.Fprintf(b, "==> OUTGOING REQUEST (Try=%d)\n", try)
46				pipeline.WriteRequestWithResponse(b, prepareRequestForLogging(request), nil, nil)
47				po.Log(pipeline.LogInfo, b.String())
48			}
49
50			// Set the time for this particular retry operation and then Do the operation.
51			tryStart := time.Now()
52			response, err = next.Do(ctx, request) // Make the request
53			tryEnd := time.Now()
54			tryDuration := tryEnd.Sub(tryStart)
55			opDuration := tryEnd.Sub(operationStart)
56
57			logLevel, forceLog := pipeline.LogInfo, false // Default logging information
58
59			// If the response took too long, we'll upgrade to warning.
60			if o.LogWarningIfTryOverThreshold > 0 && tryDuration > o.LogWarningIfTryOverThreshold {
61				// Log a warning if the try duration exceeded the specified threshold
62				logLevel, forceLog = pipeline.LogWarning, true
63			}
64
65			var sc int
66			if err == nil { // We got a valid response from the service
67				sc = response.Response().StatusCode
68			} else { // We got an error, so we should inspect if we got a response
69				if se, ok := err.(StorageError); ok {
70					if r := se.Response(); r != nil {
71						sc = r.StatusCode
72					}
73				}
74			}
75
76			if sc == 0 || ((sc >= 400 && sc <= 499) && sc != http.StatusNotFound && sc != http.StatusConflict && sc != http.StatusPreconditionFailed && sc != http.StatusRequestedRangeNotSatisfiable) || (sc >= 500 && sc <= 599) {
77				logLevel, forceLog = pipeline.LogError, true // Promote to Error any 4xx (except those listed is an error) or any 5xx
78			} else {
79				// For other status codes, we leave the level as is.
80			}
81
82			if shouldLog := po.ShouldLog(logLevel); forceLog || shouldLog {
83				// We're going to log this; build the string to log
84				b := &bytes.Buffer{}
85				slow := ""
86				if o.LogWarningIfTryOverThreshold > 0 && tryDuration > o.LogWarningIfTryOverThreshold {
87					slow = fmt.Sprintf("[SLOW >%v]", o.LogWarningIfTryOverThreshold)
88				}
89				fmt.Fprintf(b, "==> REQUEST/RESPONSE (Try=%d/%v%s, OpTime=%v) -- ", try, tryDuration, slow, opDuration)
90				if err != nil { // This HTTP request did not get a response from the service
91					fmt.Fprint(b, "REQUEST ERROR\n")
92				} else {
93					if logLevel == pipeline.LogError {
94						fmt.Fprint(b, "RESPONSE STATUS CODE ERROR\n")
95					} else {
96						fmt.Fprint(b, "RESPONSE SUCCESSFULLY RECEIVED\n")
97					}
98				}
99
100				pipeline.WriteRequestWithResponse(b, prepareRequestForLogging(request), response.Response(), err)
101				if logLevel <= pipeline.LogError {
102					b.Write(stack()) // For errors (or lower levels), we append the stack trace (an expensive operation)
103				}
104				msg := b.String()
105
106				if forceLog {
107					pipeline.ForceLog(logLevel, msg)
108				}
109				if shouldLog {
110					po.Log(logLevel, msg)
111				}
112			}
113			return response, err
114		}
115	})
116}
117
118// RedactSigQueryParam redacts the 'sig' query parameter in URL's raw query to protect secret.
119func RedactSigQueryParam(rawQuery string) (bool, string) {
120	rawQuery = strings.ToLower(rawQuery) // lowercase the string so we can look for ?sig= and &sig=
121	sigFound := strings.Contains(rawQuery, "?sig=")
122	if !sigFound {
123		sigFound = strings.Contains(rawQuery, "&sig=")
124		if !sigFound {
125			return sigFound, rawQuery // [?|&]sig= not found; return same rawQuery passed in (no memory allocation)
126		}
127	}
128	// [?|&]sig= found, redact its value
129	values, _ := url.ParseQuery(rawQuery)
130	for name := range values {
131		if strings.EqualFold(name, "sig") {
132			values[name] = []string{"REDACTED"}
133		}
134	}
135	return sigFound, values.Encode()
136}
137
138func prepareRequestForLogging(request pipeline.Request) *http.Request {
139	req := request
140	if sigFound, rawQuery := RedactSigQueryParam(req.URL.RawQuery); sigFound {
141		// Make copy so we don't destroy the query parameters we actually need to send in the request
142		req = request.Copy()
143		req.Request.URL.RawQuery = rawQuery
144	}
145
146	return prepareRequestForServiceLogging(req)
147}
148
149func stack() []byte {
150	buf := make([]byte, 1024)
151	for {
152		n := runtime.Stack(buf, false)
153		if n < len(buf) {
154			return buf[:n]
155		}
156		buf = make([]byte, 2*len(buf))
157	}
158}
159
160///////////////////////////////////////////////////////////////////////////////////////
161// Redact phase useful for blob and file service only. For other services,
162// this method can directly return request.Request.
163///////////////////////////////////////////////////////////////////////////////////////
164func prepareRequestForServiceLogging(request pipeline.Request) *http.Request {
165	req := request
166	if exist, key := doesHeaderExistCaseInsensitive(req.Header, xMsCopySourceHeader); exist {
167		req = request.Copy()
168		url, err := url.Parse(req.Header.Get(key))
169		if err == nil {
170			if sigFound, rawQuery := RedactSigQueryParam(url.RawQuery); sigFound {
171				url.RawQuery = rawQuery
172				req.Header.Set(xMsCopySourceHeader, url.String())
173			}
174		}
175	}
176	return req.Request
177}
178
179const xMsCopySourceHeader = "x-ms-copy-source"
180
181func doesHeaderExistCaseInsensitive(header http.Header, key string) (bool, string) {
182	for keyInHeader := range header {
183		if strings.EqualFold(keyInHeader, key) {
184			return true, keyInHeader
185		}
186	}
187	return false, ""
188}
189