1/*
2Copyright 2014 SAP SE
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package driver
18
19import (
20	"context"
21	"crypto/tls"
22	"crypto/x509"
23	"database/sql/driver"
24	"fmt"
25	"io/ioutil"
26	"net/url"
27	"strconv"
28	"sync"
29)
30
31/*
32SessionVaribles maps session variables to their values.
33All defined session variables will be set once after a database connection is opened.
34*/
35type SessionVariables map[string]string
36
37/*
38A Connector represents a hdb driver in a fixed configuration.
39A Connector can be passed to sql.OpenDB (starting from go 1.10) allowing users to bypass a string based data source name.
40*/
41type Connector struct {
42	mu                             sync.RWMutex
43	host, username, password       string
44	locale                         string
45	bufferSize, fetchSize, timeout int
46	tlsConfig                      *tls.Config
47	sessionVariables               SessionVariables
48}
49
50func newConnector() *Connector {
51	return &Connector{
52		fetchSize: DefaultFetchSize,
53		timeout:   DefaultTimeout,
54	}
55}
56
57// NewBasicAuthConnector creates a connector for basic authentication.
58func NewBasicAuthConnector(host, username, password string) *Connector {
59	c := newConnector()
60	c.host = host
61	c.username = username
62	c.password = password
63	return c
64}
65
66// NewDSNConnector creates a connector from a data source name.
67func NewDSNConnector(dsn string) (*Connector, error) {
68	c := newConnector()
69
70	url, err := url.Parse(dsn)
71	if err != nil {
72		return nil, err
73	}
74
75	c.host = url.Host
76
77	if url.User != nil {
78		c.username = url.User.Username()
79		c.password, _ = url.User.Password()
80	}
81
82	var certPool *x509.CertPool
83
84	for k, v := range url.Query() {
85		switch k {
86
87		default:
88			return nil, fmt.Errorf("URL parameter %s is not supported", k)
89
90		case DSNFetchSize:
91			if len(v) == 0 {
92				continue
93			}
94			fetchSize, err := strconv.Atoi(v[0])
95			if err != nil {
96				return nil, fmt.Errorf("failed to parse fetchSize: %s", v[0])
97			}
98			if fetchSize < minFetchSize {
99				c.fetchSize = minFetchSize
100			} else {
101				c.fetchSize = fetchSize
102			}
103
104		case DSNTimeout:
105			if len(v) == 0 {
106				continue
107			}
108			timeout, err := strconv.Atoi(v[0])
109			if err != nil {
110				return nil, fmt.Errorf("failed to parse timeout: %s", v[0])
111			}
112			if timeout < minTimeout {
113				c.timeout = minTimeout
114			} else {
115				c.timeout = timeout
116			}
117
118		case DSNLocale:
119			if len(v) == 0 {
120				continue
121			}
122			c.locale = v[0]
123
124		case DSNTLSServerName:
125			if len(v) == 0 {
126				continue
127			}
128			if c.tlsConfig == nil {
129				c.tlsConfig = &tls.Config{}
130			}
131			c.tlsConfig.ServerName = v[0]
132
133		case DSNTLSInsecureSkipVerify:
134			if len(v) == 0 {
135				continue
136			}
137			var err error
138			b := true
139			if v[0] != "" {
140				b, err = strconv.ParseBool(v[0])
141				if err != nil {
142					return nil, fmt.Errorf("failed to parse InsecureSkipVerify (bool): %s", v[0])
143				}
144			}
145			if c.tlsConfig == nil {
146				c.tlsConfig = &tls.Config{}
147			}
148			c.tlsConfig.InsecureSkipVerify = b
149
150		case DSNTLSRootCAFile:
151			for _, fn := range v {
152				rootPEM, err := ioutil.ReadFile(fn)
153				if err != nil {
154					return nil, err
155				}
156				if certPool == nil {
157					certPool = x509.NewCertPool()
158				}
159				if ok := certPool.AppendCertsFromPEM(rootPEM); !ok {
160					return nil, fmt.Errorf("failed to parse root certificate - filename: %s", fn)
161				}
162			}
163			if certPool != nil {
164				if c.tlsConfig == nil {
165					c.tlsConfig = &tls.Config{}
166				}
167				c.tlsConfig.RootCAs = certPool
168			}
169		}
170	}
171	return c, nil
172}
173
174// Host returns the host of the connector.
175func (c *Connector) Host() string {
176	return c.host
177}
178
179// Username returns the username of the connector.
180func (c *Connector) Username() string {
181	return c.username
182}
183
184// Password returns the password of the connector.
185func (c *Connector) Password() string {
186	return c.password
187}
188
189// Locale returns the locale of the connector.
190func (c *Connector) Locale() string {
191	c.mu.RLock()
192	defer c.mu.RUnlock()
193	return c.locale
194}
195
196/*
197SetLocale sets the locale of the connector.
198
199For more information please see DSNLocale.
200*/
201func (c *Connector) SetLocale(locale string) {
202	c.mu.Lock()
203	c.locale = locale
204	c.mu.Unlock()
205}
206
207// FetchSize returns the fetchSize of the connector.
208func (c *Connector) FetchSize() int {
209	c.mu.RLock()
210	defer c.mu.RUnlock()
211	return c.fetchSize
212}
213
214/*
215SetFetchSize sets the fetchSize of the connector.
216
217For more information please see DSNFetchSize.
218*/
219func (c *Connector) SetFetchSize(fetchSize int) error {
220	c.mu.Lock()
221	defer c.mu.Unlock()
222	if fetchSize < minFetchSize {
223		fetchSize = minFetchSize
224	}
225	c.fetchSize = fetchSize
226	return nil
227}
228
229// Timeout returns the timeout of the connector.
230func (c *Connector) Timeout() int {
231	c.mu.RLock()
232	defer c.mu.RUnlock()
233	return c.timeout
234}
235
236/*
237SetTimeout sets the timeout of the connector.
238
239For more information please see DSNTimeout.
240*/
241func (c *Connector) SetTimeout(timeout int) error {
242	c.mu.Lock()
243	defer c.mu.Unlock()
244	if timeout < minTimeout {
245		timeout = minTimeout
246	}
247	c.timeout = timeout
248	return nil
249}
250
251// TLSConfig returns the TLS configuration of the connector.
252func (c *Connector) TLSConfig() *tls.Config {
253	c.mu.RLock()
254	defer c.mu.RUnlock()
255	return c.tlsConfig
256}
257
258// SetTLSConfig sets the TLS configuration of the connector.
259func (c *Connector) SetTLSConfig(tlsConfig *tls.Config) error {
260	c.mu.Lock()
261	defer c.mu.Unlock()
262	c.tlsConfig = tlsConfig
263	return nil
264}
265
266// SessionVariables returns the session variables stored in connector.
267func (c *Connector) SessionVariables() SessionVariables {
268	c.mu.RLock()
269	defer c.mu.RUnlock()
270	return c.sessionVariables
271}
272
273// SetSessionVariables sets the session varibles of the connector.
274func (c *Connector) SetSessionVariables(sessionVariables SessionVariables) error {
275	c.mu.Lock()
276	defer c.mu.Unlock()
277	c.sessionVariables = sessionVariables
278	return nil
279}
280
281// BasicAuthDSN return the connector DSN for basic authentication.
282func (c *Connector) BasicAuthDSN() string {
283	values := url.Values{}
284	if c.locale != "" {
285		values.Set(DSNLocale, c.locale)
286	}
287	if c.fetchSize != 0 {
288		values.Set(DSNFetchSize, fmt.Sprintf("%d", c.fetchSize))
289	}
290	if c.timeout != 0 {
291		values.Set(DSNTimeout, fmt.Sprintf("%d", c.timeout))
292	}
293	return (&url.URL{
294		Scheme:   DriverName,
295		User:     url.UserPassword(c.username, c.password),
296		Host:     c.host,
297		RawQuery: values.Encode(),
298	}).String()
299}
300
301// Connect implements the database/sql/driver/Connector interface.
302func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
303	return newConn(ctx, c)
304}
305
306// Driver implements the database/sql/driver/Connector interface.
307func (c *Connector) Driver() driver.Driver {
308	return drv
309}
310