1package apidiff
2
3import (
4	"fmt"
5	"go/types"
6	"reflect"
7)
8
9func (d *differ) checkCompatible(otn *types.TypeName, old, new types.Type) {
10	switch old := old.(type) {
11	case *types.Interface:
12		if new, ok := new.(*types.Interface); ok {
13			d.checkCompatibleInterface(otn, old, new)
14			return
15		}
16
17	case *types.Struct:
18		if new, ok := new.(*types.Struct); ok {
19			d.checkCompatibleStruct(otn, old, new)
20			return
21		}
22
23	case *types.Chan:
24		if new, ok := new.(*types.Chan); ok {
25			d.checkCompatibleChan(otn, old, new)
26			return
27		}
28
29	case *types.Basic:
30		if new, ok := new.(*types.Basic); ok {
31			d.checkCompatibleBasic(otn, old, new)
32			return
33		}
34
35	case *types.Named:
36		panic("unreachable")
37
38	default:
39		d.checkCorrespondence(otn, "", old, new)
40		return
41
42	}
43	// Here if old and new are different kinds of types.
44	d.typeChanged(otn, "", old, new)
45}
46
47func (d *differ) checkCompatibleChan(otn *types.TypeName, old, new *types.Chan) {
48	d.checkCorrespondence(otn, ", element type", old.Elem(), new.Elem())
49	if old.Dir() != new.Dir() {
50		if new.Dir() == types.SendRecv {
51			d.compatible(otn, "", "removed direction")
52		} else {
53			d.incompatible(otn, "", "changed direction")
54		}
55	}
56}
57
58func (d *differ) checkCompatibleBasic(otn *types.TypeName, old, new *types.Basic) {
59	// Certain changes to numeric types are compatible. Approximately, the info must
60	// be the same, and the new values must be a superset of the old.
61	if old.Kind() == new.Kind() {
62		// old and new are identical
63		return
64	}
65	if compatibleBasics[[2]types.BasicKind{old.Kind(), new.Kind()}] {
66		d.compatible(otn, "", "changed from %s to %s", old, new)
67	} else {
68		d.typeChanged(otn, "", old, new)
69	}
70}
71
72// All pairs (old, new) of compatible basic types.
73var compatibleBasics = map[[2]types.BasicKind]bool{
74	{types.Uint8, types.Uint16}:         true,
75	{types.Uint8, types.Uint32}:         true,
76	{types.Uint8, types.Uint}:           true,
77	{types.Uint8, types.Uint64}:         true,
78	{types.Uint16, types.Uint32}:        true,
79	{types.Uint16, types.Uint}:          true,
80	{types.Uint16, types.Uint64}:        true,
81	{types.Uint32, types.Uint}:          true,
82	{types.Uint32, types.Uint64}:        true,
83	{types.Uint, types.Uint64}:          true,
84	{types.Int8, types.Int16}:           true,
85	{types.Int8, types.Int32}:           true,
86	{types.Int8, types.Int}:             true,
87	{types.Int8, types.Int64}:           true,
88	{types.Int16, types.Int32}:          true,
89	{types.Int16, types.Int}:            true,
90	{types.Int16, types.Int64}:          true,
91	{types.Int32, types.Int}:            true,
92	{types.Int32, types.Int64}:          true,
93	{types.Int, types.Int64}:            true,
94	{types.Float32, types.Float64}:      true,
95	{types.Complex64, types.Complex128}: true,
96}
97
98// Interface compatibility:
99// If the old interface has an unexported method, the new interface is compatible
100// if its exported method set is a superset of the old. (Users could not implement,
101// only embed.)
102//
103// If the old interface did not have an unexported method, the new interface is
104// compatible if its exported method set is the same as the old, and it has no
105// unexported methods. (Adding an unexported method makes the interface
106// unimplementable outside the package.)
107//
108// TODO: must also check that if any methods were added or removed, every exposed
109// type in the package that implemented the interface in old still implements it in
110// new. Otherwise external assignments could fail.
111func (d *differ) checkCompatibleInterface(otn *types.TypeName, old, new *types.Interface) {
112	// Method sets are checked in checkCompatibleDefined.
113
114	// Does the old interface have an unexported method?
115	if unexportedMethod(old) != nil {
116		d.checkMethodSet(otn, old, new, additionsCompatible)
117	} else {
118		// Perform an equivalence check, but with more information.
119		d.checkMethodSet(otn, old, new, additionsIncompatible)
120		if u := unexportedMethod(new); u != nil {
121			d.incompatible(otn, u.Name(), "added unexported method")
122		}
123	}
124}
125
126// Return an unexported method from the method set of t, or nil if there are none.
127func unexportedMethod(t *types.Interface) *types.Func {
128	for i := 0; i < t.NumMethods(); i++ {
129		if m := t.Method(i); !m.Exported() {
130			return m
131		}
132	}
133	return nil
134}
135
136// We need to check three things for structs:
137// 1. The set of exported fields must be compatible. This ensures that keyed struct
138//    literals continue to compile. (There is no compatibility guarantee for unkeyed
139//    struct literals.)
140// 2. The set of exported *selectable* fields must be compatible. This includes the exported
141//    fields of all embedded structs. This ensures that selections continue to compile.
142// 3. If the old struct is comparable, so must the new one be. This ensures that equality
143//    expressions and uses of struct values as map keys continue to compile.
144//
145// An unexported embedded struct can't appear in a struct literal outside the
146// package, so it doesn't have to be present, or have the same name, in the new
147// struct.
148//
149// Field tags are ignored: they have no compile-time implications.
150func (d *differ) checkCompatibleStruct(obj types.Object, old, new *types.Struct) {
151	d.checkCompatibleObjectSets(obj, exportedFields(old), exportedFields(new))
152	d.checkCompatibleObjectSets(obj, exportedSelectableFields(old), exportedSelectableFields(new))
153	// Removing comparability from a struct is an incompatible change.
154	if types.Comparable(old) && !types.Comparable(new) {
155		d.incompatible(obj, "", "old is comparable, new is not")
156	}
157}
158
159// exportedFields collects all the immediate fields of the struct that are exported.
160// This is also the set of exported keys for keyed struct literals.
161func exportedFields(s *types.Struct) map[string]types.Object {
162	m := map[string]types.Object{}
163	for i := 0; i < s.NumFields(); i++ {
164		f := s.Field(i)
165		if f.Exported() {
166			m[f.Name()] = f
167		}
168	}
169	return m
170}
171
172// exportedSelectableFields collects all the exported fields of the struct, including
173// exported fields of embedded structs.
174//
175// We traverse the struct breadth-first, because of the rule that a lower-depth field
176// shadows one at a higher depth.
177func exportedSelectableFields(s *types.Struct) map[string]types.Object {
178	var (
179		m    = map[string]types.Object{}
180		next []*types.Struct // embedded structs at the next depth
181		seen []*types.Struct // to handle recursive embedding
182	)
183	for cur := []*types.Struct{s}; len(cur) > 0; cur, next = next, nil {
184		seen = append(seen, cur...)
185		// We only want to consider unambiguous fields. Ambiguous fields (where there
186		// is more than one field of the same name at the same level) are legal, but
187		// cannot be selected.
188		for name, f := range unambiguousFields(cur) {
189			// Record an exported field we haven't seen before. If we have seen it,
190			// it occurred a lower depth, so it shadows this field.
191			if f.Exported() && m[name] == nil {
192				m[name] = f
193			}
194			// Remember embedded structs for processing at the next depth,
195			// but only if we haven't seen the struct at this depth or above.
196			if !f.Anonymous() {
197				continue
198			}
199			t := f.Type().Underlying()
200			if p, ok := t.(*types.Pointer); ok {
201				t = p.Elem().Underlying()
202			}
203			if t, ok := t.(*types.Struct); ok && !contains(seen, t) {
204				next = append(next, t)
205			}
206		}
207	}
208	return m
209}
210
211func contains(ts []*types.Struct, t *types.Struct) bool {
212	for _, s := range ts {
213		if types.Identical(s, t) {
214			return true
215		}
216	}
217	return false
218}
219
220// Given a set of structs at the same depth, the unambiguous fields are the ones whose
221// names appear exactly once.
222func unambiguousFields(structs []*types.Struct) map[string]*types.Var {
223	fields := map[string]*types.Var{}
224	seen := map[string]bool{}
225	for _, s := range structs {
226		for i := 0; i < s.NumFields(); i++ {
227			f := s.Field(i)
228			name := f.Name()
229			if seen[name] {
230				delete(fields, name)
231			} else {
232				seen[name] = true
233				fields[name] = f
234			}
235		}
236	}
237	return fields
238}
239
240// Anything removed or change from the old set is an incompatible change.
241// Anything added to the new set is a compatible change.
242func (d *differ) checkCompatibleObjectSets(obj types.Object, old, new map[string]types.Object) {
243	for name, oldo := range old {
244		newo := new[name]
245		if newo == nil {
246			d.incompatible(obj, name, "removed")
247		} else {
248			d.checkCorrespondence(obj, name, oldo.Type(), newo.Type())
249		}
250	}
251	for name := range new {
252		if old[name] == nil {
253			d.compatible(obj, name, "added")
254		}
255	}
256}
257
258func (d *differ) checkCompatibleDefined(otn *types.TypeName, old *types.Named, new types.Type) {
259	// We've already checked that old and new correspond.
260	d.checkCompatible(otn, old.Underlying(), new.Underlying())
261	// If there are different kinds of types (e.g. struct and interface), don't bother checking
262	// the method sets.
263	if reflect.TypeOf(old.Underlying()) != reflect.TypeOf(new.Underlying()) {
264		return
265	}
266	// Interface method sets are checked in checkCompatibleInterface.
267	if _, ok := old.Underlying().(*types.Interface); ok {
268		return
269	}
270
271	// A new method set is compatible with an old if the new exported methods are a superset of the old.
272	d.checkMethodSet(otn, old, new, additionsCompatible)
273	d.checkMethodSet(otn, types.NewPointer(old), types.NewPointer(new), additionsCompatible)
274}
275
276const (
277	additionsCompatible   = true
278	additionsIncompatible = false
279)
280
281func (d *differ) checkMethodSet(otn *types.TypeName, oldt, newt types.Type, addcompat bool) {
282	// TODO: find a way to use checkCompatibleObjectSets for this.
283	oldMethodSet := exportedMethods(oldt)
284	newMethodSet := exportedMethods(newt)
285	msname := otn.Name()
286	if _, ok := oldt.(*types.Pointer); ok {
287		msname = "*" + msname
288	}
289	for name, oldMethod := range oldMethodSet {
290		newMethod := newMethodSet[name]
291		if newMethod == nil {
292			var part string
293			// Due to embedding, it's possible that the method's receiver type is not
294			// the same as the defined type whose method set we're looking at. So for
295			// a type T with removed method M that is embedded in some other type U,
296			// we will generate two "removed" messages for T.M, one for its own type
297			// T and one for the embedded type U. We want both messages to appear,
298			// but the messageSet dedup logic will allow only one message for a given
299			// object. So use the part string to distinguish them.
300			if receiverNamedType(oldMethod).Obj() != otn {
301				part = fmt.Sprintf(", method set of %s", msname)
302			}
303			d.incompatible(oldMethod, part, "removed")
304		} else {
305			obj := oldMethod
306			// If a value method is changed to a pointer method and has a signature
307			// change, then we can get two messages for the same method definition: one
308			// for the value method set that says it's removed, and another for the
309			// pointer method set that says it changed. To keep both messages (since
310			// messageSet dedups), use newMethod for the second. (Slight hack.)
311			if !hasPointerReceiver(oldMethod) && hasPointerReceiver(newMethod) {
312				obj = newMethod
313			}
314			d.checkCorrespondence(obj, "", oldMethod.Type(), newMethod.Type())
315		}
316	}
317
318	// Check for added methods.
319	for name, newMethod := range newMethodSet {
320		if oldMethodSet[name] == nil {
321			if addcompat {
322				d.compatible(newMethod, "", "added")
323			} else {
324				d.incompatible(newMethod, "", "added")
325			}
326		}
327	}
328}
329
330// exportedMethods collects all the exported methods of type's method set.
331func exportedMethods(t types.Type) map[string]types.Object {
332	m := map[string]types.Object{}
333	ms := types.NewMethodSet(t)
334	for i := 0; i < ms.Len(); i++ {
335		obj := ms.At(i).Obj()
336		if obj.Exported() {
337			m[obj.Name()] = obj
338		}
339	}
340	return m
341}
342
343func receiverType(method types.Object) types.Type {
344	return method.Type().(*types.Signature).Recv().Type()
345}
346
347func receiverNamedType(method types.Object) *types.Named {
348	switch t := receiverType(method).(type) {
349	case *types.Pointer:
350		return t.Elem().(*types.Named)
351	case *types.Named:
352		return t
353	default:
354		panic("unreachable")
355	}
356}
357
358func hasPointerReceiver(method types.Object) bool {
359	_, ok := receiverType(method).(*types.Pointer)
360	return ok
361}
362