1// Unless explicitly stated otherwise all files in this repository are licensed
2// under the Apache License Version 2.0.
3// This product includes software developed at Datadog (https://www.datadoghq.com/).
4// Copyright 2016 Datadog, Inc.
5
6// Copied from package github.com/lib/pq:
7//
8// parseOpts: https://github.com/lib/pq/blob/61fe37aa2ee24fabcdbe5c4ac1d4ac566f88f345/conn.go
9// parseURL: https://github.com/lib/pq/blob/50761b0867bd1d9d069276790bcd4a3bccf2324a/url.go
10
11package internal
12
13import (
14	"fmt"
15	"net"
16	nurl "net/url"
17	"sort"
18	"strings"
19	"unicode"
20)
21
22type values map[string]string
23
24// scanner implements a tokenizer for libpq-style option strings.
25type scanner struct {
26	s []rune
27	i int
28}
29
30// newScanner returns a new scanner initialized with the option string s.
31func newScanner(s string) *scanner {
32	return &scanner{[]rune(s), 0}
33}
34
35// Next returns the next rune.
36// It returns 0, false if the end of the text has been reached.
37func (s *scanner) Next() (rune, bool) {
38	if s.i >= len(s.s) {
39		return 0, false
40	}
41	r := s.s[s.i]
42	s.i++
43	return r, true
44}
45
46// SkipSpaces returns the next non-whitespace rune.
47// It returns 0, false if the end of the text has been reached.
48func (s *scanner) SkipSpaces() (rune, bool) {
49	r, ok := s.Next()
50	for unicode.IsSpace(r) && ok {
51		r, ok = s.Next()
52	}
53	return r, ok
54}
55
56// parseOpts parses the options from name and adds them to the values.
57// The parsing code is based on conninfo_parse from libpq's fe-connect.c
58func parseOpts(name string, o values) error {
59	s := newScanner(name)
60
61	for {
62		var (
63			keyRunes, valRunes []rune
64			r                  rune
65			ok                 bool
66		)
67
68		if r, ok = s.SkipSpaces(); !ok {
69			break
70		}
71
72		// Scan the key
73		for !unicode.IsSpace(r) && r != '=' {
74			keyRunes = append(keyRunes, r)
75			if r, ok = s.Next(); !ok {
76				break
77			}
78		}
79
80		// Skip any whitespace if we're not at the = yet
81		if r != '=' {
82			r, ok = s.SkipSpaces()
83		}
84
85		// The current character should be =
86		if r != '=' || !ok {
87			return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
88		}
89
90		// Skip any whitespace after the =
91		if r, ok = s.SkipSpaces(); !ok {
92			// If we reach the end here, the last value is just an empty string as per libpq.
93			o[string(keyRunes)] = ""
94			break
95		}
96
97		if r != '\'' {
98			for !unicode.IsSpace(r) {
99				if r == '\\' {
100					if r, ok = s.Next(); !ok {
101						return fmt.Errorf(`missing character after backslash`)
102					}
103				}
104				valRunes = append(valRunes, r)
105
106				if r, ok = s.Next(); !ok {
107					break
108				}
109			}
110		} else {
111		quote:
112			for {
113				if r, ok = s.Next(); !ok {
114					return fmt.Errorf(`unterminated quoted string literal in connection string`)
115				}
116				switch r {
117				case '\'':
118					break quote
119				case '\\':
120					r, _ = s.Next()
121					fallthrough
122				default:
123					valRunes = append(valRunes, r)
124				}
125			}
126		}
127
128		o[string(keyRunes)] = string(valRunes)
129	}
130
131	return nil
132}
133
134// parseURL no longer needs to be used by clients of this library since supplying a URL as a
135// connection string to sql.Open() is now supported:
136//
137//    sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full")
138//
139// It remains exported here for backwards-compatibility.
140//
141// parseURL converts a url to a connection string for driver.Open.
142// Example:
143//
144//    "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full"
145//
146// converts to:
147//
148//    "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full"
149//
150// A minimal example:
151//
152//    "postgres://"
153//
154// This will be blank, causing driver.Open to use all of the defaults
155func parseURL(url string) (string, error) {
156	u, err := nurl.Parse(url)
157	if err != nil {
158		return "", err
159	}
160
161	if u.Scheme != "postgres" && u.Scheme != "postgresql" {
162		return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
163	}
164
165	var kvs []string
166	escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
167	accrue := func(k, v string) {
168		if v != "" {
169			kvs = append(kvs, k+"="+escaper.Replace(v))
170		}
171	}
172
173	if u.User != nil {
174		v := u.User.Username()
175		accrue("user", v)
176
177		v, _ = u.User.Password()
178		accrue("password", v)
179	}
180
181	if host, port, err := net.SplitHostPort(u.Host); err != nil {
182		accrue("host", u.Host)
183	} else {
184		accrue("host", host)
185		accrue("port", port)
186	}
187
188	if u.Path != "" {
189		accrue("dbname", u.Path[1:])
190	}
191
192	q := u.Query()
193	for k := range q {
194		accrue(k, q.Get(k))
195	}
196
197	sort.Strings(kvs) // Makes testing easier (not a performance concern)
198	return strings.Join(kvs, " "), nil
199}
200