1/*
2 * Copyright © 2019-2020 A Bunch Tell LLC.
3 *
4 * This file is part of WriteFreely.
5 *
6 * WriteFreely is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Affero General Public License, included
8 * in the LICENSE file in this source code package.
9 */
10
11package db
12
13import (
14	"fmt"
15	"strings"
16)
17
18type ColumnType int
19
20type OptionalInt struct {
21	Set   bool
22	Value int
23}
24
25type OptionalString struct {
26	Set   bool
27	Value string
28}
29
30type SQLBuilder interface {
31	ToSQL() (string, error)
32}
33
34type Column struct {
35	Dialect    DialectType
36	Name       string
37	Nullable   bool
38	Default    OptionalString
39	Type       ColumnType
40	Size       OptionalInt
41	PrimaryKey bool
42}
43
44type CreateTableSqlBuilder struct {
45	Dialect     DialectType
46	Name        string
47	IfNotExists bool
48	ColumnOrder []string
49	Columns     map[string]*Column
50	Constraints []string
51}
52
53const (
54	ColumnTypeBool     ColumnType = iota
55	ColumnTypeSmallInt ColumnType = iota
56	ColumnTypeInteger  ColumnType = iota
57	ColumnTypeChar     ColumnType = iota
58	ColumnTypeVarChar  ColumnType = iota
59	ColumnTypeText     ColumnType = iota
60	ColumnTypeDateTime ColumnType = iota
61)
62
63var _ SQLBuilder = &CreateTableSqlBuilder{}
64
65var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
66var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
67
68func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
69	if dialect != DialectMySQL && dialect != DialectSQLite {
70		return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
71	}
72	switch d {
73	case ColumnTypeSmallInt:
74		{
75			if dialect == DialectSQLite {
76				return "INTEGER", nil
77			}
78			mod := ""
79			if size.Set {
80				mod = fmt.Sprintf("(%d)", size.Value)
81			}
82			return "SMALLINT" + mod, nil
83		}
84	case ColumnTypeInteger:
85		{
86			if dialect == DialectSQLite {
87				return "INTEGER", nil
88			}
89			mod := ""
90			if size.Set {
91				mod = fmt.Sprintf("(%d)", size.Value)
92			}
93			return "INT" + mod, nil
94		}
95	case ColumnTypeChar:
96		{
97			if dialect == DialectSQLite {
98				return "TEXT", nil
99			}
100			mod := ""
101			if size.Set {
102				mod = fmt.Sprintf("(%d)", size.Value)
103			}
104			return "CHAR" + mod, nil
105		}
106	case ColumnTypeVarChar:
107		{
108			if dialect == DialectSQLite {
109				return "TEXT", nil
110			}
111			mod := ""
112			if size.Set {
113				mod = fmt.Sprintf("(%d)", size.Value)
114			}
115			return "VARCHAR" + mod, nil
116		}
117	case ColumnTypeBool:
118		{
119			if dialect == DialectSQLite {
120				return "INTEGER", nil
121			}
122			return "TINYINT(1)", nil
123		}
124	case ColumnTypeDateTime:
125		return "DATETIME", nil
126	case ColumnTypeText:
127		return "TEXT", nil
128	}
129	return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
130}
131
132func (c *Column) SetName(name string) *Column {
133	c.Name = name
134	return c
135}
136
137func (c *Column) SetNullable(nullable bool) *Column {
138	c.Nullable = nullable
139	return c
140}
141
142func (c *Column) SetPrimaryKey(pk bool) *Column {
143	c.PrimaryKey = pk
144	return c
145}
146
147func (c *Column) SetDefault(value string) *Column {
148	c.Default = OptionalString{Set: true, Value: value}
149	return c
150}
151
152func (c *Column) SetDefaultCurrentTimestamp() *Column {
153	def := "NOW()"
154	if c.Dialect == DialectSQLite {
155		def = "CURRENT_TIMESTAMP"
156	}
157	c.Default = OptionalString{Set: true, Value: def}
158	return c
159}
160
161func (c *Column) SetType(t ColumnType) *Column {
162	c.Type = t
163	return c
164}
165
166func (c *Column) SetSize(size int) *Column {
167	c.Size = OptionalInt{Set: true, Value: size}
168	return c
169}
170
171func (c *Column) String() (string, error) {
172	var str strings.Builder
173
174	str.WriteString(c.Name)
175
176	str.WriteString(" ")
177	typeStr, err := c.Type.Format(c.Dialect, c.Size)
178	if err != nil {
179		return "", err
180	}
181
182	str.WriteString(typeStr)
183
184	if !c.Nullable {
185		str.WriteString(" NOT NULL")
186	}
187
188	if c.Default.Set {
189		str.WriteString(" DEFAULT ")
190		val := c.Default.Value
191		if val == "" {
192			val = "''"
193		}
194		str.WriteString(val)
195	}
196
197	if c.PrimaryKey {
198		str.WriteString(" PRIMARY KEY")
199	}
200
201	return str.String(), nil
202}
203
204func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder {
205	if b.Columns == nil {
206		b.Columns = make(map[string]*Column)
207	}
208	b.Columns[column.Name] = column
209	b.ColumnOrder = append(b.ColumnOrder, column.Name)
210	return b
211}
212
213func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder {
214	for _, column := range columns {
215		if _, ok := b.Columns[column]; !ok {
216			// This fails silently.
217			return b
218		}
219	}
220	b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ",")))
221	return b
222}
223
224func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder {
225	b.IfNotExists = ine
226	return b
227}
228
229func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
230	var str strings.Builder
231
232	str.WriteString("CREATE TABLE ")
233	if b.IfNotExists {
234		str.WriteString("IF NOT EXISTS ")
235	}
236	str.WriteString(b.Name)
237
238	var things []string
239	for _, columnName := range b.ColumnOrder {
240		column, ok := b.Columns[columnName]
241		if !ok {
242			return "", fmt.Errorf("column not found: %s", columnName)
243		}
244		columnStr, err := column.String()
245		if err != nil {
246			return "", err
247		}
248		things = append(things, columnStr)
249	}
250	for _, constraint := range b.Constraints {
251		things = append(things, constraint)
252	}
253
254	if thingLen := len(things); thingLen > 0 {
255		str.WriteString(" ( ")
256		for i, thing := range things {
257			str.WriteString(thing)
258			if i < thingLen-1 {
259				str.WriteString(", ")
260			}
261		}
262		str.WriteString(" )")
263	}
264
265	return str.String(), nil
266}
267