1package reflectwalk
2
3import (
4	"fmt"
5	"reflect"
6	"testing"
7)
8
9type TestEnterExitWalker struct {
10	Locs []Location
11}
12
13func (t *TestEnterExitWalker) Enter(l Location) error {
14	if t.Locs == nil {
15		t.Locs = make([]Location, 0, 5)
16	}
17
18	t.Locs = append(t.Locs, l)
19	return nil
20}
21
22func (t *TestEnterExitWalker) Exit(l Location) error {
23	t.Locs = append(t.Locs, l)
24	return nil
25}
26
27type TestPointerWalker struct {
28	pointers []bool
29	count    int
30	enters   int
31	exits    int
32}
33
34func (t *TestPointerWalker) PointerEnter(v bool) error {
35	t.pointers = append(t.pointers, v)
36	t.enters++
37	if v {
38		t.count++
39	}
40	return nil
41}
42
43func (t *TestPointerWalker) PointerExit(v bool) error {
44	t.exits++
45	if t.pointers[len(t.pointers)-1] != v {
46		return fmt.Errorf("bad pointer exit '%t' at exit %d", v, t.exits)
47	}
48	t.pointers = t.pointers[:len(t.pointers)-1]
49	return nil
50}
51
52type TestPointerValueWalker struct {
53	skip     bool
54	pointers []reflect.Type
55}
56
57func (t *TestPointerValueWalker) Pointer(v reflect.Value) error {
58	t.pointers = append(t.pointers, v.Type())
59	if t.skip {
60		return SkipEntry
61	}
62
63	return nil
64}
65
66type TestPrimitiveWalker struct {
67	Value reflect.Value
68}
69
70func (t *TestPrimitiveWalker) Primitive(v reflect.Value) error {
71	t.Value = v
72	return nil
73}
74
75type TestPrimitiveCountWalker struct {
76	Count int
77}
78
79func (t *TestPrimitiveCountWalker) Primitive(v reflect.Value) error {
80	t.Count += 1
81	return nil
82}
83
84type TestPrimitiveReplaceWalker struct {
85	Value reflect.Value
86}
87
88func (t *TestPrimitiveReplaceWalker) Primitive(v reflect.Value) error {
89	v.Set(reflect.ValueOf("bar"))
90	return nil
91}
92
93type TestMapWalker struct {
94	MapVal reflect.Value
95	Keys   map[string]bool
96	Values map[string]bool
97}
98
99func (t *TestMapWalker) Map(m reflect.Value) error {
100	t.MapVal = m
101	return nil
102}
103
104func (t *TestMapWalker) MapElem(m, k, v reflect.Value) error {
105	if t.Keys == nil {
106		t.Keys = make(map[string]bool)
107		t.Values = make(map[string]bool)
108	}
109
110	t.Keys[k.Interface().(string)] = true
111	t.Values[v.Interface().(string)] = true
112	return nil
113}
114
115type TestMapElemReplaceWalker struct {
116	ValueFn func(v reflect.Value) reflect.Value
117}
118
119func (t *TestMapElemReplaceWalker) Map(m reflect.Value) error {
120	return nil
121}
122
123func (t *TestMapElemReplaceWalker) MapElem(m, k, v reflect.Value) error {
124	m.SetMapIndex(k, t.ValueFn(v))
125	return nil
126}
127
128type TestSliceWalker struct {
129	Count    int
130	SliceVal reflect.Value
131}
132
133func (t *TestSliceWalker) Slice(v reflect.Value) error {
134	t.SliceVal = v
135	return nil
136}
137
138func (t *TestSliceWalker) SliceElem(int, reflect.Value) error {
139	t.Count++
140	return nil
141}
142
143type TestArrayWalker struct {
144	Count    int
145	ArrayVal reflect.Value
146}
147
148func (t *TestArrayWalker) Array(v reflect.Value) error {
149	t.ArrayVal = v
150	return nil
151}
152
153func (t *TestArrayWalker) ArrayElem(int, reflect.Value) error {
154	t.Count++
155	return nil
156}
157
158type TestStructWalker struct {
159	Fields []string
160}
161
162func (t *TestStructWalker) Struct(v reflect.Value) error {
163	return nil
164}
165
166func (t *TestStructWalker) StructField(sf reflect.StructField, v reflect.Value) error {
167	if t.Fields == nil {
168		t.Fields = make([]string, 0, 1)
169	}
170
171	t.Fields = append(t.Fields, sf.Name)
172	return nil
173}
174
175func TestTestStructs(t *testing.T) {
176	var raw interface{}
177	raw = new(TestEnterExitWalker)
178	if _, ok := raw.(EnterExitWalker); !ok {
179		t.Fatal("EnterExitWalker is bad")
180	}
181
182	raw = new(TestPrimitiveWalker)
183	if _, ok := raw.(PrimitiveWalker); !ok {
184		t.Fatal("PrimitiveWalker is bad")
185	}
186
187	raw = new(TestMapWalker)
188	if _, ok := raw.(MapWalker); !ok {
189		t.Fatal("MapWalker is bad")
190	}
191
192	raw = new(TestSliceWalker)
193	if _, ok := raw.(SliceWalker); !ok {
194		t.Fatal("SliceWalker is bad")
195	}
196
197	raw = new(TestArrayWalker)
198	if _, ok := raw.(ArrayWalker); !ok {
199		t.Fatal("ArrayWalker is bad")
200	}
201
202	raw = new(TestStructWalker)
203	if _, ok := raw.(StructWalker); !ok {
204		t.Fatal("StructWalker is bad")
205	}
206}
207
208func TestWalk_Basic(t *testing.T) {
209	w := new(TestPrimitiveWalker)
210
211	type S struct {
212		Foo string
213	}
214
215	data := &S{
216		Foo: "foo",
217	}
218
219	err := Walk(data, w)
220	if err != nil {
221		t.Fatalf("err: %s", err)
222	}
223
224	if w.Value.Kind() != reflect.String {
225		t.Fatalf("bad: %#v", w.Value)
226	}
227}
228
229func TestWalk_Basic_Replace(t *testing.T) {
230	w := new(TestPrimitiveReplaceWalker)
231
232	type S struct {
233		Foo string
234		Bar []interface{}
235	}
236
237	data := &S{
238		Foo: "foo",
239		Bar: []interface{}{[]string{"what"}},
240	}
241
242	err := Walk(data, w)
243	if err != nil {
244		t.Fatalf("err: %s", err)
245	}
246
247	if data.Foo != "bar" {
248		t.Fatalf("bad: %#v", data.Foo)
249	}
250	if data.Bar[0].([]string)[0] != "bar" {
251		t.Fatalf("bad: %#v", data.Bar)
252	}
253}
254
255func TestWalk_Basic_ReplaceInterface(t *testing.T) {
256	w := new(TestPrimitiveReplaceWalker)
257
258	type S struct {
259		Foo []interface{}
260	}
261
262	data := &S{
263		Foo: []interface{}{"foo"},
264	}
265
266	err := Walk(data, w)
267	if err != nil {
268		t.Fatalf("err: %s", err)
269	}
270}
271
272func TestWalk_EnterExit(t *testing.T) {
273	w := new(TestEnterExitWalker)
274
275	type S struct {
276		A string
277		M map[string]string
278	}
279
280	data := &S{
281		A: "foo",
282		M: map[string]string{
283			"a": "b",
284		},
285	}
286
287	err := Walk(data, w)
288	if err != nil {
289		t.Fatalf("err: %s", err)
290	}
291
292	expected := []Location{
293		WalkLoc,
294		Struct,
295		StructField,
296		StructField,
297		StructField,
298		Map,
299		MapKey,
300		MapKey,
301		MapValue,
302		MapValue,
303		Map,
304		StructField,
305		Struct,
306		WalkLoc,
307	}
308	if !reflect.DeepEqual(w.Locs, expected) {
309		t.Fatalf("Bad: %#v", w.Locs)
310	}
311}
312
313func TestWalk_Interface(t *testing.T) {
314	w := new(TestPrimitiveCountWalker)
315
316	type S struct {
317		Foo string
318		Bar []interface{}
319	}
320
321	var data interface{} = &S{
322		Foo: "foo",
323		Bar: []interface{}{[]string{"bar", "what"}, "baz"},
324	}
325
326	err := Walk(data, w)
327	if err != nil {
328		t.Fatalf("err: %s", err)
329	}
330
331	if w.Count != 4 {
332		t.Fatalf("bad: %#v", w.Count)
333	}
334}
335
336func TestWalk_Interface_nil(t *testing.T) {
337	w := new(TestPrimitiveCountWalker)
338
339	type S struct {
340		Bar interface{}
341	}
342
343	var data interface{} = &S{}
344
345	err := Walk(data, w)
346	if err != nil {
347		t.Fatalf("err: %s", err)
348	}
349}
350
351func TestWalk_Map(t *testing.T) {
352	w := new(TestMapWalker)
353
354	type S struct {
355		Foo map[string]string
356	}
357
358	data := &S{
359		Foo: map[string]string{
360			"foo": "foov",
361			"bar": "barv",
362		},
363	}
364
365	err := Walk(data, w)
366	if err != nil {
367		t.Fatalf("err: %s", err)
368	}
369
370	if !reflect.DeepEqual(w.MapVal.Interface(), data.Foo) {
371		t.Fatalf("Bad: %#v", w.MapVal.Interface())
372	}
373
374	expectedK := map[string]bool{"foo": true, "bar": true}
375	if !reflect.DeepEqual(w.Keys, expectedK) {
376		t.Fatalf("Bad keys: %#v", w.Keys)
377	}
378
379	expectedV := map[string]bool{"foov": true, "barv": true}
380	if !reflect.DeepEqual(w.Values, expectedV) {
381		t.Fatalf("Bad values: %#v", w.Values)
382	}
383}
384
385func TestWalk_Map_ReplaceValue(t *testing.T) {
386	w := &TestMapElemReplaceWalker{
387		ValueFn: func(v reflect.Value) reflect.Value {
388			if v.Type().Kind() == reflect.String {
389				return reflect.ValueOf("replaced")
390			}
391
392			if v.Type().Kind() == reflect.Interface {
393				if elem := v.Elem(); elem.Type() == reflect.TypeOf(map[string]interface{}{}) {
394					newMap := make(map[string]interface{})
395					for _, k := range elem.MapKeys() {
396						newMap[k.String()] = elem.MapIndex(k).Interface()
397					}
398					newMap["extra-replaced"] = "not-replaced"
399					return reflect.ValueOf(newMap)
400				} else if elem.Type().Kind() == reflect.String {
401					return reflect.ValueOf("replaced")
402				}
403			}
404
405			return v
406		},
407	}
408
409	type S struct {
410		Foo map[string]interface{}
411	}
412
413	data := &S{
414		Foo: map[string]interface{}{
415			"foo": map[string]interface{}{
416				"bar": map[string]string{"baz": "should-get-replaced"},
417			},
418		},
419	}
420
421	expected := &S{
422		Foo: map[string]interface{}{
423			"foo": map[string]interface{}{
424				"bar":            map[string]string{"baz": "replaced"},
425				"extra-replaced": "replaced",
426			},
427		},
428	}
429
430	err := Walk(data, w)
431	if err != nil {
432		t.Fatalf("err: %v", err)
433	}
434
435	if !reflect.DeepEqual(data, expected) {
436		t.Fatalf("Values not equal: %#v", data)
437	}
438}
439
440func TestWalk_Pointer(t *testing.T) {
441	w := new(TestPointerWalker)
442
443	type S struct {
444		Foo string
445		Bar *string
446		Baz **string
447	}
448
449	s := ""
450	sp := &s
451
452	data := &S{
453		Baz: &sp,
454	}
455
456	err := Walk(data, w)
457	if err != nil {
458		t.Fatalf("err: %s", err)
459	}
460
461	if w.enters != 5 {
462		t.Fatal("expected 4 values, saw", w.enters)
463	}
464
465	if w.count != 4 {
466		t.Fatal("exptec 3 pointers, saw", w.count)
467	}
468
469	if w.exits != w.enters {
470		t.Fatalf("number of enters (%d) and exits (%d) don't match", w.enters, w.exits)
471	}
472}
473
474func TestWalk_PointerPointer(t *testing.T) {
475	w := new(TestPointerWalker)
476
477	s := ""
478	sp := &s
479	pp := &sp
480
481	err := Walk(pp, w)
482	if err != nil {
483		t.Fatalf("err: %s", err)
484	}
485
486	if w.enters != 2 {
487		t.Fatal("expected 2 values, saw", w.enters)
488	}
489
490	if w.count != 2 {
491		t.Fatal("expected 2 pointers, saw", w.count)
492	}
493
494	if w.exits != w.enters {
495		t.Fatalf("number of enters (%d) and exits (%d) don't match", w.enters, w.exits)
496	}
497}
498
499func TestWalk_PointerValue(t *testing.T) {
500	type X struct{}
501	type T struct {
502		x *X
503	}
504
505	v := &T{x: &X{}}
506
507	expected := []reflect.Type{
508		reflect.TypeOf(v),
509		reflect.TypeOf(v.x),
510	}
511
512	w := new(TestPointerValueWalker)
513	err := Walk(v, w)
514	if err != nil {
515		t.Fatal(err)
516	}
517
518	if !reflect.DeepEqual(expected, w.pointers) {
519		t.Fatalf("unexpected pointer order or length (expected len=%d, actual len=%d)", len(expected), len(w.pointers))
520	}
521}
522
523func TestWalk_PointerValueSkip(t *testing.T) {
524	type T struct{}
525	type W struct {
526		TestPointerValueWalker
527		TestPointerWalker
528	}
529
530	v := &T{}
531	w := &W{
532		TestPointerValueWalker: TestPointerValueWalker{
533			skip: true,
534		},
535	}
536	err := Walk(v, w)
537	if err != nil {
538		t.Fatal(err)
539	}
540
541	if len(w.TestPointerValueWalker.pointers) != 1 {
542		t.Errorf("expected len=1, got len=%d", len(w.TestPointerValueWalker.pointers))
543	}
544
545	if w.TestPointerValueWalker.pointers[0] != reflect.TypeOf(v) {
546		t.Error("pointer value type mismatch")
547	}
548
549	if w.enters != 0 || w.exits != 0 {
550		t.Error("should have been skipped and have zero enters or exits")
551	}
552}
553
554func TestWalk_Slice(t *testing.T) {
555	w := new(TestSliceWalker)
556
557	type S struct {
558		Foo []string
559	}
560
561	data := &S{
562		Foo: []string{"a", "b", "c"},
563	}
564
565	err := Walk(data, w)
566	if err != nil {
567		t.Fatalf("err: %s", err)
568	}
569
570	if !reflect.DeepEqual(w.SliceVal.Interface(), data.Foo) {
571		t.Fatalf("bad: %#v", w.SliceVal.Interface())
572	}
573
574	if w.Count != 3 {
575		t.Fatalf("Bad count: %d", w.Count)
576	}
577}
578
579func TestWalk_SliceWithPtr(t *testing.T) {
580	w := new(TestSliceWalker)
581
582	// This is key, the panic only happened when the slice field was
583	// an interface!
584	type I interface{}
585
586	type S struct {
587		Foo []I
588	}
589
590	type Empty struct{}
591
592	data := &S{
593		Foo: []I{&Empty{}},
594	}
595
596	err := Walk(data, w)
597	if err != nil {
598		t.Fatalf("err: %s", err)
599	}
600
601	if !reflect.DeepEqual(w.SliceVal.Interface(), data.Foo) {
602		t.Fatalf("bad: %#v", w.SliceVal.Interface())
603	}
604
605	if w.Count != 1 {
606		t.Fatalf("Bad count: %d", w.Count)
607	}
608}
609
610func TestWalk_Array(t *testing.T) {
611	w := new(TestArrayWalker)
612
613	type S struct {
614		Foo [3]string
615	}
616
617	data := &S{
618		Foo: [3]string{"a", "b", "c"},
619	}
620
621	err := Walk(data, w)
622	if err != nil {
623		t.Fatalf("err: %s", err)
624	}
625
626	if !reflect.DeepEqual(w.ArrayVal.Interface(), data.Foo) {
627		t.Fatalf("bad: %#v", w.ArrayVal.Interface())
628	}
629
630	if w.Count != 3 {
631		t.Fatalf("Bad count: %d", w.Count)
632	}
633}
634
635func TestWalk_ArrayWithPtr(t *testing.T) {
636	w := new(TestArrayWalker)
637
638	// based on similar slice test
639	type I interface{}
640
641	type S struct {
642		Foo [1]I
643	}
644
645	type Empty struct{}
646
647	data := &S{
648		Foo: [1]I{&Empty{}},
649	}
650
651	err := Walk(data, w)
652	if err != nil {
653		t.Fatalf("err: %s", err)
654	}
655
656	if !reflect.DeepEqual(w.ArrayVal.Interface(), data.Foo) {
657		t.Fatalf("bad: %#v", w.ArrayVal.Interface())
658	}
659
660	if w.Count != 1 {
661		t.Fatalf("Bad count: %d", w.Count)
662	}
663}
664
665type testErr struct{}
666
667func (t *testErr) Error() string {
668	return "test error"
669}
670
671func TestWalk_Struct(t *testing.T) {
672	w := new(TestStructWalker)
673
674	// This makes sure we can also walk over pointer-to-pointers, and the ever
675	// so rare pointer-to-interface
676	type S struct {
677		Foo string
678		Bar *string
679		Baz **string
680		Err *error
681	}
682
683	bar := "ptr"
684	baz := &bar
685	e := error(&testErr{})
686
687	data := &S{
688		Foo: "foo",
689		Bar: &bar,
690		Baz: &baz,
691		Err: &e,
692	}
693
694	err := Walk(data, w)
695	if err != nil {
696		t.Fatalf("err: %s", err)
697	}
698
699	expected := []string{"Foo", "Bar", "Baz", "Err"}
700	if !reflect.DeepEqual(w.Fields, expected) {
701		t.Fatalf("bad: %#v", w.Fields)
702	}
703}
704
705// Very similar to above test but used to fail for #2, copied here for
706// regression testing
707func TestWalk_StructWithPtr(t *testing.T) {
708	w := new(TestStructWalker)
709
710	type S struct {
711		Foo string
712		Bar string
713		Baz *int
714	}
715
716	data := &S{
717		Foo: "foo",
718		Bar: "bar",
719	}
720
721	err := Walk(data, w)
722	if err != nil {
723		t.Fatalf("err: %s", err)
724	}
725
726	expected := []string{"Foo", "Bar", "Baz"}
727	if !reflect.DeepEqual(w.Fields, expected) {
728		t.Fatalf("bad: %#v", w.Fields)
729	}
730}
731
732type TestInterfaceMapWalker struct {
733	MapVal reflect.Value
734	Keys   map[string]bool
735	Values map[interface{}]bool
736}
737
738func (t *TestInterfaceMapWalker) Map(m reflect.Value) error {
739	t.MapVal = m
740	return nil
741}
742
743func (t *TestInterfaceMapWalker) MapElem(m, k, v reflect.Value) error {
744	if t.Keys == nil {
745		t.Keys = make(map[string]bool)
746		t.Values = make(map[interface{}]bool)
747	}
748
749	t.Keys[k.Interface().(string)] = true
750	t.Values[v.Interface()] = true
751	return nil
752}
753
754func TestWalk_MapWithPointers(t *testing.T) {
755	w := new(TestInterfaceMapWalker)
756
757	type S struct {
758		Foo map[string]interface{}
759	}
760
761	a := "a"
762	b := "b"
763
764	data := &S{
765		Foo: map[string]interface{}{
766			"foo": &a,
767			"bar": &b,
768			"baz": 11,
769			"zab": (*int)(nil),
770		},
771	}
772
773	err := Walk(data, w)
774	if err != nil {
775		t.Fatalf("err: %s", err)
776	}
777
778	if !reflect.DeepEqual(w.MapVal.Interface(), data.Foo) {
779		t.Fatalf("Bad: %#v", w.MapVal.Interface())
780	}
781
782	expectedK := map[string]bool{"foo": true, "bar": true, "baz": true, "zab": true}
783	if !reflect.DeepEqual(w.Keys, expectedK) {
784		t.Fatalf("Bad keys: %#v", w.Keys)
785	}
786
787	expectedV := map[interface{}]bool{&a: true, &b: true, 11: true, (*int)(nil): true}
788	if !reflect.DeepEqual(w.Values, expectedV) {
789		t.Fatalf("Bad values: %#v", w.Values)
790	}
791}
792
793type TestStructWalker_fieldSkip struct {
794	Skip   bool
795	Fields int
796}
797
798func (t *TestStructWalker_fieldSkip) Enter(l Location) error {
799	if l == StructField {
800		t.Fields++
801	}
802
803	return nil
804}
805
806func (t *TestStructWalker_fieldSkip) Exit(Location) error {
807	return nil
808}
809
810func (t *TestStructWalker_fieldSkip) Struct(v reflect.Value) error {
811	return nil
812}
813
814func (t *TestStructWalker_fieldSkip) StructField(sf reflect.StructField, v reflect.Value) error {
815	if t.Skip && sf.Name[0] == '_' {
816		return SkipEntry
817	}
818
819	return nil
820}
821
822func TestWalk_StructWithSkipEntry(t *testing.T) {
823	data := &struct {
824		Foo, _Bar int
825	}{
826		Foo:  1,
827		_Bar: 2,
828	}
829
830	{
831		var s TestStructWalker_fieldSkip
832		if err := Walk(data, &s); err != nil {
833			t.Fatalf("err: %s", err)
834		}
835
836		if s.Fields != 2 {
837			t.Fatalf("bad: %d", s.Fields)
838		}
839	}
840
841	{
842		var s TestStructWalker_fieldSkip
843		s.Skip = true
844		if err := Walk(data, &s); err != nil {
845			t.Fatalf("err: %s", err)
846		}
847
848		if s.Fields != 1 {
849			t.Fatalf("bad: %d", s.Fields)
850		}
851	}
852}
853
854type TestStructWalker_valueSkip struct {
855	Skip   bool
856	Fields int
857}
858
859func (t *TestStructWalker_valueSkip) Enter(l Location) error {
860	if l == StructField {
861		t.Fields++
862	}
863
864	return nil
865}
866
867func (t *TestStructWalker_valueSkip) Exit(Location) error {
868	return nil
869}
870
871func (t *TestStructWalker_valueSkip) Struct(v reflect.Value) error {
872	if t.Skip {
873		return SkipEntry
874	}
875
876	return nil
877}
878
879func (t *TestStructWalker_valueSkip) StructField(sf reflect.StructField, v reflect.Value) error {
880	return nil
881}
882
883func TestWalk_StructParentWithSkipEntry(t *testing.T) {
884	data := &struct {
885		Foo, _Bar int
886	}{
887		Foo:  1,
888		_Bar: 2,
889	}
890
891	{
892		var s TestStructWalker_valueSkip
893		if err := Walk(data, &s); err != nil {
894			t.Fatalf("err: %s", err)
895		}
896
897		if s.Fields != 2 {
898			t.Fatalf("bad: %d", s.Fields)
899		}
900	}
901
902	{
903		var s TestStructWalker_valueSkip
904		s.Skip = true
905		if err := Walk(data, &s); err != nil {
906			t.Fatalf("err: %s", err)
907		}
908
909		if s.Fields != 0 {
910			t.Fatalf("bad: %d", s.Fields)
911		}
912	}
913}
914