1// Copyright (c) 2012 The gocql Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package gocql
6
7import (
8	"fmt"
9	"math/big"
10	"net"
11	"reflect"
12	"strings"
13	"time"
14
15	"gopkg.in/inf.v0"
16)
17
18type RowData struct {
19	Columns []string
20	Values  []interface{}
21}
22
23func goType(t TypeInfo) reflect.Type {
24	switch t.Type() {
25	case TypeVarchar, TypeAscii, TypeInet, TypeText:
26		return reflect.TypeOf(*new(string))
27	case TypeBigInt, TypeCounter:
28		return reflect.TypeOf(*new(int64))
29	case TypeTime:
30		return reflect.TypeOf(*new(time.Duration))
31	case TypeTimestamp:
32		return reflect.TypeOf(*new(time.Time))
33	case TypeBlob:
34		return reflect.TypeOf(*new([]byte))
35	case TypeBoolean:
36		return reflect.TypeOf(*new(bool))
37	case TypeFloat:
38		return reflect.TypeOf(*new(float32))
39	case TypeDouble:
40		return reflect.TypeOf(*new(float64))
41	case TypeInt:
42		return reflect.TypeOf(*new(int))
43	case TypeSmallInt:
44		return reflect.TypeOf(*new(int16))
45	case TypeTinyInt:
46		return reflect.TypeOf(*new(int8))
47	case TypeDecimal:
48		return reflect.TypeOf(*new(*inf.Dec))
49	case TypeUUID, TypeTimeUUID:
50		return reflect.TypeOf(*new(UUID))
51	case TypeList, TypeSet:
52		return reflect.SliceOf(goType(t.(CollectionType).Elem))
53	case TypeMap:
54		return reflect.MapOf(goType(t.(CollectionType).Key), goType(t.(CollectionType).Elem))
55	case TypeVarint:
56		return reflect.TypeOf(*new(*big.Int))
57	case TypeTuple:
58		// what can we do here? all there is to do is to make a list of interface{}
59		tuple := t.(TupleTypeInfo)
60		return reflect.TypeOf(make([]interface{}, len(tuple.Elems)))
61	case TypeUDT:
62		return reflect.TypeOf(make(map[string]interface{}))
63	case TypeDate:
64		return reflect.TypeOf(*new(time.Time))
65	case TypeDuration:
66		return reflect.TypeOf(*new(Duration))
67	default:
68		return nil
69	}
70}
71
72func dereference(i interface{}) interface{} {
73	return reflect.Indirect(reflect.ValueOf(i)).Interface()
74}
75
76func getCassandraBaseType(name string) Type {
77	switch name {
78	case "ascii":
79		return TypeAscii
80	case "bigint":
81		return TypeBigInt
82	case "blob":
83		return TypeBlob
84	case "boolean":
85		return TypeBoolean
86	case "counter":
87		return TypeCounter
88	case "decimal":
89		return TypeDecimal
90	case "double":
91		return TypeDouble
92	case "float":
93		return TypeFloat
94	case "int":
95		return TypeInt
96	case "tinyint":
97		return TypeTinyInt
98	case "time":
99		return TypeTime
100	case "timestamp":
101		return TypeTimestamp
102	case "uuid":
103		return TypeUUID
104	case "varchar":
105		return TypeVarchar
106	case "text":
107		return TypeText
108	case "varint":
109		return TypeVarint
110	case "timeuuid":
111		return TypeTimeUUID
112	case "inet":
113		return TypeInet
114	case "MapType":
115		return TypeMap
116	case "ListType":
117		return TypeList
118	case "SetType":
119		return TypeSet
120	case "TupleType":
121		return TypeTuple
122	default:
123		return TypeCustom
124	}
125}
126
127func getCassandraType(name string) TypeInfo {
128	if strings.HasPrefix(name, "frozen<") {
129		return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"))
130	} else if strings.HasPrefix(name, "set<") {
131		return CollectionType{
132			NativeType: NativeType{typ: TypeSet},
133			Elem:       getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<")),
134		}
135	} else if strings.HasPrefix(name, "list<") {
136		return CollectionType{
137			NativeType: NativeType{typ: TypeList},
138			Elem:       getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<")),
139		}
140	} else if strings.HasPrefix(name, "map<") {
141		names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<"))
142		if len(names) != 2 {
143			Logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names))
144			return NativeType{
145				typ: TypeCustom,
146			}
147		}
148		return CollectionType{
149			NativeType: NativeType{typ: TypeMap},
150			Key:        getCassandraType(names[0]),
151			Elem:       getCassandraType(names[1]),
152		}
153	} else if strings.HasPrefix(name, "tuple<") {
154		names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<"))
155		types := make([]TypeInfo, len(names))
156
157		for i, name := range names {
158			types[i] = getCassandraType(name)
159		}
160
161		return TupleTypeInfo{
162			NativeType: NativeType{typ: TypeTuple},
163			Elems:      types,
164		}
165	} else {
166		return NativeType{
167			typ: getCassandraBaseType(name),
168		}
169	}
170}
171
172func splitCompositeTypes(name string) []string {
173	if !strings.Contains(name, "<") {
174		return strings.Split(name, ", ")
175	}
176	var parts []string
177	lessCount := 0
178	segment := ""
179	for _, char := range name {
180		if char == ',' && lessCount == 0 {
181			if segment != "" {
182				parts = append(parts, strings.TrimSpace(segment))
183			}
184			segment = ""
185			continue
186		}
187		segment += string(char)
188		if char == '<' {
189			lessCount++
190		} else if char == '>' {
191			lessCount--
192		}
193	}
194	if segment != "" {
195		parts = append(parts, strings.TrimSpace(segment))
196	}
197	return parts
198}
199
200func apacheToCassandraType(t string) string {
201	t = strings.Replace(t, apacheCassandraTypePrefix, "", -1)
202	t = strings.Replace(t, "(", "<", -1)
203	t = strings.Replace(t, ")", ">", -1)
204	types := strings.FieldsFunc(t, func(r rune) bool {
205		return r == '<' || r == '>' || r == ','
206	})
207	for _, typ := range types {
208		t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1)
209	}
210	// This is done so it exactly matches what Cassandra returns
211	return strings.Replace(t, ",", ", ", -1)
212}
213
214func getApacheCassandraType(class string) Type {
215	switch strings.TrimPrefix(class, apacheCassandraTypePrefix) {
216	case "AsciiType":
217		return TypeAscii
218	case "LongType":
219		return TypeBigInt
220	case "BytesType":
221		return TypeBlob
222	case "BooleanType":
223		return TypeBoolean
224	case "CounterColumnType":
225		return TypeCounter
226	case "DecimalType":
227		return TypeDecimal
228	case "DoubleType":
229		return TypeDouble
230	case "FloatType":
231		return TypeFloat
232	case "Int32Type":
233		return TypeInt
234	case "ShortType":
235		return TypeSmallInt
236	case "ByteType":
237		return TypeTinyInt
238	case "TimeType":
239		return TypeTime
240	case "DateType", "TimestampType":
241		return TypeTimestamp
242	case "UUIDType", "LexicalUUIDType":
243		return TypeUUID
244	case "UTF8Type":
245		return TypeVarchar
246	case "IntegerType":
247		return TypeVarint
248	case "TimeUUIDType":
249		return TypeTimeUUID
250	case "InetAddressType":
251		return TypeInet
252	case "MapType":
253		return TypeMap
254	case "ListType":
255		return TypeList
256	case "SetType":
257		return TypeSet
258	case "TupleType":
259		return TypeTuple
260	case "DurationType":
261		return TypeDuration
262	default:
263		return TypeCustom
264	}
265}
266
267func typeCanBeNull(typ TypeInfo) bool {
268	switch typ.(type) {
269	case CollectionType, UDTTypeInfo, TupleTypeInfo:
270		return false
271	}
272
273	return true
274}
275
276func (r *RowData) rowMap(m map[string]interface{}) {
277	for i, column := range r.Columns {
278		val := dereference(r.Values[i])
279		if valVal := reflect.ValueOf(val); valVal.Kind() == reflect.Slice {
280			valCopy := reflect.MakeSlice(valVal.Type(), valVal.Len(), valVal.Cap())
281			reflect.Copy(valCopy, valVal)
282			m[column] = valCopy.Interface()
283		} else {
284			m[column] = val
285		}
286	}
287}
288
289// TupeColumnName will return the column name of a tuple value in a column named
290// c at index n. It should be used if a specific element within a tuple is needed
291// to be extracted from a map returned from SliceMap or MapScan.
292func TupleColumnName(c string, n int) string {
293	return fmt.Sprintf("%s[%d]", c, n)
294}
295
296func (iter *Iter) RowData() (RowData, error) {
297	if iter.err != nil {
298		return RowData{}, iter.err
299	}
300
301	columns := make([]string, 0, len(iter.Columns()))
302	values := make([]interface{}, 0, len(iter.Columns()))
303
304	for _, column := range iter.Columns() {
305		if c, ok := column.TypeInfo.(TupleTypeInfo); !ok {
306			val := column.TypeInfo.New()
307			columns = append(columns, column.Name)
308			values = append(values, val)
309		} else {
310			for i, elem := range c.Elems {
311				columns = append(columns, TupleColumnName(column.Name, i))
312				values = append(values, elem.New())
313			}
314		}
315	}
316
317	rowData := RowData{
318		Columns: columns,
319		Values:  values,
320	}
321
322	return rowData, nil
323}
324
325// TODO(zariel): is it worth exporting this?
326func (iter *Iter) rowMap() (map[string]interface{}, error) {
327	if iter.err != nil {
328		return nil, iter.err
329	}
330
331	rowData, _ := iter.RowData()
332	iter.Scan(rowData.Values...)
333	m := make(map[string]interface{}, len(rowData.Columns))
334	rowData.rowMap(m)
335	return m, nil
336}
337
338// SliceMap is a helper function to make the API easier to use
339// returns the data from the query in the form of []map[string]interface{}
340func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
341	if iter.err != nil {
342		return nil, iter.err
343	}
344
345	// Not checking for the error because we just did
346	rowData, _ := iter.RowData()
347	dataToReturn := make([]map[string]interface{}, 0)
348	for iter.Scan(rowData.Values...) {
349		m := make(map[string]interface{}, len(rowData.Columns))
350		rowData.rowMap(m)
351		dataToReturn = append(dataToReturn, m)
352	}
353	if iter.err != nil {
354		return nil, iter.err
355	}
356	return dataToReturn, nil
357}
358
359// MapScan takes a map[string]interface{} and populates it with a row
360// that is returned from cassandra.
361//
362// Each call to MapScan() must be called with a new map object.
363// During the call to MapScan() any pointers in the existing map
364// are replaced with non pointer types before the call returns
365//
366//	iter := session.Query(`SELECT * FROM mytable`).Iter()
367//	for {
368//		// New map each iteration
369//		row = make(map[string]interface{})
370//		if !iter.MapScan(row) {
371//			break
372//		}
373//		// Do things with row
374//		if fullname, ok := row["fullname"]; ok {
375//			fmt.Printf("Full Name: %s\n", fullname)
376//		}
377//	}
378//
379// You can also pass pointers in the map before each call
380//
381//	var fullName FullName // Implements gocql.Unmarshaler and gocql.Marshaler interfaces
382//	var address net.IP
383//	var age int
384//	iter := session.Query(`SELECT * FROM scan_map_table`).Iter()
385//	for {
386//		// New map each iteration
387//		row := map[string]interface{}{
388//			"fullname": &fullName,
389//			"age":      &age,
390//			"address":  &address,
391//		}
392//		if !iter.MapScan(row) {
393//			break
394//		}
395//		fmt.Printf("First: %s Age: %d Address: %q\n", fullName.FirstName, age, address)
396//	}
397func (iter *Iter) MapScan(m map[string]interface{}) bool {
398	if iter.err != nil {
399		return false
400	}
401
402	// Not checking for the error because we just did
403	rowData, _ := iter.RowData()
404
405	for i, col := range rowData.Columns {
406		if dest, ok := m[col]; ok {
407			rowData.Values[i] = dest
408		}
409	}
410
411	if iter.Scan(rowData.Values...) {
412		rowData.rowMap(m)
413		return true
414	}
415	return false
416}
417
418func copyBytes(p []byte) []byte {
419	b := make([]byte, len(p))
420	copy(b, p)
421	return b
422}
423
424var failDNS = false
425
426func LookupIP(host string) ([]net.IP, error) {
427	if failDNS {
428		return nil, &net.DNSError{}
429	}
430	return net.LookupIP(host)
431
432}
433