1package nsdctl
2
3import (
4	"bufio"
5	"crypto/tls"
6	"crypto/x509"
7	"errors"
8	"fmt"
9	"io"
10	"io/ioutil"
11	"net"
12	"os"
13	"path"
14	"regexp"
15	"strconv"
16	"strings"
17	"time"
18)
19
20// Constants
21var supportedProtocols = map[string]protocol{
22	"nsd":     {Prefix: "NSDCT", Version: 1, ServerName: "nsd", ErrorStr: "error"},
23	"unbound": {Prefix: "UBCT", Version: 1, ServerName: "unbound", ErrorStr: "error"},
24}
25
26var configDefaults = map[string]config{
27	// Taken from https://www.nlnetlabs.nl/projects/nsd/nsd.conf.5.html
28	"nsd": {
29		port: nsdConfig{
30			"control-port",
31			"8952",
32		},
33		caFile: nsdConfig{
34			"server-cert-file",
35			"/usr/local/etc/nsd/nsd_server.pem",
36		},
37		keyFile: nsdConfig{
38			"control-key-file",
39			"/usr/local/etc/nsd/nsd_control.key",
40		},
41		certFile: nsdConfig{
42			"control-cert-file",
43			"/usr/local/etc/nsd/nsd_control.pem"},
44	},
45	// Taken from https://www.unbound.net/documentation/unbound.conf.html
46	"unbound": {
47		port: nsdConfig{
48			"control-port",
49			"8953",
50		},
51		caFile: nsdConfig{
52			"server-cert-file",
53			"unbound_server.pem",
54		},
55		keyFile: nsdConfig{
56			"control-key-file",
57			"unbound_control.key",
58		},
59		certFile: nsdConfig{
60			"control-cert-file",
61			"unbound_control.pem",
62		},
63	},
64}
65
66// Structs
67
68// config contains the necessary config file values we want
69type config struct {
70	port     nsdConfig
71	caFile   nsdConfig
72	certFile nsdConfig
73	keyFile  nsdConfig
74}
75
76// nsdConfig contains what the config file line looks like and the default
77type nsdConfig struct {
78	Config  string
79	Default string
80}
81
82// protocol defines constants for each nsd-like protocol
83type protocol struct {
84	// Prefix defines the command prefix
85	Prefix string
86	// Version defines the command version
87	Version uint
88	// ServerName defines the expected certificate server name
89	ServerName string
90	// ErrorStr defines the response that signifies an error
91	ErrorStr string
92}
93
94// NSDError defines an error type for NSD
95type NSDError struct {
96	err string
97}
98
99func (e *NSDError) Error() string {
100	return e.err
101}
102
103// NSDClient is a client for NSD's control socket
104type NSDClient struct {
105	// TODO: Add in detection of type
106	// HostString is the string used to connect
107	HostString string
108	// Dialer is dialer used to create the connection
109	Dialer *net.Dialer
110	// TLSClientConfig is the tls.Config for the connection
111	TLSClientConfig *tls.Config
112	// Connection is the raw net.Conn for the client
113	Connection net.Conn
114	// protocol is the NSD protocol type (see supportedProtocols)
115	protocol *protocol
116}
117
118// NewClientFromConfig tries to autodetect and create a new NSDClient from an config file
119func NewClientFromConfig(configPath string) (*NSDClient, error) {
120	filename := path.Base(configPath)
121
122	var detectedType string
123	for k := range configDefaults {
124		if strings.Contains(filename, k) {
125			detectedType = k
126		}
127	}
128	if detectedType == "" {
129		fmt.Println("Could not detect type from config file")
130		return nil, &NSDError{"Could not detect type from config file"}
131	}
132
133	file, err := os.Open(configPath)
134	if err != nil {
135		return nil, err
136	}
137
138	// TODO: Rewrite search section.
139	// Minor optimization to precompile regex
140	// More generic way of plugging matches to results
141	// Also, flawed regexes won't match unicode names or other special characters
142
143	conf := configDefaults[detectedType]
144	rePort, err := regexp.Compile(conf.port.Config + ": *([0-9]+)(?:#.*)?")
145	reCAFile, err := regexp.Compile(conf.caFile.Config + ": *([a-zA-Z0-9/]+)(?:#.*)?")
146	reKeyFile, err := regexp.Compile(conf.keyFile.Config + ": *([a-zA-Z0-9/]+)(?:#.*)?")
147	reCertFile, err := regexp.Compile(conf.certFile.Config + ": *([a-zA-Z0-9/]+)(?:#.*)?")
148
149	var port uint
150	var hostString, caFile, keyFile, certFile string
151
152	scanner := bufio.NewScanner(file)
153	for scanner.Scan() {
154		line := scanner.Text()
155
156		if port == 0 {
157			res := rePort.FindStringSubmatch(line)
158			if res != nil {
159				port64, err := strconv.ParseUint(res[1], 10, 16)
160				if err != nil {
161					return nil, err
162				}
163				port = uint(port64)
164			}
165		}
166
167		if caFile == "" {
168			res := reCAFile.FindStringSubmatch(line)
169			if res != nil {
170				caFile = res[1]
171			}
172		}
173		if keyFile == "" {
174			res := reKeyFile.FindStringSubmatch(line)
175			if res != nil {
176				keyFile = res[1]
177			}
178		}
179		if certFile == "" {
180			res := reCertFile.FindStringSubmatch(line)
181			if res != nil {
182				certFile = res[1]
183			}
184		}
185	}
186	err = scanner.Err()
187	if err != nil {
188		return nil, err
189	}
190
191	if port != 0 {
192		hostString = "127.0.0.1:" + string(port)
193	}
194
195	return NewClient(detectedType, hostString, caFile, keyFile, certFile, false)
196}
197
198// NewClient creates a complete new NSDClient and returns any errors encountered
199func NewClient(serverType string, hostString string, caFile string, keyFile string, certFile string, skipVerify bool) (*NSDClient, error) {
200	protocol, ok := supportedProtocols[serverType]
201	if !ok {
202		return nil, errors.New("Server Type not Supported")
203	}
204
205	// Defaults
206	defaults := configDefaults[serverType]
207	if hostString == "" {
208		hostString = "127.0.0.1:" + defaults.port.Default
209	}
210	if caFile == "" {
211		caFile = defaults.caFile.Default
212	}
213	if keyFile == "" {
214		keyFile = defaults.keyFile.Default
215	}
216	if certFile == "" {
217		certFile = defaults.certFile.Default
218	}
219
220	// Set up connection
221	dialer := &net.Dialer{
222		// TODO: Don't hardcode these
223		Timeout: 1 * time.Second,
224		// NSD 4.1.x doesn't allow more than one connection to the socket
225		// and also closes connection after every command
226		// so keepalive is useless
227		KeepAlive: 0,
228		DualStack: true,
229	}
230
231	client := &NSDClient{
232		HostString: hostString,
233		Dialer:     dialer,
234		protocol:   &protocol,
235	}
236
237	clientCertKeyPair, err := tls.LoadX509KeyPair(certFile, keyFile)
238	if err != nil {
239		return nil, err
240	}
241
242	rootCAs, err := x509.SystemCertPool()
243	if err != nil {
244		return nil, err
245	}
246
247	buf, err := ioutil.ReadFile(caFile)
248	if err != nil {
249		fmt.Println("Could not load provided CA certificate(s). Using only system CAs.")
250	} else {
251		ok = rootCAs.AppendCertsFromPEM(buf)
252		if !ok {
253			fmt.Println("Could not load provided CA certificate(s). Using only system CAs.")
254		}
255	}
256
257	client.TLSClientConfig = &tls.Config{
258		Certificates:       []tls.Certificate{clientCertKeyPair},
259		RootCAs:            rootCAs,
260		ServerName:         protocol.ServerName,
261		InsecureSkipVerify: skipVerify,
262	}
263
264	r, err := client.Command("status")
265	if err != nil {
266		if r != nil {
267			// Drain rest of reader
268			io.Copy(ioutil.Discard, r)
269		}
270		return nil, err
271	}
272
273	return client, nil
274}
275
276// attempt to build a connection
277// NB!: Assumes connection close
278func (n *NSDClient) attemptConnection() error {
279	// Cleanly close existing connections
280	// NB!: NSD only allows one connection at a time.
281	// Old connection MUST be closed before new one is made.
282	if n.Connection != nil {
283		n.Connection.Close()
284	}
285
286	conn, err := tls.DialWithDialer(n.Dialer, "tcp", n.HostString, n.TLSClientConfig)
287	if err != nil {
288		return err
289	}
290
291	n.Connection = conn
292	return nil
293}
294
295// Command sends a command to the control socket
296// Returns an io.Reader with the results of the command.
297// error will contain any errors encountered (including invalid commands)
298func (n *NSDClient) Command(command string) (io.Reader, error) {
299	//TODO: Currently assumes connection close.
300	// Should check if connection is available to use
301	err := n.attemptConnection()
302	if err != nil {
303		return nil, err
304	}
305
306	// Format and send the command
307	_, err = fmt.Fprintf(n.Connection, "%s%d %s\n", n.protocol.Prefix, n.protocol.Version, command)
308	if err != nil {
309		return nil, err
310	}
311
312	r := bufio.NewReader(n.Connection)
313	err = n.peekError(r)
314	return r, err
315}
316
317func (n *NSDClient) peekError(r *bufio.Reader) error {
318	// Peek the scan buffer
319	preString, err := r.Peek(len(n.protocol.ErrorStr))
320	if err != nil {
321		return err
322	}
323
324	if string(preString) == n.protocol.ErrorStr {
325		line, _ := r.ReadString('\n')
326		return &NSDError{line[:len(line)-1]}
327	}
328	return nil
329}
330