1// Copyright 2017, The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE.md file.
4
5package cmpopts
6
7import (
8	"fmt"
9	"reflect"
10	"strings"
11
12	"github.com/google/go-cmp/cmp"
13)
14
15// filterField returns a new Option where opt is only evaluated on paths that
16// include a specific exported field on a single struct type.
17// The struct type is specified by passing in a value of that type.
18//
19// The name may be a dot-delimited string (e.g., "Foo.Bar") to select a
20// specific sub-field that is embedded or nested within the parent struct.
21func filterField(typ interface{}, name string, opt cmp.Option) cmp.Option {
22	// TODO: This is currently unexported over concerns of how helper filters
23	// can be composed together easily.
24	// TODO: Add tests for FilterField.
25
26	sf := newStructFilter(typ, name)
27	return cmp.FilterPath(sf.filter, opt)
28}
29
30type structFilter struct {
31	t  reflect.Type // The root struct type to match on
32	ft fieldTree    // Tree of fields to match on
33}
34
35func newStructFilter(typ interface{}, names ...string) structFilter {
36	// TODO: Perhaps allow * as a special identifier to allow ignoring any
37	// number of path steps until the next field match?
38	// This could be useful when a concrete struct gets transformed into
39	// an anonymous struct where it is not possible to specify that by type,
40	// but the transformer happens to provide guarantees about the names of
41	// the transformed fields.
42
43	t := reflect.TypeOf(typ)
44	if t == nil || t.Kind() != reflect.Struct {
45		panic(fmt.Sprintf("%T must be a non-pointer struct", typ))
46	}
47	var ft fieldTree
48	for _, name := range names {
49		cname, err := canonicalName(t, name)
50		if err != nil {
51			panic(fmt.Sprintf("%s: %v", strings.Join(cname, "."), err))
52		}
53		ft.insert(cname)
54	}
55	return structFilter{t, ft}
56}
57
58func (sf structFilter) filter(p cmp.Path) bool {
59	for i, ps := range p {
60		if ps.Type().AssignableTo(sf.t) && sf.ft.matchPrefix(p[i+1:]) {
61			return true
62		}
63	}
64	return false
65}
66
67// fieldTree represents a set of dot-separated identifiers.
68//
69// For example, inserting the following selectors:
70//	Foo
71//	Foo.Bar.Baz
72//	Foo.Buzz
73//	Nuka.Cola.Quantum
74//
75// Results in a tree of the form:
76//	{sub: {
77//		"Foo": {ok: true, sub: {
78//			"Bar": {sub: {
79//				"Baz": {ok: true},
80//			}},
81//			"Buzz": {ok: true},
82//		}},
83//		"Nuka": {sub: {
84//			"Cola": {sub: {
85//				"Quantum": {ok: true},
86//			}},
87//		}},
88//	}}
89type fieldTree struct {
90	ok  bool                 // Whether this is a specified node
91	sub map[string]fieldTree // The sub-tree of fields under this node
92}
93
94// insert inserts a sequence of field accesses into the tree.
95func (ft *fieldTree) insert(cname []string) {
96	if ft.sub == nil {
97		ft.sub = make(map[string]fieldTree)
98	}
99	if len(cname) == 0 {
100		ft.ok = true
101		return
102	}
103	sub := ft.sub[cname[0]]
104	sub.insert(cname[1:])
105	ft.sub[cname[0]] = sub
106}
107
108// matchPrefix reports whether any selector in the fieldTree matches
109// the start of path p.
110func (ft fieldTree) matchPrefix(p cmp.Path) bool {
111	for _, ps := range p {
112		switch ps := ps.(type) {
113		case cmp.StructField:
114			ft = ft.sub[ps.Name()]
115			if ft.ok {
116				return true
117			}
118			if len(ft.sub) == 0 {
119				return false
120			}
121		case cmp.Indirect:
122		default:
123			return false
124		}
125	}
126	return false
127}
128
129// canonicalName returns a list of identifiers where any struct field access
130// through an embedded field is expanded to include the names of the embedded
131// types themselves.
132//
133// For example, suppose field "Foo" is not directly in the parent struct,
134// but actually from an embedded struct of type "Bar". Then, the canonical name
135// of "Foo" is actually "Bar.Foo".
136//
137// Suppose field "Foo" is not directly in the parent struct, but actually
138// a field in two different embedded structs of types "Bar" and "Baz".
139// Then the selector "Foo" causes a panic since it is ambiguous which one it
140// refers to. The user must specify either "Bar.Foo" or "Baz.Foo".
141func canonicalName(t reflect.Type, sel string) ([]string, error) {
142	var name string
143	sel = strings.TrimPrefix(sel, ".")
144	if sel == "" {
145		return nil, fmt.Errorf("name must not be empty")
146	}
147	if i := strings.IndexByte(sel, '.'); i < 0 {
148		name, sel = sel, ""
149	} else {
150		name, sel = sel[:i], sel[i:]
151	}
152
153	// Type must be a struct or pointer to struct.
154	if t.Kind() == reflect.Ptr {
155		t = t.Elem()
156	}
157	if t.Kind() != reflect.Struct {
158		return nil, fmt.Errorf("%v must be a struct", t)
159	}
160
161	// Find the canonical name for this current field name.
162	// If the field exists in an embedded struct, then it will be expanded.
163	sf, _ := t.FieldByName(name)
164	if !isExported(name) {
165		// Avoid using reflect.Type.FieldByName for unexported fields due to
166		// buggy behavior with regard to embeddeding and unexported fields.
167		// See https://golang.org/issue/4876 for details.
168		sf = reflect.StructField{}
169		for i := 0; i < t.NumField() && sf.Name == ""; i++ {
170			if t.Field(i).Name == name {
171				sf = t.Field(i)
172			}
173		}
174	}
175	if sf.Name == "" {
176		return []string{name}, fmt.Errorf("does not exist")
177	}
178	var ss []string
179	for i := range sf.Index {
180		ss = append(ss, t.FieldByIndex(sf.Index[:i+1]).Name)
181	}
182	if sel == "" {
183		return ss, nil
184	}
185	ssPost, err := canonicalName(sf.Type, sel)
186	return append(ss, ssPost...), err
187}
188