1package libtrust
2
3import (
4	"crypto/tls"
5	"crypto/x509"
6	"fmt"
7	"io/ioutil"
8	"net"
9	"os"
10	"path"
11	"sync"
12)
13
14// ClientKeyManager manages client keys on the filesystem
15type ClientKeyManager struct {
16	key        PrivateKey
17	clientFile string
18	clientDir  string
19
20	clientLock sync.RWMutex
21	clients    []PublicKey
22
23	configLock sync.Mutex
24	configs    []*tls.Config
25}
26
27// NewClientKeyManager loads a new manager from a set of key files
28// and managed by the given private key.
29func NewClientKeyManager(trustKey PrivateKey, clientFile, clientDir string) (*ClientKeyManager, error) {
30	m := &ClientKeyManager{
31		key:        trustKey,
32		clientFile: clientFile,
33		clientDir:  clientDir,
34	}
35	if err := m.loadKeys(); err != nil {
36		return nil, err
37	}
38	// TODO Start watching file and directory
39
40	return m, nil
41}
42
43func (c *ClientKeyManager) loadKeys() (err error) {
44	// Load authorized keys file
45	var clients []PublicKey
46	if c.clientFile != "" {
47		clients, err = LoadKeySetFile(c.clientFile)
48		if err != nil {
49			return fmt.Errorf("unable to load authorized keys: %s", err)
50		}
51	}
52
53	// Add clients from authorized keys directory
54	files, err := ioutil.ReadDir(c.clientDir)
55	if err != nil && !os.IsNotExist(err) {
56		return fmt.Errorf("unable to open authorized keys directory: %s", err)
57	}
58	for _, f := range files {
59		if !f.IsDir() {
60			publicKey, err := LoadPublicKeyFile(path.Join(c.clientDir, f.Name()))
61			if err != nil {
62				return fmt.Errorf("unable to load authorized key file: %s", err)
63			}
64			clients = append(clients, publicKey)
65		}
66	}
67
68	c.clientLock.Lock()
69	c.clients = clients
70	c.clientLock.Unlock()
71
72	return nil
73}
74
75// RegisterTLSConfig registers a tls configuration to manager
76// such that any changes to the keys may be reflected in
77// the tls client CA pool
78func (c *ClientKeyManager) RegisterTLSConfig(tlsConfig *tls.Config) error {
79	c.clientLock.RLock()
80	certPool, err := GenerateCACertPool(c.key, c.clients)
81	if err != nil {
82		return fmt.Errorf("CA pool generation error: %s", err)
83	}
84	c.clientLock.RUnlock()
85
86	tlsConfig.ClientCAs = certPool
87
88	c.configLock.Lock()
89	c.configs = append(c.configs, tlsConfig)
90	c.configLock.Unlock()
91
92	return nil
93}
94
95// NewIdentityAuthTLSConfig creates a tls.Config for the server to use for
96// libtrust identity authentication for the domain specified
97func NewIdentityAuthTLSConfig(trustKey PrivateKey, clients *ClientKeyManager, addr string, domain string) (*tls.Config, error) {
98	tlsConfig := newTLSConfig()
99
100	tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
101	if err := clients.RegisterTLSConfig(tlsConfig); err != nil {
102		return nil, err
103	}
104
105	// Generate cert
106	ips, domains, err := parseAddr(addr)
107	if err != nil {
108		return nil, err
109	}
110	// add domain that it expects clients to use
111	domains = append(domains, domain)
112	x509Cert, err := GenerateSelfSignedServerCert(trustKey, domains, ips)
113	if err != nil {
114		return nil, fmt.Errorf("certificate generation error: %s", err)
115	}
116	tlsConfig.Certificates = []tls.Certificate{{
117		Certificate: [][]byte{x509Cert.Raw},
118		PrivateKey:  trustKey.CryptoPrivateKey(),
119		Leaf:        x509Cert,
120	}}
121
122	return tlsConfig, nil
123}
124
125// NewCertAuthTLSConfig creates a tls.Config for the server to use for
126// certificate authentication
127func NewCertAuthTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
128	tlsConfig := newTLSConfig()
129
130	cert, err := tls.LoadX509KeyPair(certPath, keyPath)
131	if err != nil {
132		return nil, fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?", certPath, keyPath, err)
133	}
134	tlsConfig.Certificates = []tls.Certificate{cert}
135
136	// Verify client certificates against a CA?
137	if caPath != "" {
138		certPool := x509.NewCertPool()
139		file, err := ioutil.ReadFile(caPath)
140		if err != nil {
141			return nil, fmt.Errorf("Couldn't read CA certificate: %s", err)
142		}
143		certPool.AppendCertsFromPEM(file)
144
145		tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
146		tlsConfig.ClientCAs = certPool
147	}
148
149	return tlsConfig, nil
150}
151
152func newTLSConfig() *tls.Config {
153	return &tls.Config{
154		NextProtos: []string{"http/1.1"},
155		// Avoid fallback on insecure SSL protocols
156		MinVersion: tls.VersionTLS10,
157	}
158}
159
160// parseAddr parses an address into an array of IPs and domains
161func parseAddr(addr string) ([]net.IP, []string, error) {
162	host, _, err := net.SplitHostPort(addr)
163	if err != nil {
164		return nil, nil, err
165	}
166	var domains []string
167	var ips []net.IP
168	ip := net.ParseIP(host)
169	if ip != nil {
170		ips = []net.IP{ip}
171	} else {
172		domains = []string{host}
173	}
174	return ips, domains, nil
175}
176