1package lua
2
3import (
4	"reflect"
5)
6
7func checkChannel(L *LState, idx int) reflect.Value {
8	ch := L.CheckChannel(idx)
9	return reflect.ValueOf(ch)
10}
11
12func checkGoroutineSafe(L *LState, idx int) LValue {
13	v := L.CheckAny(2)
14	if !isGoroutineSafe(v) {
15		L.ArgError(2, "can not send a function, userdata, thread or table that has a metatable")
16	}
17	return v
18}
19
20func OpenChannel(L *LState) int {
21	var mod LValue
22	//_, ok := L.G.builtinMts[int(LTChannel)]
23	//	if !ok {
24	mod = L.RegisterModule(ChannelLibName, channelFuncs)
25	mt := L.SetFuncs(L.NewTable(), channelMethods)
26	mt.RawSetString("__index", mt)
27	L.G.builtinMts[int(LTChannel)] = mt
28	//	}
29	L.Push(mod)
30	return 1
31}
32
33var channelFuncs = map[string]LGFunction{
34	"make":   channelMake,
35	"select": channelSelect,
36}
37
38func channelMake(L *LState) int {
39	buffer := L.OptInt(1, 0)
40	L.Push(LChannel(make(chan LValue, buffer)))
41	return 1
42}
43
44func channelSelect(L *LState) int {
45	//TODO check case table size
46	cases := make([]reflect.SelectCase, L.GetTop())
47	top := L.GetTop()
48	for i := 0; i < top; i++ {
49		cas := reflect.SelectCase{
50			Dir:  reflect.SelectSend,
51			Chan: reflect.ValueOf(nil),
52			Send: reflect.ValueOf(nil),
53		}
54		tbl := L.CheckTable(i + 1)
55		dir, ok1 := tbl.RawGetInt(1).(LString)
56		if !ok1 {
57			L.ArgError(i+1, "invalid select case")
58		}
59		switch string(dir) {
60		case "<-|":
61			ch, ok := tbl.RawGetInt(2).(LChannel)
62			if !ok {
63				L.ArgError(i+1, "invalid select case")
64			}
65			cas.Chan = reflect.ValueOf((chan LValue)(ch))
66			v := tbl.RawGetInt(3)
67			if !isGoroutineSafe(v) {
68				L.ArgError(i+1, "can not send a function, userdata, thread or table that has a metatable")
69			}
70			cas.Send = reflect.ValueOf(v)
71		case "|<-":
72			ch, ok := tbl.RawGetInt(2).(LChannel)
73			if !ok {
74				L.ArgError(i+1, "invalid select case")
75			}
76			cas.Chan = reflect.ValueOf((chan LValue)(ch))
77			cas.Dir = reflect.SelectRecv
78		case "default":
79			cas.Dir = reflect.SelectDefault
80		default:
81			L.ArgError(i+1, "invalid channel direction:"+string(dir))
82		}
83		cases[i] = cas
84	}
85
86	if L.ctx != nil {
87		cases = append(cases, reflect.SelectCase{
88			Dir:  reflect.SelectRecv,
89			Chan: reflect.ValueOf(L.ctx.Done()),
90			Send: reflect.ValueOf(nil),
91		})
92	}
93
94	pos, recv, rok := reflect.Select(cases)
95
96	if L.ctx != nil && pos == L.GetTop() {
97		return 0
98	}
99
100	lv := LNil
101	if recv.Kind() != 0 {
102		lv, _ = recv.Interface().(LValue)
103		if lv == nil {
104			lv = LNil
105		}
106	}
107	tbl := L.Get(pos + 1).(*LTable)
108	last := tbl.RawGetInt(tbl.Len())
109	if last.Type() == LTFunction {
110		L.Push(last)
111		switch cases[pos].Dir {
112		case reflect.SelectRecv:
113			if rok {
114				L.Push(LTrue)
115			} else {
116				L.Push(LFalse)
117			}
118			L.Push(lv)
119			L.Call(2, 0)
120		case reflect.SelectSend:
121			L.Push(tbl.RawGetInt(3))
122			L.Call(1, 0)
123		case reflect.SelectDefault:
124			L.Call(0, 0)
125		}
126	}
127	L.Push(LNumber(pos + 1))
128	L.Push(lv)
129	if rok {
130		L.Push(LTrue)
131	} else {
132		L.Push(LFalse)
133	}
134	return 3
135}
136
137var channelMethods = map[string]LGFunction{
138	"receive": channelReceive,
139	"send":    channelSend,
140	"close":   channelClose,
141}
142
143func channelReceive(L *LState) int {
144	rch := checkChannel(L, 1)
145	var v reflect.Value
146	var ok bool
147	if L.ctx != nil {
148		cases := []reflect.SelectCase{{
149			Dir:  reflect.SelectRecv,
150			Chan: reflect.ValueOf(L.ctx.Done()),
151			Send: reflect.ValueOf(nil),
152		}, {
153			Dir:  reflect.SelectRecv,
154			Chan: rch,
155			Send: reflect.ValueOf(nil),
156		}}
157		_, v, ok = reflect.Select(cases)
158	} else {
159		v, ok = rch.Recv()
160	}
161	if ok {
162		L.Push(LTrue)
163		L.Push(v.Interface().(LValue))
164	} else {
165		L.Push(LFalse)
166		L.Push(LNil)
167	}
168	return 2
169}
170
171func channelSend(L *LState) int {
172	rch := checkChannel(L, 1)
173	v := checkGoroutineSafe(L, 2)
174	rch.Send(reflect.ValueOf(v))
175	return 0
176}
177
178func channelClose(L *LState) int {
179	rch := checkChannel(L, 1)
180	rch.Close()
181	return 0
182}
183
184//
185