1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package unified 8 9import ( 10 "context" 11 "fmt" 12 13 "go.mongodb.org/mongo-driver/mongo" 14) 15 16// ctxKey is used to define keys for values stored in context.Context objects. 17type ctxKey string 18 19const ( 20 // entitiesKey is used to store an EntityMap instance in a Context. 21 entitiesKey ctxKey = "test-entities" 22 // failPointsKey is used to store a map from a fail point name to the Client instance used to configure it. 23 failPointsKey ctxKey = "test-failpoints" 24 // targetedFailPointsKey is used to store a map from a fail point name to the host on which the fail point is set. 25 targetedFailPointsKey ctxKey = "test-targeted-failpoints" 26) 27 28// NewTestContext creates a new Context derived from ctx with values initialized to store the state required for test 29// execution. 30func NewTestContext(ctx context.Context) context.Context { 31 ctx = context.WithValue(ctx, entitiesKey, NewEntityMap()) 32 ctx = context.WithValue(ctx, failPointsKey, make(map[string]*mongo.Client)) 33 ctx = context.WithValue(ctx, targetedFailPointsKey, make(map[string]string)) 34 return ctx 35} 36 37func AddFailPoint(ctx context.Context, failPoint string, client *mongo.Client) error { 38 failPoints := ctx.Value(failPointsKey).(map[string]*mongo.Client) 39 if _, ok := failPoints[failPoint]; ok { 40 return fmt.Errorf("fail point %q already exists in tracked fail points map", failPoint) 41 } 42 43 failPoints[failPoint] = client 44 return nil 45} 46 47func AddTargetedFailPoint(ctx context.Context, failPoint string, host string) error { 48 failPoints := ctx.Value(targetedFailPointsKey).(map[string]string) 49 if _, ok := failPoints[failPoint]; ok { 50 return fmt.Errorf("fail point %q already exists in tracked targeted fail points map", failPoint) 51 } 52 53 failPoints[failPoint] = host 54 return nil 55} 56 57func FailPoints(ctx context.Context) map[string]*mongo.Client { 58 return ctx.Value(failPointsKey).(map[string]*mongo.Client) 59} 60 61func TargetedFailPoints(ctx context.Context) map[string]string { 62 return ctx.Value(targetedFailPointsKey).(map[string]string) 63} 64 65func Entities(ctx context.Context) *EntityMap { 66 return ctx.Value(entitiesKey).(*EntityMap) 67} 68