1package kubernetes
2
3import (
4	"fmt"
5	"os"
6	"strconv"
7	"sync"
8
9	"github.com/hashicorp/go-hclog"
10	sr "github.com/hashicorp/vault/serviceregistration"
11	"github.com/hashicorp/vault/serviceregistration/kubernetes/client"
12)
13
14const (
15	// Labels are placed in a pod's metadata.
16	labelVaultVersion = "vault-version"
17	labelActive       = "vault-active"
18	labelSealed       = "vault-sealed"
19	labelPerfStandby  = "vault-perf-standby"
20	labelInitialized  = "vault-initialized"
21
22	// This is the path to where these labels are applied.
23	pathToLabels = "/metadata/labels/"
24)
25
26func NewServiceRegistration(config map[string]string, logger hclog.Logger, state sr.State, _ string) (sr.ServiceRegistration, error) {
27	namespace, err := getRequiredField(logger, config, client.EnvVarKubernetesNamespace, "namespace")
28	if err != nil {
29		return nil, err
30	}
31	podName, err := getRequiredField(logger, config, client.EnvVarKubernetesPodName, "pod_name")
32	if err != nil {
33		return nil, err
34	}
35	return &serviceRegistration{
36		logger:       logger,
37		namespace:    namespace,
38		podName:      podName,
39		initialState: state,
40		retryHandler: &retryHandler{
41			logger:         logger,
42			namespace:      namespace,
43			podName:        podName,
44			patchesToRetry: make(map[string]*client.Patch),
45		},
46	}, nil
47}
48
49type serviceRegistration struct {
50	logger             hclog.Logger
51	namespace, podName string
52	client             *client.Client
53	initialState       sr.State
54	retryHandler       *retryHandler
55}
56
57func (r *serviceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) error {
58	c, err := client.New(r.logger, shutdownCh)
59	if err != nil {
60		return err
61	}
62	r.client = c
63
64	// Now that we've populated the client, we can begin using the retry handler.
65	r.retryHandler.SetInitialState(r.setInitialState)
66	r.retryHandler.Run(shutdownCh, wait, c)
67	return nil
68}
69
70func (r *serviceRegistration) setInitialState() error {
71	// Verify that the pod exists and our configuration looks good.
72	pod, err := r.client.GetPod(r.namespace, r.podName)
73	if err != nil {
74		return err
75	}
76
77	// Now to initially label our pod.
78	if pod.Metadata == nil {
79		// This should never happen IRL, just being defensive.
80		return fmt.Errorf("no pod metadata on %+v", pod)
81	}
82	if pod.Metadata.Labels == nil {
83		// Notify the labels field, and the labels as part of that one call.
84		// The reason we must take a different approach to adding them is discussed here:
85		// https://stackoverflow.com/questions/57480205/error-while-applying-json-patch-to-kubernetes-custom-resource
86		if err := r.client.PatchPod(r.namespace, r.podName, &client.Patch{
87			Operation: client.Add,
88			Path:      "/metadata/labels",
89			Value: map[string]string{
90				labelVaultVersion: r.initialState.VaultVersion,
91				labelActive:       strconv.FormatBool(r.initialState.IsActive),
92				labelSealed:       strconv.FormatBool(r.initialState.IsSealed),
93				labelPerfStandby:  strconv.FormatBool(r.initialState.IsPerformanceStandby),
94				labelInitialized:  strconv.FormatBool(r.initialState.IsInitialized),
95			},
96		}); err != nil {
97			return err
98		}
99	} else {
100		// Create the labels through a patch to each individual field.
101		patches := []*client.Patch{
102			{
103				Operation: client.Replace,
104				Path:      pathToLabels + labelVaultVersion,
105				Value:     r.initialState.VaultVersion,
106			},
107			{
108				Operation: client.Replace,
109				Path:      pathToLabels + labelActive,
110				Value:     strconv.FormatBool(r.initialState.IsActive),
111			},
112			{
113				Operation: client.Replace,
114				Path:      pathToLabels + labelSealed,
115				Value:     strconv.FormatBool(r.initialState.IsSealed),
116			},
117			{
118				Operation: client.Replace,
119				Path:      pathToLabels + labelPerfStandby,
120				Value:     strconv.FormatBool(r.initialState.IsPerformanceStandby),
121			},
122			{
123				Operation: client.Replace,
124				Path:      pathToLabels + labelInitialized,
125				Value:     strconv.FormatBool(r.initialState.IsInitialized),
126			},
127		}
128		if err := r.client.PatchPod(r.namespace, r.podName, patches...); err != nil {
129			return err
130		}
131	}
132	return nil
133}
134
135func (r *serviceRegistration) NotifyActiveStateChange(isActive bool) error {
136	r.retryHandler.Notify(r.client, &client.Patch{
137		Operation: client.Replace,
138		Path:      pathToLabels + labelActive,
139		Value:     strconv.FormatBool(isActive),
140	})
141	return nil
142}
143
144func (r *serviceRegistration) NotifySealedStateChange(isSealed bool) error {
145	r.retryHandler.Notify(r.client, &client.Patch{
146		Operation: client.Replace,
147		Path:      pathToLabels + labelSealed,
148		Value:     strconv.FormatBool(isSealed),
149	})
150	return nil
151}
152
153func (r *serviceRegistration) NotifyPerformanceStandbyStateChange(isStandby bool) error {
154	r.retryHandler.Notify(r.client, &client.Patch{
155		Operation: client.Replace,
156		Path:      pathToLabels + labelPerfStandby,
157		Value:     strconv.FormatBool(isStandby),
158	})
159	return nil
160}
161
162func (r *serviceRegistration) NotifyInitializedStateChange(isInitialized bool) error {
163	r.retryHandler.Notify(r.client, &client.Patch{
164		Operation: client.Replace,
165		Path:      pathToLabels + labelInitialized,
166		Value:     strconv.FormatBool(isInitialized),
167	})
168	return nil
169}
170
171func getRequiredField(logger hclog.Logger, config map[string]string, envVar, configParam string) (string, error) {
172	value := ""
173	switch {
174	case os.Getenv(envVar) != "":
175		value = os.Getenv(envVar)
176	case config[configParam] != "":
177		value = config[configParam]
178	default:
179		return "", fmt.Errorf(`%s must be provided via %q or the %q config parameter`, configParam, envVar, configParam)
180	}
181	if logger.IsDebug() {
182		logger.Debug(fmt.Sprintf("%q: %q", configParam, value))
183	}
184	return value, nil
185}
186