1package api
2
3import (
4	"context"
5	"fmt"
6	"io"
7	"net/http"
8
9	"github.com/hashicorp/vault/sdk/helper/consts"
10)
11
12// RaftJoinResponse represents the response of the raft join API
13type RaftJoinResponse struct {
14	Joined bool `json:"joined"`
15}
16
17// RaftJoinRequest represents the parameters consumed by the raft join API
18type RaftJoinRequest struct {
19	LeaderAPIAddr    string `json:"leader_api_addr"`
20	LeaderCACert     string `json:"leader_ca_cert"`
21	LeaderClientCert string `json:"leader_client_cert"`
22	LeaderClientKey  string `json:"leader_client_key"`
23	Retry            bool   `json:"retry"`
24	NonVoter         bool   `json:"non_voter"`
25}
26
27// RaftJoin adds the node from which this call is invoked from to the raft
28// cluster represented by the leader address in the parameter.
29func (c *Sys) RaftJoin(opts *RaftJoinRequest) (*RaftJoinResponse, error) {
30	r := c.c.NewRequest("POST", "/v1/sys/storage/raft/join")
31
32	if err := r.SetJSONBody(opts); err != nil {
33		return nil, err
34	}
35
36	ctx, cancelFunc := context.WithCancel(context.Background())
37	defer cancelFunc()
38	resp, err := c.c.RawRequestWithContext(ctx, r)
39	if err != nil {
40		return nil, err
41	}
42	defer resp.Body.Close()
43
44	var result RaftJoinResponse
45	err = resp.DecodeJSON(&result)
46	return &result, err
47}
48
49// RaftSnapshot invokes the API that takes the snapshot of the raft cluster and
50// writes it to the supplied io.Writer.
51func (c *Sys) RaftSnapshot(snapWriter io.Writer) error {
52	r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot")
53	r.URL.RawQuery = r.Params.Encode()
54
55	req, err := http.NewRequest(http.MethodGet, r.URL.RequestURI(), nil)
56	if err != nil {
57		return err
58	}
59
60	req.URL.User = r.URL.User
61	req.URL.Scheme = r.URL.Scheme
62	req.URL.Host = r.URL.Host
63	req.Host = r.URL.Host
64
65	if r.Headers != nil {
66		for header, vals := range r.Headers {
67			for _, val := range vals {
68				req.Header.Add(header, val)
69			}
70		}
71	}
72
73	if len(r.ClientToken) != 0 {
74		req.Header.Set(consts.AuthHeaderName, r.ClientToken)
75	}
76
77	if len(r.WrapTTL) != 0 {
78		req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL)
79	}
80
81	if len(r.MFAHeaderVals) != 0 {
82		for _, mfaHeaderVal := range r.MFAHeaderVals {
83			req.Header.Add("X-Vault-MFA", mfaHeaderVal)
84		}
85	}
86
87	if r.PolicyOverride {
88		req.Header.Set("X-Vault-Policy-Override", "true")
89	}
90
91	// Avoiding the use of RawRequestWithContext which reads the response body
92	// to determine if the body contains error message.
93	var result *Response
94	resp, err := c.c.config.HttpClient.Do(req)
95	if resp == nil {
96		return nil
97	}
98
99	// Check for a redirect, only allowing for a single redirect
100	if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 {
101		// Parse the updated location
102		respLoc, err := resp.Location()
103		if err != nil {
104			return err
105		}
106
107		// Ensure a protocol downgrade doesn't happen
108		if req.URL.Scheme == "https" && respLoc.Scheme != "https" {
109			return fmt.Errorf("redirect would cause protocol downgrade")
110		}
111
112		// Update the request
113		req.URL = respLoc
114
115		// Retry the request
116		resp, err = c.c.config.HttpClient.Do(req)
117		if err != nil {
118			return err
119		}
120	}
121
122	result = &Response{Response: resp}
123	if err := result.Error(); err != nil {
124		return err
125	}
126
127	_, err = io.Copy(snapWriter, resp.Body)
128	if err != nil {
129		return err
130	}
131
132	return nil
133}
134
135// RaftSnapshotRestore reads the snapshot from the io.Reader and installs that
136// snapshot, returning the cluster to the state defined by it.
137func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error {
138	path := "/v1/sys/storage/raft/snapshot"
139	if force {
140		path = "/v1/sys/storage/raft/snapshot-force"
141	}
142	r := c.c.NewRequest("POST", path)
143
144	r.Body = snapReader
145
146	ctx, cancelFunc := context.WithCancel(context.Background())
147	defer cancelFunc()
148	resp, err := c.c.RawRequestWithContext(ctx, r)
149	if err != nil {
150		return err
151	}
152	defer resp.Body.Close()
153
154	return nil
155}
156