1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package unified
8
9import (
10	"context"
11	"fmt"
12
13	"go.mongodb.org/mongo-driver/mongo"
14)
15
16// ctxKey is used to define keys for values stored in context.Context objects.
17type ctxKey string
18
19const (
20	// entitiesKey is used to store an EntityMap instance in a Context.
21	entitiesKey ctxKey = "test-entities"
22	// failPointsKey is used to store a map from a fail point name to the Client instance used to configure it.
23	failPointsKey ctxKey = "test-failpoints"
24	// targetedFailPointsKey is used to store a map from a fail point name to the host on which the fail point is set.
25	targetedFailPointsKey ctxKey = "test-targeted-failpoints"
26)
27
28// NewTestContext creates a new Context derived from ctx with values initialized to store the state required for test
29// execution.
30func NewTestContext(ctx context.Context) context.Context {
31	ctx = context.WithValue(ctx, entitiesKey, NewEntityMap())
32	ctx = context.WithValue(ctx, failPointsKey, make(map[string]*mongo.Client))
33	ctx = context.WithValue(ctx, targetedFailPointsKey, make(map[string]string))
34	return ctx
35}
36
37func AddFailPoint(ctx context.Context, failPoint string, client *mongo.Client) error {
38	failPoints := ctx.Value(failPointsKey).(map[string]*mongo.Client)
39	if _, ok := failPoints[failPoint]; ok {
40		return fmt.Errorf("fail point %q already exists in tracked fail points map", failPoint)
41	}
42
43	failPoints[failPoint] = client
44	return nil
45}
46
47func AddTargetedFailPoint(ctx context.Context, failPoint string, host string) error {
48	failPoints := ctx.Value(targetedFailPointsKey).(map[string]string)
49	if _, ok := failPoints[failPoint]; ok {
50		return fmt.Errorf("fail point %q already exists in tracked targeted fail points map", failPoint)
51	}
52
53	failPoints[failPoint] = host
54	return nil
55}
56
57func FailPoints(ctx context.Context) map[string]*mongo.Client {
58	return ctx.Value(failPointsKey).(map[string]*mongo.Client)
59}
60
61func TargetedFailPoints(ctx context.Context) map[string]string {
62	return ctx.Value(targetedFailPointsKey).(map[string]string)
63}
64
65func Entities(ctx context.Context) *EntityMap {
66	return ctx.Value(entitiesKey).(*EntityMap)
67}
68