1package assertion
2
3import (
4	"fmt"
5	"reflect"
6
7	"github.com/onsi/gomega/types"
8)
9
10type Assertion struct {
11	actualInput interface{}
12	failWrapper *types.GomegaFailWrapper
13	offset      int
14	extra       []interface{}
15}
16
17func New(actualInput interface{}, failWrapper *types.GomegaFailWrapper, offset int, extra ...interface{}) *Assertion {
18	return &Assertion{
19		actualInput: actualInput,
20		failWrapper: failWrapper,
21		offset:      offset,
22		extra:       extra,
23	}
24}
25
26func (assertion *Assertion) Should(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool {
27	assertion.failWrapper.TWithHelper.Helper()
28	return assertion.vetExtras(optionalDescription...) && assertion.match(matcher, true, optionalDescription...)
29}
30
31func (assertion *Assertion) ShouldNot(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool {
32	assertion.failWrapper.TWithHelper.Helper()
33	return assertion.vetExtras(optionalDescription...) && assertion.match(matcher, false, optionalDescription...)
34}
35
36func (assertion *Assertion) To(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool {
37	assertion.failWrapper.TWithHelper.Helper()
38	return assertion.vetExtras(optionalDescription...) && assertion.match(matcher, true, optionalDescription...)
39}
40
41func (assertion *Assertion) ToNot(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool {
42	assertion.failWrapper.TWithHelper.Helper()
43	return assertion.vetExtras(optionalDescription...) && assertion.match(matcher, false, optionalDescription...)
44}
45
46func (assertion *Assertion) NotTo(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool {
47	assertion.failWrapper.TWithHelper.Helper()
48	return assertion.vetExtras(optionalDescription...) && assertion.match(matcher, false, optionalDescription...)
49}
50
51func (assertion *Assertion) buildDescription(optionalDescription ...interface{}) string {
52	switch len(optionalDescription) {
53	case 0:
54		return ""
55	case 1:
56		if describe, ok := optionalDescription[0].(func() string); ok {
57			return describe() + "\n"
58		}
59	}
60	return fmt.Sprintf(optionalDescription[0].(string), optionalDescription[1:]...) + "\n"
61}
62
63func (assertion *Assertion) match(matcher types.GomegaMatcher, desiredMatch bool, optionalDescription ...interface{}) bool {
64	matches, err := matcher.Match(assertion.actualInput)
65	assertion.failWrapper.TWithHelper.Helper()
66	if err != nil {
67		description := assertion.buildDescription(optionalDescription...)
68		assertion.failWrapper.Fail(description+err.Error(), 2+assertion.offset)
69		return false
70	}
71	if matches != desiredMatch {
72		var message string
73		if desiredMatch {
74			message = matcher.FailureMessage(assertion.actualInput)
75		} else {
76			message = matcher.NegatedFailureMessage(assertion.actualInput)
77		}
78		description := assertion.buildDescription(optionalDescription...)
79		assertion.failWrapper.Fail(description+message, 2+assertion.offset)
80		return false
81	}
82
83	return true
84}
85
86func (assertion *Assertion) vetExtras(optionalDescription ...interface{}) bool {
87	success, message := vetExtras(assertion.extra)
88	if success {
89		return true
90	}
91
92	description := assertion.buildDescription(optionalDescription...)
93	assertion.failWrapper.TWithHelper.Helper()
94	assertion.failWrapper.Fail(description+message, 2+assertion.offset)
95	return false
96}
97
98func vetExtras(extras []interface{}) (bool, string) {
99	for i, extra := range extras {
100		if extra != nil {
101			zeroValue := reflect.Zero(reflect.TypeOf(extra)).Interface()
102			if !reflect.DeepEqual(zeroValue, extra) {
103				message := fmt.Sprintf("Unexpected non-nil/non-zero extra argument at index %d:\n\t<%T>: %#v", i+1, extra, extra)
104				return false, message
105			}
106		}
107	}
108	return true, ""
109}
110