1package assert
2
3import (
4	"fmt"
5	"reflect"
6)
7
8type CompareType int
9
10const (
11	compareLess CompareType = iota - 1
12	compareEqual
13	compareGreater
14)
15
16var (
17	intType   = reflect.TypeOf(int(1))
18	int8Type  = reflect.TypeOf(int8(1))
19	int16Type = reflect.TypeOf(int16(1))
20	int32Type = reflect.TypeOf(int32(1))
21	int64Type = reflect.TypeOf(int64(1))
22
23	uintType   = reflect.TypeOf(uint(1))
24	uint8Type  = reflect.TypeOf(uint8(1))
25	uint16Type = reflect.TypeOf(uint16(1))
26	uint32Type = reflect.TypeOf(uint32(1))
27	uint64Type = reflect.TypeOf(uint64(1))
28
29	float32Type = reflect.TypeOf(float32(1))
30	float64Type = reflect.TypeOf(float64(1))
31
32	stringType = reflect.TypeOf("")
33)
34
35func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
36	obj1Value := reflect.ValueOf(obj1)
37	obj2Value := reflect.ValueOf(obj2)
38
39	// throughout this switch we try and avoid calling .Convert() if possible,
40	// as this has a pretty big performance impact
41	switch kind {
42	case reflect.Int:
43		{
44			intobj1, ok := obj1.(int)
45			if !ok {
46				intobj1 = obj1Value.Convert(intType).Interface().(int)
47			}
48			intobj2, ok := obj2.(int)
49			if !ok {
50				intobj2 = obj2Value.Convert(intType).Interface().(int)
51			}
52			if intobj1 > intobj2 {
53				return compareGreater, true
54			}
55			if intobj1 == intobj2 {
56				return compareEqual, true
57			}
58			if intobj1 < intobj2 {
59				return compareLess, true
60			}
61		}
62	case reflect.Int8:
63		{
64			int8obj1, ok := obj1.(int8)
65			if !ok {
66				int8obj1 = obj1Value.Convert(int8Type).Interface().(int8)
67			}
68			int8obj2, ok := obj2.(int8)
69			if !ok {
70				int8obj2 = obj2Value.Convert(int8Type).Interface().(int8)
71			}
72			if int8obj1 > int8obj2 {
73				return compareGreater, true
74			}
75			if int8obj1 == int8obj2 {
76				return compareEqual, true
77			}
78			if int8obj1 < int8obj2 {
79				return compareLess, true
80			}
81		}
82	case reflect.Int16:
83		{
84			int16obj1, ok := obj1.(int16)
85			if !ok {
86				int16obj1 = obj1Value.Convert(int16Type).Interface().(int16)
87			}
88			int16obj2, ok := obj2.(int16)
89			if !ok {
90				int16obj2 = obj2Value.Convert(int16Type).Interface().(int16)
91			}
92			if int16obj1 > int16obj2 {
93				return compareGreater, true
94			}
95			if int16obj1 == int16obj2 {
96				return compareEqual, true
97			}
98			if int16obj1 < int16obj2 {
99				return compareLess, true
100			}
101		}
102	case reflect.Int32:
103		{
104			int32obj1, ok := obj1.(int32)
105			if !ok {
106				int32obj1 = obj1Value.Convert(int32Type).Interface().(int32)
107			}
108			int32obj2, ok := obj2.(int32)
109			if !ok {
110				int32obj2 = obj2Value.Convert(int32Type).Interface().(int32)
111			}
112			if int32obj1 > int32obj2 {
113				return compareGreater, true
114			}
115			if int32obj1 == int32obj2 {
116				return compareEqual, true
117			}
118			if int32obj1 < int32obj2 {
119				return compareLess, true
120			}
121		}
122	case reflect.Int64:
123		{
124			int64obj1, ok := obj1.(int64)
125			if !ok {
126				int64obj1 = obj1Value.Convert(int64Type).Interface().(int64)
127			}
128			int64obj2, ok := obj2.(int64)
129			if !ok {
130				int64obj2 = obj2Value.Convert(int64Type).Interface().(int64)
131			}
132			if int64obj1 > int64obj2 {
133				return compareGreater, true
134			}
135			if int64obj1 == int64obj2 {
136				return compareEqual, true
137			}
138			if int64obj1 < int64obj2 {
139				return compareLess, true
140			}
141		}
142	case reflect.Uint:
143		{
144			uintobj1, ok := obj1.(uint)
145			if !ok {
146				uintobj1 = obj1Value.Convert(uintType).Interface().(uint)
147			}
148			uintobj2, ok := obj2.(uint)
149			if !ok {
150				uintobj2 = obj2Value.Convert(uintType).Interface().(uint)
151			}
152			if uintobj1 > uintobj2 {
153				return compareGreater, true
154			}
155			if uintobj1 == uintobj2 {
156				return compareEqual, true
157			}
158			if uintobj1 < uintobj2 {
159				return compareLess, true
160			}
161		}
162	case reflect.Uint8:
163		{
164			uint8obj1, ok := obj1.(uint8)
165			if !ok {
166				uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8)
167			}
168			uint8obj2, ok := obj2.(uint8)
169			if !ok {
170				uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8)
171			}
172			if uint8obj1 > uint8obj2 {
173				return compareGreater, true
174			}
175			if uint8obj1 == uint8obj2 {
176				return compareEqual, true
177			}
178			if uint8obj1 < uint8obj2 {
179				return compareLess, true
180			}
181		}
182	case reflect.Uint16:
183		{
184			uint16obj1, ok := obj1.(uint16)
185			if !ok {
186				uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16)
187			}
188			uint16obj2, ok := obj2.(uint16)
189			if !ok {
190				uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16)
191			}
192			if uint16obj1 > uint16obj2 {
193				return compareGreater, true
194			}
195			if uint16obj1 == uint16obj2 {
196				return compareEqual, true
197			}
198			if uint16obj1 < uint16obj2 {
199				return compareLess, true
200			}
201		}
202	case reflect.Uint32:
203		{
204			uint32obj1, ok := obj1.(uint32)
205			if !ok {
206				uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32)
207			}
208			uint32obj2, ok := obj2.(uint32)
209			if !ok {
210				uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32)
211			}
212			if uint32obj1 > uint32obj2 {
213				return compareGreater, true
214			}
215			if uint32obj1 == uint32obj2 {
216				return compareEqual, true
217			}
218			if uint32obj1 < uint32obj2 {
219				return compareLess, true
220			}
221		}
222	case reflect.Uint64:
223		{
224			uint64obj1, ok := obj1.(uint64)
225			if !ok {
226				uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64)
227			}
228			uint64obj2, ok := obj2.(uint64)
229			if !ok {
230				uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64)
231			}
232			if uint64obj1 > uint64obj2 {
233				return compareGreater, true
234			}
235			if uint64obj1 == uint64obj2 {
236				return compareEqual, true
237			}
238			if uint64obj1 < uint64obj2 {
239				return compareLess, true
240			}
241		}
242	case reflect.Float32:
243		{
244			float32obj1, ok := obj1.(float32)
245			if !ok {
246				float32obj1 = obj1Value.Convert(float32Type).Interface().(float32)
247			}
248			float32obj2, ok := obj2.(float32)
249			if !ok {
250				float32obj2 = obj2Value.Convert(float32Type).Interface().(float32)
251			}
252			if float32obj1 > float32obj2 {
253				return compareGreater, true
254			}
255			if float32obj1 == float32obj2 {
256				return compareEqual, true
257			}
258			if float32obj1 < float32obj2 {
259				return compareLess, true
260			}
261		}
262	case reflect.Float64:
263		{
264			float64obj1, ok := obj1.(float64)
265			if !ok {
266				float64obj1 = obj1Value.Convert(float64Type).Interface().(float64)
267			}
268			float64obj2, ok := obj2.(float64)
269			if !ok {
270				float64obj2 = obj2Value.Convert(float64Type).Interface().(float64)
271			}
272			if float64obj1 > float64obj2 {
273				return compareGreater, true
274			}
275			if float64obj1 == float64obj2 {
276				return compareEqual, true
277			}
278			if float64obj1 < float64obj2 {
279				return compareLess, true
280			}
281		}
282	case reflect.String:
283		{
284			stringobj1, ok := obj1.(string)
285			if !ok {
286				stringobj1 = obj1Value.Convert(stringType).Interface().(string)
287			}
288			stringobj2, ok := obj2.(string)
289			if !ok {
290				stringobj2 = obj2Value.Convert(stringType).Interface().(string)
291			}
292			if stringobj1 > stringobj2 {
293				return compareGreater, true
294			}
295			if stringobj1 == stringobj2 {
296				return compareEqual, true
297			}
298			if stringobj1 < stringobj2 {
299				return compareLess, true
300			}
301		}
302	}
303
304	return compareEqual, false
305}
306
307// Greater asserts that the first element is greater than the second
308//
309//    assert.Greater(t, 2, 1)
310//    assert.Greater(t, float64(2), float64(1))
311//    assert.Greater(t, "b", "a")
312func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
313	return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs)
314}
315
316// GreaterOrEqual asserts that the first element is greater than or equal to the second
317//
318//    assert.GreaterOrEqual(t, 2, 1)
319//    assert.GreaterOrEqual(t, 2, 2)
320//    assert.GreaterOrEqual(t, "b", "a")
321//    assert.GreaterOrEqual(t, "b", "b")
322func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
323	return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs)
324}
325
326// Less asserts that the first element is less than the second
327//
328//    assert.Less(t, 1, 2)
329//    assert.Less(t, float64(1), float64(2))
330//    assert.Less(t, "a", "b")
331func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
332	return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs)
333}
334
335// LessOrEqual asserts that the first element is less than or equal to the second
336//
337//    assert.LessOrEqual(t, 1, 2)
338//    assert.LessOrEqual(t, 2, 2)
339//    assert.LessOrEqual(t, "a", "b")
340//    assert.LessOrEqual(t, "b", "b")
341func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
342	return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs)
343}
344
345func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
346	if h, ok := t.(tHelper); ok {
347		h.Helper()
348	}
349
350	e1Kind := reflect.ValueOf(e1).Kind()
351	e2Kind := reflect.ValueOf(e2).Kind()
352	if e1Kind != e2Kind {
353		return Fail(t, "Elements should be the same type", msgAndArgs...)
354	}
355
356	compareResult, isComparable := compare(e1, e2, e1Kind)
357	if !isComparable {
358		return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
359	}
360
361	if !containsValue(allowedComparesResults, compareResult) {
362		return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...)
363	}
364
365	return true
366}
367
368func containsValue(values []CompareType, value CompareType) bool {
369	for _, v := range values {
370		if v == value {
371			return true
372		}
373	}
374
375	return false
376}
377