1// Copyright 2015 go-swagger maintainers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package client
16
17import (
18	"context"
19	"crypto"
20	"crypto/ecdsa"
21	"crypto/rsa"
22	"crypto/tls"
23	"crypto/x509"
24	"encoding/pem"
25	"fmt"
26	"io/ioutil"
27	"mime"
28	"net/http"
29	"net/http/httputil"
30	"strings"
31	"sync"
32	"time"
33
34	"github.com/go-openapi/runtime"
35	"github.com/go-openapi/runtime/logger"
36	"github.com/go-openapi/runtime/middleware"
37	"github.com/go-openapi/strfmt"
38)
39
40// TLSClientOptions to configure client authentication with mutual TLS
41type TLSClientOptions struct {
42	// Certificate is the path to a PEM-encoded certificate to be used for
43	// client authentication. If set then Key must also be set.
44	Certificate string
45
46	// LoadedCertificate is the certificate to be used for client authentication.
47	// This field is ignored if Certificate is set. If this field is set, LoadedKey
48	// is also required.
49	LoadedCertificate *x509.Certificate
50
51	// Key is the path to an unencrypted PEM-encoded private key for client
52	// authentication. This field is required if Certificate is set.
53	Key string
54
55	// LoadedKey is the key for client authentication. This field is required if
56	// LoadedCertificate is set.
57	LoadedKey crypto.PrivateKey
58
59	// CA is a path to a PEM-encoded certificate that specifies the root certificate
60	// to use when validating the TLS certificate presented by the server. If this field
61	// (and LoadedCA) is not set, the system certificate pool is used. This field is ignored if LoadedCA
62	// is set.
63	CA string
64
65	// LoadedCA specifies the root certificate to use when validating the server's TLS certificate.
66	// If this field (and CA) is not set, the system certificate pool is used.
67	LoadedCA *x509.Certificate
68
69	// ServerName specifies the hostname to use when verifying the server certificate.
70	// If this field is set then InsecureSkipVerify will be ignored and treated as
71	// false.
72	ServerName string
73
74	// InsecureSkipVerify controls whether the certificate chain and hostname presented
75	// by the server are validated. If false, any certificate is accepted.
76	InsecureSkipVerify bool
77
78	// VerifyPeerCertificate, if not nil, is called after normal
79	// certificate verification. It receives the raw ASN.1 certificates
80	// provided by the peer and also any verified chains that normal processing found.
81	// If it returns a non-nil error, the handshake is aborted and that error results.
82	//
83	// If normal verification fails then the handshake will abort before
84	// considering this callback. If normal verification is disabled by
85	// setting InsecureSkipVerify then this callback will be considered but
86	// the verifiedChains argument will always be nil.
87	VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
88
89	// SessionTicketsDisabled may be set to true to disable session ticket and
90	// PSK (resumption) support. Note that on clients, session ticket support is
91	// also disabled if ClientSessionCache is nil.
92	SessionTicketsDisabled bool
93
94	// ClientSessionCache is a cache of ClientSessionState entries for TLS
95	// session resumption. It is only used by clients.
96	ClientSessionCache tls.ClientSessionCache
97
98	// Prevents callers using unkeyed fields.
99	_ struct{}
100}
101
102// TLSClientAuth creates a tls.Config for mutual auth
103func TLSClientAuth(opts TLSClientOptions) (*tls.Config, error) {
104	// create client tls config
105	cfg := &tls.Config{}
106
107	// load client cert if specified
108	if opts.Certificate != "" {
109		cert, err := tls.LoadX509KeyPair(opts.Certificate, opts.Key)
110		if err != nil {
111			return nil, fmt.Errorf("tls client cert: %v", err)
112		}
113		cfg.Certificates = []tls.Certificate{cert}
114	} else if opts.LoadedCertificate != nil {
115		block := pem.Block{Type: "CERTIFICATE", Bytes: opts.LoadedCertificate.Raw}
116		certPem := pem.EncodeToMemory(&block)
117
118		var keyBytes []byte
119		switch k := opts.LoadedKey.(type) {
120		case *rsa.PrivateKey:
121			keyBytes = x509.MarshalPKCS1PrivateKey(k)
122		case *ecdsa.PrivateKey:
123			var err error
124			keyBytes, err = x509.MarshalECPrivateKey(k)
125			if err != nil {
126				return nil, fmt.Errorf("tls client priv key: %v", err)
127			}
128		default:
129			return nil, fmt.Errorf("tls client priv key: unsupported key type")
130		}
131
132		block = pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes}
133		keyPem := pem.EncodeToMemory(&block)
134
135		cert, err := tls.X509KeyPair(certPem, keyPem)
136		if err != nil {
137			return nil, fmt.Errorf("tls client cert: %v", err)
138		}
139		cfg.Certificates = []tls.Certificate{cert}
140	}
141
142	cfg.InsecureSkipVerify = opts.InsecureSkipVerify
143
144	cfg.VerifyPeerCertificate = opts.VerifyPeerCertificate
145	cfg.SessionTicketsDisabled = opts.SessionTicketsDisabled
146	cfg.ClientSessionCache = opts.ClientSessionCache
147
148	// When no CA certificate is provided, default to the system cert pool
149	// that way when a request is made to a server known by the system trust store,
150	// the name is still verified
151	if opts.LoadedCA != nil {
152		caCertPool := x509.NewCertPool()
153		caCertPool.AddCert(opts.LoadedCA)
154		cfg.RootCAs = caCertPool
155	} else if opts.CA != "" {
156		// load ca cert
157		caCert, err := ioutil.ReadFile(opts.CA)
158		if err != nil {
159			return nil, fmt.Errorf("tls client ca: %v", err)
160		}
161		caCertPool := x509.NewCertPool()
162		caCertPool.AppendCertsFromPEM(caCert)
163		cfg.RootCAs = caCertPool
164	}
165
166	// apply servername overrride
167	if opts.ServerName != "" {
168		cfg.InsecureSkipVerify = false
169		cfg.ServerName = opts.ServerName
170	}
171
172	cfg.BuildNameToCertificate()
173
174	return cfg, nil
175}
176
177// TLSTransport creates a http client transport suitable for mutual tls auth
178func TLSTransport(opts TLSClientOptions) (http.RoundTripper, error) {
179	cfg, err := TLSClientAuth(opts)
180	if err != nil {
181		return nil, err
182	}
183
184	return &http.Transport{TLSClientConfig: cfg}, nil
185}
186
187// TLSClient creates a http.Client for mutual auth
188func TLSClient(opts TLSClientOptions) (*http.Client, error) {
189	transport, err := TLSTransport(opts)
190	if err != nil {
191		return nil, err
192	}
193	return &http.Client{Transport: transport}, nil
194}
195
196// DefaultTimeout the default request timeout
197var DefaultTimeout = 30 * time.Second
198
199// Runtime represents an API client that uses the transport
200// to make http requests based on a swagger specification.
201type Runtime struct {
202	DefaultMediaType      string
203	DefaultAuthentication runtime.ClientAuthInfoWriter
204	Consumers             map[string]runtime.Consumer
205	Producers             map[string]runtime.Producer
206
207	Transport http.RoundTripper
208	Jar       http.CookieJar
209	//Spec      *spec.Document
210	Host     string
211	BasePath string
212	Formats  strfmt.Registry
213	Context  context.Context
214
215	Debug  bool
216	logger logger.Logger
217
218	clientOnce *sync.Once
219	client     *http.Client
220	schemes    []string
221}
222
223// New creates a new default runtime for a swagger api runtime.Client
224func New(host, basePath string, schemes []string) *Runtime {
225	var rt Runtime
226	rt.DefaultMediaType = runtime.JSONMime
227
228	// TODO: actually infer this stuff from the spec
229	rt.Consumers = map[string]runtime.Consumer{
230		runtime.JSONMime:    runtime.JSONConsumer(),
231		runtime.XMLMime:     runtime.XMLConsumer(),
232		runtime.TextMime:    runtime.TextConsumer(),
233		runtime.HTMLMime:    runtime.TextConsumer(),
234		runtime.CSVMime:     runtime.CSVConsumer(),
235		runtime.DefaultMime: runtime.ByteStreamConsumer(),
236	}
237	rt.Producers = map[string]runtime.Producer{
238		runtime.JSONMime:    runtime.JSONProducer(),
239		runtime.XMLMime:     runtime.XMLProducer(),
240		runtime.TextMime:    runtime.TextProducer(),
241		runtime.HTMLMime:    runtime.TextProducer(),
242		runtime.CSVMime:     runtime.CSVProducer(),
243		runtime.DefaultMime: runtime.ByteStreamProducer(),
244	}
245	rt.Transport = http.DefaultTransport
246	rt.Jar = nil
247	rt.Host = host
248	rt.BasePath = basePath
249	rt.Context = context.Background()
250	rt.clientOnce = new(sync.Once)
251	if !strings.HasPrefix(rt.BasePath, "/") {
252		rt.BasePath = "/" + rt.BasePath
253	}
254
255	rt.Debug = logger.DebugEnabled()
256	rt.logger = logger.StandardLogger{}
257
258	if len(schemes) > 0 {
259		rt.schemes = schemes
260	}
261	return &rt
262}
263
264// NewWithClient allows you to create a new transport with a configured http.Client
265func NewWithClient(host, basePath string, schemes []string, client *http.Client) *Runtime {
266	rt := New(host, basePath, schemes)
267	if client != nil {
268		rt.clientOnce.Do(func() {
269			rt.client = client
270		})
271	}
272	return rt
273}
274
275func (r *Runtime) pickScheme(schemes []string) string {
276	if v := r.selectScheme(r.schemes); v != "" {
277		return v
278	}
279	if v := r.selectScheme(schemes); v != "" {
280		return v
281	}
282	return "http"
283}
284
285func (r *Runtime) selectScheme(schemes []string) string {
286	schLen := len(schemes)
287	if schLen == 0 {
288		return ""
289	}
290
291	scheme := schemes[0]
292	// prefer https, but skip when not possible
293	if scheme != "https" && schLen > 1 {
294		for _, sch := range schemes {
295			if sch == "https" {
296				scheme = sch
297				break
298			}
299		}
300	}
301	return scheme
302}
303func transportOrDefault(left, right http.RoundTripper) http.RoundTripper {
304	if left == nil {
305		return right
306	}
307	return left
308}
309
310// EnableConnectionReuse drains the remaining body from a response
311// so that go will reuse the TCP connections.
312//
313// This is not enabled by default because there are servers where
314// the response never gets closed and that would make the code hang forever.
315// So instead it's provided as a http client middleware that can be used to override
316// any request.
317func (r *Runtime) EnableConnectionReuse() {
318	if r.client == nil {
319		r.Transport = KeepAliveTransport(
320			transportOrDefault(r.Transport, http.DefaultTransport),
321		)
322		return
323	}
324
325	r.client.Transport = KeepAliveTransport(
326		transportOrDefault(r.client.Transport,
327			transportOrDefault(r.Transport, http.DefaultTransport),
328		),
329	)
330}
331
332// Submit a request and when there is a body on success it will turn that into the result
333// all other things are turned into an api error for swagger which retains the status code
334func (r *Runtime) Submit(operation *runtime.ClientOperation) (interface{}, error) {
335	params, readResponse, auth := operation.Params, operation.Reader, operation.AuthInfo
336
337	request, err := newRequest(operation.Method, operation.PathPattern, params)
338	if err != nil {
339		return nil, err
340	}
341
342	var accept []string
343	accept = append(accept, operation.ProducesMediaTypes...)
344	if err = request.SetHeaderParam(runtime.HeaderAccept, accept...); err != nil {
345		return nil, err
346	}
347
348	if auth == nil && r.DefaultAuthentication != nil {
349		auth = r.DefaultAuthentication
350	}
351	//if auth != nil {
352	//	if err := auth.AuthenticateRequest(request, r.Formats); err != nil {
353	//		return nil, err
354	//	}
355	//}
356
357	// TODO: pick appropriate media type
358	cmt := r.DefaultMediaType
359	for _, mediaType := range operation.ConsumesMediaTypes {
360		// Pick first non-empty media type
361		if mediaType != "" {
362			cmt = mediaType
363			break
364		}
365	}
366
367	if _, ok := r.Producers[cmt]; !ok && cmt != runtime.MultipartFormMime && cmt != runtime.URLencodedFormMime {
368		return nil, fmt.Errorf("none of producers: %v registered. try %s", r.Producers, cmt)
369	}
370
371	req, err := request.buildHTTP(cmt, r.BasePath, r.Producers, r.Formats, auth)
372	if err != nil {
373		return nil, err
374	}
375	req.URL.Scheme = r.pickScheme(operation.Schemes)
376	req.URL.Host = r.Host
377	req.Host = r.Host
378
379	r.clientOnce.Do(func() {
380		r.client = &http.Client{
381			Transport: r.Transport,
382			Jar:       r.Jar,
383		}
384	})
385
386	if r.Debug {
387		b, err2 := httputil.DumpRequestOut(req, true)
388		if err2 != nil {
389			return nil, err2
390		}
391		r.logger.Debugf("%s\n", string(b))
392	}
393
394	var hasTimeout bool
395	pctx := operation.Context
396	if pctx == nil {
397		pctx = r.Context
398	} else {
399		hasTimeout = true
400	}
401	if pctx == nil {
402		pctx = context.Background()
403	}
404	var ctx context.Context
405	var cancel context.CancelFunc
406	if hasTimeout {
407		ctx, cancel = context.WithCancel(pctx)
408	} else {
409		ctx, cancel = context.WithTimeout(pctx, request.timeout)
410	}
411	defer cancel()
412
413	client := operation.Client
414	if client == nil {
415		client = r.client
416	}
417	req = req.WithContext(ctx)
418	res, err := client.Do(req) // make requests, by default follows 10 redirects before failing
419	if err != nil {
420		return nil, err
421	}
422	defer res.Body.Close()
423
424	if r.Debug {
425		b, err2 := httputil.DumpResponse(res, true)
426		if err2 != nil {
427			return nil, err2
428		}
429		r.logger.Debugf("%s\n", string(b))
430	}
431
432	ct := res.Header.Get(runtime.HeaderContentType)
433	if ct == "" { // this should really really never occur
434		ct = r.DefaultMediaType
435	}
436
437	mt, _, err := mime.ParseMediaType(ct)
438	if err != nil {
439		return nil, fmt.Errorf("parse content type: %s", err)
440	}
441
442	cons, ok := r.Consumers[mt]
443	if !ok {
444		if cons, ok = r.Consumers["*/*"]; !ok {
445			// scream about not knowing what to do
446			return nil, fmt.Errorf("no consumer: %q", ct)
447		}
448	}
449	return readResponse.ReadResponse(response{res}, cons)
450}
451
452// SetDebug changes the debug flag.
453// It ensures that client and middlewares have the set debug level.
454func (r *Runtime) SetDebug(debug bool) {
455	r.Debug = debug
456	middleware.Debug = debug
457}
458
459// SetLogger changes the logger stream.
460// It ensures that client and middlewares use the same logger.
461func (r *Runtime) SetLogger(logger logger.Logger) {
462	r.logger = logger
463	middleware.Logger = logger
464}
465