1/*
2 * gomacro - A Go interpreter with Lisp-like macros
3 *
4 * Copyright (C) 2017-2019 Massimiliano Ghilardi
5 *
6 *     This Source Code Form is subject to the terms of the Mozilla Public
7 *     License, v. 2.0. If a copy of the MPL was not distributed with this
8 *     file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 *
10 *
11 * select.go
12 *
13 *  Created on Jun 05, 2017
14 *      Author Massimiliano Ghilardi
15 */
16
17package fast
18
19import (
20	"go/ast"
21	"go/token"
22	r "reflect"
23	"sort"
24)
25
26type selectEntry struct {
27	Dir  r.SelectDir
28	Chan func(*Env) r.Value
29	Send func(*Env) r.Value
30}
31
32func (c *Comp) Select(node *ast.SelectStmt, labels []string) {
33	if node.Body == nil || len(node.Body.List) == 0 {
34		return
35	}
36	sort.Strings(labels)
37
38	// unnamed bind, contains received value. Nil means nothing received
39	// note: containLocalBinds knows we create a local bind,
40	// and returns true if it encounters a non-empty SelectStmt
41	bindrecv := c.NewBind("", VarBind, c.TypeOfInterface())
42	idxrecv := bindrecv.Desc.Index()
43
44	list := node.Body.List
45	n := len(list)
46	entries := make([]selectEntry, n)
47	ips := make([]int, n)
48	defaultip := -1
49	defaultpos := token.NoPos
50
51	// restore current Comp.Loop before returning
52	defer func(loop *LoopInfo) {
53		c.Loop = loop
54	}(c.Loop)
55	c.Loop = &LoopInfo{
56		Break:      new(int),
57		ThisLabels: labels,
58	}
59
60	c.append(func(env *Env) (Stmt, *Env) {
61		cases := make([]r.SelectCase, len(entries))
62		for i := range entries {
63			c := &cases[i]
64			e := &entries[i]
65			c.Dir = e.Dir
66			if e.Chan != nil {
67				c.Chan = e.Chan(env)
68				if e.Send != nil {
69					c.Send = e.Send(env)
70				}
71			}
72		}
73		chosen, recv, _ := r.Select(cases)
74		env.Vals[idxrecv] = recv
75		ip := ips[chosen]
76		env.IP = ip
77		return env.Code[ip], env
78	})
79
80	for i, stmt := range list {
81		ips[i] = c.Code.Len()
82		switch clause := stmt.(type) {
83		case *ast.CommClause:
84			if clause.Comm == nil {
85				if defaultip >= 0 {
86					c.Errorf("multiple defaults in select (first at %s)", c.Fileset.Position(defaultpos))
87				}
88				defaultip = c.Code.Len()
89				defaultpos = clause.Pos()
90				entries[i] = c.selectDefault(clause)
91			} else {
92				entries[i] = c.selectCase(clause, bindrecv)
93			}
94		default:
95			c.Errorf("invalid statement inside select: expecting case or default, found: %v <%v>", stmt, r.TypeOf(stmt))
96		}
97	}
98	// we finally know this
99	*c.Loop.Break = c.Code.Len()
100}
101
102// selectDefault compiles the default case in a switch
103func (c *Comp) selectDefault(node *ast.CommClause) selectEntry {
104	if len(node.Body) != 0 {
105		c.List(node.Body)
106	}
107	c.jumpOut(0, c.Loop.Break)
108	return selectEntry{Dir: r.SelectDefault}
109}
110
111// selectCase compiles a case in a select.
112func (c *Comp) selectCase(clause *ast.CommClause, bind *Bind) selectEntry {
113
114	var entry selectEntry
115	var nbind [2]int
116	stmt := clause.Comm
117	c2 := c
118	locals := false
119
120	switch node := stmt.(type) {
121	case *ast.ExprStmt:
122		// <-ch
123		entry = selectEntry{
124			Dir:  r.SelectRecv,
125			Chan: c.selectRecv(stmt, node.X).AsX1(),
126		}
127	case *ast.AssignStmt:
128		// v := <-ch or v = <-ch
129		lhs := node.Lhs
130		n := len(lhs)
131		if (n != 1 && n != 2) || len(node.Rhs) != 1 {
132			c.badSelectCase(stmt)
133		}
134		var l0, l1 ast.Expr = lhs[0], nil
135		if n == 2 {
136			l1 = lhs[1]
137		}
138		r0 := node.Rhs[0]
139		switch node.Tok {
140		case token.DEFINE:
141			id0 := asIdent(l0)
142			id1 := asIdent(l1)
143			if (id0 == nil && l0 != nil) || (id1 == nil && l1 != nil) {
144				c.badSelectCase(stmt)
145			}
146			echan := c.selectRecv(node, r0)
147			entry = selectEntry{Dir: r.SelectRecv, Chan: echan.AsX1()}
148
149			if id0 != nil && id0.Name != "_" || id1 != nil && id1.Name != "_" {
150				c2, locals = c.pushEnvIfFlag(&nbind, true)
151
152				if id0 != nil && id0.Name != "_" {
153					t := echan.Type.Elem()
154					c2.DeclVar0(id0.Name, t, unwrapBindUp1(bind, t))
155				}
156				if id1 != nil && id1.Name != "_" {
157					idx := bind.Desc.Index()
158					c2.DeclVar0(id1.Name, c.TypeOfBool(), c.exprBool(func(env *Env) bool {
159						return env.Outer.Vals[idx].IsValid()
160					}))
161				}
162			} else if len(clause.Body) != 0 {
163				c2, locals = c.pushEnvIfLocalBinds(&nbind, clause.Body...)
164			}
165
166		case token.ASSIGN:
167			echan := c.selectRecv(stmt, r0)
168			entry = selectEntry{Dir: r.SelectRecv, Chan: echan.AsX1()}
169
170			if l0 != nil {
171				place := c.Place(l0)
172				t := echan.Type.Elem()
173				tplace := place.Type
174				if !t.AssignableTo(tplace) {
175					c.Errorf("cannot use <%v> as <%v> in assignment: %v = %v", t, tplace, l0, r0)
176				}
177				c.SetPlace(place, token.ASSIGN, unwrapBind(bind, t))
178			}
179			if l1 != nil {
180				place := c.Place(l1)
181				t := c.TypeOfBool()
182				tplace := place.Type
183				if !t.AssignableTo(tplace) {
184					c.Errorf("cannot use <%v> as <%v> in assignment: _, %v = %v", t, tplace, l1, r0)
185				}
186				idx := bind.Desc.Index()
187				c.SetPlace(place, token.ASSIGN, c.exprBool(func(env *Env) bool {
188					return env.Vals[idx].IsValid()
189				}))
190			}
191
192			if len(clause.Body) != 0 {
193				c2, locals = c.pushEnvIfLocalBinds(&nbind, clause.Body...)
194			}
195		}
196
197	case *ast.SendStmt:
198		// ch <- v
199		echan := c.Expr1(node.Chan, nil)
200		if echan.Type.Kind() != r.Chan {
201			c.Errorf("cannot use %v <%v> as channel in select case", node, echan.Type)
202		}
203		esend := c.Expr1(node.Value, nil)
204		tactual := esend.Type
205		texpected := echan.Type.Elem()
206		if !tactual.AssignableTo(texpected) {
207			c.Errorf("cannot use %v <%v> as <%v> in channel send", node.Value, tactual, texpected)
208		}
209		entry = selectEntry{Dir: r.SelectSend, Chan: echan.AsX1(), Send: esend.AsX1()}
210
211	default:
212		c.badSelectCase(stmt)
213	}
214
215	if len(clause.Body) != 0 {
216		c2.List(clause.Body)
217	}
218	if c2 != c {
219		c2.popEnvIfFlag(&nbind, locals)
220	}
221	c.jumpOut(0, c.Loop.Break)
222	return entry
223}
224
225func (c *Comp) selectRecv(stmt ast.Stmt, node ast.Expr) *Expr {
226	for {
227		switch expr := node.(type) {
228		case *ast.ParenExpr:
229			node = expr.X
230			continue
231		case *ast.UnaryExpr:
232			if expr.Op == token.ARROW {
233				e := c.Expr1(expr.X, nil)
234				if e.Type.Kind() != r.Chan {
235					c.Errorf("cannot use %v <%v> as channel in select case", node, e.Type)
236				}
237				return e
238			}
239		}
240		c.badSelectCase(stmt)
241		return nil
242	}
243}
244
245func (c *Comp) badSelectCase(stmt ast.Stmt) {
246	c.Errorf("invalid select case, expecting [ch <- val] or [<-ch] or [vars := <-ch] or [places = <-ch], found: %v <%v>",
247		stmt, r.TypeOf(stmt))
248}
249