1// Copyright 2015 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package mitm provides tooling for MITMing TLS connections. It provides
16// tooling to create CA certs and generate TLS configs that can be used to MITM
17// a TLS connection with a provided CA certificate.
18package mitm
19
20import (
21	"bytes"
22	"crypto/rand"
23	"crypto/rsa"
24	"crypto/sha1"
25	"crypto/tls"
26	"crypto/x509"
27	"crypto/x509/pkix"
28	"errors"
29	"math/big"
30	"net"
31	"net/http"
32	"sync"
33	"time"
34
35	"github.com/google/martian/v3/log"
36)
37
38// MaxSerialNumber is the upper boundary that is used to create unique serial
39// numbers for the certificate. This can be any unsigned integer up to 20
40// bytes (2^(8*20)-1).
41var MaxSerialNumber = big.NewInt(0).SetBytes(bytes.Repeat([]byte{255}, 20))
42
43// Config is a set of configuration values that are used to build TLS configs
44// capable of MITM.
45type Config struct {
46	ca                     *x509.Certificate
47	capriv                 interface{}
48	priv                   *rsa.PrivateKey
49	keyID                  []byte
50	validity               time.Duration
51	org                    string
52	getCertificate         func(*tls.ClientHelloInfo) (*tls.Certificate, error)
53	roots                  *x509.CertPool
54	skipVerify             bool
55	handshakeErrorCallback func(*http.Request, error)
56
57	certmu sync.RWMutex
58	certs  map[string]*tls.Certificate
59}
60
61// NewAuthority creates a new CA certificate and associated
62// private key.
63func NewAuthority(name, organization string, validity time.Duration) (*x509.Certificate, *rsa.PrivateKey, error) {
64	priv, err := rsa.GenerateKey(rand.Reader, 2048)
65	if err != nil {
66		return nil, nil, err
67	}
68	pub := priv.Public()
69
70	// Subject Key Identifier support for end entity certificate.
71	// https://www.ietf.org/rfc/rfc3280.txt (section 4.2.1.2)
72	pkixpub, err := x509.MarshalPKIXPublicKey(pub)
73	if err != nil {
74		return nil, nil, err
75	}
76	h := sha1.New()
77	h.Write(pkixpub)
78	keyID := h.Sum(nil)
79
80	// TODO: keep a map of used serial numbers to avoid potentially reusing a
81	// serial multiple times.
82	serial, err := rand.Int(rand.Reader, MaxSerialNumber)
83	if err != nil {
84		return nil, nil, err
85	}
86
87	tmpl := &x509.Certificate{
88		SerialNumber: serial,
89		Subject: pkix.Name{
90			CommonName:   name,
91			Organization: []string{organization},
92		},
93		SubjectKeyId:          keyID,
94		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
95		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
96		BasicConstraintsValid: true,
97		NotBefore:             time.Now().Add(-validity),
98		NotAfter:              time.Now().Add(validity),
99		DNSNames:              []string{name},
100		IsCA:                  true,
101	}
102
103	raw, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, pub, priv)
104	if err != nil {
105		return nil, nil, err
106	}
107
108	// Parse certificate bytes so that we have a leaf certificate.
109	x509c, err := x509.ParseCertificate(raw)
110	if err != nil {
111		return nil, nil, err
112	}
113
114	return x509c, priv, nil
115}
116
117// NewConfig creates a MITM config using the CA certificate and
118// private key to generate on-the-fly certificates.
119func NewConfig(ca *x509.Certificate, privateKey interface{}) (*Config, error) {
120	roots := x509.NewCertPool()
121	roots.AddCert(ca)
122
123	priv, err := rsa.GenerateKey(rand.Reader, 2048)
124	if err != nil {
125		return nil, err
126	}
127	pub := priv.Public()
128
129	// Subject Key Identifier support for end entity certificate.
130	// https://www.ietf.org/rfc/rfc3280.txt (section 4.2.1.2)
131	pkixpub, err := x509.MarshalPKIXPublicKey(pub)
132	if err != nil {
133		return nil, err
134	}
135	h := sha1.New()
136	h.Write(pkixpub)
137	keyID := h.Sum(nil)
138
139	return &Config{
140		ca:       ca,
141		capriv:   privateKey,
142		priv:     priv,
143		keyID:    keyID,
144		validity: time.Hour,
145		org:      "Martian Proxy",
146		certs:    make(map[string]*tls.Certificate),
147		roots:    roots,
148	}, nil
149}
150
151// SetValidity sets the validity window around the current time that the
152// certificate is valid for.
153func (c *Config) SetValidity(validity time.Duration) {
154	c.validity = validity
155}
156
157// SkipTLSVerify skips the TLS certification verification check.
158func (c *Config) SkipTLSVerify(skip bool) {
159	c.skipVerify = skip
160}
161
162// SetOrganization sets the organization of the certificate.
163func (c *Config) SetOrganization(org string) {
164	c.org = org
165}
166
167// SetHandshakeErrorCallback sets the handshakeErrorCallback function.
168func (c *Config) SetHandshakeErrorCallback(cb func(*http.Request, error)) {
169	c.handshakeErrorCallback = cb
170}
171
172// HandshakeErrorCallback calls the handshakeErrorCallback function in this
173// Config, if it is non-nil. Request is the connect request that this handshake
174// is being executed through.
175func (c *Config) HandshakeErrorCallback(r *http.Request, err error) {
176	if c.handshakeErrorCallback != nil {
177		c.handshakeErrorCallback(r, err)
178	}
179}
180
181// TLS returns a *tls.Config that will generate certificates on-the-fly using
182// the SNI extension in the TLS ClientHello.
183func (c *Config) TLS() *tls.Config {
184	return &tls.Config{
185		InsecureSkipVerify: c.skipVerify,
186		GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
187			if clientHello.ServerName == "" {
188				return nil, errors.New("mitm: SNI not provided, failed to build certificate")
189			}
190
191			return c.cert(clientHello.ServerName)
192		},
193		NextProtos: []string{"http/1.1"},
194	}
195}
196
197// TLSForHost returns a *tls.Config that will generate certificates on-the-fly
198// using SNI from the connection, or fall back to the provided hostname.
199func (c *Config) TLSForHost(hostname string) *tls.Config {
200	return &tls.Config{
201		InsecureSkipVerify: c.skipVerify,
202		GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
203			host := clientHello.ServerName
204			if host == "" {
205				host = hostname
206			}
207
208			return c.cert(host)
209		},
210		NextProtos: []string{"http/1.1"},
211	}
212}
213
214func (c *Config) cert(hostname string) (*tls.Certificate, error) {
215	// Remove the port if it exists.
216	host, _, err := net.SplitHostPort(hostname)
217	if err == nil {
218		hostname = host
219	}
220
221	c.certmu.RLock()
222	tlsc, ok := c.certs[hostname]
223	c.certmu.RUnlock()
224
225	if ok {
226		log.Debugf("mitm: cache hit for %s", hostname)
227
228		// Check validity of the certificate for hostname match, expiry, etc. In
229		// particular, if the cached certificate has expired, create a new one.
230		if _, err := tlsc.Leaf.Verify(x509.VerifyOptions{
231			DNSName: hostname,
232			Roots:   c.roots,
233		}); err == nil {
234			return tlsc, nil
235		}
236
237		log.Debugf("mitm: invalid certificate in cache for %s", hostname)
238	}
239
240	log.Debugf("mitm: cache miss for %s", hostname)
241
242	serial, err := rand.Int(rand.Reader, MaxSerialNumber)
243	if err != nil {
244		return nil, err
245	}
246
247	tmpl := &x509.Certificate{
248		SerialNumber: serial,
249		Subject: pkix.Name{
250			CommonName:   hostname,
251			Organization: []string{c.org},
252		},
253		SubjectKeyId:          c.keyID,
254		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
255		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
256		BasicConstraintsValid: true,
257		NotBefore:             time.Now().Add(-c.validity),
258		NotAfter:              time.Now().Add(c.validity),
259	}
260
261	if ip := net.ParseIP(hostname); ip != nil {
262		tmpl.IPAddresses = []net.IP{ip}
263	} else {
264		tmpl.DNSNames = []string{hostname}
265	}
266
267	raw, err := x509.CreateCertificate(rand.Reader, tmpl, c.ca, c.priv.Public(), c.capriv)
268	if err != nil {
269		return nil, err
270	}
271
272	// Parse certificate bytes so that we have a leaf certificate.
273	x509c, err := x509.ParseCertificate(raw)
274	if err != nil {
275		return nil, err
276	}
277
278	tlsc = &tls.Certificate{
279		Certificate: [][]byte{raw, c.ca.Raw},
280		PrivateKey:  c.priv,
281		Leaf:        x509c,
282	}
283
284	c.certmu.Lock()
285	c.certs[hostname] = tlsc
286	c.certmu.Unlock()
287
288	return tlsc, nil
289}
290