1package assert 2 3import ( 4 "fmt" 5 "go/ast" 6 "go/token" 7 "reflect" 8 9 "gotest.tools/v3/assert/cmp" 10 "gotest.tools/v3/internal/format" 11 "gotest.tools/v3/internal/source" 12) 13 14// LogT is the subset of testing.T used by the assert package. 15type LogT interface { 16 Log(args ...interface{}) 17} 18 19type helperT interface { 20 Helper() 21} 22 23const failureMessage = "assertion failed: " 24 25// Eval the comparison and print a failure messages if the comparison has failed. 26// nolint: gocyclo 27func Eval( 28 t LogT, 29 argSelector argSelector, 30 comparison interface{}, 31 msgAndArgs ...interface{}, 32) bool { 33 if ht, ok := t.(helperT); ok { 34 ht.Helper() 35 } 36 var success bool 37 switch check := comparison.(type) { 38 case bool: 39 if check { 40 return true 41 } 42 logFailureFromBool(t, msgAndArgs...) 43 44 // Undocumented legacy comparison without Result type 45 case func() (success bool, message string): 46 success = runCompareFunc(t, check, msgAndArgs...) 47 48 case nil: 49 return true 50 51 case error: 52 msg := failureMsgFromError(check) 53 t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...)) 54 55 case cmp.Comparison: 56 success = RunComparison(t, argSelector, check, msgAndArgs...) 57 58 case func() cmp.Result: 59 success = RunComparison(t, argSelector, check, msgAndArgs...) 60 61 default: 62 t.Log(fmt.Sprintf("invalid Comparison: %v (%T)", check, check)) 63 } 64 return success 65} 66 67func runCompareFunc( 68 t LogT, 69 f func() (success bool, message string), 70 msgAndArgs ...interface{}, 71) bool { 72 if ht, ok := t.(helperT); ok { 73 ht.Helper() 74 } 75 if success, message := f(); !success { 76 t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...)) 77 return false 78 } 79 return true 80} 81 82func logFailureFromBool(t LogT, msgAndArgs ...interface{}) { 83 if ht, ok := t.(helperT); ok { 84 ht.Helper() 85 } 86 const stackIndex = 3 // Assert()/Check(), assert(), logFailureFromBool() 87 args, err := source.CallExprArgs(stackIndex) 88 if err != nil { 89 t.Log(err.Error()) 90 return 91 } 92 93 const comparisonArgIndex = 1 // Assert(t, comparison) 94 if len(args) <= comparisonArgIndex { 95 t.Log(failureMessage + "but assert failed to find the expression to print") 96 return 97 } 98 99 msg, err := boolFailureMessage(args[comparisonArgIndex]) 100 if err != nil { 101 t.Log(err.Error()) 102 msg = "expression is false" 103 } 104 105 t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...)) 106} 107 108func failureMsgFromError(err error) string { 109 // Handle errors with non-nil types 110 v := reflect.ValueOf(err) 111 if v.Kind() == reflect.Ptr && v.IsNil() { 112 return fmt.Sprintf("error is not nil: error has type %T", err) 113 } 114 return "error is not nil: " + err.Error() 115} 116 117func boolFailureMessage(expr ast.Expr) (string, error) { 118 if binaryExpr, ok := expr.(*ast.BinaryExpr); ok && binaryExpr.Op == token.NEQ { 119 x, err := source.FormatNode(binaryExpr.X) 120 if err != nil { 121 return "", err 122 } 123 y, err := source.FormatNode(binaryExpr.Y) 124 if err != nil { 125 return "", err 126 } 127 return x + " is " + y, nil 128 } 129 130 if unaryExpr, ok := expr.(*ast.UnaryExpr); ok && unaryExpr.Op == token.NOT { 131 x, err := source.FormatNode(unaryExpr.X) 132 if err != nil { 133 return "", err 134 } 135 return x + " is true", nil 136 } 137 138 formatted, err := source.FormatNode(expr) 139 if err != nil { 140 return "", err 141 } 142 return "expression is false: " + formatted, nil 143} 144