1package sqlbuilder 2 3import ( 4 "database/sql/driver" 5 "fmt" 6 "strings" 7 8 db "upper.io/db.v3" 9 "upper.io/db.v3/internal/sqladapter/exql" 10) 11 12type templateWithUtils struct { 13 *exql.Template 14} 15 16func newTemplateWithUtils(template *exql.Template) *templateWithUtils { 17 return &templateWithUtils{template} 18} 19 20func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []interface{}) { 21 switch t := in.(type) { 22 case db.RawValue: 23 return exql.RawValue(t.String()), t.Arguments() 24 case db.Function: 25 fnName := t.Name() 26 fnArgs := []interface{}{} 27 args, _ := toInterfaceArguments(t.Arguments()) 28 fragments := []string{} 29 for i := range args { 30 frag, args := tu.PlaceholderValue(args[i]) 31 fragment, err := frag.Compile(tu.Template) 32 if err == nil { 33 fragments = append(fragments, fragment) 34 fnArgs = append(fnArgs, args...) 35 } 36 } 37 return exql.RawValue(fnName + `(` + strings.Join(fragments, `, `) + `)`), fnArgs 38 default: 39 // Value must be escaped. 40 return sqlPlaceholder, []interface{}{in} 41 } 42} 43 44// toWhereWithArguments converts the given parameters into a exql.Where 45// value. 46func (tu *templateWithUtils) toWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) { 47 args = []interface{}{} 48 49 switch t := term.(type) { 50 case []interface{}: 51 if len(t) > 0 { 52 if s, ok := t[0].(string); ok { 53 if strings.ContainsAny(s, "?") || len(t) == 1 { 54 s, args = Preprocess(s, t[1:]) 55 where.Conditions = []exql.Fragment{exql.RawValue(s)} 56 } else { 57 var val interface{} 58 key := s 59 if len(t) > 2 { 60 val = t[1:] 61 } else { 62 val = t[1] 63 } 64 cv, v := tu.toColumnValues(db.NewConstraint(key, val)) 65 args = append(args, v...) 66 for i := range cv.ColumnValues { 67 where.Conditions = append(where.Conditions, cv.ColumnValues[i]) 68 } 69 } 70 return 71 } 72 } 73 for i := range t { 74 w, v := tu.toWhereWithArguments(t[i]) 75 if len(w.Conditions) == 0 { 76 continue 77 } 78 args = append(args, v...) 79 where.Conditions = append(where.Conditions, w.Conditions...) 80 } 81 return 82 case db.RawValue: 83 r, v := Preprocess(t.Raw(), t.Arguments()) 84 where.Conditions = []exql.Fragment{exql.RawValue(r)} 85 args = append(args, v...) 86 return 87 case db.Constraints: 88 for _, c := range t.Constraints() { 89 w, v := tu.toWhereWithArguments(c) 90 if len(w.Conditions) == 0 { 91 continue 92 } 93 args = append(args, v...) 94 where.Conditions = append(where.Conditions, w.Conditions...) 95 } 96 return 97 case db.Compound: 98 var cond exql.Where 99 100 for _, c := range t.Sentences() { 101 w, v := tu.toWhereWithArguments(c) 102 if len(w.Conditions) == 0 { 103 continue 104 } 105 args = append(args, v...) 106 cond.Conditions = append(cond.Conditions, w.Conditions...) 107 } 108 109 if len(cond.Conditions) > 0 { 110 var frag exql.Fragment 111 switch t.Operator() { 112 case db.OperatorNone, db.OperatorAnd: 113 q := exql.And(cond) 114 frag = &q 115 case db.OperatorOr: 116 q := exql.Or(cond) 117 frag = &q 118 default: 119 panic(fmt.Sprintf("Unknown type %T", t)) 120 } 121 where.Conditions = append(where.Conditions, frag) 122 } 123 124 return 125 case db.Constraint: 126 cv, v := tu.toColumnValues(t) 127 args = append(args, v...) 128 where.Conditions = append(where.Conditions, cv.ColumnValues...) 129 return where, args 130 } 131 132 panic(fmt.Sprintf("Unknown condition type %T", term)) 133} 134 135func (tu *templateWithUtils) comparisonOperatorMapper(t db.ComparisonOperator) string { 136 if t == db.ComparisonOperatorNone { 137 return "" 138 } 139 if tu.ComparisonOperator != nil { 140 if op, ok := tu.ComparisonOperator[t]; ok { 141 return op 142 } 143 } 144 if op, ok := comparisonOperators[t]; ok { 145 return op 146 } 147 panic(fmt.Sprintf("unsupported comparison operator %v", t)) 148} 149 150func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnValues, args []interface{}) { 151 args = []interface{}{} 152 153 switch t := term.(type) { 154 case db.Constraint: 155 columnValue := exql.ColumnValue{} 156 157 // Getting column and operator. 158 if column, ok := t.Key().(string); ok { 159 chunks := strings.SplitN(strings.TrimSpace(column), " ", 2) 160 columnValue.Column = exql.ColumnWithName(chunks[0]) 161 if len(chunks) > 1 { 162 columnValue.Operator = chunks[1] 163 } 164 } else { 165 if rawValue, ok := t.Key().(db.RawValue); ok { 166 columnValue.Column = exql.RawValue(rawValue.Raw()) 167 args = append(args, rawValue.Arguments()...) 168 } else { 169 columnValue.Column = exql.RawValue(fmt.Sprintf("%v", t.Key())) 170 } 171 } 172 173 switch value := t.Value().(type) { 174 case db.Function: 175 fnName, fnArgs := value.Name(), value.Arguments() 176 if len(fnArgs) == 0 { 177 // A function with no arguments. 178 fnName = fnName + "()" 179 } else { 180 // A function with one or more arguments. 181 fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" 182 } 183 fnName, fnArgs = Preprocess(fnName, fnArgs) 184 columnValue.Value = exql.RawValue(fnName) 185 args = append(args, fnArgs...) 186 case db.RawValue: 187 q, a := Preprocess(value.Raw(), value.Arguments()) 188 columnValue.Value = exql.RawValue(q) 189 args = append(args, a...) 190 case driver.Valuer: 191 columnValue.Value = exql.RawValue("?") 192 args = append(args, value) 193 case db.Comparison: 194 wrapper := &operatorWrapper{ 195 tu: tu, 196 cv: &columnValue, 197 op: value, 198 } 199 200 q, a := wrapper.preprocess() 201 q, a = Preprocess(q, a) 202 203 columnValue = exql.ColumnValue{ 204 Column: exql.RawValue(q), 205 } 206 if a != nil { 207 args = append(args, a...) 208 } 209 210 cv.ColumnValues = append(cv.ColumnValues, &columnValue) 211 return cv, args 212 default: 213 wrapper := &operatorWrapper{ 214 tu: tu, 215 cv: &columnValue, 216 v: value, 217 } 218 219 q, a := wrapper.preprocess() 220 q, a = Preprocess(q, a) 221 222 columnValue = exql.ColumnValue{ 223 Column: exql.RawValue(q), 224 } 225 if a != nil { 226 args = append(args, a...) 227 } 228 229 cv.ColumnValues = append(cv.ColumnValues, &columnValue) 230 return cv, args 231 } 232 233 if columnValue.Operator == "" { 234 columnValue.Operator = tu.comparisonOperatorMapper(db.ComparisonOperatorEqual) 235 } 236 cv.ColumnValues = append(cv.ColumnValues, &columnValue) 237 return cv, args 238 case db.RawValue: 239 columnValue := exql.ColumnValue{} 240 p, q := Preprocess(t.Raw(), t.Arguments()) 241 columnValue.Column = exql.RawValue(p) 242 cv.ColumnValues = append(cv.ColumnValues, &columnValue) 243 args = append(args, q...) 244 return cv, args 245 case db.Constraints: 246 for _, constraint := range t.Constraints() { 247 p, q := tu.toColumnValues(constraint) 248 cv.ColumnValues = append(cv.ColumnValues, p.ColumnValues...) 249 args = append(args, q...) 250 } 251 return cv, args 252 } 253 254 panic(fmt.Sprintf("Unknown term type %T.", term)) 255} 256 257func (tu *templateWithUtils) setColumnValues(term interface{}) (cv exql.ColumnValues, args []interface{}) { 258 args = []interface{}{} 259 260 switch t := term.(type) { 261 case []interface{}: 262 l := len(t) 263 for i := 0; i < l; i++ { 264 column, isString := t[i].(string) 265 266 if !isString { 267 p, q := tu.setColumnValues(t[i]) 268 cv.ColumnValues = append(cv.ColumnValues, p.ColumnValues...) 269 args = append(args, q...) 270 continue 271 } 272 273 if !strings.ContainsAny(column, tu.AssignmentOperator) { 274 column = column + " " + tu.AssignmentOperator + " ?" 275 } 276 277 chunks := strings.SplitN(column, tu.AssignmentOperator, 2) 278 279 column = chunks[0] 280 format := strings.TrimSpace(chunks[1]) 281 282 columnValue := exql.ColumnValue{ 283 Column: exql.ColumnWithName(column), 284 Operator: tu.AssignmentOperator, 285 Value: exql.RawValue(format), 286 } 287 288 ps := strings.Count(format, "?") 289 if i+ps < l { 290 for j := 0; j < ps; j++ { 291 args = append(args, t[i+j+1]) 292 } 293 i = i + ps 294 } else { 295 panic(fmt.Sprintf("Format string %q has more placeholders than given arguments.", format)) 296 } 297 298 cv.ColumnValues = append(cv.ColumnValues, &columnValue) 299 } 300 return cv, args 301 case db.RawValue: 302 columnValue := exql.ColumnValue{} 303 p, q := Preprocess(t.Raw(), t.Arguments()) 304 columnValue.Column = exql.RawValue(p) 305 cv.ColumnValues = append(cv.ColumnValues, &columnValue) 306 args = append(args, q...) 307 return cv, args 308 } 309 310 panic(fmt.Sprintf("Unknown term type %T.", term)) 311} 312