1//go:build go1.16
2// +build go1.16
3
4// Copyright (c) Microsoft Corporation. All rights reserved.
5// Licensed under the MIT License.
6
7package recording
8
9import (
10	"errors"
11	"fmt"
12	"io/ioutil"
13	"math/rand"
14	"net/http"
15	"os"
16	"path/filepath"
17	"strconv"
18	"strings"
19	"time"
20
21	"github.com/Azure/azure-sdk-for-go/sdk/internal/uuid"
22	"github.com/dnaeon/go-vcr/cassette"
23	"github.com/dnaeon/go-vcr/recorder"
24	"gopkg.in/yaml.v2"
25)
26
27type Recording struct {
28	SessionName              string
29	RecordingFile            string
30	VariablesFile            string
31	Mode                     RecordMode
32	variables                map[string]*string `yaml:"variables"`
33	previousSessionVariables map[string]*string `yaml:"variables"`
34	recorder                 *recorder.Recorder
35	src                      rand.Source
36	now                      *time.Time
37	Sanitizer                *Sanitizer
38	Matcher                  *RequestMatcher
39	c                        TestContext
40}
41
42const (
43	alphanumericBytes           = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
44	alphanumericLowercaseBytes  = "abcdefghijklmnopqrstuvwxyz1234567890"
45	randomSeedVariableName      = "randomSeed"
46	nowVariableName             = "now"
47	ModeEnvironmentVariableName = "AZURE_TEST_MODE"
48)
49
50// Inspired by https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-go
51const (
52	letterIdxBits = 6                    // 6 bits to represent a letter index
53	letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
54	letterIdxMax  = 63 / letterIdxBits   // # of letter indices fitting in 63 bits
55)
56
57type RecordMode string
58
59const (
60	Record   RecordMode = "record"
61	Playback RecordMode = "playback"
62	Live     RecordMode = "live"
63)
64
65type VariableType string
66
67const (
68	// NoSanitization indicates that the recorded value should not be sanitized.
69	NoSanitization VariableType = "default"
70	// Secret_String indicates that the recorded value should be replaced with a sanitized value.
71	Secret_String VariableType = "secret_string"
72	// Secret_Base64String indicates that the recorded value should be replaced with a sanitized valid base-64 string value.
73	Secret_Base64String VariableType = "secret_base64String"
74)
75
76// NewRecording initializes a new Recording instance
77func NewRecording(c TestContext, mode RecordMode) (*Recording, error) {
78	// create recorder based on the test name, recordMode, variables, and sanitizers
79	recPath, varPath := getFilePaths(c.Name())
80	rec, err := recorder.NewAsMode(recPath, modeMap[mode], nil)
81	if err != nil {
82		return nil, err
83	}
84
85	// If the mode is set in the environment, let that override the requested mode
86	// This is to enable support for nightly live test pipelines
87	envMode := getOptionalEnv(ModeEnvironmentVariableName, string(mode))
88	mode = RecordMode(*envMode)
89
90	// initialize the Recording
91	recording := &Recording{
92		SessionName:              recPath,
93		RecordingFile:            recPath + ".yaml",
94		VariablesFile:            varPath,
95		Mode:                     mode,
96		variables:                make(map[string]*string),
97		previousSessionVariables: make(map[string]*string),
98		recorder:                 rec,
99		c:                        c,
100	}
101
102	// Try loading the recording if it already exists to hydrate the variables
103	err = recording.initVariables()
104	if err != nil {
105		return nil, err
106	}
107
108	// set the recorder Matcher
109	recording.Matcher = defaultMatcher(c)
110	rec.SetMatcher(recording.matchRequest)
111
112	// wire up the sanitizer
113	recording.Sanitizer = defaultSanitizer(rec)
114
115	return recording, err
116}
117
118// GetEnvVar returns a recorded environment variable. If the variable is not found we return an error.
119// variableType determines how the recorded variable will be saved.
120func (r *Recording) GetEnvVar(name string, variableType VariableType) (string, error) {
121	var err error
122	result, ok := r.previousSessionVariables[name]
123	if !ok || r.Mode == Live {
124
125		result, err = getRequiredEnv(name)
126		if err != nil {
127			r.c.Fail(err.Error())
128			return "", err
129		}
130		r.variables[name] = applyVariableOptions(result, variableType)
131	}
132	return *result, err
133}
134
135// GetOptionalEnvVar returns a recorded environment variable with a fallback default value.
136// default Value configures the fallback value to be returned if the environment variable is not set.
137// variableType determines how the recorded variable will be saved.
138func (r *Recording) GetOptionalEnvVar(name string, defaultValue string, variableType VariableType) string {
139	result, ok := r.previousSessionVariables[name]
140	if !ok || r.Mode == Live {
141		result = getOptionalEnv(name, defaultValue)
142		r.variables[name] = applyVariableOptions(result, variableType)
143	}
144	return *result
145}
146
147// Do satisfies the azcore.Transport interface so that Recording can be used as the transport for recorded requests
148func (r *Recording) Do(req *http.Request) (*http.Response, error) {
149	resp, err := r.recorder.RoundTrip(req)
150	if err == cassette.ErrInteractionNotFound {
151		error := missingRequestError(req)
152		r.c.Fail(error)
153		return nil, errors.New(error)
154	}
155	return resp, err
156}
157
158// Stop stops the recording and saves them, including any captured variables, to disk
159func (r *Recording) Stop() error {
160
161	err := r.recorder.Stop()
162	if err != nil {
163		return err
164	}
165	if r.Mode == Live {
166		return nil
167	}
168
169	if len(r.variables) > 0 {
170		// Merge values from previousVariables that are not in variables to variables
171		for k, v := range r.previousSessionVariables {
172			if _, ok := r.variables[k]; ok {
173				// skip variables that were new in the current session
174				continue
175			}
176			r.variables[k] = v
177		}
178
179		// Marshal to YAML and save variables
180		data, err := yaml.Marshal(r.variables)
181		if err != nil {
182			return err
183		}
184
185		f, err := r.createVariablesFileIfNotExists()
186		if err != nil {
187			return err
188		}
189
190		defer f.Close()
191
192		// http://www.yaml.org/spec/1.2/spec.html#id2760395
193		_, err = f.Write([]byte("---\n"))
194		if err != nil {
195			return err
196		}
197
198		_, err = f.Write(data)
199		if err != nil {
200			return err
201		}
202	}
203	return nil
204}
205
206func (r *Recording) Now() time.Time {
207	r.initNow()
208
209	return *r.now
210}
211
212func (r *Recording) UUID() uuid.UUID {
213	r.initRandomSource()
214	u := uuid.UUID{}
215	// Set all bits to randomly (or pseudo-randomly) chosen values.
216	// math/rand.Read() is no-fail so we omit any error checking.
217	rnd := rand.New(r.src)
218	rnd.Read(u[:])
219	u[8] = (u[8] | 0x40) & 0x7F // u.setVariant(ReservedRFC4122)
220
221	var version byte = 4
222	u[6] = (u[6] & 0xF) | (version << 4) // u.setVersion(4)
223	return u
224}
225
226// GenerateAlphaNumericID will generate a recorded random alpha numeric id
227// if the recording has a randomSeed already set, the value will be generated from that seed, else a new random seed will be used
228func (r *Recording) GenerateAlphaNumericID(prefix string, length int, lowercaseOnly bool) (string, error) {
229
230	if length <= len(prefix) {
231		return "", errors.New("length must be greater than prefix")
232	}
233
234	r.initRandomSource()
235
236	sb := strings.Builder{}
237	sb.Grow(length)
238	sb.WriteString(prefix)
239	i := length - len(prefix) - 1
240	// A src.Int63() generates 63 random bits, enough for letterIdxMax characters!
241	for cache, remain := r.src.Int63(), letterIdxMax; i >= 0; {
242		if remain == 0 {
243			cache, remain = r.src.Int63(), letterIdxMax
244		}
245		if lowercaseOnly {
246			if idx := int(cache & letterIdxMask); idx < len(alphanumericLowercaseBytes) {
247				sb.WriteByte(alphanumericLowercaseBytes[idx])
248				i--
249			}
250		} else {
251			if idx := int(cache & letterIdxMask); idx < len(alphanumericBytes) {
252				sb.WriteByte(alphanumericBytes[idx])
253				i--
254			}
255		}
256		cache >>= letterIdxBits
257		remain--
258	}
259	str := sb.String()
260	return str, nil
261}
262
263// getRequiredEnv gets an environment variable by name and returns an error if it is not found
264func getRequiredEnv(name string) (*string, error) {
265	env, ok := os.LookupEnv(name)
266	if ok {
267		return &env, nil
268	} else {
269		return nil, errors.New(envNotExistsError(name))
270	}
271}
272
273// getOptionalEnv gets an environment variable by name and returns the defaultValue if not found
274func getOptionalEnv(name string, defaultValue string) *string {
275	env, ok := os.LookupEnv(name)
276	if ok {
277		return &env
278	} else {
279		return &defaultValue
280	}
281}
282
283func (r *Recording) matchRequest(req *http.Request, rec cassette.Request) bool {
284	isMatch := r.Matcher.compareMethods(req, rec.Method) &&
285		r.Matcher.compareURLs(req, rec.URL) &&
286		r.Matcher.compareHeaders(req, rec) &&
287		r.Matcher.compareBodies(req, rec.Body)
288
289	return isMatch
290}
291
292func missingRequestError(req *http.Request) string {
293	reqUrl := req.URL.String()
294	return fmt.Sprintf("\nNo matching recorded request found.\nRequest: [%s] %s\n", req.Method, reqUrl)
295}
296
297func envNotExistsError(varName string) string {
298	return "Required environment variable not set: " + varName
299}
300
301// applyVariableOptions applies the VariableType transform to the value
302// If variableType is not provided or Default, return result
303// If variableType is Secret_String, return SanitizedValue
304// If variableType isSecret_Base64String return SanitizedBase64Value
305func applyVariableOptions(val *string, variableType VariableType) *string {
306	var ret string
307
308	switch variableType {
309	case Secret_String:
310		ret = SanitizedValue
311		return &ret
312	case Secret_Base64String:
313		ret = SanitizedBase64Value
314		return &ret
315	default:
316		return val
317	}
318}
319
320// initRandomSource initializes the Source to be used for random value creation in this Recording
321func (r *Recording) initRandomSource() {
322	// if we already have a Source generated, return immediately
323	if r.src != nil {
324		return
325	}
326
327	var seed int64
328	var err error
329
330	// check to see if we already have a random seed stored, use that if so
331	seedString, ok := r.previousSessionVariables[randomSeedVariableName]
332	if ok {
333		seed, err = strconv.ParseInt(*seedString, 10, 64)
334	}
335
336	// We did not have a random seed already stored; create a new one
337	if !ok || err != nil || r.Mode == Live {
338		seed = time.Now().Unix()
339		val := strconv.FormatInt(seed, 10)
340		r.variables[randomSeedVariableName] = &val
341	}
342
343	// create a Source with the seed
344	r.src = rand.NewSource(seed)
345}
346
347// initNow initializes the Source to be used for random value creation in this Recording
348func (r *Recording) initNow() {
349	// if we already have a now generated, return immediately
350	if r.now != nil {
351		return
352	}
353
354	var err error
355	var nowStr *string
356	var newNow time.Time
357
358	// check to see if we already have a random seed stored, use that if so
359	nowStr, ok := r.previousSessionVariables[nowVariableName]
360	if ok {
361		newNow, err = time.Parse(time.RFC3339Nano, *nowStr)
362	}
363
364	// We did not have a random seed already stored; create a new one
365	if !ok || err != nil || r.Mode == Live {
366		newNow = time.Now()
367		nowStr = new(string)
368		*nowStr = newNow.Format(time.RFC3339Nano)
369		r.variables[nowVariableName] = nowStr
370	}
371
372	// save the now value.
373	r.now = &newNow
374}
375
376// getFilePaths returns (recordingFilePath, variablesFilePath)
377func getFilePaths(name string) (string, string) {
378	recPath := "recordings/" + name
379	varPath := fmt.Sprintf("%s-variables.yaml", recPath)
380	return recPath, varPath
381}
382
383// createVariablesFileIfNotExists calls os.Create on the VariablesFile and creates it if it or the path does not exist
384// Callers must call Close on the result
385func (r *Recording) createVariablesFileIfNotExists() (*os.File, error) {
386	f, err := os.Create(r.VariablesFile)
387	if err != nil {
388		if !os.IsNotExist(err) {
389			return nil, err
390		}
391		// Create directory for the variables if missing
392		variablesDir := filepath.Dir(r.VariablesFile)
393		if _, err := os.Stat(variablesDir); os.IsNotExist(err) {
394			if err = os.MkdirAll(variablesDir, 0755); err != nil {
395				return nil, err
396			}
397		}
398
399		f, err = os.Create(r.VariablesFile)
400		if err != nil {
401			return nil, err
402		}
403	}
404
405	return f, nil
406}
407
408func (r *Recording) unmarshalVariablesFile(out interface{}) error {
409	data, err := ioutil.ReadFile(r.VariablesFile)
410	if err != nil {
411		// If the file or dir do not exist, this is not an error to report
412		if os.IsNotExist(err) {
413			r.c.Log(fmt.Sprintf("Did not find recording for test '%s'", r.RecordingFile))
414			return nil
415		} else {
416			return err
417		}
418	} else {
419		err = yaml.Unmarshal(data, out)
420		if err != nil {
421			return err
422		}
423	}
424	return nil
425}
426
427func (r *Recording) initVariables() error {
428	return r.unmarshalVariablesFile(r.previousSessionVariables)
429}
430
431var modeMap = map[RecordMode]recorder.Mode{
432	Record:   recorder.ModeRecording,
433	Live:     recorder.ModeDisabled,
434	Playback: recorder.ModeReplaying,
435}
436