1// Copyright 2015 go-swagger maintainers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package client
16
17import (
18	"context"
19	"crypto"
20	"crypto/ecdsa"
21	"crypto/rsa"
22	"crypto/tls"
23	"crypto/x509"
24	"encoding/pem"
25	"fmt"
26	"io/ioutil"
27	"mime"
28	"net/http"
29	"net/http/httputil"
30	"strings"
31	"sync"
32	"time"
33
34	"github.com/go-openapi/strfmt"
35	"github.com/opentracing/opentracing-go"
36
37	"github.com/go-openapi/runtime"
38	"github.com/go-openapi/runtime/logger"
39	"github.com/go-openapi/runtime/middleware"
40	"github.com/go-openapi/runtime/yamlpc"
41)
42
43// TLSClientOptions to configure client authentication with mutual TLS
44type TLSClientOptions struct {
45	// Certificate is the path to a PEM-encoded certificate to be used for
46	// client authentication. If set then Key must also be set.
47	Certificate string
48
49	// LoadedCertificate is the certificate to be used for client authentication.
50	// This field is ignored if Certificate is set. If this field is set, LoadedKey
51	// is also required.
52	LoadedCertificate *x509.Certificate
53
54	// Key is the path to an unencrypted PEM-encoded private key for client
55	// authentication. This field is required if Certificate is set.
56	Key string
57
58	// LoadedKey is the key for client authentication. This field is required if
59	// LoadedCertificate is set.
60	LoadedKey crypto.PrivateKey
61
62	// CA is a path to a PEM-encoded certificate that specifies the root certificate
63	// to use when validating the TLS certificate presented by the server. If this field
64	// (and LoadedCA) is not set, the system certificate pool is used. This field is ignored if LoadedCA
65	// is set.
66	CA string
67
68	// LoadedCA specifies the root certificate to use when validating the server's TLS certificate.
69	// If this field (and CA) is not set, the system certificate pool is used.
70	LoadedCA *x509.Certificate
71
72	// LoadedCAPool specifies a pool of RootCAs to use when validating the server's TLS certificate.
73	// If set, it will be combined with the the other loaded certificates (see LoadedCA and CA).
74	// If neither LoadedCA or CA is set, the provided pool with override the system
75	// certificate pool.
76	// The caller must not use the supplied pool after calling TLSClientAuth.
77	LoadedCAPool *x509.CertPool
78
79	// ServerName specifies the hostname to use when verifying the server certificate.
80	// If this field is set then InsecureSkipVerify will be ignored and treated as
81	// false.
82	ServerName string
83
84	// InsecureSkipVerify controls whether the certificate chain and hostname presented
85	// by the server are validated. If true, any certificate is accepted.
86	InsecureSkipVerify bool
87
88	// VerifyPeerCertificate, if not nil, is called after normal
89	// certificate verification. It receives the raw ASN.1 certificates
90	// provided by the peer and also any verified chains that normal processing found.
91	// If it returns a non-nil error, the handshake is aborted and that error results.
92	//
93	// If normal verification fails then the handshake will abort before
94	// considering this callback. If normal verification is disabled by
95	// setting InsecureSkipVerify then this callback will be considered but
96	// the verifiedChains argument will always be nil.
97	VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
98
99	// SessionTicketsDisabled may be set to true to disable session ticket and
100	// PSK (resumption) support. Note that on clients, session ticket support is
101	// also disabled if ClientSessionCache is nil.
102	SessionTicketsDisabled bool
103
104	// ClientSessionCache is a cache of ClientSessionState entries for TLS
105	// session resumption. It is only used by clients.
106	ClientSessionCache tls.ClientSessionCache
107
108	// Prevents callers using unkeyed fields.
109	_ struct{}
110}
111
112// TLSClientAuth creates a tls.Config for mutual auth
113func TLSClientAuth(opts TLSClientOptions) (*tls.Config, error) {
114	// create client tls config
115	cfg := &tls.Config{}
116
117	// load client cert if specified
118	if opts.Certificate != "" {
119		cert, err := tls.LoadX509KeyPair(opts.Certificate, opts.Key)
120		if err != nil {
121			return nil, fmt.Errorf("tls client cert: %v", err)
122		}
123		cfg.Certificates = []tls.Certificate{cert}
124	} else if opts.LoadedCertificate != nil {
125		block := pem.Block{Type: "CERTIFICATE", Bytes: opts.LoadedCertificate.Raw}
126		certPem := pem.EncodeToMemory(&block)
127
128		var keyBytes []byte
129		switch k := opts.LoadedKey.(type) {
130		case *rsa.PrivateKey:
131			keyBytes = x509.MarshalPKCS1PrivateKey(k)
132		case *ecdsa.PrivateKey:
133			var err error
134			keyBytes, err = x509.MarshalECPrivateKey(k)
135			if err != nil {
136				return nil, fmt.Errorf("tls client priv key: %v", err)
137			}
138		default:
139			return nil, fmt.Errorf("tls client priv key: unsupported key type")
140		}
141
142		block = pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes}
143		keyPem := pem.EncodeToMemory(&block)
144
145		cert, err := tls.X509KeyPair(certPem, keyPem)
146		if err != nil {
147			return nil, fmt.Errorf("tls client cert: %v", err)
148		}
149		cfg.Certificates = []tls.Certificate{cert}
150	}
151
152	cfg.InsecureSkipVerify = opts.InsecureSkipVerify
153
154	cfg.VerifyPeerCertificate = opts.VerifyPeerCertificate
155	cfg.SessionTicketsDisabled = opts.SessionTicketsDisabled
156	cfg.ClientSessionCache = opts.ClientSessionCache
157
158	// When no CA certificate is provided, default to the system cert pool
159	// that way when a request is made to a server known by the system trust store,
160	// the name is still verified
161	if opts.LoadedCA != nil {
162		caCertPool := basePool(opts.LoadedCAPool)
163		caCertPool.AddCert(opts.LoadedCA)
164		cfg.RootCAs = caCertPool
165	} else if opts.CA != "" {
166		// load ca cert
167		caCert, err := ioutil.ReadFile(opts.CA)
168		if err != nil {
169			return nil, fmt.Errorf("tls client ca: %v", err)
170		}
171		caCertPool := basePool(opts.LoadedCAPool)
172		caCertPool.AppendCertsFromPEM(caCert)
173		cfg.RootCAs = caCertPool
174	} else if opts.LoadedCAPool != nil {
175		cfg.RootCAs = opts.LoadedCAPool
176	}
177
178	// apply servername overrride
179	if opts.ServerName != "" {
180		cfg.InsecureSkipVerify = false
181		cfg.ServerName = opts.ServerName
182	}
183
184	cfg.BuildNameToCertificate()
185
186	return cfg, nil
187}
188
189func basePool(pool *x509.CertPool) *x509.CertPool {
190	if pool == nil {
191		return x509.NewCertPool()
192	}
193	return pool
194}
195
196// TLSTransport creates a http client transport suitable for mutual tls auth
197func TLSTransport(opts TLSClientOptions) (http.RoundTripper, error) {
198	cfg, err := TLSClientAuth(opts)
199	if err != nil {
200		return nil, err
201	}
202
203	return &http.Transport{TLSClientConfig: cfg}, nil
204}
205
206// TLSClient creates a http.Client for mutual auth
207func TLSClient(opts TLSClientOptions) (*http.Client, error) {
208	transport, err := TLSTransport(opts)
209	if err != nil {
210		return nil, err
211	}
212	return &http.Client{Transport: transport}, nil
213}
214
215// DefaultTimeout the default request timeout
216var DefaultTimeout = 30 * time.Second
217
218// Runtime represents an API client that uses the transport
219// to make http requests based on a swagger specification.
220type Runtime struct {
221	DefaultMediaType      string
222	DefaultAuthentication runtime.ClientAuthInfoWriter
223	Consumers             map[string]runtime.Consumer
224	Producers             map[string]runtime.Producer
225
226	Transport http.RoundTripper
227	Jar       http.CookieJar
228	//Spec      *spec.Document
229	Host     string
230	BasePath string
231	Formats  strfmt.Registry
232	Context  context.Context
233
234	Debug  bool
235	logger logger.Logger
236
237	clientOnce *sync.Once
238	client     *http.Client
239	schemes    []string
240}
241
242// New creates a new default runtime for a swagger api runtime.Client
243func New(host, basePath string, schemes []string) *Runtime {
244	var rt Runtime
245	rt.DefaultMediaType = runtime.JSONMime
246
247	// TODO: actually infer this stuff from the spec
248	rt.Consumers = map[string]runtime.Consumer{
249		runtime.YAMLMime:    yamlpc.YAMLConsumer(),
250		runtime.JSONMime:    runtime.JSONConsumer(),
251		runtime.XMLMime:     runtime.XMLConsumer(),
252		runtime.TextMime:    runtime.TextConsumer(),
253		runtime.HTMLMime:    runtime.TextConsumer(),
254		runtime.CSVMime:     runtime.CSVConsumer(),
255		runtime.DefaultMime: runtime.ByteStreamConsumer(),
256	}
257	rt.Producers = map[string]runtime.Producer{
258		runtime.YAMLMime:    yamlpc.YAMLProducer(),
259		runtime.JSONMime:    runtime.JSONProducer(),
260		runtime.XMLMime:     runtime.XMLProducer(),
261		runtime.TextMime:    runtime.TextProducer(),
262		runtime.HTMLMime:    runtime.TextProducer(),
263		runtime.CSVMime:     runtime.CSVProducer(),
264		runtime.DefaultMime: runtime.ByteStreamProducer(),
265	}
266	rt.Transport = http.DefaultTransport
267	rt.Jar = nil
268	rt.Host = host
269	rt.BasePath = basePath
270	rt.Context = context.Background()
271	rt.clientOnce = new(sync.Once)
272	if !strings.HasPrefix(rt.BasePath, "/") {
273		rt.BasePath = "/" + rt.BasePath
274	}
275
276	rt.Debug = logger.DebugEnabled()
277	rt.logger = logger.StandardLogger{}
278
279	if len(schemes) > 0 {
280		rt.schemes = schemes
281	}
282	return &rt
283}
284
285// NewWithClient allows you to create a new transport with a configured http.Client
286func NewWithClient(host, basePath string, schemes []string, client *http.Client) *Runtime {
287	rt := New(host, basePath, schemes)
288	if client != nil {
289		rt.clientOnce.Do(func() {
290			rt.client = client
291		})
292	}
293	return rt
294}
295
296// WithOpenTracing adds opentracing support to the provided runtime.
297// A new client span is created for each request.
298// If the context of the client operation does not contain an active span, no span is created.
299// The provided opts are applied to each spans - for example to add global tags.
300func (r *Runtime) WithOpenTracing(opts ...opentracing.StartSpanOption) runtime.ClientTransport {
301	return newOpenTracingTransport(r, r.Host, opts)
302}
303
304func (r *Runtime) pickScheme(schemes []string) string {
305	if v := r.selectScheme(r.schemes); v != "" {
306		return v
307	}
308	if v := r.selectScheme(schemes); v != "" {
309		return v
310	}
311	return "http"
312}
313
314func (r *Runtime) selectScheme(schemes []string) string {
315	schLen := len(schemes)
316	if schLen == 0 {
317		return ""
318	}
319
320	scheme := schemes[0]
321	// prefer https, but skip when not possible
322	if scheme != "https" && schLen > 1 {
323		for _, sch := range schemes {
324			if sch == "https" {
325				scheme = sch
326				break
327			}
328		}
329	}
330	return scheme
331}
332func transportOrDefault(left, right http.RoundTripper) http.RoundTripper {
333	if left == nil {
334		return right
335	}
336	return left
337}
338
339// EnableConnectionReuse drains the remaining body from a response
340// so that go will reuse the TCP connections.
341//
342// This is not enabled by default because there are servers where
343// the response never gets closed and that would make the code hang forever.
344// So instead it's provided as a http client middleware that can be used to override
345// any request.
346func (r *Runtime) EnableConnectionReuse() {
347	if r.client == nil {
348		r.Transport = KeepAliveTransport(
349			transportOrDefault(r.Transport, http.DefaultTransport),
350		)
351		return
352	}
353
354	r.client.Transport = KeepAliveTransport(
355		transportOrDefault(r.client.Transport,
356			transportOrDefault(r.Transport, http.DefaultTransport),
357		),
358	)
359}
360
361// Submit a request and when there is a body on success it will turn that into the result
362// all other things are turned into an api error for swagger which retains the status code
363func (r *Runtime) Submit(operation *runtime.ClientOperation) (interface{}, error) {
364	params, readResponse, auth := operation.Params, operation.Reader, operation.AuthInfo
365
366	request, err := newRequest(operation.Method, operation.PathPattern, params)
367	if err != nil {
368		return nil, err
369	}
370
371	var accept []string
372	accept = append(accept, operation.ProducesMediaTypes...)
373	if err = request.SetHeaderParam(runtime.HeaderAccept, accept...); err != nil {
374		return nil, err
375	}
376
377	if auth == nil && r.DefaultAuthentication != nil {
378		auth = r.DefaultAuthentication
379	}
380	//if auth != nil {
381	//	if err := auth.AuthenticateRequest(request, r.Formats); err != nil {
382	//		return nil, err
383	//	}
384	//}
385
386	// TODO: pick appropriate media type
387	cmt := r.DefaultMediaType
388	for _, mediaType := range operation.ConsumesMediaTypes {
389		// Pick first non-empty media type
390		if mediaType != "" {
391			cmt = mediaType
392			break
393		}
394	}
395
396	if _, ok := r.Producers[cmt]; !ok && cmt != runtime.MultipartFormMime && cmt != runtime.URLencodedFormMime {
397		return nil, fmt.Errorf("none of producers: %v registered. try %s", r.Producers, cmt)
398	}
399
400	req, err := request.buildHTTP(cmt, r.BasePath, r.Producers, r.Formats, auth)
401	if err != nil {
402		return nil, err
403	}
404	req.URL.Scheme = r.pickScheme(operation.Schemes)
405	req.URL.Host = r.Host
406	req.Host = r.Host
407
408	r.clientOnce.Do(func() {
409		r.client = &http.Client{
410			Transport: r.Transport,
411			Jar:       r.Jar,
412		}
413	})
414
415	if r.Debug {
416		b, err2 := httputil.DumpRequestOut(req, true)
417		if err2 != nil {
418			return nil, err2
419		}
420		r.logger.Debugf("%s\n", string(b))
421	}
422
423	var hasTimeout bool
424	pctx := operation.Context
425	if pctx == nil {
426		pctx = r.Context
427	} else {
428		hasTimeout = true
429	}
430	if pctx == nil {
431		pctx = context.Background()
432	}
433	var ctx context.Context
434	var cancel context.CancelFunc
435	if hasTimeout {
436		ctx, cancel = context.WithCancel(pctx)
437	} else {
438		ctx, cancel = context.WithTimeout(pctx, request.timeout)
439	}
440	defer cancel()
441
442	client := operation.Client
443	if client == nil {
444		client = r.client
445	}
446	req = req.WithContext(ctx)
447	res, err := client.Do(req) // make requests, by default follows 10 redirects before failing
448	if err != nil {
449		return nil, err
450	}
451	defer res.Body.Close()
452
453	ct := res.Header.Get(runtime.HeaderContentType)
454	if ct == "" { // this should really really never occur
455		ct = r.DefaultMediaType
456	}
457
458	if r.Debug {
459		printBody := true
460		if ct == runtime.DefaultMime {
461			printBody = false // Spare the terminal from a binary blob.
462		}
463		b, err2 := httputil.DumpResponse(res, printBody)
464		if err2 != nil {
465			return nil, err2
466		}
467		r.logger.Debugf("%s\n", string(b))
468	}
469
470	mt, _, err := mime.ParseMediaType(ct)
471	if err != nil {
472		return nil, fmt.Errorf("parse content type: %s", err)
473	}
474
475	cons, ok := r.Consumers[mt]
476	if !ok {
477		if cons, ok = r.Consumers["*/*"]; !ok {
478			// scream about not knowing what to do
479			return nil, fmt.Errorf("no consumer: %q", ct)
480		}
481	}
482	return readResponse.ReadResponse(response{res}, cons)
483}
484
485// SetDebug changes the debug flag.
486// It ensures that client and middlewares have the set debug level.
487func (r *Runtime) SetDebug(debug bool) {
488	r.Debug = debug
489	middleware.Debug = debug
490}
491
492// SetLogger changes the logger stream.
493// It ensures that client and middlewares use the same logger.
494func (r *Runtime) SetLogger(logger logger.Logger) {
495	r.logger = logger
496	middleware.Logger = logger
497}
498