1package format
2
3import (
4	"testing"
5
6	"github.com/google/go-cmp/cmp"
7	"google.golang.org/grpc/codes"
8	"google.golang.org/grpc/metadata"
9	"google.golang.org/grpc/status"
10)
11
12type formatter struct {
13	FormatHeaderCalled, FormatMessageCalled, FormatStatusCalled, FormatTrailerCalled bool
14}
15
16func (f *formatter) FormatHeader(header metadata.MD) {
17	f.FormatHeaderCalled = true
18}
19
20func (f *formatter) FormatMessage(v interface{}) error {
21	f.FormatMessageCalled = true
22	return nil
23}
24
25func (f *formatter) FormatStatus(status *status.Status) error {
26	f.FormatStatusCalled = true
27	return nil
28}
29
30func (f *formatter) FormatTrailer(trailer metadata.MD) {
31	f.FormatTrailerCalled = true
32}
33
34func (f *formatter) Done() error {
35	return nil
36}
37
38func TestResponseFormatter(t *testing.T) {
39	cases := map[string]struct {
40		enrich bool
41	}{
42		"enrich=true":  {true},
43		"enrich=false": {false},
44	}
45
46	for name, c := range cases {
47		c := c
48		t.Run(name, func(t *testing.T) {
49			impl := &formatter{}
50			f := NewResponseFormatter(impl, c.enrich)
51			f.FormatHeader(metadata.Pairs("key", "val"))
52			if err := f.FormatMessage(struct{}{}); err != nil {
53				t.Fatalf("FormatMessage should not return an error, but got '%s'", err)
54			}
55			if err := f.FormatTrailer(status.New(codes.Internal, "internal error"), metadata.Pairs("key", "val")); err != nil {
56				t.Fatalf("FormatTrailer should not return an error, but got '%s'", err)
57			}
58			if err := f.Done(); err != nil {
59				t.Fatalf("Done should not return an error, but got '%s'", err)
60			}
61
62			res := map[bool]bool{
63				true:  impl.FormatHeaderCalled && impl.FormatMessageCalled && impl.FormatTrailerCalled && impl.FormatStatusCalled,
64				false: impl.FormatMessageCalled,
65			}
66
67			if called, ok := res[c.enrich]; ok && !called {
68				t.Errorf("expected true, but false")
69			}
70
71			t.Run("Format", func(t *testing.T) {
72				impl := &formatter{}
73				f := NewResponseFormatter(impl, c.enrich)
74				err := f.Format(
75					status.New(codes.Internal, "internal error"),
76					metadata.Pairs("key", "val"),
77					metadata.Pairs("key", "val"),
78					struct{}{},
79				)
80				if err != nil {
81					t.Fatalf("Format should not return an error, but got '%s'", err)
82				}
83				if err := f.Done(); err != nil {
84					t.Fatalf("Done should not return an error, but got '%s'", err)
85				}
86
87				FormatRes := map[bool]bool{
88					true:  impl.FormatHeaderCalled && impl.FormatMessageCalled && impl.FormatTrailerCalled && impl.FormatStatusCalled,
89					false: impl.FormatMessageCalled,
90				}
91
92				if diff := cmp.Diff(res, FormatRes); diff != "" {
93					t.Errorf("two results should be equal:\n%s", diff)
94				}
95			})
96		})
97	}
98}
99