1package forwarding 2 3import ( 4 "bytes" 5 "crypto/tls" 6 "crypto/x509" 7 "errors" 8 "io" 9 "io/ioutil" 10 "net/http" 11 "net/url" 12 "os" 13 14 "github.com/golang/protobuf/proto" 15 "github.com/hashicorp/vault/sdk/helper/compressutil" 16 "github.com/hashicorp/vault/sdk/helper/jsonutil" 17) 18 19type bufCloser struct { 20 *bytes.Buffer 21} 22 23func (b bufCloser) Close() error { 24 b.Reset() 25 return nil 26} 27 28// GenerateForwardedRequest generates a new http.Request that contains the 29// original requests's information in the new request's body. 30func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request, error) { 31 fq, err := GenerateForwardedRequest(req) 32 if err != nil { 33 return nil, err 34 } 35 36 var newBody []byte 37 switch os.Getenv("VAULT_MESSAGE_TYPE") { 38 case "json": 39 newBody, err = jsonutil.EncodeJSON(fq) 40 case "json_compress": 41 newBody, err = jsonutil.EncodeJSONAndCompress(fq, &compressutil.CompressionConfig{ 42 Type: compressutil.CompressionTypeLZW, 43 }) 44 case "proto3": 45 fallthrough 46 default: 47 newBody, err = proto.Marshal(fq) 48 } 49 if err != nil { 50 return nil, err 51 } 52 53 ret, err := http.NewRequest("POST", addr, bytes.NewBuffer(newBody)) 54 if err != nil { 55 return nil, err 56 } 57 58 return ret, nil 59} 60 61func GenerateForwardedRequest(req *http.Request) (*Request, error) { 62 var reader io.Reader = req.Body 63 ctx := req.Context() 64 maxRequestSize := ctx.Value("max_request_size") 65 if maxRequestSize != nil { 66 max, ok := maxRequestSize.(int64) 67 if !ok { 68 return nil, errors.New("could not parse max_request_size from request context") 69 } 70 if max > 0 { 71 reader = io.LimitReader(req.Body, max) 72 } 73 } 74 75 body, err := ioutil.ReadAll(reader) 76 if err != nil { 77 return nil, err 78 } 79 80 fq := Request{ 81 Method: req.Method, 82 HeaderEntries: make(map[string]*HeaderEntry, len(req.Header)), 83 Host: req.Host, 84 RemoteAddr: req.RemoteAddr, 85 Body: body, 86 } 87 88 reqURL := req.URL 89 fq.Url = &URL{ 90 Scheme: reqURL.Scheme, 91 Opaque: reqURL.Opaque, 92 Host: reqURL.Host, 93 Path: reqURL.Path, 94 RawPath: reqURL.RawPath, 95 RawQuery: reqURL.RawQuery, 96 Fragment: reqURL.Fragment, 97 } 98 99 for k, v := range req.Header { 100 fq.HeaderEntries[k] = &HeaderEntry{ 101 Values: v, 102 } 103 } 104 105 if req.TLS != nil && req.TLS.PeerCertificates != nil && len(req.TLS.PeerCertificates) > 0 { 106 fq.PeerCertificates = make([][]byte, len(req.TLS.PeerCertificates)) 107 for i, cert := range req.TLS.PeerCertificates { 108 fq.PeerCertificates[i] = cert.Raw 109 } 110 } 111 112 return &fq, nil 113} 114 115// ParseForwardedRequest generates a new http.Request that is comprised of the 116// values in the given request's body, assuming it correctly parses into a 117// ForwardedRequest. 118func ParseForwardedHTTPRequest(req *http.Request) (*http.Request, error) { 119 buf := bytes.NewBuffer(nil) 120 _, err := buf.ReadFrom(req.Body) 121 if err != nil { 122 return nil, err 123 } 124 125 fq := new(Request) 126 switch os.Getenv("VAULT_MESSAGE_TYPE") { 127 case "json", "json_compress": 128 err = jsonutil.DecodeJSON(buf.Bytes(), fq) 129 default: 130 err = proto.Unmarshal(buf.Bytes(), fq) 131 } 132 if err != nil { 133 return nil, err 134 } 135 136 return ParseForwardedRequest(fq) 137} 138 139func ParseForwardedRequest(fq *Request) (*http.Request, error) { 140 buf := bufCloser{ 141 Buffer: bytes.NewBuffer(fq.Body), 142 } 143 144 ret := &http.Request{ 145 Method: fq.Method, 146 Header: make(map[string][]string, len(fq.HeaderEntries)), 147 Body: buf, 148 Host: fq.Host, 149 RemoteAddr: fq.RemoteAddr, 150 } 151 152 ret.URL = &url.URL{ 153 Scheme: fq.Url.Scheme, 154 Opaque: fq.Url.Opaque, 155 Host: fq.Url.Host, 156 Path: fq.Url.Path, 157 RawPath: fq.Url.RawPath, 158 RawQuery: fq.Url.RawQuery, 159 Fragment: fq.Url.Fragment, 160 } 161 162 for k, v := range fq.HeaderEntries { 163 ret.Header[k] = v.Values 164 } 165 166 if fq.PeerCertificates != nil && len(fq.PeerCertificates) > 0 { 167 ret.TLS = &tls.ConnectionState{ 168 PeerCertificates: make([]*x509.Certificate, len(fq.PeerCertificates)), 169 } 170 for i, certBytes := range fq.PeerCertificates { 171 cert, err := x509.ParseCertificate(certBytes) 172 if err != nil { 173 return nil, err 174 } 175 ret.TLS.PeerCertificates[i] = cert 176 } 177 } 178 179 return ret, nil 180} 181 182type RPCResponseWriter struct { 183 statusCode int 184 header http.Header 185 body *bytes.Buffer 186} 187 188// NewRPCResponseWriter returns an initialized RPCResponseWriter 189func NewRPCResponseWriter() *RPCResponseWriter { 190 w := &RPCResponseWriter{ 191 header: make(http.Header), 192 body: new(bytes.Buffer), 193 statusCode: 200, 194 } 195 //w.header.Set("Content-Type", "application/octet-stream") 196 return w 197} 198 199func (w *RPCResponseWriter) Header() http.Header { 200 return w.header 201} 202 203func (w *RPCResponseWriter) Write(buf []byte) (int, error) { 204 w.body.Write(buf) 205 return len(buf), nil 206} 207 208func (w *RPCResponseWriter) WriteHeader(code int) { 209 w.statusCode = code 210} 211 212func (w *RPCResponseWriter) StatusCode() int { 213 return w.statusCode 214} 215 216func (w *RPCResponseWriter) Body() *bytes.Buffer { 217 return w.body 218} 219