1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package walk
6
7import (
8	"go/constant"
9	"go/token"
10	"sort"
11
12	"cmd/compile/internal/base"
13	"cmd/compile/internal/ir"
14	"cmd/compile/internal/typecheck"
15	"cmd/compile/internal/types"
16	"cmd/internal/src"
17)
18
19// walkSwitch walks a switch statement.
20func walkSwitch(sw *ir.SwitchStmt) {
21	// Guard against double walk, see #25776.
22	if sw.Walked() {
23		return // Was fatal, but eliminating every possible source of double-walking is hard
24	}
25	sw.SetWalked(true)
26
27	if sw.Tag != nil && sw.Tag.Op() == ir.OTYPESW {
28		walkSwitchType(sw)
29	} else {
30		walkSwitchExpr(sw)
31	}
32}
33
34// walkSwitchExpr generates an AST implementing sw.  sw is an
35// expression switch.
36func walkSwitchExpr(sw *ir.SwitchStmt) {
37	lno := ir.SetPos(sw)
38
39	cond := sw.Tag
40	sw.Tag = nil
41
42	// convert switch {...} to switch true {...}
43	if cond == nil {
44		cond = ir.NewBool(true)
45		cond = typecheck.Expr(cond)
46		cond = typecheck.DefaultLit(cond, nil)
47	}
48
49	// Given "switch string(byteslice)",
50	// with all cases being side-effect free,
51	// use a zero-cost alias of the byte slice.
52	// Do this before calling walkExpr on cond,
53	// because walkExpr will lower the string
54	// conversion into a runtime call.
55	// See issue 24937 for more discussion.
56	if cond.Op() == ir.OBYTES2STR && allCaseExprsAreSideEffectFree(sw) {
57		cond := cond.(*ir.ConvExpr)
58		cond.SetOp(ir.OBYTES2STRTMP)
59	}
60
61	cond = walkExpr(cond, sw.PtrInit())
62	if cond.Op() != ir.OLITERAL && cond.Op() != ir.ONIL {
63		cond = copyExpr(cond, cond.Type(), &sw.Compiled)
64	}
65
66	base.Pos = lno
67
68	s := exprSwitch{
69		exprname: cond,
70	}
71
72	var defaultGoto ir.Node
73	var body ir.Nodes
74	for _, ncase := range sw.Cases {
75		label := typecheck.AutoLabel(".s")
76		jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, label)
77
78		// Process case dispatch.
79		if len(ncase.List) == 0 {
80			if defaultGoto != nil {
81				base.Fatalf("duplicate default case not detected during typechecking")
82			}
83			defaultGoto = jmp
84		}
85
86		for _, n1 := range ncase.List {
87			s.Add(ncase.Pos(), n1, jmp)
88		}
89
90		// Process body.
91		body.Append(ir.NewLabelStmt(ncase.Pos(), label))
92		body.Append(ncase.Body...)
93		if fall, pos := endsInFallthrough(ncase.Body); !fall {
94			br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
95			br.SetPos(pos)
96			body.Append(br)
97		}
98	}
99	sw.Cases = nil
100
101	if defaultGoto == nil {
102		br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
103		br.SetPos(br.Pos().WithNotStmt())
104		defaultGoto = br
105	}
106
107	s.Emit(&sw.Compiled)
108	sw.Compiled.Append(defaultGoto)
109	sw.Compiled.Append(body.Take()...)
110	walkStmtList(sw.Compiled)
111}
112
113// An exprSwitch walks an expression switch.
114type exprSwitch struct {
115	exprname ir.Node // value being switched on
116
117	done    ir.Nodes
118	clauses []exprClause
119}
120
121type exprClause struct {
122	pos    src.XPos
123	lo, hi ir.Node
124	jmp    ir.Node
125}
126
127func (s *exprSwitch) Add(pos src.XPos, expr, jmp ir.Node) {
128	c := exprClause{pos: pos, lo: expr, hi: expr, jmp: jmp}
129	if types.IsOrdered[s.exprname.Type().Kind()] && expr.Op() == ir.OLITERAL {
130		s.clauses = append(s.clauses, c)
131		return
132	}
133
134	s.flush()
135	s.clauses = append(s.clauses, c)
136	s.flush()
137}
138
139func (s *exprSwitch) Emit(out *ir.Nodes) {
140	s.flush()
141	out.Append(s.done.Take()...)
142}
143
144func (s *exprSwitch) flush() {
145	cc := s.clauses
146	s.clauses = nil
147	if len(cc) == 0 {
148		return
149	}
150
151	// Caution: If len(cc) == 1, then cc[0] might not an OLITERAL.
152	// The code below is structured to implicitly handle this case
153	// (e.g., sort.Slice doesn't need to invoke the less function
154	// when there's only a single slice element).
155
156	if s.exprname.Type().IsString() && len(cc) >= 2 {
157		// Sort strings by length and then by value. It is
158		// much cheaper to compare lengths than values, and
159		// all we need here is consistency. We respect this
160		// sorting below.
161		sort.Slice(cc, func(i, j int) bool {
162			si := ir.StringVal(cc[i].lo)
163			sj := ir.StringVal(cc[j].lo)
164			if len(si) != len(sj) {
165				return len(si) < len(sj)
166			}
167			return si < sj
168		})
169
170		// runLen returns the string length associated with a
171		// particular run of exprClauses.
172		runLen := func(run []exprClause) int64 { return int64(len(ir.StringVal(run[0].lo))) }
173
174		// Collapse runs of consecutive strings with the same length.
175		var runs [][]exprClause
176		start := 0
177		for i := 1; i < len(cc); i++ {
178			if runLen(cc[start:]) != runLen(cc[i:]) {
179				runs = append(runs, cc[start:i])
180				start = i
181			}
182		}
183		runs = append(runs, cc[start:])
184
185		// Perform two-level binary search.
186		binarySearch(len(runs), &s.done,
187			func(i int) ir.Node {
188				return ir.NewBinaryExpr(base.Pos, ir.OLE, ir.NewUnaryExpr(base.Pos, ir.OLEN, s.exprname), ir.NewInt(runLen(runs[i-1])))
189			},
190			func(i int, nif *ir.IfStmt) {
191				run := runs[i]
192				nif.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, ir.NewUnaryExpr(base.Pos, ir.OLEN, s.exprname), ir.NewInt(runLen(run)))
193				s.search(run, &nif.Body)
194			},
195		)
196		return
197	}
198
199	sort.Slice(cc, func(i, j int) bool {
200		return constant.Compare(cc[i].lo.Val(), token.LSS, cc[j].lo.Val())
201	})
202
203	// Merge consecutive integer cases.
204	if s.exprname.Type().IsInteger() {
205		consecutive := func(last, next constant.Value) bool {
206			delta := constant.BinaryOp(next, token.SUB, last)
207			return constant.Compare(delta, token.EQL, constant.MakeInt64(1))
208		}
209
210		merged := cc[:1]
211		for _, c := range cc[1:] {
212			last := &merged[len(merged)-1]
213			if last.jmp == c.jmp && consecutive(last.hi.Val(), c.lo.Val()) {
214				last.hi = c.lo
215			} else {
216				merged = append(merged, c)
217			}
218		}
219		cc = merged
220	}
221
222	s.search(cc, &s.done)
223}
224
225func (s *exprSwitch) search(cc []exprClause, out *ir.Nodes) {
226	binarySearch(len(cc), out,
227		func(i int) ir.Node {
228			return ir.NewBinaryExpr(base.Pos, ir.OLE, s.exprname, cc[i-1].hi)
229		},
230		func(i int, nif *ir.IfStmt) {
231			c := &cc[i]
232			nif.Cond = c.test(s.exprname)
233			nif.Body = []ir.Node{c.jmp}
234		},
235	)
236}
237
238func (c *exprClause) test(exprname ir.Node) ir.Node {
239	// Integer range.
240	if c.hi != c.lo {
241		low := ir.NewBinaryExpr(c.pos, ir.OGE, exprname, c.lo)
242		high := ir.NewBinaryExpr(c.pos, ir.OLE, exprname, c.hi)
243		return ir.NewLogicalExpr(c.pos, ir.OANDAND, low, high)
244	}
245
246	// Optimize "switch true { ...}" and "switch false { ... }".
247	if ir.IsConst(exprname, constant.Bool) && !c.lo.Type().IsInterface() {
248		if ir.BoolVal(exprname) {
249			return c.lo
250		} else {
251			return ir.NewUnaryExpr(c.pos, ir.ONOT, c.lo)
252		}
253	}
254
255	return ir.NewBinaryExpr(c.pos, ir.OEQ, exprname, c.lo)
256}
257
258func allCaseExprsAreSideEffectFree(sw *ir.SwitchStmt) bool {
259	// In theory, we could be more aggressive, allowing any
260	// side-effect-free expressions in cases, but it's a bit
261	// tricky because some of that information is unavailable due
262	// to the introduction of temporaries during order.
263	// Restricting to constants is simple and probably powerful
264	// enough.
265
266	for _, ncase := range sw.Cases {
267		for _, v := range ncase.List {
268			if v.Op() != ir.OLITERAL {
269				return false
270			}
271		}
272	}
273	return true
274}
275
276// endsInFallthrough reports whether stmts ends with a "fallthrough" statement.
277func endsInFallthrough(stmts []ir.Node) (bool, src.XPos) {
278	// Search backwards for the index of the fallthrough
279	// statement. Do not assume it'll be in the last
280	// position, since in some cases (e.g. when the statement
281	// list contains autotmp_ variables), one or more OVARKILL
282	// nodes will be at the end of the list.
283
284	i := len(stmts) - 1
285	for i >= 0 && stmts[i].Op() == ir.OVARKILL {
286		i--
287	}
288	if i < 0 {
289		return false, src.NoXPos
290	}
291	return stmts[i].Op() == ir.OFALL, stmts[i].Pos()
292}
293
294// walkSwitchType generates an AST that implements sw, where sw is a
295// type switch.
296func walkSwitchType(sw *ir.SwitchStmt) {
297	var s typeSwitch
298	s.facename = sw.Tag.(*ir.TypeSwitchGuard).X
299	sw.Tag = nil
300
301	s.facename = walkExpr(s.facename, sw.PtrInit())
302	s.facename = copyExpr(s.facename, s.facename.Type(), &sw.Compiled)
303	s.okname = typecheck.Temp(types.Types[types.TBOOL])
304
305	// Get interface descriptor word.
306	// For empty interfaces this will be the type.
307	// For non-empty interfaces this will be the itab.
308	itab := ir.NewUnaryExpr(base.Pos, ir.OITAB, s.facename)
309
310	// For empty interfaces, do:
311	//     if e._type == nil {
312	//         do nil case if it exists, otherwise default
313	//     }
314	//     h := e._type.hash
315	// Use a similar strategy for non-empty interfaces.
316	ifNil := ir.NewIfStmt(base.Pos, nil, nil, nil)
317	ifNil.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, itab, typecheck.NodNil())
318	base.Pos = base.Pos.WithNotStmt() // disable statement marks after the first check.
319	ifNil.Cond = typecheck.Expr(ifNil.Cond)
320	ifNil.Cond = typecheck.DefaultLit(ifNil.Cond, nil)
321	// ifNil.Nbody assigned at end.
322	sw.Compiled.Append(ifNil)
323
324	// Load hash from type or itab.
325	dotHash := typeHashFieldOf(base.Pos, itab)
326	s.hashname = copyExpr(dotHash, dotHash.Type(), &sw.Compiled)
327
328	br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
329	var defaultGoto, nilGoto ir.Node
330	var body ir.Nodes
331	for _, ncase := range sw.Cases {
332		caseVar := ncase.Var
333
334		// For single-type cases with an interface type,
335		// we initialize the case variable as part of the type assertion.
336		// In other cases, we initialize it in the body.
337		var singleType *types.Type
338		if len(ncase.List) == 1 && ncase.List[0].Op() == ir.OTYPE {
339			singleType = ncase.List[0].Type()
340		}
341		caseVarInitialized := false
342
343		label := typecheck.AutoLabel(".s")
344		jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, label)
345
346		if len(ncase.List) == 0 { // default:
347			if defaultGoto != nil {
348				base.Fatalf("duplicate default case not detected during typechecking")
349			}
350			defaultGoto = jmp
351		}
352
353		for _, n1 := range ncase.List {
354			if ir.IsNil(n1) { // case nil:
355				if nilGoto != nil {
356					base.Fatalf("duplicate nil case not detected during typechecking")
357				}
358				nilGoto = jmp
359				continue
360			}
361
362			if singleType != nil && singleType.IsInterface() {
363				s.Add(ncase.Pos(), n1, caseVar, jmp)
364				caseVarInitialized = true
365			} else {
366				s.Add(ncase.Pos(), n1, nil, jmp)
367			}
368		}
369
370		body.Append(ir.NewLabelStmt(ncase.Pos(), label))
371		if caseVar != nil && !caseVarInitialized {
372			val := s.facename
373			if singleType != nil {
374				// We have a single concrete type. Extract the data.
375				if singleType.IsInterface() {
376					base.Fatalf("singleType interface should have been handled in Add")
377				}
378				val = ifaceData(ncase.Pos(), s.facename, singleType)
379			}
380			if len(ncase.List) == 1 && ncase.List[0].Op() == ir.ODYNAMICTYPE {
381				dt := ncase.List[0].(*ir.DynamicType)
382				x := ir.NewDynamicTypeAssertExpr(ncase.Pos(), ir.ODYNAMICDOTTYPE, val, dt.X)
383				if dt.ITab != nil {
384					// TODO: make ITab a separate field in DynamicTypeAssertExpr?
385					x.T = dt.ITab
386				}
387				x.SetType(caseVar.Type())
388				x.SetTypecheck(1)
389				val = x
390			}
391			l := []ir.Node{
392				ir.NewDecl(ncase.Pos(), ir.ODCL, caseVar),
393				ir.NewAssignStmt(ncase.Pos(), caseVar, val),
394			}
395			typecheck.Stmts(l)
396			body.Append(l...)
397		}
398		body.Append(ncase.Body...)
399		body.Append(br)
400	}
401	sw.Cases = nil
402
403	if defaultGoto == nil {
404		defaultGoto = br
405	}
406	if nilGoto == nil {
407		nilGoto = defaultGoto
408	}
409	ifNil.Body = []ir.Node{nilGoto}
410
411	s.Emit(&sw.Compiled)
412	sw.Compiled.Append(defaultGoto)
413	sw.Compiled.Append(body.Take()...)
414
415	walkStmtList(sw.Compiled)
416}
417
418// typeHashFieldOf returns an expression to select the type hash field
419// from an interface's descriptor word (whether a *runtime._type or
420// *runtime.itab pointer).
421func typeHashFieldOf(pos src.XPos, itab *ir.UnaryExpr) *ir.SelectorExpr {
422	if itab.Op() != ir.OITAB {
423		base.Fatalf("expected OITAB, got %v", itab.Op())
424	}
425	var hashField *types.Field
426	if itab.X.Type().IsEmptyInterface() {
427		// runtime._type's hash field
428		if rtypeHashField == nil {
429			rtypeHashField = runtimeField("hash", int64(2*types.PtrSize), types.Types[types.TUINT32])
430		}
431		hashField = rtypeHashField
432	} else {
433		// runtime.itab's hash field
434		if itabHashField == nil {
435			itabHashField = runtimeField("hash", int64(2*types.PtrSize), types.Types[types.TUINT32])
436		}
437		hashField = itabHashField
438	}
439	return boundedDotPtr(pos, itab, hashField)
440}
441
442var rtypeHashField, itabHashField *types.Field
443
444// A typeSwitch walks a type switch.
445type typeSwitch struct {
446	// Temporary variables (i.e., ONAMEs) used by type switch dispatch logic:
447	facename ir.Node // value being type-switched on
448	hashname ir.Node // type hash of the value being type-switched on
449	okname   ir.Node // boolean used for comma-ok type assertions
450
451	done    ir.Nodes
452	clauses []typeClause
453}
454
455type typeClause struct {
456	hash uint32
457	body ir.Nodes
458}
459
460func (s *typeSwitch) Add(pos src.XPos, n1 ir.Node, caseVar *ir.Name, jmp ir.Node) {
461	typ := n1.Type()
462	var body ir.Nodes
463	if caseVar != nil {
464		l := []ir.Node{
465			ir.NewDecl(pos, ir.ODCL, caseVar),
466			ir.NewAssignStmt(pos, caseVar, nil),
467		}
468		typecheck.Stmts(l)
469		body.Append(l...)
470	} else {
471		caseVar = ir.BlankNode.(*ir.Name)
472	}
473
474	// cv, ok = iface.(type)
475	as := ir.NewAssignListStmt(pos, ir.OAS2, nil, nil)
476	as.Lhs = []ir.Node{caseVar, s.okname} // cv, ok =
477	switch n1.Op() {
478	case ir.OTYPE:
479		// Static type assertion (non-generic)
480		dot := ir.NewTypeAssertExpr(pos, s.facename, nil)
481		dot.SetType(typ) // iface.(type)
482		as.Rhs = []ir.Node{dot}
483	case ir.ODYNAMICTYPE:
484		// Dynamic type assertion (generic)
485		dt := n1.(*ir.DynamicType)
486		dot := ir.NewDynamicTypeAssertExpr(pos, ir.ODYNAMICDOTTYPE, s.facename, dt.X)
487		if dt.ITab != nil {
488			dot.T = dt.ITab
489		}
490		dot.SetType(typ)
491		dot.SetTypecheck(1)
492		as.Rhs = []ir.Node{dot}
493	default:
494		base.Fatalf("unhandled type case %s", n1.Op())
495	}
496	appendWalkStmt(&body, as)
497
498	// if ok { goto label }
499	nif := ir.NewIfStmt(pos, nil, nil, nil)
500	nif.Cond = s.okname
501	nif.Body = []ir.Node{jmp}
502	body.Append(nif)
503
504	if n1.Op() == ir.OTYPE && !typ.IsInterface() {
505		// Defer static, noninterface cases so they can be binary searched by hash.
506		s.clauses = append(s.clauses, typeClause{
507			hash: types.TypeHash(n1.Type()),
508			body: body,
509		})
510		return
511	}
512
513	s.flush()
514	s.done.Append(body.Take()...)
515}
516
517func (s *typeSwitch) Emit(out *ir.Nodes) {
518	s.flush()
519	out.Append(s.done.Take()...)
520}
521
522func (s *typeSwitch) flush() {
523	cc := s.clauses
524	s.clauses = nil
525	if len(cc) == 0 {
526		return
527	}
528
529	sort.Slice(cc, func(i, j int) bool { return cc[i].hash < cc[j].hash })
530
531	// Combine adjacent cases with the same hash.
532	merged := cc[:1]
533	for _, c := range cc[1:] {
534		last := &merged[len(merged)-1]
535		if last.hash == c.hash {
536			last.body.Append(c.body.Take()...)
537		} else {
538			merged = append(merged, c)
539		}
540	}
541	cc = merged
542
543	binarySearch(len(cc), &s.done,
544		func(i int) ir.Node {
545			return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashname, ir.NewInt(int64(cc[i-1].hash)))
546		},
547		func(i int, nif *ir.IfStmt) {
548			// TODO(mdempsky): Omit hash equality check if
549			// there's only one type.
550			c := cc[i]
551			nif.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, s.hashname, ir.NewInt(int64(c.hash)))
552			nif.Body.Append(c.body.Take()...)
553		},
554	)
555}
556
557// binarySearch constructs a binary search tree for handling n cases,
558// and appends it to out. It's used for efficiently implementing
559// switch statements.
560//
561// less(i) should return a boolean expression. If it evaluates true,
562// then cases before i will be tested; otherwise, cases i and later.
563//
564// leaf(i, nif) should setup nif (an OIF node) to test case i. In
565// particular, it should set nif.Left and nif.Nbody.
566func binarySearch(n int, out *ir.Nodes, less func(i int) ir.Node, leaf func(i int, nif *ir.IfStmt)) {
567	const binarySearchMin = 4 // minimum number of cases for binary search
568
569	var do func(lo, hi int, out *ir.Nodes)
570	do = func(lo, hi int, out *ir.Nodes) {
571		n := hi - lo
572		if n < binarySearchMin {
573			for i := lo; i < hi; i++ {
574				nif := ir.NewIfStmt(base.Pos, nil, nil, nil)
575				leaf(i, nif)
576				base.Pos = base.Pos.WithNotStmt()
577				nif.Cond = typecheck.Expr(nif.Cond)
578				nif.Cond = typecheck.DefaultLit(nif.Cond, nil)
579				out.Append(nif)
580				out = &nif.Else
581			}
582			return
583		}
584
585		half := lo + n/2
586		nif := ir.NewIfStmt(base.Pos, nil, nil, nil)
587		nif.Cond = less(half)
588		base.Pos = base.Pos.WithNotStmt()
589		nif.Cond = typecheck.Expr(nif.Cond)
590		nif.Cond = typecheck.DefaultLit(nif.Cond, nil)
591		do(lo, half, &nif.Body)
592		do(half, hi, &nif.Else)
593		out.Append(nif)
594	}
595
596	do(0, n, out)
597}
598