1package clickhouse
2
3import (
4	"bufio"
5	"database/sql"
6	"database/sql/driver"
7	"fmt"
8	"io"
9	"log"
10	"net/url"
11	"os"
12	"strconv"
13	"strings"
14	"sync"
15	"sync/atomic"
16	"time"
17
18	"github.com/ClickHouse/clickhouse-go/lib/leakypool"
19
20	"github.com/ClickHouse/clickhouse-go/lib/binary"
21	"github.com/ClickHouse/clickhouse-go/lib/data"
22	"github.com/ClickHouse/clickhouse-go/lib/protocol"
23)
24
25const (
26	// DefaultDatabase when connecting to ClickHouse
27	DefaultDatabase = "default"
28	// DefaultUsername when connecting to ClickHouse
29	DefaultUsername = "default"
30	// DefaultConnTimeout when connecting to ClickHouse
31	DefaultConnTimeout = 5 * time.Second
32	// DefaultReadTimeout when reading query results
33	DefaultReadTimeout = time.Minute
34	// DefaultWriteTimeout when sending queries
35	DefaultWriteTimeout = time.Minute
36)
37
38var (
39	unixtime    int64
40	logOutput   io.Writer = os.Stdout
41	hostname, _           = os.Hostname()
42	poolInit    sync.Once
43)
44
45func init() {
46	sql.Register("clickhouse", &bootstrap{})
47	go func() {
48		for tick := time.Tick(time.Second); ; {
49			select {
50			case <-tick:
51				atomic.AddInt64(&unixtime, int64(time.Second))
52			}
53		}
54	}()
55}
56
57func now() time.Time {
58	return time.Unix(0, atomic.LoadInt64(&unixtime))
59}
60
61type bootstrap struct{}
62
63func (d *bootstrap) Open(dsn string) (driver.Conn, error) {
64	return Open(dsn)
65}
66
67// SetLogOutput allows to change output of the default logger
68func SetLogOutput(output io.Writer) {
69	logOutput = output
70}
71
72// Open the connection
73func Open(dsn string) (driver.Conn, error) {
74	clickhouse, err := open(dsn)
75	if err != nil {
76		return nil, err
77	}
78
79	return clickhouse, err
80}
81
82func open(dsn string) (*clickhouse, error) {
83	url, err := url.Parse(dsn)
84	if err != nil {
85		return nil, err
86	}
87	var (
88		hosts            = []string{url.Host}
89		query            = url.Query()
90		secure           = false
91		skipVerify       = false
92		tlsConfigName    = query.Get("tls_config")
93		noDelay          = true
94		compress         = false
95		database         = query.Get("database")
96		username         = query.Get("username")
97		password         = query.Get("password")
98		blockSize        = 1000000
99		connTimeout      = DefaultConnTimeout
100		readTimeout      = DefaultReadTimeout
101		writeTimeout     = DefaultWriteTimeout
102		connOpenStrategy = connOpenRandom
103		poolSize         = 100
104	)
105	if len(database) == 0 {
106		database = DefaultDatabase
107	}
108	if len(username) == 0 {
109		username = DefaultUsername
110	}
111	if v, err := strconv.ParseBool(query.Get("no_delay")); err == nil {
112		noDelay = v
113	}
114	tlsConfig := getTLSConfigClone(tlsConfigName)
115	if tlsConfigName != "" && tlsConfig == nil {
116		return nil, fmt.Errorf("invalid tls_config - no config registered under name %s", tlsConfigName)
117	}
118	secure = tlsConfig != nil
119	if v, err := strconv.ParseBool(query.Get("secure")); err == nil {
120		secure = v
121	}
122	if v, err := strconv.ParseBool(query.Get("skip_verify")); err == nil {
123		skipVerify = v
124	}
125	if duration, err := strconv.ParseFloat(query.Get("timeout"), 64); err == nil {
126		connTimeout = time.Duration(duration * float64(time.Second))
127	}
128	if duration, err := strconv.ParseFloat(query.Get("read_timeout"), 64); err == nil {
129		readTimeout = time.Duration(duration * float64(time.Second))
130	}
131	if duration, err := strconv.ParseFloat(query.Get("write_timeout"), 64); err == nil {
132		writeTimeout = time.Duration(duration * float64(time.Second))
133	}
134	if size, err := strconv.ParseInt(query.Get("block_size"), 10, 64); err == nil {
135		blockSize = int(size)
136	}
137	if size, err := strconv.ParseInt(query.Get("pool_size"), 10, 64); err == nil {
138		poolSize = int(size)
139	}
140	poolInit.Do(func() {
141		leakypool.InitBytePool(poolSize)
142	})
143	if altHosts := strings.Split(query.Get("alt_hosts"), ","); len(altHosts) != 0 {
144		for _, host := range altHosts {
145			if len(host) != 0 {
146				hosts = append(hosts, host)
147			}
148		}
149	}
150	switch query.Get("connection_open_strategy") {
151	case "random":
152		connOpenStrategy = connOpenRandom
153	case "in_order":
154		connOpenStrategy = connOpenInOrder
155	case "time_random":
156		connOpenStrategy = connOpenTimeRandom
157	}
158
159	settings, err := makeQuerySettings(query)
160	if err != nil {
161		return nil, err
162	}
163
164	if v, err := strconv.ParseBool(query.Get("compress")); err == nil {
165		compress = v
166	}
167
168	var (
169		ch = clickhouse{
170			logf:      func(string, ...interface{}) {},
171			settings:  settings,
172			compress:  compress,
173			blockSize: blockSize,
174			ServerInfo: data.ServerInfo{
175				Timezone: time.Local,
176			},
177		}
178		logger = log.New(logOutput, "[clickhouse]", 0)
179	)
180	if debug, err := strconv.ParseBool(url.Query().Get("debug")); err == nil && debug {
181		ch.logf = logger.Printf
182	}
183	ch.logf("host(s)=%s, database=%s, username=%s",
184		strings.Join(hosts, ", "),
185		database,
186		username,
187	)
188	options := connOptions{
189		secure:       secure,
190		tlsConfig:    tlsConfig,
191		skipVerify:   skipVerify,
192		hosts:        hosts,
193		connTimeout:  connTimeout,
194		readTimeout:  readTimeout,
195		writeTimeout: writeTimeout,
196		noDelay:      noDelay,
197		openStrategy: connOpenStrategy,
198		logf:         ch.logf,
199	}
200	if ch.conn, err = dial(options); err != nil {
201		return nil, err
202	}
203	logger.SetPrefix(fmt.Sprintf("[clickhouse][connect=%d]", ch.conn.ident))
204	ch.buffer = bufio.NewWriter(ch.conn)
205
206	ch.decoder = binary.NewDecoderWithCompress(ch.conn)
207	ch.encoder = binary.NewEncoderWithCompress(ch.buffer)
208
209	if err := ch.hello(database, username, password); err != nil {
210		ch.conn.Close()
211		return nil, err
212	}
213	return &ch, nil
214}
215
216func (ch *clickhouse) hello(database, username, password string) error {
217	ch.logf("[hello] -> %s", ch.ClientInfo)
218	{
219		ch.encoder.Uvarint(protocol.ClientHello)
220		if err := ch.ClientInfo.Write(ch.encoder); err != nil {
221			return err
222		}
223		{
224			ch.encoder.String(database)
225			ch.encoder.String(username)
226			ch.encoder.String(password)
227		}
228		if err := ch.encoder.Flush(); err != nil {
229			return err
230		}
231
232	}
233	{
234		packet, err := ch.decoder.Uvarint()
235		if err != nil {
236			return err
237		}
238		switch packet {
239		case protocol.ServerException:
240			return ch.exception()
241		case protocol.ServerHello:
242			if err := ch.ServerInfo.Read(ch.decoder); err != nil {
243				return err
244			}
245		case protocol.ServerEndOfStream:
246			ch.logf("[bootstrap] <- end of stream")
247			return nil
248		default:
249			return fmt.Errorf("[hello] unexpected packet [%d] from server", packet)
250		}
251	}
252	ch.logf("[hello] <- %s", ch.ServerInfo)
253	return nil
254}
255