1package forwarding
2
3import (
4	"bytes"
5	"crypto/tls"
6	"crypto/x509"
7	"errors"
8	"io"
9	"io/ioutil"
10	"net/http"
11	"net/url"
12	"os"
13
14	"github.com/golang/protobuf/proto"
15	"github.com/hashicorp/vault/sdk/helper/compressutil"
16	"github.com/hashicorp/vault/sdk/helper/jsonutil"
17)
18
19type bufCloser struct {
20	*bytes.Buffer
21}
22
23func (b bufCloser) Close() error {
24	b.Reset()
25	return nil
26}
27
28// GenerateForwardedRequest generates a new http.Request that contains the
29// original requests's information in the new request's body.
30func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request, error) {
31	fq, err := GenerateForwardedRequest(req)
32	if err != nil {
33		return nil, err
34	}
35
36	var newBody []byte
37	switch os.Getenv("VAULT_MESSAGE_TYPE") {
38	case "json":
39		newBody, err = jsonutil.EncodeJSON(fq)
40	case "json_compress":
41		newBody, err = jsonutil.EncodeJSONAndCompress(fq, &compressutil.CompressionConfig{
42			Type: compressutil.CompressionTypeLZW,
43		})
44	case "proto3":
45		fallthrough
46	default:
47		newBody, err = proto.Marshal(fq)
48	}
49	if err != nil {
50		return nil, err
51	}
52
53	ret, err := http.NewRequest("POST", addr, bytes.NewBuffer(newBody))
54	if err != nil {
55		return nil, err
56	}
57
58	return ret, nil
59}
60
61func GenerateForwardedRequest(req *http.Request) (*Request, error) {
62	var reader io.Reader = req.Body
63	ctx := req.Context()
64	maxRequestSize := ctx.Value("max_request_size")
65	if maxRequestSize != nil {
66		max, ok := maxRequestSize.(int64)
67		if !ok {
68			return nil, errors.New("could not parse max_request_size from request context")
69		}
70		if max > 0 {
71			reader = io.LimitReader(req.Body, max)
72		}
73	}
74
75	body, err := ioutil.ReadAll(reader)
76	if err != nil {
77		return nil, err
78	}
79
80	fq := Request{
81		Method:        req.Method,
82		HeaderEntries: make(map[string]*HeaderEntry, len(req.Header)),
83		Host:          req.Host,
84		RemoteAddr:    req.RemoteAddr,
85		Body:          body,
86	}
87
88	reqURL := req.URL
89	fq.Url = &URL{
90		Scheme:   reqURL.Scheme,
91		Opaque:   reqURL.Opaque,
92		Host:     reqURL.Host,
93		Path:     reqURL.Path,
94		RawPath:  reqURL.RawPath,
95		RawQuery: reqURL.RawQuery,
96		Fragment: reqURL.Fragment,
97	}
98
99	for k, v := range req.Header {
100		fq.HeaderEntries[k] = &HeaderEntry{
101			Values: v,
102		}
103	}
104
105	if req.TLS != nil && req.TLS.PeerCertificates != nil && len(req.TLS.PeerCertificates) > 0 {
106		fq.PeerCertificates = make([][]byte, len(req.TLS.PeerCertificates))
107		for i, cert := range req.TLS.PeerCertificates {
108			fq.PeerCertificates[i] = cert.Raw
109		}
110	}
111
112	return &fq, nil
113}
114
115// ParseForwardedRequest generates a new http.Request that is comprised of the
116// values in the given request's body, assuming it correctly parses into a
117// ForwardedRequest.
118func ParseForwardedHTTPRequest(req *http.Request) (*http.Request, error) {
119	buf := bytes.NewBuffer(nil)
120	_, err := buf.ReadFrom(req.Body)
121	if err != nil {
122		return nil, err
123	}
124
125	fq := new(Request)
126	switch os.Getenv("VAULT_MESSAGE_TYPE") {
127	case "json", "json_compress":
128		err = jsonutil.DecodeJSON(buf.Bytes(), fq)
129	default:
130		err = proto.Unmarshal(buf.Bytes(), fq)
131	}
132	if err != nil {
133		return nil, err
134	}
135
136	return ParseForwardedRequest(fq)
137}
138
139func ParseForwardedRequest(fq *Request) (*http.Request, error) {
140	buf := bufCloser{
141		Buffer: bytes.NewBuffer(fq.Body),
142	}
143
144	ret := &http.Request{
145		Method:     fq.Method,
146		Header:     make(map[string][]string, len(fq.HeaderEntries)),
147		Body:       buf,
148		Host:       fq.Host,
149		RemoteAddr: fq.RemoteAddr,
150	}
151
152	ret.URL = &url.URL{
153		Scheme:   fq.Url.Scheme,
154		Opaque:   fq.Url.Opaque,
155		Host:     fq.Url.Host,
156		Path:     fq.Url.Path,
157		RawPath:  fq.Url.RawPath,
158		RawQuery: fq.Url.RawQuery,
159		Fragment: fq.Url.Fragment,
160	}
161
162	for k, v := range fq.HeaderEntries {
163		ret.Header[k] = v.Values
164	}
165
166	if fq.PeerCertificates != nil && len(fq.PeerCertificates) > 0 {
167		ret.TLS = &tls.ConnectionState{
168			PeerCertificates: make([]*x509.Certificate, len(fq.PeerCertificates)),
169		}
170		for i, certBytes := range fq.PeerCertificates {
171			cert, err := x509.ParseCertificate(certBytes)
172			if err != nil {
173				return nil, err
174			}
175			ret.TLS.PeerCertificates[i] = cert
176		}
177	}
178
179	return ret, nil
180}
181
182type RPCResponseWriter struct {
183	statusCode int
184	header     http.Header
185	body       *bytes.Buffer
186}
187
188// NewRPCResponseWriter returns an initialized RPCResponseWriter
189func NewRPCResponseWriter() *RPCResponseWriter {
190	w := &RPCResponseWriter{
191		header:     make(http.Header),
192		body:       new(bytes.Buffer),
193		statusCode: 200,
194	}
195	//w.header.Set("Content-Type", "application/octet-stream")
196	return w
197}
198
199func (w *RPCResponseWriter) Header() http.Header {
200	return w.header
201}
202
203func (w *RPCResponseWriter) Write(buf []byte) (int, error) {
204	w.body.Write(buf)
205	return len(buf), nil
206}
207
208func (w *RPCResponseWriter) WriteHeader(code int) {
209	w.statusCode = code
210}
211
212func (w *RPCResponseWriter) StatusCode() int {
213	return w.statusCode
214}
215
216func (w *RPCResponseWriter) Body() *bytes.Buffer {
217	return w.body
218}
219