1package sqlbuilder 2 3import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "strings" 9 10 db "upper.io/db.v3" 11 "upper.io/db.v3/internal/immutable" 12 "upper.io/db.v3/internal/sqladapter/exql" 13) 14 15type selectorQuery struct { 16 table *exql.Columns 17 tableArgs []interface{} 18 19 distinct bool 20 21 where *exql.Where 22 whereArgs []interface{} 23 24 groupBy *exql.GroupBy 25 groupByArgs []interface{} 26 27 orderBy *exql.OrderBy 28 orderByArgs []interface{} 29 30 limit exql.Limit 31 offset exql.Offset 32 33 columns *exql.Columns 34 columnsArgs []interface{} 35 36 joins []*exql.Join 37 joinsArgs []interface{} 38 39 amendFn func(string) string 40} 41 42func (sq *selectorQuery) and(b *sqlBuilder, terms ...interface{}) error { 43 where, whereArgs := b.t.toWhereWithArguments(terms) 44 45 if sq.where == nil { 46 sq.where, sq.whereArgs = &exql.Where{}, []interface{}{} 47 } 48 sq.where.Append(&where) 49 sq.whereArgs = append(sq.whereArgs, whereArgs...) 50 51 return nil 52} 53 54func (sq *selectorQuery) arguments() []interface{} { 55 return joinArguments( 56 sq.columnsArgs, 57 sq.tableArgs, 58 sq.joinsArgs, 59 sq.whereArgs, 60 sq.groupByArgs, 61 sq.orderByArgs, 62 ) 63} 64 65func (sq *selectorQuery) statement() *exql.Statement { 66 stmt := &exql.Statement{ 67 Type: exql.Select, 68 Table: sq.table, 69 Columns: sq.columns, 70 Distinct: sq.distinct, 71 Limit: sq.limit, 72 Offset: sq.offset, 73 Where: sq.where, 74 OrderBy: sq.orderBy, 75 GroupBy: sq.groupBy, 76 } 77 78 if len(sq.joins) > 0 { 79 stmt.Joins = exql.JoinConditions(sq.joins...) 80 } 81 82 stmt.SetAmendment(sq.amendFn) 83 84 return stmt 85} 86 87func (sq *selectorQuery) pushJoin(t string, tables []interface{}) error { 88 fragments, args, err := columnFragments(tables) 89 if err != nil { 90 return err 91 } 92 93 if sq.joins == nil { 94 sq.joins = []*exql.Join{} 95 } 96 sq.joins = append(sq.joins, 97 &exql.Join{ 98 Type: t, 99 Table: exql.JoinColumns(fragments...), 100 }, 101 ) 102 103 sq.joinsArgs = append(sq.joinsArgs, args...) 104 105 return nil 106} 107 108type selector struct { 109 builder *sqlBuilder 110 111 fn func(*selectorQuery) error 112 prev *selector 113} 114 115var _ = immutable.Immutable(&selector{}) 116 117func (sel *selector) SQLBuilder() *sqlBuilder { 118 if sel.prev == nil { 119 return sel.builder 120 } 121 return sel.prev.SQLBuilder() 122} 123 124func (sel *selector) String() string { 125 s, err := sel.Compile() 126 if err != nil { 127 panic(err.Error()) 128 } 129 return prepareQueryForDisplay(s) 130} 131 132func (sel *selector) frame(fn func(*selectorQuery) error) *selector { 133 return &selector{prev: sel, fn: fn} 134} 135 136func (sel *selector) clone() Selector { 137 return sel.frame(func(*selectorQuery) error { 138 return nil 139 }) 140} 141 142func (sel *selector) From(tables ...interface{}) Selector { 143 return sel.frame( 144 func(sq *selectorQuery) error { 145 fragments, args, err := columnFragments(tables) 146 if err != nil { 147 return err 148 } 149 sq.table = exql.JoinColumns(fragments...) 150 sq.tableArgs = args 151 return nil 152 }, 153 ) 154} 155 156func (sel *selector) setColumns(columns ...interface{}) Selector { 157 return sel.frame(func(sq *selectorQuery) error { 158 sq.columns = nil 159 return sq.pushColumns(columns...) 160 }) 161} 162 163func (sel *selector) Columns(columns ...interface{}) Selector { 164 return sel.frame(func(sq *selectorQuery) error { 165 return sq.pushColumns(columns...) 166 }) 167} 168 169func (sq *selectorQuery) pushColumns(columns ...interface{}) error { 170 f, args, err := columnFragments(columns) 171 if err != nil { 172 return err 173 } 174 175 c := exql.JoinColumns(f...) 176 177 if sq.columns != nil { 178 sq.columns.Append(c) 179 } else { 180 sq.columns = c 181 } 182 183 sq.columnsArgs = append(sq.columnsArgs, args...) 184 return nil 185} 186 187func (sel *selector) Distinct(exps ...interface{}) Selector { 188 return sel.frame(func(sq *selectorQuery) error { 189 sq.distinct = true 190 return sq.pushColumns(exps...) 191 }) 192} 193 194func (sel *selector) Where(terms ...interface{}) Selector { 195 return sel.frame(func(sq *selectorQuery) error { 196 if len(terms) == 1 && terms[0] == nil { 197 sq.where, sq.whereArgs = &exql.Where{}, []interface{}{} 198 return nil 199 } 200 return sq.and(sel.SQLBuilder(), terms...) 201 }) 202} 203 204func (sel *selector) And(terms ...interface{}) Selector { 205 return sel.frame(func(sq *selectorQuery) error { 206 return sq.and(sel.SQLBuilder(), terms...) 207 }) 208} 209 210func (sel *selector) Amend(fn func(string) string) Selector { 211 return sel.frame(func(sq *selectorQuery) error { 212 sq.amendFn = fn 213 return nil 214 }) 215} 216 217func (sel *selector) Arguments() []interface{} { 218 sq, err := sel.build() 219 if err != nil { 220 return nil 221 } 222 return sq.arguments() 223} 224 225func (sel *selector) GroupBy(columns ...interface{}) Selector { 226 return sel.frame(func(sq *selectorQuery) error { 227 fragments, args, err := columnFragments(columns) 228 if err != nil { 229 return err 230 } 231 232 if fragments != nil { 233 sq.groupBy = exql.GroupByColumns(fragments...) 234 } 235 sq.groupByArgs = args 236 237 return nil 238 }) 239} 240 241func (sel *selector) OrderBy(columns ...interface{}) Selector { 242 return sel.frame(func(sq *selectorQuery) error { 243 244 if len(columns) == 1 && columns[0] == nil { 245 sq.orderBy = nil 246 sq.orderByArgs = nil 247 return nil 248 } 249 250 var sortColumns exql.SortColumns 251 252 for i := range columns { 253 var sort *exql.SortColumn 254 255 switch value := columns[i].(type) { 256 case db.RawValue: 257 query, args := Preprocess(value.Raw(), value.Arguments()) 258 sort = &exql.SortColumn{ 259 Column: exql.RawValue(query), 260 } 261 sq.orderByArgs = append(sq.orderByArgs, args...) 262 case db.Function: 263 fnName, fnArgs := value.Name(), value.Arguments() 264 if len(fnArgs) == 0 { 265 fnName = fnName + "()" 266 } else { 267 fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" 268 } 269 fnName, fnArgs = Preprocess(fnName, fnArgs) 270 sort = &exql.SortColumn{ 271 Column: exql.RawValue(fnName), 272 } 273 sq.orderByArgs = append(sq.orderByArgs, fnArgs...) 274 case string: 275 if strings.HasPrefix(value, "-") { 276 sort = &exql.SortColumn{ 277 Column: exql.ColumnWithName(value[1:]), 278 Order: exql.Descendent, 279 } 280 } else { 281 chunks := strings.SplitN(value, " ", 2) 282 283 order := exql.Ascendent 284 if len(chunks) > 1 && strings.ToUpper(chunks[1]) == "DESC" { 285 order = exql.Descendent 286 } 287 288 sort = &exql.SortColumn{ 289 Column: exql.ColumnWithName(chunks[0]), 290 Order: order, 291 } 292 } 293 default: 294 return fmt.Errorf("Can't sort by type %T", value) 295 } 296 sortColumns.Columns = append(sortColumns.Columns, sort) 297 } 298 299 sq.orderBy = &exql.OrderBy{ 300 SortColumns: &sortColumns, 301 } 302 return nil 303 }) 304} 305 306func (sel *selector) Using(columns ...interface{}) Selector { 307 return sel.frame(func(sq *selectorQuery) error { 308 309 joins := len(sq.joins) 310 if joins == 0 { 311 return errors.New(`cannot use Using() without a preceding Join() expression`) 312 } 313 314 lastJoin := sq.joins[joins-1] 315 if lastJoin.On != nil { 316 return errors.New(`cannot use Using() and On() with the same Join() expression`) 317 } 318 319 fragments, args, err := columnFragments(columns) 320 if err != nil { 321 return err 322 } 323 324 sq.joinsArgs = append(sq.joinsArgs, args...) 325 lastJoin.Using = exql.UsingColumns(fragments...) 326 327 return nil 328 }) 329} 330 331func (sel *selector) FullJoin(tables ...interface{}) Selector { 332 return sel.frame(func(sq *selectorQuery) error { 333 return sq.pushJoin("FULL", tables) 334 }) 335} 336 337func (sel *selector) CrossJoin(tables ...interface{}) Selector { 338 return sel.frame(func(sq *selectorQuery) error { 339 return sq.pushJoin("CROSS", tables) 340 }) 341} 342 343func (sel *selector) RightJoin(tables ...interface{}) Selector { 344 return sel.frame(func(sq *selectorQuery) error { 345 return sq.pushJoin("RIGHT", tables) 346 }) 347} 348 349func (sel *selector) LeftJoin(tables ...interface{}) Selector { 350 return sel.frame(func(sq *selectorQuery) error { 351 return sq.pushJoin("LEFT", tables) 352 }) 353} 354 355func (sel *selector) Join(tables ...interface{}) Selector { 356 return sel.frame(func(sq *selectorQuery) error { 357 return sq.pushJoin("", tables) 358 }) 359} 360 361func (sel *selector) On(terms ...interface{}) Selector { 362 return sel.frame(func(sq *selectorQuery) error { 363 joins := len(sq.joins) 364 365 if joins == 0 { 366 return errors.New(`cannot use On() without a preceding Join() expression`) 367 } 368 369 lastJoin := sq.joins[joins-1] 370 if lastJoin.On != nil { 371 return errors.New(`cannot use Using() and On() with the same Join() expression`) 372 } 373 374 w, a := sel.SQLBuilder().t.toWhereWithArguments(terms) 375 o := exql.On(w) 376 377 lastJoin.On = &o 378 379 sq.joinsArgs = append(sq.joinsArgs, a...) 380 381 return nil 382 }) 383} 384 385func (sel *selector) Limit(n int) Selector { 386 return sel.frame(func(sq *selectorQuery) error { 387 if n < 0 { 388 n = 0 389 } 390 sq.limit = exql.Limit(n) 391 return nil 392 }) 393} 394 395func (sel *selector) Offset(n int) Selector { 396 return sel.frame(func(sq *selectorQuery) error { 397 if n < 0 { 398 n = 0 399 } 400 sq.offset = exql.Offset(n) 401 return nil 402 }) 403} 404 405func (sel *selector) template() *exql.Template { 406 return sel.SQLBuilder().t.Template 407} 408 409func (sel *selector) As(alias string) Selector { 410 return sel.frame(func(sq *selectorQuery) error { 411 if sq.table == nil { 412 return errors.New("Cannot use As() without a preceding From() expression") 413 } 414 last := len(sq.table.Columns) - 1 415 if raw, ok := sq.table.Columns[last].(*exql.Raw); ok { 416 compiled, err := exql.ColumnWithName(alias).Compile(sel.template()) 417 if err != nil { 418 return err 419 } 420 sq.table.Columns[last] = exql.RawValue(raw.Value + " AS " + compiled) 421 } 422 return nil 423 }) 424} 425 426func (sel *selector) statement() *exql.Statement { 427 sq, _ := sel.build() 428 return sq.statement() 429} 430 431func (sel *selector) QueryRow() (*sql.Row, error) { 432 return sel.QueryRowContext(sel.SQLBuilder().sess.Context()) 433} 434 435func (sel *selector) QueryRowContext(ctx context.Context) (*sql.Row, error) { 436 sq, err := sel.build() 437 if err != nil { 438 return nil, err 439 } 440 441 return sel.SQLBuilder().sess.StatementQueryRow(ctx, sq.statement(), sq.arguments()...) 442} 443 444func (sel *selector) Prepare() (*sql.Stmt, error) { 445 return sel.PrepareContext(sel.SQLBuilder().sess.Context()) 446} 447 448func (sel *selector) PrepareContext(ctx context.Context) (*sql.Stmt, error) { 449 sq, err := sel.build() 450 if err != nil { 451 return nil, err 452 } 453 return sel.SQLBuilder().sess.StatementPrepare(ctx, sq.statement()) 454} 455 456func (sel *selector) Query() (*sql.Rows, error) { 457 return sel.QueryContext(sel.SQLBuilder().sess.Context()) 458} 459 460func (sel *selector) QueryContext(ctx context.Context) (*sql.Rows, error) { 461 sq, err := sel.build() 462 if err != nil { 463 return nil, err 464 } 465 return sel.SQLBuilder().sess.StatementQuery(ctx, sq.statement(), sq.arguments()...) 466} 467 468func (sel *selector) Iterator() Iterator { 469 return sel.IteratorContext(sel.SQLBuilder().sess.Context()) 470} 471 472func (sel *selector) IteratorContext(ctx context.Context) Iterator { 473 sess := sel.SQLBuilder().sess 474 sq, err := sel.build() 475 if err != nil { 476 return &iterator{sess, nil, err} 477 } 478 479 rows, err := sess.StatementQuery(ctx, sq.statement(), sq.arguments()...) 480 return &iterator{sess, rows, err} 481} 482 483func (sel *selector) Paginate(pageSize uint) Paginator { 484 return newPaginator(sel.clone(), pageSize) 485} 486 487func (sel *selector) All(destSlice interface{}) error { 488 return sel.Iterator().All(destSlice) 489} 490 491func (sel *selector) One(dest interface{}) error { 492 return sel.Iterator().One(dest) 493} 494 495func (sel *selector) build() (*selectorQuery, error) { 496 sq, err := immutable.FastForward(sel) 497 if err != nil { 498 return nil, err 499 } 500 return sq.(*selectorQuery), nil 501} 502 503func (sel *selector) Compile() (string, error) { 504 return sel.statement().Compile(sel.template()) 505} 506 507func (sel *selector) Prev() immutable.Immutable { 508 if sel == nil { 509 return nil 510 } 511 return sel.prev 512} 513 514func (sel *selector) Fn(in interface{}) error { 515 if sel.fn == nil { 516 return nil 517 } 518 return sel.fn(in.(*selectorQuery)) 519} 520 521func (sel *selector) Base() interface{} { 522 return &selectorQuery{} 523} 524