1package assert
2
3import (
4	"bytes"
5	"reflect"
6	"testing"
7)
8
9func NoError(t testing.TB, err error) {
10	if err != nil {
11		t.Helper()
12		t.Fatalf("%+v", err)
13	}
14}
15
16func Error(t testing.TB, err error) {
17	if err == nil {
18		t.Helper()
19		t.Fatal("expected an error")
20	}
21}
22
23func Equal(t testing.TB, a, b interface{}) {
24	if ta, tb := reflect.TypeOf(a), reflect.TypeOf(b); ta != nil && tb != nil {
25		if ta.Comparable() && tb.Comparable() {
26			if a == b || literalConvert(a) == literalConvert(b) {
27				return
28			}
29		}
30	}
31
32	if deepEqual(a, b) {
33		return
34	}
35
36	t.Helper()
37	t.Fatalf("%#v != %#v", a, b)
38}
39
40func DeepEqual(t testing.TB, a, b interface{}) {
41	if !deepEqual(a, b) {
42		t.Helper()
43		t.Fatalf("%#v != %#v", a, b)
44	}
45}
46
47func That(t testing.TB, v bool) {
48	if !v {
49		t.Helper()
50		t.Fatal("expected condition failed")
51	}
52}
53
54func True(t testing.TB, v bool) {
55	if !v {
56		t.Helper()
57		t.Fatal("expected condition failed")
58	}
59}
60
61func False(t testing.TB, v bool) {
62	if v {
63		t.Helper()
64		t.Fatal("expected condition failed")
65	}
66}
67
68func Nil(t testing.TB, a interface{}) {
69	if a == nil {
70		return
71	}
72
73	rv := reflect.ValueOf(a)
74	if !canNil(rv) {
75		t.Helper()
76		t.Fatalf("%#v cannot be nil", a)
77	}
78	if !rv.IsNil() {
79		t.Helper()
80		t.Fatalf("%#v != nil", a)
81	}
82}
83
84func NotNil(t testing.TB, a interface{}) {
85	if a == nil {
86		t.Helper()
87		t.Fatal("expected not nil")
88	}
89
90	rv := reflect.ValueOf(a)
91	if !canNil(rv) {
92		return
93	}
94	if rv.IsNil() {
95		t.Helper()
96		t.Fatalf("%#v == nil", a)
97	}
98}
99
100func deepEqual(a, b interface{}) bool {
101	ab, aok := a.([]byte)
102	bb, bok := b.([]byte)
103	if aok && bok {
104		return bytes.Equal(ab, bb)
105	}
106	return reflect.DeepEqual(a, b)
107}
108
109func canNil(rv reflect.Value) bool {
110	switch rv.Kind() {
111	case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
112		return true
113	}
114	return false
115}
116
117func literalConvert(val interface{}) interface{} {
118	switch val := reflect.ValueOf(val); val.Kind() {
119	case reflect.Bool:
120		return val.Bool()
121
122	case reflect.String:
123		return val.Convert(reflect.TypeOf("")).Interface()
124
125	case reflect.Float32, reflect.Float64:
126		return val.Float()
127
128	case reflect.Complex64, reflect.Complex128:
129		return val.Complex()
130
131	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
132		if asInt := val.Int(); asInt < 0 {
133			return asInt
134		}
135		return val.Convert(reflect.TypeOf(uint64(0))).Uint()
136
137	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
138		return val.Uint()
139
140	default:
141		return val
142	}
143}
144