1package nsdctl 2 3import ( 4 "bufio" 5 "crypto/tls" 6 "crypto/x509" 7 "errors" 8 "fmt" 9 "io" 10 "io/ioutil" 11 "net" 12 "os" 13 "path" 14 "regexp" 15 "strconv" 16 "strings" 17 "time" 18) 19 20// Constants 21var supportedProtocols = map[string]protocol{ 22 "nsd": {Prefix: "NSDCT", Version: 1, ServerName: "nsd", ErrorStr: "error"}, 23 "unbound": {Prefix: "UBCT", Version: 1, ServerName: "unbound", ErrorStr: "error"}, 24} 25 26var configDefaults = map[string]config{ 27 // Taken from https://www.nlnetlabs.nl/projects/nsd/nsd.conf.5.html 28 "nsd": { 29 port: nsdConfig{ 30 "control-port", 31 "8952", 32 }, 33 caFile: nsdConfig{ 34 "server-cert-file", 35 "/usr/local/etc/nsd/nsd_server.pem", 36 }, 37 keyFile: nsdConfig{ 38 "control-key-file", 39 "/usr/local/etc/nsd/nsd_control.key", 40 }, 41 certFile: nsdConfig{ 42 "control-cert-file", 43 "/usr/local/etc/nsd/nsd_control.pem"}, 44 }, 45 // Taken from https://www.unbound.net/documentation/unbound.conf.html 46 "unbound": { 47 port: nsdConfig{ 48 "control-port", 49 "8953", 50 }, 51 caFile: nsdConfig{ 52 "server-cert-file", 53 "unbound_server.pem", 54 }, 55 keyFile: nsdConfig{ 56 "control-key-file", 57 "unbound_control.key", 58 }, 59 certFile: nsdConfig{ 60 "control-cert-file", 61 "unbound_control.pem", 62 }, 63 }, 64} 65 66// Structs 67 68// config contains the necessary config file values we want 69type config struct { 70 port nsdConfig 71 caFile nsdConfig 72 certFile nsdConfig 73 keyFile nsdConfig 74} 75 76// nsdConfig contains what the config file line looks like and the default 77type nsdConfig struct { 78 Config string 79 Default string 80} 81 82// protocol defines constants for each nsd-like protocol 83type protocol struct { 84 // Prefix defines the command prefix 85 Prefix string 86 // Version defines the command version 87 Version uint 88 // ServerName defines the expected certificate server name 89 ServerName string 90 // ErrorStr defines the response that signifies an error 91 ErrorStr string 92} 93 94// NSDError defines an error type for NSD 95type NSDError struct { 96 err string 97} 98 99func (e *NSDError) Error() string { 100 return e.err 101} 102 103// NSDClient is a client for NSD's control socket 104type NSDClient struct { 105 // TODO: Add in detection of type 106 // HostString is the string used to connect 107 HostString string 108 // Dialer is dialer used to create the connection 109 Dialer *net.Dialer 110 // TLSClientConfig is the tls.Config for the connection 111 TLSClientConfig *tls.Config 112 // Connection is the raw net.Conn for the client 113 Connection net.Conn 114 // protocol is the NSD protocol type (see supportedProtocols) 115 protocol *protocol 116} 117 118// NewClientFromConfig tries to autodetect and create a new NSDClient from an config file 119func NewClientFromConfig(configPath string) (*NSDClient, error) { 120 filename := path.Base(configPath) 121 122 var detectedType string 123 for k := range configDefaults { 124 if strings.Contains(filename, k) { 125 detectedType = k 126 } 127 } 128 if detectedType == "" { 129 fmt.Println("Could not detect type from config file") 130 return nil, &NSDError{"Could not detect type from config file"} 131 } 132 133 file, err := os.Open(configPath) 134 if err != nil { 135 return nil, err 136 } 137 138 // TODO: Rewrite search section. 139 // Minor optimization to precompile regex 140 // More generic way of plugging matches to results 141 // Also, flawed regexes won't match unicode names or other special characters 142 143 conf := configDefaults[detectedType] 144 rePort, err := regexp.Compile(conf.port.Config + ": *([0-9]+)(?:#.*)?") 145 reCAFile, err := regexp.Compile(conf.caFile.Config + ": *([a-zA-Z0-9/]+)(?:#.*)?") 146 reKeyFile, err := regexp.Compile(conf.keyFile.Config + ": *([a-zA-Z0-9/]+)(?:#.*)?") 147 reCertFile, err := regexp.Compile(conf.certFile.Config + ": *([a-zA-Z0-9/]+)(?:#.*)?") 148 149 var port uint 150 var hostString, caFile, keyFile, certFile string 151 152 scanner := bufio.NewScanner(file) 153 for scanner.Scan() { 154 line := scanner.Text() 155 156 if port == 0 { 157 res := rePort.FindStringSubmatch(line) 158 if res != nil { 159 port64, err := strconv.ParseUint(res[1], 10, 16) 160 if err != nil { 161 return nil, err 162 } 163 port = uint(port64) 164 } 165 } 166 167 if caFile == "" { 168 res := reCAFile.FindStringSubmatch(line) 169 if res != nil { 170 caFile = res[1] 171 } 172 } 173 if keyFile == "" { 174 res := reKeyFile.FindStringSubmatch(line) 175 if res != nil { 176 keyFile = res[1] 177 } 178 } 179 if certFile == "" { 180 res := reCertFile.FindStringSubmatch(line) 181 if res != nil { 182 certFile = res[1] 183 } 184 } 185 } 186 err = scanner.Err() 187 if err != nil { 188 return nil, err 189 } 190 191 if port != 0 { 192 hostString = "127.0.0.1:" + string(port) 193 } 194 195 return NewClient(detectedType, hostString, caFile, keyFile, certFile, false) 196} 197 198// NewClient creates a complete new NSDClient and returns any errors encountered 199func NewClient(serverType string, hostString string, caFile string, keyFile string, certFile string, skipVerify bool) (*NSDClient, error) { 200 protocol, ok := supportedProtocols[serverType] 201 if !ok { 202 return nil, errors.New("Server Type not Supported") 203 } 204 205 // Defaults 206 defaults := configDefaults[serverType] 207 if hostString == "" { 208 hostString = "127.0.0.1:" + defaults.port.Default 209 } 210 if caFile == "" { 211 caFile = defaults.caFile.Default 212 } 213 if keyFile == "" { 214 keyFile = defaults.keyFile.Default 215 } 216 if certFile == "" { 217 certFile = defaults.certFile.Default 218 } 219 220 // Set up connection 221 dialer := &net.Dialer{ 222 // TODO: Don't hardcode these 223 Timeout: 1 * time.Second, 224 // NSD 4.1.x doesn't allow more than one connection to the socket 225 // and also closes connection after every command 226 // so keepalive is useless 227 KeepAlive: 0, 228 DualStack: true, 229 } 230 231 client := &NSDClient{ 232 HostString: hostString, 233 Dialer: dialer, 234 protocol: &protocol, 235 } 236 237 clientCertKeyPair, err := tls.LoadX509KeyPair(certFile, keyFile) 238 if err != nil { 239 return nil, err 240 } 241 242 rootCAs, err := x509.SystemCertPool() 243 if err != nil { 244 return nil, err 245 } 246 247 buf, err := ioutil.ReadFile(caFile) 248 if err != nil { 249 fmt.Println("Could not load provided CA certificate(s). Using only system CAs.") 250 } else { 251 ok = rootCAs.AppendCertsFromPEM(buf) 252 if !ok { 253 fmt.Println("Could not load provided CA certificate(s). Using only system CAs.") 254 } 255 } 256 257 client.TLSClientConfig = &tls.Config{ 258 Certificates: []tls.Certificate{clientCertKeyPair}, 259 RootCAs: rootCAs, 260 ServerName: protocol.ServerName, 261 InsecureSkipVerify: skipVerify, 262 } 263 264 r, err := client.Command("status") 265 if err != nil { 266 if r != nil { 267 // Drain rest of reader 268 io.Copy(ioutil.Discard, r) 269 } 270 return nil, err 271 } 272 273 return client, nil 274} 275 276// attempt to build a connection 277// NB!: Assumes connection close 278func (n *NSDClient) attemptConnection() error { 279 // Cleanly close existing connections 280 // NB!: NSD only allows one connection at a time. 281 // Old connection MUST be closed before new one is made. 282 if n.Connection != nil { 283 n.Connection.Close() 284 } 285 286 conn, err := tls.DialWithDialer(n.Dialer, "tcp", n.HostString, n.TLSClientConfig) 287 if err != nil { 288 return err 289 } 290 291 n.Connection = conn 292 return nil 293} 294 295// Command sends a command to the control socket 296// Returns an io.Reader with the results of the command. 297// error will contain any errors encountered (including invalid commands) 298func (n *NSDClient) Command(command string) (io.Reader, error) { 299 //TODO: Currently assumes connection close. 300 // Should check if connection is available to use 301 err := n.attemptConnection() 302 if err != nil { 303 return nil, err 304 } 305 306 // Format and send the command 307 _, err = fmt.Fprintf(n.Connection, "%s%d %s\n", n.protocol.Prefix, n.protocol.Version, command) 308 if err != nil { 309 return nil, err 310 } 311 312 r := bufio.NewReader(n.Connection) 313 err = n.peekError(r) 314 return r, err 315} 316 317func (n *NSDClient) peekError(r *bufio.Reader) error { 318 // Peek the scan buffer 319 preString, err := r.Peek(len(n.protocol.ErrorStr)) 320 if err != nil { 321 return err 322 } 323 324 if string(preString) == n.protocol.ErrorStr { 325 line, _ := r.ReadString('\n') 326 return &NSDError{line[:len(line)-1]} 327 } 328 return nil 329} 330