1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package apidiff
6
7import (
8	"fmt"
9	"go/types"
10	"reflect"
11)
12
13func (d *differ) checkCompatible(otn *types.TypeName, old, new types.Type) {
14	switch old := old.(type) {
15	case *types.Interface:
16		if new, ok := new.(*types.Interface); ok {
17			d.checkCompatibleInterface(otn, old, new)
18			return
19		}
20
21	case *types.Struct:
22		if new, ok := new.(*types.Struct); ok {
23			d.checkCompatibleStruct(otn, old, new)
24			return
25		}
26
27	case *types.Chan:
28		if new, ok := new.(*types.Chan); ok {
29			d.checkCompatibleChan(otn, old, new)
30			return
31		}
32
33	case *types.Basic:
34		if new, ok := new.(*types.Basic); ok {
35			d.checkCompatibleBasic(otn, old, new)
36			return
37		}
38
39	case *types.Named:
40		panic("unreachable")
41
42	default:
43		d.checkCorrespondence(otn, "", old, new)
44		return
45
46	}
47	// Here if old and new are different kinds of types.
48	d.typeChanged(otn, "", old, new)
49}
50
51func (d *differ) checkCompatibleChan(otn *types.TypeName, old, new *types.Chan) {
52	d.checkCorrespondence(otn, ", element type", old.Elem(), new.Elem())
53	if old.Dir() != new.Dir() {
54		if new.Dir() == types.SendRecv {
55			d.compatible(otn, "", "removed direction")
56		} else {
57			d.incompatible(otn, "", "changed direction")
58		}
59	}
60}
61
62func (d *differ) checkCompatibleBasic(otn *types.TypeName, old, new *types.Basic) {
63	// Certain changes to numeric types are compatible. Approximately, the info must
64	// be the same, and the new values must be a superset of the old.
65	if old.Kind() == new.Kind() {
66		// old and new are identical
67		return
68	}
69	if compatibleBasics[[2]types.BasicKind{old.Kind(), new.Kind()}] {
70		d.compatible(otn, "", "changed from %s to %s", old, new)
71	} else {
72		d.typeChanged(otn, "", old, new)
73	}
74}
75
76// All pairs (old, new) of compatible basic types.
77var compatibleBasics = map[[2]types.BasicKind]bool{
78	{types.Uint8, types.Uint16}:         true,
79	{types.Uint8, types.Uint32}:         true,
80	{types.Uint8, types.Uint}:           true,
81	{types.Uint8, types.Uint64}:         true,
82	{types.Uint16, types.Uint32}:        true,
83	{types.Uint16, types.Uint}:          true,
84	{types.Uint16, types.Uint64}:        true,
85	{types.Uint32, types.Uint}:          true,
86	{types.Uint32, types.Uint64}:        true,
87	{types.Uint, types.Uint64}:          true,
88	{types.Int8, types.Int16}:           true,
89	{types.Int8, types.Int32}:           true,
90	{types.Int8, types.Int}:             true,
91	{types.Int8, types.Int64}:           true,
92	{types.Int16, types.Int32}:          true,
93	{types.Int16, types.Int}:            true,
94	{types.Int16, types.Int64}:          true,
95	{types.Int32, types.Int}:            true,
96	{types.Int32, types.Int64}:          true,
97	{types.Int, types.Int64}:            true,
98	{types.Float32, types.Float64}:      true,
99	{types.Complex64, types.Complex128}: true,
100}
101
102// Interface compatibility:
103// If the old interface has an unexported method, the new interface is compatible
104// if its exported method set is a superset of the old. (Users could not implement,
105// only embed.)
106//
107// If the old interface did not have an unexported method, the new interface is
108// compatible if its exported method set is the same as the old, and it has no
109// unexported methods. (Adding an unexported method makes the interface
110// unimplementable outside the package.)
111//
112// TODO: must also check that if any methods were added or removed, every exposed
113// type in the package that implemented the interface in old still implements it in
114// new. Otherwise external assignments could fail.
115func (d *differ) checkCompatibleInterface(otn *types.TypeName, old, new *types.Interface) {
116	// Method sets are checked in checkCompatibleDefined.
117
118	// Does the old interface have an unexported method?
119	if unexportedMethod(old) != nil {
120		d.checkMethodSet(otn, old, new, additionsCompatible)
121	} else {
122		// Perform an equivalence check, but with more information.
123		d.checkMethodSet(otn, old, new, additionsIncompatible)
124		if u := unexportedMethod(new); u != nil {
125			d.incompatible(otn, u.Name(), "added unexported method")
126		}
127	}
128}
129
130// Return an unexported method from the method set of t, or nil if there are none.
131func unexportedMethod(t *types.Interface) *types.Func {
132	for i := 0; i < t.NumMethods(); i++ {
133		if m := t.Method(i); !m.Exported() {
134			return m
135		}
136	}
137	return nil
138}
139
140// We need to check three things for structs:
141// 1. The set of exported fields must be compatible. This ensures that keyed struct
142//    literals continue to compile. (There is no compatibility guarantee for unkeyed
143//    struct literals.)
144// 2. The set of exported *selectable* fields must be compatible. This includes the exported
145//    fields of all embedded structs. This ensures that selections continue to compile.
146// 3. If the old struct is comparable, so must the new one be. This ensures that equality
147//    expressions and uses of struct values as map keys continue to compile.
148//
149// An unexported embedded struct can't appear in a struct literal outside the
150// package, so it doesn't have to be present, or have the same name, in the new
151// struct.
152//
153// Field tags are ignored: they have no compile-time implications.
154func (d *differ) checkCompatibleStruct(obj types.Object, old, new *types.Struct) {
155	d.checkCompatibleObjectSets(obj, exportedFields(old), exportedFields(new))
156	d.checkCompatibleObjectSets(obj, exportedSelectableFields(old), exportedSelectableFields(new))
157	// Removing comparability from a struct is an incompatible change.
158	if types.Comparable(old) && !types.Comparable(new) {
159		d.incompatible(obj, "", "old is comparable, new is not")
160	}
161}
162
163// exportedFields collects all the immediate fields of the struct that are exported.
164// This is also the set of exported keys for keyed struct literals.
165func exportedFields(s *types.Struct) map[string]types.Object {
166	m := map[string]types.Object{}
167	for i := 0; i < s.NumFields(); i++ {
168		f := s.Field(i)
169		if f.Exported() {
170			m[f.Name()] = f
171		}
172	}
173	return m
174}
175
176// exportedSelectableFields collects all the exported fields of the struct, including
177// exported fields of embedded structs.
178//
179// We traverse the struct breadth-first, because of the rule that a lower-depth field
180// shadows one at a higher depth.
181func exportedSelectableFields(s *types.Struct) map[string]types.Object {
182	var (
183		m    = map[string]types.Object{}
184		next []*types.Struct // embedded structs at the next depth
185		seen []*types.Struct // to handle recursive embedding
186	)
187	for cur := []*types.Struct{s}; len(cur) > 0; cur, next = next, nil {
188		seen = append(seen, cur...)
189		// We only want to consider unambiguous fields. Ambiguous fields (where there
190		// is more than one field of the same name at the same level) are legal, but
191		// cannot be selected.
192		for name, f := range unambiguousFields(cur) {
193			// Record an exported field we haven't seen before. If we have seen it,
194			// it occurred a lower depth, so it shadows this field.
195			if f.Exported() && m[name] == nil {
196				m[name] = f
197			}
198			// Remember embedded structs for processing at the next depth,
199			// but only if we haven't seen the struct at this depth or above.
200			if !f.Anonymous() {
201				continue
202			}
203			t := f.Type().Underlying()
204			if p, ok := t.(*types.Pointer); ok {
205				t = p.Elem().Underlying()
206			}
207			if t, ok := t.(*types.Struct); ok && !contains(seen, t) {
208				next = append(next, t)
209			}
210		}
211	}
212	return m
213}
214
215func contains(ts []*types.Struct, t *types.Struct) bool {
216	for _, s := range ts {
217		if types.Identical(s, t) {
218			return true
219		}
220	}
221	return false
222}
223
224// Given a set of structs at the same depth, the unambiguous fields are the ones whose
225// names appear exactly once.
226func unambiguousFields(structs []*types.Struct) map[string]*types.Var {
227	fields := map[string]*types.Var{}
228	seen := map[string]bool{}
229	for _, s := range structs {
230		for i := 0; i < s.NumFields(); i++ {
231			f := s.Field(i)
232			name := f.Name()
233			if seen[name] {
234				delete(fields, name)
235			} else {
236				seen[name] = true
237				fields[name] = f
238			}
239		}
240	}
241	return fields
242}
243
244// Anything removed or change from the old set is an incompatible change.
245// Anything added to the new set is a compatible change.
246func (d *differ) checkCompatibleObjectSets(obj types.Object, old, new map[string]types.Object) {
247	for name, oldo := range old {
248		newo := new[name]
249		if newo == nil {
250			d.incompatible(obj, name, "removed")
251		} else {
252			d.checkCorrespondence(obj, name, oldo.Type(), newo.Type())
253		}
254	}
255	for name := range new {
256		if old[name] == nil {
257			d.compatible(obj, name, "added")
258		}
259	}
260}
261
262func (d *differ) checkCompatibleDefined(otn *types.TypeName, old *types.Named, new types.Type) {
263	// We've already checked that old and new correspond.
264	d.checkCompatible(otn, old.Underlying(), new.Underlying())
265	// If there are different kinds of types (e.g. struct and interface), don't bother checking
266	// the method sets.
267	if reflect.TypeOf(old.Underlying()) != reflect.TypeOf(new.Underlying()) {
268		return
269	}
270	// Interface method sets are checked in checkCompatibleInterface.
271	if _, ok := old.Underlying().(*types.Interface); ok {
272		return
273	}
274
275	// A new method set is compatible with an old if the new exported methods are a superset of the old.
276	d.checkMethodSet(otn, old, new, additionsCompatible)
277	d.checkMethodSet(otn, types.NewPointer(old), types.NewPointer(new), additionsCompatible)
278}
279
280const (
281	additionsCompatible   = true
282	additionsIncompatible = false
283)
284
285func (d *differ) checkMethodSet(otn *types.TypeName, oldt, newt types.Type, addcompat bool) {
286	// TODO: find a way to use checkCompatibleObjectSets for this.
287	oldMethodSet := exportedMethods(oldt)
288	newMethodSet := exportedMethods(newt)
289	msname := otn.Name()
290	if _, ok := oldt.(*types.Pointer); ok {
291		msname = "*" + msname
292	}
293	for name, oldMethod := range oldMethodSet {
294		newMethod := newMethodSet[name]
295		if newMethod == nil {
296			var part string
297			// Due to embedding, it's possible that the method's receiver type is not
298			// the same as the defined type whose method set we're looking at. So for
299			// a type T with removed method M that is embedded in some other type U,
300			// we will generate two "removed" messages for T.M, one for its own type
301			// T and one for the embedded type U. We want both messages to appear,
302			// but the messageSet dedup logic will allow only one message for a given
303			// object. So use the part string to distinguish them.
304			if receiverNamedType(oldMethod).Obj() != otn {
305				part = fmt.Sprintf(", method set of %s", msname)
306			}
307			d.incompatible(oldMethod, part, "removed")
308		} else {
309			obj := oldMethod
310			// If a value method is changed to a pointer method and has a signature
311			// change, then we can get two messages for the same method definition: one
312			// for the value method set that says it's removed, and another for the
313			// pointer method set that says it changed. To keep both messages (since
314			// messageSet dedups), use newMethod for the second. (Slight hack.)
315			if !hasPointerReceiver(oldMethod) && hasPointerReceiver(newMethod) {
316				obj = newMethod
317			}
318			d.checkCorrespondence(obj, "", oldMethod.Type(), newMethod.Type())
319		}
320	}
321
322	// Check for added methods.
323	for name, newMethod := range newMethodSet {
324		if oldMethodSet[name] == nil {
325			if addcompat {
326				d.compatible(newMethod, "", "added")
327			} else {
328				d.incompatible(newMethod, "", "added")
329			}
330		}
331	}
332}
333
334// exportedMethods collects all the exported methods of type's method set.
335func exportedMethods(t types.Type) map[string]types.Object {
336	m := map[string]types.Object{}
337	ms := types.NewMethodSet(t)
338	for i := 0; i < ms.Len(); i++ {
339		obj := ms.At(i).Obj()
340		if obj.Exported() {
341			m[obj.Name()] = obj
342		}
343	}
344	return m
345}
346
347func receiverType(method types.Object) types.Type {
348	return method.Type().(*types.Signature).Recv().Type()
349}
350
351func receiverNamedType(method types.Object) *types.Named {
352	switch t := receiverType(method).(type) {
353	case *types.Pointer:
354		return t.Elem().(*types.Named)
355	case *types.Named:
356		return t
357	default:
358		panic("unreachable")
359	}
360}
361
362func hasPointerReceiver(method types.Object) bool {
363	_, ok := receiverType(method).(*types.Pointer)
364	return ok
365}
366