1package assert
2
3import (
4	"fmt"
5	"go/ast"
6	"go/token"
7	"reflect"
8
9	"gotest.tools/v3/assert/cmp"
10	"gotest.tools/v3/internal/format"
11	"gotest.tools/v3/internal/source"
12)
13
14// LogT is the subset of testing.T used by the assert package.
15type LogT interface {
16	Log(args ...interface{})
17}
18
19type helperT interface {
20	Helper()
21}
22
23const failureMessage = "assertion failed: "
24
25// Eval the comparison and print a failure messages if the comparison has failed.
26// nolint: gocyclo
27func Eval(
28	t LogT,
29	argSelector argSelector,
30	comparison interface{},
31	msgAndArgs ...interface{},
32) bool {
33	if ht, ok := t.(helperT); ok {
34		ht.Helper()
35	}
36	var success bool
37	switch check := comparison.(type) {
38	case bool:
39		if check {
40			return true
41		}
42		logFailureFromBool(t, msgAndArgs...)
43
44	// Undocumented legacy comparison without Result type
45	case func() (success bool, message string):
46		success = runCompareFunc(t, check, msgAndArgs...)
47
48	case nil:
49		return true
50
51	case error:
52		msg := failureMsgFromError(check)
53		t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
54
55	case cmp.Comparison:
56		success = RunComparison(t, argSelector, check, msgAndArgs...)
57
58	case func() cmp.Result:
59		success = RunComparison(t, argSelector, check, msgAndArgs...)
60
61	default:
62		t.Log(fmt.Sprintf("invalid Comparison: %v (%T)", check, check))
63	}
64	return success
65}
66
67func runCompareFunc(
68	t LogT,
69	f func() (success bool, message string),
70	msgAndArgs ...interface{},
71) bool {
72	if ht, ok := t.(helperT); ok {
73		ht.Helper()
74	}
75	if success, message := f(); !success {
76		t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
77		return false
78	}
79	return true
80}
81
82func logFailureFromBool(t LogT, msgAndArgs ...interface{}) {
83	if ht, ok := t.(helperT); ok {
84		ht.Helper()
85	}
86	const stackIndex = 3 // Assert()/Check(), assert(), logFailureFromBool()
87	args, err := source.CallExprArgs(stackIndex)
88	if err != nil {
89		t.Log(err.Error())
90		return
91	}
92
93	const comparisonArgIndex = 1 // Assert(t, comparison)
94	if len(args) <= comparisonArgIndex {
95		t.Log(failureMessage + "but assert failed to find the expression to print")
96		return
97	}
98
99	msg, err := boolFailureMessage(args[comparisonArgIndex])
100	if err != nil {
101		t.Log(err.Error())
102		msg = "expression is false"
103	}
104
105	t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
106}
107
108func failureMsgFromError(err error) string {
109	// Handle errors with non-nil types
110	v := reflect.ValueOf(err)
111	if v.Kind() == reflect.Ptr && v.IsNil() {
112		return fmt.Sprintf("error is not nil: error has type %T", err)
113	}
114	return "error is not nil: " + err.Error()
115}
116
117func boolFailureMessage(expr ast.Expr) (string, error) {
118	if binaryExpr, ok := expr.(*ast.BinaryExpr); ok && binaryExpr.Op == token.NEQ {
119		x, err := source.FormatNode(binaryExpr.X)
120		if err != nil {
121			return "", err
122		}
123		y, err := source.FormatNode(binaryExpr.Y)
124		if err != nil {
125			return "", err
126		}
127		return x + " is " + y, nil
128	}
129
130	if unaryExpr, ok := expr.(*ast.UnaryExpr); ok && unaryExpr.Op == token.NOT {
131		x, err := source.FormatNode(unaryExpr.X)
132		if err != nil {
133			return "", err
134		}
135		return x + " is true", nil
136	}
137
138	formatted, err := source.FormatNode(expr)
139	if err != nil {
140		return "", err
141	}
142	return "expression is false: " + formatted, nil
143}
144