1package api
2
3import (
4	"context"
5	"io"
6	"net/http"
7
8	"github.com/hashicorp/vault/sdk/helper/consts"
9)
10
11// RaftJoinResponse represents the response of the raft join API
12type RaftJoinResponse struct {
13	Joined bool `json:"joined"`
14}
15
16// RaftJoinRequest represents the parameters consumed by the raft join API
17type RaftJoinRequest struct {
18	LeaderAPIAddr    string `json:"leader_api_addr"`
19	LeaderCACert     string `json:"leader_ca_cert":`
20	LeaderClientCert string `json:"leader_client_cert"`
21	LeaderClientKey  string `json:"leader_client_key"`
22	Retry            bool   `json:"retry"`
23}
24
25// RaftJoin adds the node from which this call is invoked from to the raft
26// cluster represented by the leader address in the parameter.
27func (c *Sys) RaftJoin(opts *RaftJoinRequest) (*RaftJoinResponse, error) {
28	r := c.c.NewRequest("POST", "/v1/sys/storage/raft/join")
29
30	if err := r.SetJSONBody(opts); err != nil {
31		return nil, err
32	}
33
34	ctx, cancelFunc := context.WithCancel(context.Background())
35	defer cancelFunc()
36	resp, err := c.c.RawRequestWithContext(ctx, r)
37	if err != nil {
38		return nil, err
39	}
40	defer resp.Body.Close()
41
42	var result RaftJoinResponse
43	err = resp.DecodeJSON(&result)
44	return &result, err
45}
46
47// RaftSnapshot invokes the API that takes the snapshot of the raft cluster and
48// writes it to the supplied io.Writer.
49func (c *Sys) RaftSnapshot(snapWriter io.Writer) error {
50	r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot")
51	r.URL.RawQuery = r.Params.Encode()
52
53	req, err := http.NewRequest(http.MethodGet, r.URL.RequestURI(), nil)
54	if err != nil {
55		return err
56	}
57
58	req.URL.User = r.URL.User
59	req.URL.Scheme = r.URL.Scheme
60	req.URL.Host = r.URL.Host
61	req.Host = r.URL.Host
62
63	if r.Headers != nil {
64		for header, vals := range r.Headers {
65			for _, val := range vals {
66				req.Header.Add(header, val)
67			}
68		}
69	}
70
71	if len(r.ClientToken) != 0 {
72		req.Header.Set(consts.AuthHeaderName, r.ClientToken)
73	}
74
75	if len(r.WrapTTL) != 0 {
76		req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL)
77	}
78
79	if len(r.MFAHeaderVals) != 0 {
80		for _, mfaHeaderVal := range r.MFAHeaderVals {
81			req.Header.Add("X-Vault-MFA", mfaHeaderVal)
82		}
83	}
84
85	if r.PolicyOverride {
86		req.Header.Set("X-Vault-Policy-Override", "true")
87	}
88
89	// Avoiding the use of RawRequestWithContext which reads the response body
90	// to determine if the body contains error message.
91	var result *Response
92	resp, err := c.c.config.HttpClient.Do(req)
93	if resp == nil {
94		return nil
95	}
96
97	result = &Response{Response: resp}
98	if err := result.Error(); err != nil {
99		return err
100	}
101
102	_, err = io.Copy(snapWriter, resp.Body)
103	if err != nil {
104		return err
105	}
106
107	return nil
108}
109
110// RaftSnapshotRestore reads the snapshot from the io.Reader and installs that
111// snapshot, returning the cluster to the state defined by it.
112func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error {
113	path := "/v1/sys/storage/raft/snapshot"
114	if force {
115		path = "/v1/sys/storage/raft/snapshot-force"
116	}
117	r := c.c.NewRequest("POST", path)
118
119	r.Body = snapReader
120
121	ctx, cancelFunc := context.WithCancel(context.Background())
122	defer cancelFunc()
123	resp, err := c.c.RawRequestWithContext(ctx, r)
124	if err != nil {
125		return err
126	}
127	defer resp.Body.Close()
128
129	return nil
130}
131