1/* 2 * Copyright © 2019-2020 A Bunch Tell LLC. 3 * 4 * This file is part of WriteFreely. 5 * 6 * WriteFreely is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Affero General Public License, included 8 * in the LICENSE file in this source code package. 9 */ 10 11package db 12 13import ( 14 "fmt" 15 "strings" 16) 17 18type ColumnType int 19 20type OptionalInt struct { 21 Set bool 22 Value int 23} 24 25type OptionalString struct { 26 Set bool 27 Value string 28} 29 30type SQLBuilder interface { 31 ToSQL() (string, error) 32} 33 34type Column struct { 35 Dialect DialectType 36 Name string 37 Nullable bool 38 Default OptionalString 39 Type ColumnType 40 Size OptionalInt 41 PrimaryKey bool 42} 43 44type CreateTableSqlBuilder struct { 45 Dialect DialectType 46 Name string 47 IfNotExists bool 48 ColumnOrder []string 49 Columns map[string]*Column 50 Constraints []string 51} 52 53const ( 54 ColumnTypeBool ColumnType = iota 55 ColumnTypeSmallInt ColumnType = iota 56 ColumnTypeInteger ColumnType = iota 57 ColumnTypeChar ColumnType = iota 58 ColumnTypeVarChar ColumnType = iota 59 ColumnTypeText ColumnType = iota 60 ColumnTypeDateTime ColumnType = iota 61) 62 63var _ SQLBuilder = &CreateTableSqlBuilder{} 64 65var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} 66var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} 67 68func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { 69 if dialect != DialectMySQL && dialect != DialectSQLite { 70 return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) 71 } 72 switch d { 73 case ColumnTypeSmallInt: 74 { 75 if dialect == DialectSQLite { 76 return "INTEGER", nil 77 } 78 mod := "" 79 if size.Set { 80 mod = fmt.Sprintf("(%d)", size.Value) 81 } 82 return "SMALLINT" + mod, nil 83 } 84 case ColumnTypeInteger: 85 { 86 if dialect == DialectSQLite { 87 return "INTEGER", nil 88 } 89 mod := "" 90 if size.Set { 91 mod = fmt.Sprintf("(%d)", size.Value) 92 } 93 return "INT" + mod, nil 94 } 95 case ColumnTypeChar: 96 { 97 if dialect == DialectSQLite { 98 return "TEXT", nil 99 } 100 mod := "" 101 if size.Set { 102 mod = fmt.Sprintf("(%d)", size.Value) 103 } 104 return "CHAR" + mod, nil 105 } 106 case ColumnTypeVarChar: 107 { 108 if dialect == DialectSQLite { 109 return "TEXT", nil 110 } 111 mod := "" 112 if size.Set { 113 mod = fmt.Sprintf("(%d)", size.Value) 114 } 115 return "VARCHAR" + mod, nil 116 } 117 case ColumnTypeBool: 118 { 119 if dialect == DialectSQLite { 120 return "INTEGER", nil 121 } 122 return "TINYINT(1)", nil 123 } 124 case ColumnTypeDateTime: 125 return "DATETIME", nil 126 case ColumnTypeText: 127 return "TEXT", nil 128 } 129 return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) 130} 131 132func (c *Column) SetName(name string) *Column { 133 c.Name = name 134 return c 135} 136 137func (c *Column) SetNullable(nullable bool) *Column { 138 c.Nullable = nullable 139 return c 140} 141 142func (c *Column) SetPrimaryKey(pk bool) *Column { 143 c.PrimaryKey = pk 144 return c 145} 146 147func (c *Column) SetDefault(value string) *Column { 148 c.Default = OptionalString{Set: true, Value: value} 149 return c 150} 151 152func (c *Column) SetDefaultCurrentTimestamp() *Column { 153 def := "NOW()" 154 if c.Dialect == DialectSQLite { 155 def = "CURRENT_TIMESTAMP" 156 } 157 c.Default = OptionalString{Set: true, Value: def} 158 return c 159} 160 161func (c *Column) SetType(t ColumnType) *Column { 162 c.Type = t 163 return c 164} 165 166func (c *Column) SetSize(size int) *Column { 167 c.Size = OptionalInt{Set: true, Value: size} 168 return c 169} 170 171func (c *Column) String() (string, error) { 172 var str strings.Builder 173 174 str.WriteString(c.Name) 175 176 str.WriteString(" ") 177 typeStr, err := c.Type.Format(c.Dialect, c.Size) 178 if err != nil { 179 return "", err 180 } 181 182 str.WriteString(typeStr) 183 184 if !c.Nullable { 185 str.WriteString(" NOT NULL") 186 } 187 188 if c.Default.Set { 189 str.WriteString(" DEFAULT ") 190 val := c.Default.Value 191 if val == "" { 192 val = "''" 193 } 194 str.WriteString(val) 195 } 196 197 if c.PrimaryKey { 198 str.WriteString(" PRIMARY KEY") 199 } 200 201 return str.String(), nil 202} 203 204func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder { 205 if b.Columns == nil { 206 b.Columns = make(map[string]*Column) 207 } 208 b.Columns[column.Name] = column 209 b.ColumnOrder = append(b.ColumnOrder, column.Name) 210 return b 211} 212 213func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder { 214 for _, column := range columns { 215 if _, ok := b.Columns[column]; !ok { 216 // This fails silently. 217 return b 218 } 219 } 220 b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ","))) 221 return b 222} 223 224func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder { 225 b.IfNotExists = ine 226 return b 227} 228 229func (b *CreateTableSqlBuilder) ToSQL() (string, error) { 230 var str strings.Builder 231 232 str.WriteString("CREATE TABLE ") 233 if b.IfNotExists { 234 str.WriteString("IF NOT EXISTS ") 235 } 236 str.WriteString(b.Name) 237 238 var things []string 239 for _, columnName := range b.ColumnOrder { 240 column, ok := b.Columns[columnName] 241 if !ok { 242 return "", fmt.Errorf("column not found: %s", columnName) 243 } 244 columnStr, err := column.String() 245 if err != nil { 246 return "", err 247 } 248 things = append(things, columnStr) 249 } 250 for _, constraint := range b.Constraints { 251 things = append(things, constraint) 252 } 253 254 if thingLen := len(things); thingLen > 0 { 255 str.WriteString(" ( ") 256 for i, thing := range things { 257 str.WriteString(thing) 258 if i < thingLen-1 { 259 str.WriteString(", ") 260 } 261 } 262 str.WriteString(" )") 263 } 264 265 return str.String(), nil 266} 267