1/*
2Copyright 2014 SAP SE
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package driver
18
19import (
20	"bytes"
21	"database/sql"
22	"fmt"
23	"io"
24	"io/ioutil"
25	"log"
26	"math/big"
27	"os"
28	"path/filepath"
29	"reflect"
30	"sync"
31	"testing"
32	"time"
33	"unicode/utf8"
34)
35
36func TestTinyint(t *testing.T) {
37	testDatatype(t, "tinyint", 0, true,
38		uint8(minTinyint),
39		uint8(maxTinyint),
40		sql.NullInt64{Valid: false, Int64: minTinyint},
41		sql.NullInt64{Valid: true, Int64: maxTinyint},
42	)
43}
44
45func TestSmallint(t *testing.T) {
46	testDatatype(t, "smallint", 0, true,
47		int16(minSmallint),
48		int16(maxSmallint),
49		sql.NullInt64{Valid: false, Int64: minSmallint},
50		sql.NullInt64{Valid: true, Int64: maxSmallint},
51	)
52}
53
54func TestInteger(t *testing.T) {
55	testDatatype(t, "integer", 0, true,
56		int32(minInteger),
57		int32(maxInteger),
58		sql.NullInt64{Valid: false, Int64: minInteger},
59		sql.NullInt64{Valid: true, Int64: maxInteger},
60	)
61}
62
63func TestBigint(t *testing.T) {
64	testDatatype(t, "bigint", 0, true,
65		int64(minBigint),
66		int64(maxBigint),
67		sql.NullInt64{Valid: false, Int64: minBigint},
68		sql.NullInt64{Valid: true, Int64: maxBigint},
69	)
70}
71
72func TestReal(t *testing.T) {
73	testDatatype(t, "real", 0, true,
74		float32(-maxReal),
75		float32(maxReal),
76		sql.NullFloat64{Valid: false, Float64: -maxReal},
77		sql.NullFloat64{Valid: true, Float64: maxReal},
78	)
79}
80
81func TestDouble(t *testing.T) {
82	testDatatype(t, "double", 0, true,
83		float64(-maxDouble),
84		float64(maxDouble),
85		sql.NullFloat64{Valid: false, Float64: -maxDouble},
86		sql.NullFloat64{Valid: true, Float64: maxDouble},
87	)
88}
89
90var testStringDataASCII = []interface{}{
91	"Hello HDB",
92	"aaaaaaaaaa",
93	sql.NullString{Valid: false, String: "Hello HDB"},
94	sql.NullString{Valid: true, String: "Hello HDB"},
95}
96
97var testStringData = []interface{}{
98	"Hello HDB",
99	// varchar: UTF-8 4 bytes per char -> size 40 bytes
100	// nvarchar: CESU-8 6 bytes per char -> hdb counts 2 chars per 6 byte encoding -> size 20 bytes
101	"��������������������",
102	"����aa",
103	"€€",
104	"����€€",
105	"������€€",
106	"aaaaaaaaaa",
107	sql.NullString{Valid: false, String: "Hello HDB"},
108	sql.NullString{Valid: true, String: "Hello HDB"},
109}
110
111/*
112using unicode (CESU-8) data for char HDB
113- successful insert into table
114- but query table returns
115  SQL HdbError 7 - feature not supported: invalid character encoding: ...
116--> use ASCII test data only
117surprisingly: varchar works with unicode characters
118*/
119func TestChar(t *testing.T) {
120	testDatatype(t, "char", 40, true, testStringDataASCII...)
121}
122
123func TestVarchar(t *testing.T) {
124	testDatatype(t, "varchar", 40, false, testStringData...)
125}
126
127func TestNChar(t *testing.T) {
128	testDatatype(t, "nchar", 20, true, testStringData...)
129}
130
131func TestNVarchar(t *testing.T) {
132	testDatatype(t, "nvarchar", 20, false, testStringData...)
133}
134
135var testBinaryData = []interface{}{
136	[]byte("Hello HDB"),
137	[]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
138	[]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0xff},
139	NullBytes{Valid: false, Bytes: []byte("Hello HDB")},
140	NullBytes{Valid: true, Bytes: []byte("Hello HDB")},
141}
142
143func TestBinary(t *testing.T) {
144	testDatatype(t, "binary", 20, true, testBinaryData...)
145}
146
147func TestVarbinary(t *testing.T) {
148	testDatatype(t, "varbinary", 20, false, testBinaryData...)
149}
150
151var testTimeData = []interface{}{
152	time.Now(),
153	NullTime{Valid: false, Time: time.Now()},
154	NullTime{Valid: true, Time: time.Now()},
155}
156
157func TestDate(t *testing.T) {
158	testDatatype(t, "date", 0, true, testTimeData...)
159}
160
161func TestTime(t *testing.T) {
162	testDatatype(t, "time", 0, true, testTimeData...)
163}
164
165func TestTimestamp(t *testing.T) {
166	testDatatype(t, "timestamp", 0, true, testTimeData...)
167}
168
169func TestLongdate(t *testing.T) {
170	testDatatype(t, "longdate", 0, true, testTimeData...)
171}
172
173func TestSeconddate(t *testing.T) {
174	testDatatype(t, "seconddate", 0, true, testTimeData...)
175}
176
177func TestDaydate(t *testing.T) {
178	testDatatype(t, "daydate", 0, true, testTimeData...)
179}
180
181func TestSecondtime(t *testing.T) {
182	testDatatype(t, "secondtime", 0, true, testTimeData...)
183}
184
185var testDecimalData = []interface{}{
186	(*Decimal)(big.NewRat(0, 1)),
187	(*Decimal)(big.NewRat(1, 1)),
188	(*Decimal)(big.NewRat(-1, 1)),
189	(*Decimal)(big.NewRat(10, 1)),
190	(*Decimal)(big.NewRat(1000, 1)),
191	(*Decimal)(big.NewRat(1, 10)),
192	(*Decimal)(big.NewRat(-1, 10)),
193	(*Decimal)(big.NewRat(1, 1000)),
194	(*Decimal)(new(big.Rat).SetInt(maxDecimal)),
195	NullDecimal{Valid: false, Decimal: (*Decimal)(big.NewRat(1, 1))},
196	NullDecimal{Valid: true, Decimal: (*Decimal)(big.NewRat(1, 1))},
197}
198
199func TestDecimal(t *testing.T) {
200	testDatatype(t, "decimal", 0, true, testDecimalData...)
201}
202
203func TestBoolean(t *testing.T) {
204	testDatatype(t, "boolean", 0, true,
205		true,
206		false,
207		sql.NullBool{Valid: false, Bool: true},
208		sql.NullBool{Valid: true, Bool: false},
209	)
210}
211
212func TestClob(t *testing.T) {
213	testInitLobFiles(t)
214	testLobDataASCII := make([]interface{}, 0, len(testLobFiles))
215	first := true
216	for _, f := range testLobFiles {
217		if f.isASCII {
218			if first {
219				testLobDataASCII = append(testLobDataASCII, NullLob{Valid: false, Lob: &Lob{rd: bytes.NewReader(f.content)}})
220				testLobDataASCII = append(testLobDataASCII, NullLob{Valid: true, Lob: &Lob{rd: bytes.NewReader(f.content)}})
221				first = false
222			}
223			testLobDataASCII = append(testLobDataASCII, Lob{rd: bytes.NewReader(f.content)})
224		}
225	}
226	testDatatype(t, "clob", 0, true, testLobDataASCII...)
227}
228
229func TestNclob(t *testing.T) {
230	testInitLobFiles(t)
231	testLobData := make([]interface{}, 0, len(testLobFiles)+2)
232	for i, f := range testLobFiles {
233		if i == 0 {
234			testLobData = append(testLobData, NullLob{Valid: false, Lob: &Lob{rd: bytes.NewReader(f.content)}})
235			testLobData = append(testLobData, NullLob{Valid: true, Lob: &Lob{rd: bytes.NewReader(f.content)}})
236		}
237		testLobData = append(testLobData, Lob{rd: bytes.NewReader(f.content)})
238	}
239	testDatatype(t, "nclob", 0, true, testLobData...)
240}
241
242func TestBlob(t *testing.T) {
243	testInitLobFiles(t)
244	testLobData := make([]interface{}, 0, len(testLobFiles)+2)
245	for i, f := range testLobFiles {
246		if i == 0 {
247			testLobData = append(testLobData, NullLob{Valid: false, Lob: &Lob{rd: bytes.NewReader(f.content)}})
248			testLobData = append(testLobData, NullLob{Valid: true, Lob: &Lob{rd: bytes.NewReader(f.content)}})
249		}
250		testLobData = append(testLobData, Lob{rd: bytes.NewReader(f.content)})
251	}
252	testDatatype(t, "blob", 0, true, testLobData...)
253}
254
255//
256func testDatatype(t *testing.T, dataType string, dataSize int, fixedSize bool, testData ...interface{}) {
257	db, err := sql.Open(DriverName, TestDSN)
258	if err != nil {
259		t.Fatal(err)
260	}
261	defer db.Close()
262
263	table := RandomIdentifier(fmt.Sprintf("%s_", dataType))
264
265	if dataSize == 0 {
266		if _, err := db.Exec(fmt.Sprintf("create table %s.%s (i integer, x %s)", TestSchema, table, dataType)); err != nil {
267			t.Fatal(err)
268		}
269	} else {
270		if _, err := db.Exec(fmt.Sprintf("create table %s.%s (i integer, x %s(%d))", TestSchema, table, dataType, dataSize)); err != nil {
271			t.Fatal(err)
272		}
273
274	}
275
276	// use trancactions:
277	// SQL Error 596 - LOB streaming is not permitted in auto-commit mode
278	tx, err := db.Begin()
279	if err != nil {
280		t.Fatal(err)
281	}
282
283	stmt, err := tx.Prepare(fmt.Sprintf("insert into %s.%s values(?, ?)", TestSchema, table))
284	if err != nil {
285		t.Fatal(err)
286	}
287
288	for i, in := range testData {
289
290		switch in := in.(type) {
291		case Lob:
292			in.rd.(*bytes.Reader).Seek(0, io.SeekStart)
293		case NullLob:
294			in.Lob.rd.(*bytes.Reader).Seek(0, io.SeekStart)
295		}
296
297		if _, err := stmt.Exec(i, in); err != nil {
298			t.Fatal(err)
299		}
300	}
301
302	if err := tx.Commit(); err != nil {
303		t.Fatal(err)
304	}
305
306	size := len(testData)
307	var i int
308
309	if err := db.QueryRow(fmt.Sprintf("select count(*) from %s.%s", TestSchema, table)).Scan(&i); err != nil {
310		t.Fatal(err)
311	}
312
313	if i != size {
314		t.Fatalf("rows %d - expected %d", i, size)
315	}
316
317	rows, err := db.Query(fmt.Sprintf("select * from %s.%s order by i", TestSchema, table))
318	if err != nil {
319		t.Fatal(err)
320	}
321	defer rows.Close()
322
323	var timestampCheck = equalLongdate
324	if driverDataFormatVersion == 1 {
325		timestampCheck = equalTimestamp
326	}
327
328	i = 0
329	for rows.Next() {
330
331		in := testData[i]
332		out := reflect.New(reflect.TypeOf(in)).Interface()
333
334		switch out := out.(type) {
335		case *NullDecimal:
336			out.Decimal = (*Decimal)(new(big.Rat))
337		case *Lob:
338			out.SetWriter(new(bytes.Buffer))
339		case *NullLob:
340			out.Lob = new(Lob).SetWriter(new(bytes.Buffer))
341		}
342
343		if err := rows.Scan(&i, out); err != nil {
344			log.Fatal(err)
345		}
346
347		switch out := out.(type) {
348		default:
349			t.Fatalf("%d unknown type %T", i, out)
350		case *uint8:
351			if *out != in.(uint8) {
352				t.Fatalf("%d value %v - expected %v", i, *out, in)
353			}
354		case *int16:
355			if *out != in.(int16) {
356				t.Fatalf("%d value %v - expected %v", i, *out, in)
357			}
358		case *int32:
359			if *out != in.(int32) {
360				t.Fatalf("%d value %v - expected %v", i, *out, in)
361			}
362		case *int64:
363			if *out != in.(int64) {
364				t.Fatalf("%d value %v - expected %v", i, *out, in)
365			}
366		case *float32:
367			if *out != in.(float32) {
368				t.Fatalf("%d value %v - expected %v", i, *out, in)
369			}
370		case *float64:
371			if *out != in.(float64) {
372				t.Fatalf("%d value %v - expected %v", i, *out, in)
373			}
374		case *string:
375			if fixedSize {
376				if !compareStringFixSize(in.(string), *out) {
377					t.Fatalf("%d value %v - expected %v", i, *out, in)
378				}
379			} else {
380				if *out != in.(string) {
381					t.Fatalf("%d value %v - expected %v", i, *out, in)
382				}
383			}
384		case *[]byte:
385			if fixedSize {
386				if !compareBytesFixSize(in.([]byte), *out) {
387					t.Fatalf("%d value %v - expected %v", i, *out, in)
388				}
389			} else {
390				if bytes.Compare(*out, in.([]byte)) != 0 {
391					t.Fatalf("%d value %v - expected %v", i, *out, in)
392				}
393			}
394		case *time.Time:
395			switch dataType {
396			default:
397				t.Fatalf("unknown data type %s", dataType)
398			case "date", "daydate":
399				if !equalDate(*out, in.(time.Time)) {
400					t.Fatalf("%d value %v - expected %v", i, *out, in)
401				}
402			case "time", "secondtime":
403				if !equalTime(*out, in.(time.Time)) {
404					t.Fatalf("%d value %v - expected %v", i, *out, in)
405				}
406			case "timestamp", "longdate":
407				if !timestampCheck(*out, in.(time.Time)) {
408					t.Fatalf("%d value %v - expected %v", i, *out, in)
409				}
410			case "seconddate":
411				if !equalDateTime(*out, in.(time.Time)) {
412					t.Fatalf("%d value %v - expected %v", i, *out, in)
413				}
414			}
415		case **Decimal:
416			if ((*big.Rat)(*out)).Cmp((*big.Rat)(in.(*Decimal))) != 0 {
417				t.Fatalf("%d value %s - expected %s", i, (*big.Rat)(*out), (*big.Rat)(in.(*Decimal)))
418			}
419		case *bool:
420			if *out != in.(bool) {
421				t.Fatalf("%d value %v - expected %v", i, *out, in)
422			}
423		case *Lob:
424			inLob := in.(Lob)
425			ok, err := compareLob(&inLob, out)
426			if err != nil {
427				t.Fatal(err)
428			}
429			if !ok {
430				t.Fatalf("%d lob content no equal", i)
431			}
432		case *sql.NullInt64:
433			in := in.(sql.NullInt64)
434			if in.Valid != out.Valid {
435				t.Fatalf("%d value %v - expected %v", i, out, in)
436			}
437			if in.Valid && in.Int64 != out.Int64 {
438				t.Fatalf("%d value %v - expected %v", i, out, in)
439			}
440		case *sql.NullFloat64:
441			in := in.(sql.NullFloat64)
442			if in.Valid != out.Valid {
443				t.Fatalf("%d value %v - expected %v", i, out, in)
444			}
445			if in.Valid && in.Float64 != out.Float64 {
446				t.Fatalf("%d value %v - expected %v", i, out, in)
447			}
448		case *sql.NullString:
449			in := in.(sql.NullString)
450			if in.Valid != out.Valid {
451				t.Fatalf("%d value %v - expected %v", i, out, in)
452			}
453			if in.Valid {
454				if fixedSize {
455					if !compareStringFixSize(in.String, out.String) {
456						t.Fatalf("%d value %v - expected %v", i, *out, in)
457					}
458				} else {
459					if in.String != out.String {
460						t.Fatalf("%d value %v - expected %v", i, out, in)
461					}
462				}
463			}
464		case *NullBytes:
465			in := in.(NullBytes)
466			if in.Valid != out.Valid {
467				t.Fatalf("%d value %v - expected %v", i, out, in)
468			}
469			if in.Valid {
470				if fixedSize {
471					if !compareBytesFixSize(in.Bytes, out.Bytes) {
472						t.Fatalf("%d value %v - expected %v", i, *out, in)
473					}
474				} else {
475					if bytes.Compare(in.Bytes, out.Bytes) != 0 {
476						t.Fatalf("%d value %v - expected %v", i, out, in)
477					}
478				}
479			}
480		case *NullTime:
481			in := in.(NullTime)
482			if in.Valid != out.Valid {
483				t.Fatalf("%d value %v - expected %v", i, out, in)
484			}
485			if in.Valid {
486				switch dataType {
487				default:
488					t.Fatalf("unknown data type %s", dataType)
489				case "date", "daydate":
490					if !equalDate(out.Time, in.Time) {
491						t.Fatalf("%d value %v - expected %v", i, *out, in)
492					}
493				case "time", "secondtime":
494					if !equalTime(out.Time, in.Time) {
495						t.Fatalf("%d value %v - expected %v", i, *out, in)
496					}
497				case "timestamp", "longdate":
498					if !timestampCheck(out.Time, in.Time) {
499						t.Fatalf("%d value %v - expected %v", i, *out, in)
500					}
501				case "seconddate":
502					if !equalDateTime(out.Time, in.Time) {
503						t.Fatalf("%d value %v - expected %v", i, *out, in)
504					}
505				}
506			}
507		case *NullDecimal:
508			in := in.(NullDecimal)
509			if in.Valid != out.Valid {
510				t.Fatalf("%d value %v - expected %v", i, out, in)
511			}
512			if in.Valid {
513				if ((*big.Rat)(in.Decimal)).Cmp((*big.Rat)(out.Decimal)) != 0 {
514					t.Fatalf("%d value %s - expected %s", i, (*big.Rat)(in.Decimal), (*big.Rat)(in.Decimal))
515				}
516			}
517		case *sql.NullBool:
518			in := in.(sql.NullBool)
519			if in.Valid != out.Valid {
520				t.Fatalf("%d value %v - expected %v", i, out, in)
521			}
522			if in.Valid && in.Bool != out.Bool {
523				t.Fatalf("%d value %v - expected %v", i, out, in)
524			}
525		case *NullLob:
526			in := in.(NullLob)
527			if in.Valid != out.Valid {
528				t.Fatalf("%d value %v - expected %v", i, out, in)
529			}
530			if in.Valid {
531				ok, err := compareLob(in.Lob, out.Lob)
532				if err != nil {
533					t.Fatal(err)
534				}
535				if !ok {
536					t.Fatalf("%d lob content no equal", i)
537				}
538			}
539		}
540		i++
541	}
542	if err := rows.Err(); err != nil {
543		log.Fatal(err)
544	}
545}
546
547// helper
548type testLobFile struct {
549	content []byte
550	isASCII bool
551}
552
553var testLobFiles []*testLobFile = make([]*testLobFile, 0)
554
555var testInitLobFilesOnce sync.Once
556
557func testInitLobFiles(t *testing.T) {
558	testInitLobFilesOnce.Do(func() {
559		filter := func(name string) bool {
560			for _, ext := range []string{".go"} {
561				if filepath.Ext(name) == ext {
562					return true
563				}
564			}
565			return false
566		}
567
568		walk := func(path string, info os.FileInfo, err error) error {
569			if !info.IsDir() && filter(info.Name()) {
570				content, err := ioutil.ReadFile(path)
571				if err != nil {
572					t.Fatal(err)
573				}
574				testLobFiles = append(testLobFiles, &testLobFile{isASCII: isASCII(content), content: content})
575			}
576			return nil
577		}
578
579		root, err := os.Getwd()
580		if err != nil {
581			t.Fatal(err)
582		}
583		filepath.Walk(root, walk)
584	})
585}
586
587func isASCII(content []byte) bool {
588	for _, b := range content {
589		if b >= utf8.RuneSelf {
590			return false
591		}
592	}
593	return true
594}
595
596func compareLob(in, out *Lob) (bool, error) {
597	in.rd.(*bytes.Reader).Seek(0, io.SeekStart)
598	content, err := ioutil.ReadAll(in.rd)
599	if err != nil {
600		return false, err
601	}
602	if !bytes.Equal(content, out.wr.(*bytes.Buffer).Bytes()) {
603		return false, nil
604	}
605	return true, nil
606}
607
608func compareStringFixSize(in, out string) bool {
609	if in != out[:len(in)] {
610		return false
611	}
612	for _, r := range out[len(in):] {
613		if r != rune(' ') {
614			return false
615		}
616	}
617	return true
618}
619
620func compareBytesFixSize(in, out []byte) bool {
621	if bytes.Compare(in, out[:len(in)]) != 0 {
622		return false
623	}
624	for _, r := range out[len(in):] {
625		if r != 0 {
626			return false
627		}
628	}
629	return true
630}
631
632func equalDate(t1, t2 time.Time) bool {
633	u1 := t1.UTC()
634	u2 := t2.UTC()
635	return u1.Year() == u2.Year() && u1.Month() == u2.Month() && u1.Day() == u2.Day()
636}
637
638func equalTime(t1, t2 time.Time) bool {
639	u1 := t1.UTC()
640	u2 := t2.UTC()
641	return u1.Hour() == u2.Hour() && u1.Minute() == u2.Minute() && u1.Second() == u2.Second()
642}
643
644func equalDateTime(t1, t2 time.Time) bool {
645	return equalDate(t1, t2) && equalTime(t1, t2)
646}
647
648// equalMillisecond tests if the nanosecond part of two time types rounded to milliseconds are equal.
649func equalMilliSecond(t1, t2 time.Time) bool {
650	u1 := t1.UTC()
651	u2 := t2.UTC()
652	return u1.Round(time.Millisecond).Nanosecond() == u2.Round(time.Millisecond).Nanosecond()
653}
654
655func equalTimestamp(t1, t2 time.Time) bool {
656	return equalDate(t1, t2) && equalTime(t1, t2) && equalMilliSecond(t1, t2)
657}
658
659func equalLongdate(t1, t2 time.Time) bool {
660	//HDB: nanosecond 7-digit precision
661	return equalDate(t1, t2) && equalTime(t1, t2) && (t1.Nanosecond()/100) == (t2.Nanosecond()/100)
662}
663