1// Copyright (C) MongoDB, Inc. 2018-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7// This file contains code adapted from the MongoDB Go Driver.
8
9// Package tlsgo provides a mgo connection using Go's native TLS library.
10package tlsgo
11
12import (
13	"crypto/tls"
14	"crypto/x509"
15	"encoding/asn1"
16	"encoding/hex"
17	"encoding/pem"
18	"fmt"
19	"io/ioutil"
20	"strings"
21)
22
23// TLSConfig contains options for configuring an SSL connection to the server.
24type TLSConfig struct {
25	caCert     *x509.Certificate
26	clientCert *tls.Certificate
27	insecure   bool
28}
29
30// NewTLSConfig creates a new TLSConfig.
31func NewTLSConfig() *TLSConfig {
32	cfg := &TLSConfig{}
33
34	return cfg
35}
36
37// SetInsecure sets whether the client should verify the server's certificate chain and hostnames.
38func (c *TLSConfig) SetInsecure(allow bool) {
39	c.insecure = allow
40}
41
42// AddClientCertFromFile adds a client certificate to the configuration given a path to the
43// containing file and returns the certificate's subject name.
44func (c *TLSConfig) AddClientCertFromFile(clientFile, password string) (string, error) {
45	data, err := ioutil.ReadFile(clientFile)
46	if err != nil {
47		return "", err
48	}
49
50	certPEM, err := loadPEMBlock(data, "CERTIFICATE")
51	if err != nil {
52		return "", err
53	}
54
55	keyPEM, err := loadPEMBlock(data, "PRIVATE KEY")
56	if err != nil {
57		return "", err
58	}
59	// This check only covers encrypted PEM data with a DEK-Info header. It
60	// does not detect unencrypted PEM containing PKCS#8 format data with an
61	// encrypted private key.
62	if x509.IsEncryptedPEMBlock(keyPEM) {
63		if password == "" {
64			return "", fmt.Errorf("No password provided to decrypt private key")
65		}
66		decrypted, err := x509.DecryptPEMBlock(keyPEM, []byte(password))
67		if err != nil {
68			return "", err
69		}
70		keyPEM = &pem.Block{Bytes: decrypted, Type: keyPEM.Type}
71	}
72
73	if strings.Contains(keyPEM.Type, "ENCRYPTED") {
74		return "", fmt.Errorf("PKCS#8 encrypted private keys are not supported")
75	}
76
77	cert, err := tls.X509KeyPair(pem.EncodeToMemory(certPEM), pem.EncodeToMemory(keyPEM))
78	if err != nil {
79		return "", err
80	}
81
82	c.clientCert = &cert
83
84	// The documentation for the tls.X509KeyPair indicates that the Leaf
85	// certificate is not retained. Because there isn't any way of creating a
86	// tls.Certificate from an x509.Certificate short of calling X509KeyPair
87	// on the raw bytes, we're forced to parse the certificate over again to
88	// get the subject name.
89	crt, err := x509.ParseCertificate(certPEM.Bytes)
90	if err != nil {
91		return "", err
92	}
93
94	return x509CertSubject(crt), nil
95}
96
97// AddCaCertFromFile adds a root CA certificate to the configuration given a path to the containing file.
98func (c *TLSConfig) AddCaCertFromFile(caFile string) error {
99	data, err := ioutil.ReadFile(caFile)
100	if err != nil {
101		return err
102	}
103
104	certBytes, err := loadCertBytes(data)
105	if err != nil {
106		return err
107	}
108
109	cert, err := x509.ParseCertificate(certBytes)
110	if err != nil {
111		return err
112	}
113
114	c.caCert = cert
115
116	return nil
117}
118
119// MakeConfig constructs a new tls.Config from the configuration specified.
120func (c *TLSConfig) MakeConfig() (*tls.Config, error) {
121	cfg := &tls.Config{}
122
123	if c.clientCert != nil {
124		cfg.Certificates = []tls.Certificate{*c.clientCert}
125	}
126
127	if c.caCert == nil {
128		roots, err := loadSystemCAs()
129		if err != nil {
130			return nil, err
131		}
132		cfg.RootCAs = roots
133	} else {
134		cfg.RootCAs = x509.NewCertPool()
135		cfg.RootCAs.AddCert(c.caCert)
136	}
137
138	cfg.InsecureSkipVerify = c.insecure
139
140	return cfg, nil
141}
142
143func loadCertBytes(data []byte) ([]byte, error) {
144	b, err := loadPEMBlock(data, "CERTIFICATE")
145	if err != nil {
146		return nil, err
147	}
148	return b.Bytes, nil
149}
150
151func loadPEMBlock(data []byte, blocktype string) (*pem.Block, error) {
152	var b *pem.Block
153
154	for b == nil {
155		if data == nil || len(data) == 0 {
156			return nil, fmt.Errorf("no block of type %s found in .pem file", blocktype)
157		}
158
159		block, rest := pem.Decode(data)
160		if block == nil {
161			return nil, fmt.Errorf("invalid .pem file")
162		}
163
164		if strings.Contains(block.Type, blocktype) {
165			if b != nil {
166				return nil, fmt.Errorf("multiple %s sections in .pem file", blocktype)
167			}
168			b = block
169		}
170
171		data = rest
172	}
173
174	return b, nil
175}
176
177// Because the functionality to convert a pkix.Name to a string wasn't added until Go 1.10, we
178// need to copy the implementation (along with the attributeTypeNames map below).
179func x509CertSubject(cert *x509.Certificate) string {
180	r := cert.Subject.ToRDNSequence()
181
182	s := ""
183	for i := 0; i < len(r); i++ {
184		rdn := r[len(r)-1-i]
185		if i > 0 {
186			s += ","
187		}
188		for j, tv := range rdn {
189			if j > 0 {
190				s += "+"
191			}
192
193			oidString := tv.Type.String()
194			typeName, ok := attributeTypeNames[oidString]
195			if !ok {
196				derBytes, err := asn1.Marshal(tv.Value)
197				if err == nil {
198					s += oidString + "=#" + hex.EncodeToString(derBytes)
199					continue // No value escaping necessary.
200				}
201
202				typeName = oidString
203			}
204
205			valueString := fmt.Sprint(tv.Value)
206			escaped := make([]rune, 0, len(valueString))
207
208			for k, c := range valueString {
209				escape := false
210
211				switch c {
212				case ',', '+', '"', '\\', '<', '>', ';':
213					escape = true
214
215				case ' ':
216					escape = k == 0 || k == len(valueString)-1
217
218				case '#':
219					escape = k == 0
220				}
221
222				if escape {
223					escaped = append(escaped, '\\', c)
224				} else {
225					escaped = append(escaped, c)
226				}
227			}
228
229			s += typeName + "=" + string(escaped)
230		}
231	}
232
233	return s
234}
235
236var attributeTypeNames = map[string]string{
237	"2.5.4.6":  "C",
238	"2.5.4.10": "O",
239	"2.5.4.11": "OU",
240	"2.5.4.3":  "CN",
241	"2.5.4.5":  "SERIALNUMBER",
242	"2.5.4.7":  "L",
243	"2.5.4.8":  "ST",
244	"2.5.4.9":  "STREET",
245	"2.5.4.17": "POSTALCODE",
246}
247