1package cmp
2
3import (
4	"bytes"
5	"fmt"
6	"go/ast"
7	"text/template"
8
9	"gotest.tools/internal/source"
10)
11
12// Result of a Comparison.
13type Result interface {
14	Success() bool
15}
16
17type result struct {
18	success bool
19	message string
20}
21
22func (r result) Success() bool {
23	return r.success
24}
25
26func (r result) FailureMessage() string {
27	return r.message
28}
29
30// ResultSuccess is a constant which is returned by a ComparisonWithResult to
31// indicate success.
32var ResultSuccess = result{success: true}
33
34// ResultFailure returns a failed Result with a failure message.
35func ResultFailure(message string) Result {
36	return result{message: message}
37}
38
39// ResultFromError returns ResultSuccess if err is nil. Otherwise ResultFailure
40// is returned with the error message as the failure message.
41func ResultFromError(err error) Result {
42	if err == nil {
43		return ResultSuccess
44	}
45	return ResultFailure(err.Error())
46}
47
48type templatedResult struct {
49	success  bool
50	template string
51	data     map[string]interface{}
52}
53
54func (r templatedResult) Success() bool {
55	return r.success
56}
57
58func (r templatedResult) FailureMessage(args []ast.Expr) string {
59	msg, err := renderMessage(r, args)
60	if err != nil {
61		return fmt.Sprintf("failed to render failure message: %s", err)
62	}
63	return msg
64}
65
66// ResultFailureTemplate returns a Result with a template string and data which
67// can be used to format a failure message. The template may access data from .Data,
68// the comparison args with the callArg function, and the formatNode function may
69// be used to format the call args.
70func ResultFailureTemplate(template string, data map[string]interface{}) Result {
71	return templatedResult{template: template, data: data}
72}
73
74func renderMessage(result templatedResult, args []ast.Expr) (string, error) {
75	tmpl := template.New("failure").Funcs(template.FuncMap{
76		"formatNode": source.FormatNode,
77		"callArg": func(index int) ast.Expr {
78			if index >= len(args) {
79				return nil
80			}
81			return args[index]
82		},
83	})
84	var err error
85	tmpl, err = tmpl.Parse(result.template)
86	if err != nil {
87		return "", err
88	}
89	buf := new(bytes.Buffer)
90	err = tmpl.Execute(buf, map[string]interface{}{
91		"Data": result.data,
92	})
93	return buf.String(), err
94}
95