1package reloadutil
2
3import (
4	"crypto/tls"
5	"crypto/x509"
6	"encoding/pem"
7	"errors"
8	"fmt"
9	"io/ioutil"
10	"sync"
11
12	"github.com/hashicorp/errwrap"
13)
14
15// ReloadFunc are functions that are called when a reload is requested
16type ReloadFunc func(map[string]interface{}) error
17
18// CertificateGetter satisfies ReloadFunc and its GetCertificate method
19// satisfies the tls.GetCertificate function signature.  Currently it does not
20// allow changing paths after the fact.
21type CertificateGetter struct {
22	sync.RWMutex
23
24	cert *tls.Certificate
25
26	certFile   string
27	keyFile    string
28	passphrase string
29}
30
31func NewCertificateGetter(certFile, keyFile, passphrase string) *CertificateGetter {
32	return &CertificateGetter{
33		certFile:   certFile,
34		keyFile:    keyFile,
35		passphrase: passphrase,
36	}
37}
38
39func (cg *CertificateGetter) Reload(_ map[string]interface{}) error {
40	certPEMBlock, err := ioutil.ReadFile(cg.certFile)
41	if err != nil {
42		return err
43	}
44	keyPEMBlock, err := ioutil.ReadFile(cg.keyFile)
45	if err != nil {
46		return err
47	}
48
49	// Check for encrypted pem block
50	keyBlock, _ := pem.Decode(keyPEMBlock)
51	if keyBlock == nil {
52		return errors.New("decoded PEM is blank")
53	}
54
55	if x509.IsEncryptedPEMBlock(keyBlock) {
56		keyBlock.Bytes, err = x509.DecryptPEMBlock(keyBlock, []byte(cg.passphrase))
57		if err != nil {
58			return errwrap.Wrapf("Decrypting PEM block failed {{err}}", err)
59		}
60		keyPEMBlock = pem.EncodeToMemory(keyBlock)
61	}
62
63	cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
64	if err != nil {
65		return err
66	}
67
68	cg.Lock()
69	defer cg.Unlock()
70
71	cg.cert = &cert
72
73	return nil
74}
75
76func (cg *CertificateGetter) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
77	cg.RLock()
78	defer cg.RUnlock()
79
80	if cg.cert == nil {
81		return nil, fmt.Errorf("nil certificate")
82	}
83
84	return cg.cert, nil
85}
86