1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// This file implements type unification.
6
7package types2
8
9import (
10	"bytes"
11	"fmt"
12)
13
14// The unifier maintains two separate sets of type parameters x and y
15// which are used to resolve type parameters in the x and y arguments
16// provided to the unify call. For unidirectional unification, only
17// one of these sets (say x) is provided, and then type parameters are
18// only resolved for the x argument passed to unify, not the y argument
19// (even if that also contains possibly the same type parameters). This
20// is crucial to infer the type parameters of self-recursive calls:
21//
22//	func f[P any](a P) { f(a) }
23//
24// For the call f(a) we want to infer that the type argument for P is P.
25// During unification, the parameter type P must be resolved to the type
26// parameter P ("x" side), but the argument type P must be left alone so
27// that unification resolves the type parameter P to P.
28//
29// For bidirection unification, both sets are provided. This enables
30// unification to go from argument to parameter type and vice versa.
31// For constraint type inference, we use bidirectional unification
32// where both the x and y type parameters are identical. This is done
33// by setting up one of them (using init) and then assigning its value
34// to the other.
35
36// A unifier maintains the current type parameters for x and y
37// and the respective types inferred for each type parameter.
38// A unifier is created by calling newUnifier.
39type unifier struct {
40	exact bool
41	x, y  tparamsList // x and y must initialized via tparamsList.init
42	types []Type      // inferred types, shared by x and y
43}
44
45// newUnifier returns a new unifier.
46// If exact is set, unification requires unified types to match
47// exactly. If exact is not set, a named type's underlying type
48// is considered if unification would fail otherwise, and the
49// direction of channels is ignored.
50func newUnifier(exact bool) *unifier {
51	u := &unifier{exact: exact}
52	u.x.unifier = u
53	u.y.unifier = u
54	return u
55}
56
57// unify attempts to unify x and y and reports whether it succeeded.
58func (u *unifier) unify(x, y Type) bool {
59	return u.nify(x, y, nil)
60}
61
62// A tparamsList describes a list of type parameters and the types inferred for them.
63type tparamsList struct {
64	unifier *unifier
65	tparams []*TypeParam
66	// For each tparams element, there is a corresponding type slot index in indices.
67	// index  < 0: unifier.types[-index-1] == nil
68	// index == 0: no type slot allocated yet
69	// index  > 0: unifier.types[index-1] == typ
70	// Joined tparams elements share the same type slot and thus have the same index.
71	// By using a negative index for nil types we don't need to check unifier.types
72	// to see if we have a type or not.
73	indices []int // len(d.indices) == len(d.tparams)
74}
75
76// String returns a string representation for a tparamsList. For debugging.
77func (d *tparamsList) String() string {
78	var buf bytes.Buffer
79	w := newTypeWriter(&buf, nil)
80	w.byte('[')
81	for i, tpar := range d.tparams {
82		if i > 0 {
83			w.string(", ")
84		}
85		w.typ(tpar)
86		w.string(": ")
87		w.typ(d.at(i))
88	}
89	w.byte(']')
90	return buf.String()
91}
92
93// init initializes d with the given type parameters.
94// The type parameters must be in the order in which they appear in their declaration
95// (this ensures that the tparams indices match the respective type parameter index).
96func (d *tparamsList) init(tparams []*TypeParam) {
97	if len(tparams) == 0 {
98		return
99	}
100	if debug {
101		for i, tpar := range tparams {
102			assert(i == tpar.index)
103		}
104	}
105	d.tparams = tparams
106	d.indices = make([]int, len(tparams))
107}
108
109// join unifies the i'th type parameter of x with the j'th type parameter of y.
110// If both type parameters already have a type associated with them and they are
111// not joined, join fails and returns false.
112func (u *unifier) join(i, j int) bool {
113	ti := u.x.indices[i]
114	tj := u.y.indices[j]
115	switch {
116	case ti == 0 && tj == 0:
117		// Neither type parameter has a type slot associated with them.
118		// Allocate a new joined nil type slot (negative index).
119		u.types = append(u.types, nil)
120		u.x.indices[i] = -len(u.types)
121		u.y.indices[j] = -len(u.types)
122	case ti == 0:
123		// The type parameter for x has no type slot yet. Use slot of y.
124		u.x.indices[i] = tj
125	case tj == 0:
126		// The type parameter for y has no type slot yet. Use slot of x.
127		u.y.indices[j] = ti
128
129	// Both type parameters have a slot: ti != 0 && tj != 0.
130	case ti == tj:
131		// Both type parameters already share the same slot. Nothing to do.
132		break
133	case ti > 0 && tj > 0:
134		// Both type parameters have (possibly different) inferred types. Cannot join.
135		// TODO(gri) Should we check if types are identical? Investigate.
136		return false
137	case ti > 0:
138		// Only the type parameter for x has an inferred type. Use x slot for y.
139		u.y.setIndex(j, ti)
140	// This case is handled like the default case.
141	// case tj > 0:
142	// 	// Only the type parameter for y has an inferred type. Use y slot for x.
143	// 	u.x.setIndex(i, tj)
144	default:
145		// Neither type parameter has an inferred type. Use y slot for x
146		// (or x slot for y, it doesn't matter).
147		u.x.setIndex(i, tj)
148	}
149	return true
150}
151
152// If typ is a type parameter of d, index returns the type parameter index.
153// Otherwise, the result is < 0.
154func (d *tparamsList) index(typ Type) int {
155	if tpar, ok := typ.(*TypeParam); ok {
156		return tparamIndex(d.tparams, tpar)
157	}
158	return -1
159}
160
161// If tpar is a type parameter in list, tparamIndex returns the type parameter index.
162// Otherwise, the result is < 0. tpar must not be nil.
163func tparamIndex(list []*TypeParam, tpar *TypeParam) int {
164	// Once a type parameter is bound its index is >= 0. However, there are some
165	// code paths (namely tracing and type hashing) by which it is possible to
166	// arrive here with a type parameter that has not been bound, hence the check
167	// for 0 <= i below.
168	// TODO(rfindley): investigate a better approach for guarding against using
169	// unbound type parameters.
170	if i := tpar.index; 0 <= i && i < len(list) && list[i] == tpar {
171		return i
172	}
173	return -1
174}
175
176// setIndex sets the type slot index for the i'th type parameter
177// (and all its joined parameters) to tj. The type parameter
178// must have a (possibly nil) type slot associated with it.
179func (d *tparamsList) setIndex(i, tj int) {
180	ti := d.indices[i]
181	assert(ti != 0 && tj != 0)
182	for k, tk := range d.indices {
183		if tk == ti {
184			d.indices[k] = tj
185		}
186	}
187}
188
189// at returns the type set for the i'th type parameter; or nil.
190func (d *tparamsList) at(i int) Type {
191	if ti := d.indices[i]; ti > 0 {
192		return d.unifier.types[ti-1]
193	}
194	return nil
195}
196
197// set sets the type typ for the i'th type parameter;
198// typ must not be nil and it must not have been set before.
199func (d *tparamsList) set(i int, typ Type) {
200	assert(typ != nil)
201	u := d.unifier
202	switch ti := d.indices[i]; {
203	case ti < 0:
204		u.types[-ti-1] = typ
205		d.setIndex(i, -ti)
206	case ti == 0:
207		u.types = append(u.types, typ)
208		d.indices[i] = len(u.types)
209	default:
210		panic("type already set")
211	}
212}
213
214// types returns the list of inferred types (via unification) for the type parameters
215// described by d, and an index. If all types were inferred, the returned index is < 0.
216// Otherwise, it is the index of the first type parameter which couldn't be inferred;
217// i.e., for which list[index] is nil.
218func (d *tparamsList) types() (list []Type, index int) {
219	list = make([]Type, len(d.tparams))
220	index = -1
221	for i := range d.tparams {
222		t := d.at(i)
223		list[i] = t
224		if index < 0 && t == nil {
225			index = i
226		}
227	}
228	return
229}
230
231func (u *unifier) nifyEq(x, y Type, p *ifacePair) bool {
232	return x == y || u.nify(x, y, p)
233}
234
235// nify implements the core unification algorithm which is an
236// adapted version of Checker.identical. For changes to that
237// code the corresponding changes should be made here.
238// Must not be called directly from outside the unifier.
239func (u *unifier) nify(x, y Type, p *ifacePair) bool {
240	if !u.exact {
241		// If exact unification is known to fail because we attempt to
242		// match a type name against an unnamed type literal, consider
243		// the underlying type of the named type.
244		// (We use !hasName to exclude any type with a name, including
245		// basic types and type parameters; the rest are unamed types.)
246		if nx, _ := x.(*Named); nx != nil && !hasName(y) {
247			return u.nify(nx.under(), y, p)
248		} else if ny, _ := y.(*Named); ny != nil && !hasName(x) {
249			return u.nify(x, ny.under(), p)
250		}
251	}
252
253	// Cases where at least one of x or y is a type parameter.
254	switch i, j := u.x.index(x), u.y.index(y); {
255	case i >= 0 && j >= 0:
256		// both x and y are type parameters
257		if u.join(i, j) {
258			return true
259		}
260		// both x and y have an inferred type - they must match
261		return u.nifyEq(u.x.at(i), u.y.at(j), p)
262
263	case i >= 0:
264		// x is a type parameter, y is not
265		if tx := u.x.at(i); tx != nil {
266			return u.nifyEq(tx, y, p)
267		}
268		// otherwise, infer type from y
269		u.x.set(i, y)
270		return true
271
272	case j >= 0:
273		// y is a type parameter, x is not
274		if ty := u.y.at(j); ty != nil {
275			return u.nifyEq(x, ty, p)
276		}
277		// otherwise, infer type from x
278		u.y.set(j, x)
279		return true
280	}
281
282	// For type unification, do not shortcut (x == y) for identical
283	// types. Instead keep comparing them element-wise to unify the
284	// matching (and equal type parameter types). A simple test case
285	// where this matters is: func f[P any](a P) { f(a) } .
286
287	switch x := x.(type) {
288	case *Basic:
289		// Basic types are singletons except for the rune and byte
290		// aliases, thus we cannot solely rely on the x == y check
291		// above. See also comment in TypeName.IsAlias.
292		if y, ok := y.(*Basic); ok {
293			return x.kind == y.kind
294		}
295
296	case *Array:
297		// Two array types are identical if they have identical element types
298		// and the same array length.
299		if y, ok := y.(*Array); ok {
300			// If one or both array lengths are unknown (< 0) due to some error,
301			// assume they are the same to avoid spurious follow-on errors.
302			return (x.len < 0 || y.len < 0 || x.len == y.len) && u.nify(x.elem, y.elem, p)
303		}
304
305	case *Slice:
306		// Two slice types are identical if they have identical element types.
307		if y, ok := y.(*Slice); ok {
308			return u.nify(x.elem, y.elem, p)
309		}
310
311	case *Struct:
312		// Two struct types are identical if they have the same sequence of fields,
313		// and if corresponding fields have the same names, and identical types,
314		// and identical tags. Two embedded fields are considered to have the same
315		// name. Lower-case field names from different packages are always different.
316		if y, ok := y.(*Struct); ok {
317			if x.NumFields() == y.NumFields() {
318				for i, f := range x.fields {
319					g := y.fields[i]
320					if f.embedded != g.embedded ||
321						x.Tag(i) != y.Tag(i) ||
322						!f.sameId(g.pkg, g.name) ||
323						!u.nify(f.typ, g.typ, p) {
324						return false
325					}
326				}
327				return true
328			}
329		}
330
331	case *Pointer:
332		// Two pointer types are identical if they have identical base types.
333		if y, ok := y.(*Pointer); ok {
334			return u.nify(x.base, y.base, p)
335		}
336
337	case *Tuple:
338		// Two tuples types are identical if they have the same number of elements
339		// and corresponding elements have identical types.
340		if y, ok := y.(*Tuple); ok {
341			if x.Len() == y.Len() {
342				if x != nil {
343					for i, v := range x.vars {
344						w := y.vars[i]
345						if !u.nify(v.typ, w.typ, p) {
346							return false
347						}
348					}
349				}
350				return true
351			}
352		}
353
354	case *Signature:
355		// Two function types are identical if they have the same number of parameters
356		// and result values, corresponding parameter and result types are identical,
357		// and either both functions are variadic or neither is. Parameter and result
358		// names are not required to match.
359		// TODO(gri) handle type parameters or document why we can ignore them.
360		if y, ok := y.(*Signature); ok {
361			return x.variadic == y.variadic &&
362				u.nify(x.params, y.params, p) &&
363				u.nify(x.results, y.results, p)
364		}
365
366	case *Interface:
367		// Two interface types are identical if they have the same set of methods with
368		// the same names and identical function types. Lower-case method names from
369		// different packages are always different. The order of the methods is irrelevant.
370		if y, ok := y.(*Interface); ok {
371			xset := x.typeSet()
372			yset := y.typeSet()
373			if !xset.terms.equal(yset.terms) {
374				return false
375			}
376			a := xset.methods
377			b := yset.methods
378			if len(a) == len(b) {
379				// Interface types are the only types where cycles can occur
380				// that are not "terminated" via named types; and such cycles
381				// can only be created via method parameter types that are
382				// anonymous interfaces (directly or indirectly) embedding
383				// the current interface. Example:
384				//
385				//    type T interface {
386				//        m() interface{T}
387				//    }
388				//
389				// If two such (differently named) interfaces are compared,
390				// endless recursion occurs if the cycle is not detected.
391				//
392				// If x and y were compared before, they must be equal
393				// (if they were not, the recursion would have stopped);
394				// search the ifacePair stack for the same pair.
395				//
396				// This is a quadratic algorithm, but in practice these stacks
397				// are extremely short (bounded by the nesting depth of interface
398				// type declarations that recur via parameter types, an extremely
399				// rare occurrence). An alternative implementation might use a
400				// "visited" map, but that is probably less efficient overall.
401				q := &ifacePair{x, y, p}
402				for p != nil {
403					if p.identical(q) {
404						return true // same pair was compared before
405					}
406					p = p.prev
407				}
408				if debug {
409					assertSortedMethods(a)
410					assertSortedMethods(b)
411				}
412				for i, f := range a {
413					g := b[i]
414					if f.Id() != g.Id() || !u.nify(f.typ, g.typ, q) {
415						return false
416					}
417				}
418				return true
419			}
420		}
421
422	case *Map:
423		// Two map types are identical if they have identical key and value types.
424		if y, ok := y.(*Map); ok {
425			return u.nify(x.key, y.key, p) && u.nify(x.elem, y.elem, p)
426		}
427
428	case *Chan:
429		// Two channel types are identical if they have identical value types.
430		if y, ok := y.(*Chan); ok {
431			return (!u.exact || x.dir == y.dir) && u.nify(x.elem, y.elem, p)
432		}
433
434	case *Named:
435		// TODO(gri) This code differs now from the parallel code in Checker.identical. Investigate.
436		if y, ok := y.(*Named); ok {
437			xargs := x.targs.list()
438			yargs := y.targs.list()
439
440			// TODO(gri) This is not always correct: two types may have the same names
441			//           in the same package if one of them is nested in a function.
442			//           Extremely unlikely but we need an always correct solution.
443			if x.obj.pkg == y.obj.pkg && x.obj.name == y.obj.name {
444				assert(len(xargs) == len(yargs))
445				for i, x := range xargs {
446					if !u.nify(x, yargs[i], p) {
447						return false
448					}
449				}
450				return true
451			}
452		}
453
454	case *TypeParam:
455		// Two type parameters (which are not part of the type parameters of the
456		// enclosing type as those are handled in the beginning of this function)
457		// are identical if they originate in the same declaration.
458		return x == y
459
460	case nil:
461		// avoid a crash in case of nil type
462
463	default:
464		panic(fmt.Sprintf("### u.nify(%s, %s), u.x.tparams = %s", x, y, u.x.tparams))
465	}
466
467	return false
468}
469