1package exp
2
3import (
4	"reflect"
5	"sort"
6
7	"github.com/doug-martin/goqu/v9/internal/errors"
8	"github.com/doug-martin/goqu/v9/internal/util"
9)
10
11type (
12	update struct {
13		col IdentifierExpression
14		val interface{}
15	}
16)
17
18func set(col IdentifierExpression, val interface{}) UpdateExpression {
19	return update{col: col, val: val}
20}
21
22func NewUpdateExpressions(update interface{}) (updates []UpdateExpression, err error) {
23	if u, ok := update.(UpdateExpression); ok {
24		updates = append(updates, u)
25		return updates, nil
26	}
27	updateValue := reflect.Indirect(reflect.ValueOf(update))
28	switch updateValue.Kind() {
29	case reflect.Map:
30		keys := util.ValueSlice(updateValue.MapKeys())
31		sort.Sort(keys)
32		for _, key := range keys {
33			updates = append(updates, ParseIdentifier(key.String()).Set(updateValue.MapIndex(key).Interface()))
34		}
35	case reflect.Struct:
36		return getUpdateExpressionsStruct(updateValue)
37	default:
38		return nil, errors.New("unsupported update interface type %+v", updateValue.Type())
39	}
40	return updates, nil
41}
42
43func getUpdateExpressionsStruct(value reflect.Value) (updates []UpdateExpression, err error) {
44	r, err := NewRecordFromStruct(value.Interface(), false, true)
45	if err != nil {
46		return updates, err
47	}
48	cols := r.Cols()
49	for _, col := range cols {
50		updates = append(updates, ParseIdentifier(col).Set(r[col]))
51	}
52	return updates, nil
53}
54
55func (u update) Expression() Expression {
56	return u
57}
58
59func (u update) Clone() Expression {
60	return update{col: u.col.Clone().(IdentifierExpression), val: u.val}
61}
62
63func (u update) Col() IdentifierExpression {
64	return u.col
65}
66
67func (u update) Val() interface{} {
68	return u.val
69}
70