1// Copyright (c) 2012-present The upper.io/db authors. All rights reserved.
2//
3// Permission is hereby granted, free of charge, to any person obtaining
4// a copy of this software and associated documentation files (the
5// "Software"), to deal in the Software without restriction, including
6// without limitation the rights to use, copy, modify, merge, publish,
7// distribute, sublicense, and/or sell copies of the Software, and to
8// permit persons to whom the Software is furnished to do so, subject to
9// the following conditions:
10//
11// The above copyright notice and this permission notice shall be
12// included in all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21
22// Package sqlite wraps the github.com/lib/sqlite SQLite driver. See
23// https://upper.io/db.v3/sqlite for documentation, particularities and
24// usage examples.
25package sqlite
26
27import (
28	"context"
29	"database/sql"
30	"errors"
31	"fmt"
32	"sync"
33	"sync/atomic"
34
35	_ "github.com/mattn/go-sqlite3" // SQLite3 driver.
36	db "upper.io/db.v3"
37	"upper.io/db.v3/internal/sqladapter"
38	"upper.io/db.v3/internal/sqladapter/compat"
39	"upper.io/db.v3/internal/sqladapter/exql"
40	"upper.io/db.v3/lib/sqlbuilder"
41)
42
43// database is the actual implementation of Database
44type database struct {
45	sqladapter.BaseDatabase
46
47	sqlbuilder.SQLBuilder
48
49	connURL db.ConnectionURL
50	mu      sync.Mutex
51}
52
53var (
54	_ = sqlbuilder.Database(&database{})
55	_ = sqladapter.Database(&database{})
56)
57
58var (
59	fileOpenCount       int32
60	errTooManyOpenFiles       = errors.New(`Too many open database files.`)
61	maxOpenFiles        int32 = 100
62)
63
64// newDatabase creates a new *database session for internal use.
65func newDatabase(settings db.ConnectionURL) *database {
66	return &database{
67		connURL: settings,
68	}
69}
70
71// CleanUp cleans up the session.
72func (d *database) CleanUp() error {
73	if atomic.AddInt32(&fileOpenCount, -1) < 0 {
74		return errors.New(`Close() without Open()?`)
75	}
76	return nil
77}
78
79// ConnectionURL returns this database's ConnectionURL.
80func (d *database) ConnectionURL() db.ConnectionURL {
81	return d.connURL
82}
83
84// Open attempts to open a connection to the database server.
85func (d *database) Open(connURL db.ConnectionURL) error {
86	if connURL == nil {
87		return db.ErrMissingConnURL
88	}
89	d.connURL = connURL
90	return d.open()
91}
92
93// NewTx starts a transaction block.
94func (d *database) NewTx(ctx context.Context) (sqlbuilder.Tx, error) {
95	nTx, err := d.NewDatabaseTx(ctx)
96	if err != nil {
97		return nil, err
98	}
99	return &tx{DatabaseTx: nTx}, nil
100}
101
102// Collections returns a list of non-system tables from the database.
103func (d *database) Collections() (collections []string, err error) {
104	q := d.Select("tbl_name").
105		From("sqlite_master").
106		Where("type = ?", "table")
107
108	iter := q.Iterator()
109	defer iter.Close()
110
111	for iter.Next() {
112		var tableName string
113		if err := iter.Scan(&tableName); err != nil {
114			return nil, err
115		}
116		collections = append(collections, tableName)
117	}
118
119	return collections, nil
120}
121
122func (d *database) open() error {
123	// Binding with sqladapter's logic.
124	d.BaseDatabase = sqladapter.NewBaseDatabase(d)
125
126	// Binding with sqlbuilder.
127	d.SQLBuilder = sqlbuilder.WithSession(d.BaseDatabase, template)
128
129	openFn := func() error {
130		openFiles := atomic.LoadInt32(&fileOpenCount)
131		if openFiles < maxOpenFiles {
132			sess, err := sql.Open("sqlite3", d.ConnectionURL().String())
133			if err == nil {
134				if err := d.BaseDatabase.BindSession(sess); err != nil {
135					return err
136				}
137				atomic.AddInt32(&fileOpenCount, 1)
138				return nil
139			}
140			return err
141		}
142		return errTooManyOpenFiles
143	}
144
145	if err := d.BaseDatabase.WaitForConnection(openFn); err != nil {
146		return err
147	}
148
149	return nil
150}
151
152func (d *database) clone(ctx context.Context, checkConn bool) (*database, error) {
153	clone := newDatabase(d.connURL)
154
155	var err error
156	clone.BaseDatabase, err = d.NewClone(clone, checkConn)
157	if err != nil {
158		return nil, err
159	}
160
161	clone.SetContext(ctx)
162
163	clone.SQLBuilder = sqlbuilder.WithSession(clone.BaseDatabase, template)
164
165	return clone, nil
166}
167
168// CompileStatement allows sqladapter to compile the given statement into the
169// format SQLite expects.
170func (d *database) CompileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) {
171	compiled, err := stmt.Compile(template)
172	if err != nil {
173		panic(err.Error())
174	}
175	return sqlbuilder.Preprocess(compiled, args)
176}
177
178// Err allows sqladapter to translate some known errors into generic errors.
179func (d *database) Err(err error) error {
180	if err != nil {
181		if err == errTooManyOpenFiles {
182			return db.ErrTooManyClients
183		}
184	}
185	return err
186}
187
188// StatementExec wraps the statement to execute around a transaction.
189func (d *database) StatementExec(ctx context.Context, query string, args ...interface{}) (res sql.Result, err error) {
190	d.mu.Lock()
191	defer d.mu.Unlock()
192
193	if d.Transaction() != nil {
194		return compat.ExecContext(d.Driver().(*sql.Tx), ctx, query, args)
195	}
196
197	sqlTx, err := compat.BeginTx(d.Session(), ctx, d.TxOptions())
198	if err != nil {
199		return nil, err
200	}
201
202	if res, err = compat.ExecContext(sqlTx, ctx, query, args); err != nil {
203		return nil, err
204	}
205
206	if err = sqlTx.Commit(); err != nil {
207		return nil, err
208	}
209
210	return res, err
211}
212
213// NewCollection allows sqladapter create a local db.Collection.
214func (d *database) NewCollection(name string) db.Collection {
215	return newTable(d, name)
216}
217
218// Tx creates a transaction and passes it to the given function, if if the
219// function returns no error then the transaction is commited.
220func (d *database) Tx(ctx context.Context, fn func(tx sqlbuilder.Tx) error) error {
221	return sqladapter.RunTx(d, ctx, fn)
222}
223
224// NewDatabaseTx allows sqladapter start a transaction block.
225func (d *database) NewDatabaseTx(ctx context.Context) (sqladapter.DatabaseTx, error) {
226	clone, err := d.clone(ctx, true)
227	if err != nil {
228		return nil, err
229	}
230	clone.mu.Lock()
231	defer clone.mu.Unlock()
232
233	openFn := func() error {
234		//sqlTx, err := compat.BeginTx(clone.BaseDatabase.Session(), ctx, nil) // Temporal fix.
235		sqlTx, err := clone.BaseDatabase.Session().Begin()
236		if err == nil {
237			return clone.BindTx(ctx, sqlTx)
238		}
239		return err
240	}
241
242	if err := d.BaseDatabase.WaitForConnection(openFn); err != nil {
243		return nil, err
244	}
245
246	return sqladapter.NewDatabaseTx(clone), nil
247}
248
249// LookupName allows sqladapter look up the database's name.
250func (d *database) LookupName() (string, error) {
251	connURL, err := ParseURL(d.ConnectionURL().String())
252	if err != nil {
253		return "", err
254	}
255	return connURL.Database, nil
256}
257
258// TableExists allows sqladapter check whether a table exists and returns an
259// error in case it doesn't.
260func (d *database) TableExists(name string) error {
261	q := d.Select("tbl_name").
262		From("sqlite_master").
263		Where("type = 'table' AND tbl_name = ?", name)
264
265	iter := q.Iterator()
266	defer iter.Close()
267
268	if iter.Next() {
269		var name string
270		if err := iter.Scan(&name); err != nil {
271			return err
272		}
273		return nil
274	}
275	return db.ErrCollectionDoesNotExist
276}
277
278// PrimaryKeys allows sqladapter find a table's primary keys.
279func (d *database) PrimaryKeys(tableName string) ([]string, error) {
280	pk := make([]string, 0, 1)
281
282	stmt := exql.RawSQL(fmt.Sprintf("PRAGMA TABLE_INFO('%s')", tableName))
283
284	rows, err := d.Query(stmt)
285	if err != nil {
286		return nil, err
287	}
288
289	columns := []struct {
290		Name string `db:"name"`
291		PK   int    `db:"pk"`
292	}{}
293
294	if err := sqlbuilder.NewIterator(rows).All(&columns); err != nil {
295		return nil, err
296	}
297
298	maxValue := -1
299
300	for _, column := range columns {
301		if column.PK > 0 && column.PK > maxValue {
302			maxValue = column.PK
303		}
304	}
305
306	if maxValue > 0 {
307		for _, column := range columns {
308			if column.PK > 0 {
309				pk = append(pk, column.Name)
310			}
311		}
312	}
313
314	return pk, nil
315}
316
317// WithContext creates a copy of the session on the given context.
318func (d *database) WithContext(ctx context.Context) sqlbuilder.Database {
319	newDB, _ := d.clone(ctx, false)
320	return newDB
321}
322