1package lua
2
3import (
4	"context"
5	"strings"
6	"testing"
7	"time"
8)
9
10func TestLStateIsClosed(t *testing.T) {
11	L := NewState()
12	L.Close()
13	errorIfNotEqual(t, true, L.IsClosed())
14}
15
16func TestCallStackOverflowWhenFixed(t *testing.T) {
17	L := NewState(Options{
18		CallStackSize: 3,
19	})
20	defer L.Close()
21
22	// expect fixed stack implementation by default (for backwards compatibility)
23	stack := L.stack
24	if _, ok := stack.(*fixedCallFrameStack); !ok {
25		t.Errorf("expected fixed callframe stack by default")
26	}
27
28	errorIfScriptNotFail(t, L, `
29    local function recurse(count)
30      if count > 0 then
31        recurse(count - 1)
32      end
33    end
34    local function c()
35      print(_printregs())
36      recurse(9)
37    end
38    c()
39    `, "stack overflow")
40}
41
42func TestCallStackOverflowWhenAutoGrow(t *testing.T) {
43	L := NewState(Options{
44		CallStackSize:       3,
45		MinimizeStackMemory: true,
46	})
47	defer L.Close()
48
49	// expect auto growing stack implementation when MinimizeStackMemory is set
50	stack := L.stack
51	if _, ok := stack.(*autoGrowingCallFrameStack); !ok {
52		t.Errorf("expected fixed callframe stack by default")
53	}
54
55	errorIfScriptNotFail(t, L, `
56    local function recurse(count)
57      if count > 0 then
58        recurse(count - 1)
59      end
60    end
61    local function c()
62      print(_printregs())
63      recurse(9)
64    end
65    c()
66    `, "stack overflow")
67}
68
69func TestSkipOpenLibs(t *testing.T) {
70	L := NewState(Options{SkipOpenLibs: true})
71	defer L.Close()
72	errorIfScriptNotFail(t, L, `print("")`,
73		"attempt to call a non-function object")
74	L2 := NewState()
75	defer L2.Close()
76	errorIfScriptFail(t, L2, `print("")`)
77}
78
79func TestGetAndReplace(t *testing.T) {
80	L := NewState()
81	defer L.Close()
82	L.Push(LString("a"))
83	L.Replace(1, LString("b"))
84	L.Replace(0, LString("c"))
85	errorIfNotEqual(t, LNil, L.Get(0))
86	errorIfNotEqual(t, LNil, L.Get(-10))
87	errorIfNotEqual(t, L.Env, L.Get(EnvironIndex))
88	errorIfNotEqual(t, LString("b"), L.Get(1))
89	L.Push(LString("c"))
90	L.Push(LString("d"))
91	L.Replace(-2, LString("e"))
92	errorIfNotEqual(t, LString("e"), L.Get(-2))
93	registry := L.NewTable()
94	L.Replace(RegistryIndex, registry)
95	L.G.Registry = registry
96	errorIfGFuncNotFail(t, L, func(L *LState) int {
97		L.Replace(RegistryIndex, LNil)
98		return 0
99	}, "registry must be a table")
100	errorIfGFuncFail(t, L, func(L *LState) int {
101		env := L.NewTable()
102		L.Replace(EnvironIndex, env)
103		errorIfNotEqual(t, env, L.Get(EnvironIndex))
104		return 0
105	})
106	errorIfGFuncNotFail(t, L, func(L *LState) int {
107		L.Replace(EnvironIndex, LNil)
108		return 0
109	}, "environment must be a table")
110	errorIfGFuncFail(t, L, func(L *LState) int {
111		gbl := L.NewTable()
112		L.Replace(GlobalsIndex, gbl)
113		errorIfNotEqual(t, gbl, L.G.Global)
114		return 0
115	})
116	errorIfGFuncNotFail(t, L, func(L *LState) int {
117		L.Replace(GlobalsIndex, LNil)
118		return 0
119	}, "_G must be a table")
120
121	L2 := NewState()
122	defer L2.Close()
123	clo := L2.NewClosure(func(L2 *LState) int {
124		L2.Replace(UpvalueIndex(1), LNumber(3))
125		errorIfNotEqual(t, LNumber(3), L2.Get(UpvalueIndex(1)))
126		return 0
127	}, LNumber(1), LNumber(2))
128	L2.SetGlobal("clo", clo)
129	errorIfScriptFail(t, L2, `clo()`)
130}
131
132func TestRemove(t *testing.T) {
133	L := NewState()
134	defer L.Close()
135	L.Push(LString("a"))
136	L.Push(LString("b"))
137	L.Push(LString("c"))
138
139	L.Remove(4)
140	errorIfNotEqual(t, LString("a"), L.Get(1))
141	errorIfNotEqual(t, LString("b"), L.Get(2))
142	errorIfNotEqual(t, LString("c"), L.Get(3))
143	errorIfNotEqual(t, 3, L.GetTop())
144
145	L.Remove(3)
146	errorIfNotEqual(t, LString("a"), L.Get(1))
147	errorIfNotEqual(t, LString("b"), L.Get(2))
148	errorIfNotEqual(t, LNil, L.Get(3))
149	errorIfNotEqual(t, 2, L.GetTop())
150	L.Push(LString("c"))
151
152	L.Remove(-10)
153	errorIfNotEqual(t, LString("a"), L.Get(1))
154	errorIfNotEqual(t, LString("b"), L.Get(2))
155	errorIfNotEqual(t, LString("c"), L.Get(3))
156	errorIfNotEqual(t, 3, L.GetTop())
157
158	L.Remove(2)
159	errorIfNotEqual(t, LString("a"), L.Get(1))
160	errorIfNotEqual(t, LString("c"), L.Get(2))
161	errorIfNotEqual(t, LNil, L.Get(3))
162	errorIfNotEqual(t, 2, L.GetTop())
163}
164
165func TestToInt(t *testing.T) {
166	L := NewState()
167	defer L.Close()
168	L.Push(LNumber(10))
169	L.Push(LString("99.9"))
170	L.Push(L.NewTable())
171	errorIfNotEqual(t, 10, L.ToInt(1))
172	errorIfNotEqual(t, 99, L.ToInt(2))
173	errorIfNotEqual(t, 0, L.ToInt(3))
174}
175
176func TestToInt64(t *testing.T) {
177	L := NewState()
178	defer L.Close()
179	L.Push(LNumber(10))
180	L.Push(LString("99.9"))
181	L.Push(L.NewTable())
182	errorIfNotEqual(t, int64(10), L.ToInt64(1))
183	errorIfNotEqual(t, int64(99), L.ToInt64(2))
184	errorIfNotEqual(t, int64(0), L.ToInt64(3))
185}
186
187func TestToNumber(t *testing.T) {
188	L := NewState()
189	defer L.Close()
190	L.Push(LNumber(10))
191	L.Push(LString("99.9"))
192	L.Push(L.NewTable())
193	errorIfNotEqual(t, LNumber(10), L.ToNumber(1))
194	errorIfNotEqual(t, LNumber(99.9), L.ToNumber(2))
195	errorIfNotEqual(t, LNumber(0), L.ToNumber(3))
196}
197
198func TestToString(t *testing.T) {
199	L := NewState()
200	defer L.Close()
201	L.Push(LNumber(10))
202	L.Push(LString("99.9"))
203	L.Push(L.NewTable())
204	errorIfNotEqual(t, "10", L.ToString(1))
205	errorIfNotEqual(t, "99.9", L.ToString(2))
206	errorIfNotEqual(t, "", L.ToString(3))
207}
208
209func TestToTable(t *testing.T) {
210	L := NewState()
211	defer L.Close()
212	L.Push(LNumber(10))
213	L.Push(LString("99.9"))
214	L.Push(L.NewTable())
215	errorIfFalse(t, L.ToTable(1) == nil, "index 1 must be nil")
216	errorIfFalse(t, L.ToTable(2) == nil, "index 2 must be nil")
217	errorIfNotEqual(t, L.Get(3), L.ToTable(3))
218}
219
220func TestToFunction(t *testing.T) {
221	L := NewState()
222	defer L.Close()
223	L.Push(LNumber(10))
224	L.Push(LString("99.9"))
225	L.Push(L.NewFunction(func(L *LState) int { return 0 }))
226	errorIfFalse(t, L.ToFunction(1) == nil, "index 1 must be nil")
227	errorIfFalse(t, L.ToFunction(2) == nil, "index 2 must be nil")
228	errorIfNotEqual(t, L.Get(3), L.ToFunction(3))
229}
230
231func TestToUserData(t *testing.T) {
232	L := NewState()
233	defer L.Close()
234	L.Push(LNumber(10))
235	L.Push(LString("99.9"))
236	L.Push(L.NewUserData())
237	errorIfFalse(t, L.ToUserData(1) == nil, "index 1 must be nil")
238	errorIfFalse(t, L.ToUserData(2) == nil, "index 2 must be nil")
239	errorIfNotEqual(t, L.Get(3), L.ToUserData(3))
240}
241
242func TestToChannel(t *testing.T) {
243	L := NewState()
244	defer L.Close()
245	L.Push(LNumber(10))
246	L.Push(LString("99.9"))
247	var ch chan LValue
248	L.Push(LChannel(ch))
249	errorIfFalse(t, L.ToChannel(1) == nil, "index 1 must be nil")
250	errorIfFalse(t, L.ToChannel(2) == nil, "index 2 must be nil")
251	errorIfNotEqual(t, ch, L.ToChannel(3))
252}
253
254func TestObjLen(t *testing.T) {
255	L := NewState()
256	defer L.Close()
257	errorIfNotEqual(t, 3, L.ObjLen(LString("abc")))
258	tbl := L.NewTable()
259	tbl.Append(LTrue)
260	tbl.Append(LTrue)
261	errorIfNotEqual(t, 2, L.ObjLen(tbl))
262	mt := L.NewTable()
263	L.SetField(mt, "__len", L.NewFunction(func(L *LState) int {
264		tbl := L.CheckTable(1)
265		L.Push(LNumber(tbl.Len() + 1))
266		return 1
267	}))
268	L.SetMetatable(tbl, mt)
269	errorIfNotEqual(t, 3, L.ObjLen(tbl))
270	errorIfNotEqual(t, 0, L.ObjLen(LNumber(10)))
271}
272
273func TestConcat(t *testing.T) {
274	L := NewState()
275	defer L.Close()
276	errorIfNotEqual(t, "a1c", L.Concat(LString("a"), LNumber(1), LString("c")))
277}
278
279func TestPCall(t *testing.T) {
280	L := NewState()
281	defer L.Close()
282	L.Register("f1", func(L *LState) int {
283		panic("panic!")
284		return 0
285	})
286	errorIfScriptNotFail(t, L, `f1()`, "panic!")
287	L.Push(L.GetGlobal("f1"))
288	err := L.PCall(0, 0, L.NewFunction(func(L *LState) int {
289		L.Push(LString("by handler"))
290		return 1
291	}))
292	errorIfFalse(t, strings.Contains(err.Error(), "by handler"), "")
293
294	err = L.PCall(0, 0, L.NewFunction(func(L *LState) int {
295		L.RaiseError("error!")
296		return 1
297	}))
298	errorIfFalse(t, strings.Contains(err.Error(), "error!"), "")
299
300	err = L.PCall(0, 0, L.NewFunction(func(L *LState) int {
301		panic("panicc!")
302		return 1
303	}))
304	errorIfFalse(t, strings.Contains(err.Error(), "panicc!"), "")
305}
306
307func TestCoroutineApi1(t *testing.T) {
308	L := NewState()
309	defer L.Close()
310	co, _ := L.NewThread()
311	errorIfScriptFail(t, L, `
312      function coro(v)
313        assert(v == 10)
314        local ret1, ret2 = coroutine.yield(1,2,3)
315        assert(ret1 == 11)
316        assert(ret2 == 12)
317        coroutine.yield(4)
318        return 5
319      end
320    `)
321	fn := L.GetGlobal("coro").(*LFunction)
322	st, err, values := L.Resume(co, fn, LNumber(10))
323	errorIfNotEqual(t, ResumeYield, st)
324	errorIfNotNil(t, err)
325	errorIfNotEqual(t, 3, len(values))
326	errorIfNotEqual(t, LNumber(1), values[0].(LNumber))
327	errorIfNotEqual(t, LNumber(2), values[1].(LNumber))
328	errorIfNotEqual(t, LNumber(3), values[2].(LNumber))
329
330	st, err, values = L.Resume(co, fn, LNumber(11), LNumber(12))
331	errorIfNotEqual(t, ResumeYield, st)
332	errorIfNotNil(t, err)
333	errorIfNotEqual(t, 1, len(values))
334	errorIfNotEqual(t, LNumber(4), values[0].(LNumber))
335
336	st, err, values = L.Resume(co, fn)
337	errorIfNotEqual(t, ResumeOK, st)
338	errorIfNotNil(t, err)
339	errorIfNotEqual(t, 1, len(values))
340	errorIfNotEqual(t, LNumber(5), values[0].(LNumber))
341
342	L.Register("myyield", func(L *LState) int {
343		return L.Yield(L.ToNumber(1))
344	})
345	errorIfScriptFail(t, L, `
346      function coro_error()
347        coroutine.yield(1,2,3)
348        myyield(4)
349        assert(false, "--failed--")
350      end
351    `)
352	fn = L.GetGlobal("coro_error").(*LFunction)
353	co, _ = L.NewThread()
354	st, err, values = L.Resume(co, fn)
355	errorIfNotEqual(t, ResumeYield, st)
356	errorIfNotNil(t, err)
357	errorIfNotEqual(t, 3, len(values))
358	errorIfNotEqual(t, LNumber(1), values[0].(LNumber))
359	errorIfNotEqual(t, LNumber(2), values[1].(LNumber))
360	errorIfNotEqual(t, LNumber(3), values[2].(LNumber))
361
362	st, err, values = L.Resume(co, fn)
363	errorIfNotEqual(t, ResumeYield, st)
364	errorIfNotNil(t, err)
365	errorIfNotEqual(t, 1, len(values))
366	errorIfNotEqual(t, LNumber(4), values[0].(LNumber))
367
368	st, err, values = L.Resume(co, fn)
369	errorIfNotEqual(t, ResumeError, st)
370	errorIfNil(t, err)
371	errorIfFalse(t, strings.Contains(err.Error(), "--failed--"), "error message must be '--failed--'")
372	st, err, values = L.Resume(co, fn)
373	errorIfNotEqual(t, ResumeError, st)
374	errorIfNil(t, err)
375	errorIfFalse(t, strings.Contains(err.Error(), "can not resume a dead thread"), "can not resume a dead thread")
376
377}
378
379func TestContextTimeout(t *testing.T) {
380	L := NewState()
381	defer L.Close()
382	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
383	defer cancel()
384	L.SetContext(ctx)
385	errorIfNotEqual(t, ctx, L.Context())
386	err := L.DoString(`
387	  local clock = os.clock
388      function sleep(n)  -- seconds
389        local t0 = clock()
390        while clock() - t0 <= n do end
391      end
392	  sleep(3)
393	`)
394	errorIfNil(t, err)
395	errorIfFalse(t, strings.Contains(err.Error(), "context deadline exceeded"), "execution must be canceled")
396
397	oldctx := L.RemoveContext()
398	errorIfNotEqual(t, ctx, oldctx)
399	errorIfNotNil(t, L.ctx)
400}
401
402func TestContextCancel(t *testing.T) {
403	L := NewState()
404	defer L.Close()
405	ctx, cancel := context.WithCancel(context.Background())
406	errch := make(chan error, 1)
407	L.SetContext(ctx)
408	go func() {
409		errch <- L.DoString(`
410	    local clock = os.clock
411        function sleep(n)  -- seconds
412          local t0 = clock()
413          while clock() - t0 <= n do end
414        end
415	    sleep(3)
416	  `)
417	}()
418	time.Sleep(1 * time.Second)
419	cancel()
420	err := <-errch
421	errorIfNil(t, err)
422	errorIfFalse(t, strings.Contains(err.Error(), "context canceled"), "execution must be canceled")
423}
424
425func TestContextWithCroutine(t *testing.T) {
426	L := NewState()
427	defer L.Close()
428	ctx, cancel := context.WithCancel(context.Background())
429	L.SetContext(ctx)
430	defer cancel()
431	L.DoString(`
432	    function coro()
433		  local i = 0
434		  while true do
435		    coroutine.yield(i)
436			i = i+1
437		  end
438		  return i
439	    end
440	`)
441	co, cocancel := L.NewThread()
442	defer cocancel()
443	fn := L.GetGlobal("coro").(*LFunction)
444	_, err, values := L.Resume(co, fn)
445	errorIfNotNil(t, err)
446	errorIfNotEqual(t, LNumber(0), values[0])
447	// cancel the parent context
448	cancel()
449	_, err, values = L.Resume(co, fn)
450	errorIfNil(t, err)
451	errorIfFalse(t, strings.Contains(err.Error(), "context canceled"), "coroutine execution must be canceled when the parent context is canceled")
452
453}
454
455func TestPCallAfterFail(t *testing.T) {
456	L := NewState()
457	defer L.Close()
458	errFn := L.NewFunction(func(L *LState) int {
459		L.RaiseError("error!")
460		return 0
461	})
462	changeError := L.NewFunction(func(L *LState) int {
463		L.Push(errFn)
464		err := L.PCall(0, 0, nil)
465		if err != nil {
466			L.RaiseError("A New Error")
467		}
468		return 0
469	})
470	L.Push(changeError)
471	err := L.PCall(0, 0, nil)
472	errorIfFalse(t, strings.Contains(err.Error(), "A New Error"), "error not propogated correctly")
473}
474
475func TestRegistryFixedOverflow(t *testing.T) {
476	state := NewState()
477	defer state.Close()
478	reg := state.reg
479	expectedPanic := false
480	// should be non auto grow by default
481	errorIfFalse(t, reg.maxSize == 0, "state should default to non-auto growing implementation")
482	// fill the stack and check we get a panic
483	test := LString("test")
484	for i := 0; i < len(reg.array); i++ {
485		reg.Push(test)
486	}
487	defer func() {
488		rcv := recover()
489		if rcv != nil {
490			if expectedPanic {
491				errorIfFalse(t, rcv.(error).Error() != "registry overflow", "expected registry overflow exception, got "+rcv.(error).Error())
492			} else {
493				t.Errorf("did not expect registry overflow")
494			}
495		} else if expectedPanic {
496			t.Errorf("expected registry overflow exception, but didn't get panic")
497		}
498	}()
499	expectedPanic = true
500	reg.Push(test)
501}
502
503func TestRegistryAutoGrow(t *testing.T) {
504	state := NewState(Options{RegistryMaxSize: 300, RegistrySize: 200, RegistryGrowStep: 25})
505	defer state.Close()
506	expectedPanic := false
507	defer func() {
508		rcv := recover()
509		if rcv != nil {
510			if expectedPanic {
511				errorIfFalse(t, rcv.(error).Error() != "registry overflow", "expected registry overflow exception, got "+rcv.(error).Error())
512			} else {
513				t.Errorf("did not expect registry overflow")
514			}
515		} else if expectedPanic {
516			t.Errorf("expected registry overflow exception, but didn't get panic")
517		}
518	}()
519	reg := state.reg
520	test := LString("test")
521	for i := 0; i < 300; i++ {
522		reg.Push(test)
523	}
524	expectedPanic = true
525	reg.Push(test)
526}
527
528func BenchmarkCallFrameStackPushPopAutoGrow(t *testing.B) {
529	stack := newAutoGrowingCallFrameStack(256)
530
531	t.ResetTimer()
532
533	const Iterations = 256
534	for j := 0; j < t.N; j++ {
535		for i := 0; i < Iterations; i++ {
536			stack.Push(callFrame{})
537		}
538		for i := 0; i < Iterations; i++ {
539			stack.Pop()
540		}
541	}
542}
543
544func BenchmarkCallFrameStackPushPopFixed(t *testing.B) {
545	stack := newFixedCallFrameStack(256)
546
547	t.ResetTimer()
548
549	const Iterations = 256
550	for j := 0; j < t.N; j++ {
551		for i := 0; i < Iterations; i++ {
552			stack.Push(callFrame{})
553		}
554		for i := 0; i < Iterations; i++ {
555			stack.Pop()
556		}
557	}
558}
559
560// this test will intentionally not incur stack growth in order to bench the performance when no allocations happen
561func BenchmarkCallFrameStackPushPopShallowAutoGrow(t *testing.B) {
562	stack := newAutoGrowingCallFrameStack(256)
563
564	t.ResetTimer()
565
566	const Iterations = 8
567	for j := 0; j < t.N; j++ {
568		for i := 0; i < Iterations; i++ {
569			stack.Push(callFrame{})
570		}
571		for i := 0; i < Iterations; i++ {
572			stack.Pop()
573		}
574	}
575}
576
577func BenchmarkCallFrameStackPushPopShallowFixed(t *testing.B) {
578	stack := newFixedCallFrameStack(256)
579
580	t.ResetTimer()
581
582	const Iterations = 8
583	for j := 0; j < t.N; j++ {
584		for i := 0; i < Iterations; i++ {
585			stack.Push(callFrame{})
586		}
587		for i := 0; i < Iterations; i++ {
588			stack.Pop()
589		}
590	}
591}
592
593func BenchmarkCallFrameStackPushPopFixedNoInterface(t *testing.B) {
594	stack := newFixedCallFrameStack(256).(*fixedCallFrameStack)
595
596	t.ResetTimer()
597
598	const Iterations = 256
599	for j := 0; j < t.N; j++ {
600		for i := 0; i < Iterations; i++ {
601			stack.Push(callFrame{})
602		}
603		for i := 0; i < Iterations; i++ {
604			stack.Pop()
605		}
606	}
607}
608
609func BenchmarkCallFrameStackUnwindAutoGrow(t *testing.B) {
610	stack := newAutoGrowingCallFrameStack(256)
611
612	t.ResetTimer()
613
614	const Iterations = 256
615	for j := 0; j < t.N; j++ {
616		for i := 0; i < Iterations; i++ {
617			stack.Push(callFrame{})
618		}
619		stack.SetSp(0)
620	}
621}
622
623func BenchmarkCallFrameStackUnwindFixed(t *testing.B) {
624	stack := newFixedCallFrameStack(256)
625
626	t.ResetTimer()
627
628	const Iterations = 256
629	for j := 0; j < t.N; j++ {
630		for i := 0; i < Iterations; i++ {
631			stack.Push(callFrame{})
632		}
633		stack.SetSp(0)
634	}
635}
636
637func BenchmarkCallFrameStackUnwindFixedNoInterface(t *testing.B) {
638	stack := newFixedCallFrameStack(256).(*fixedCallFrameStack)
639
640	t.ResetTimer()
641
642	const Iterations = 256
643	for j := 0; j < t.N; j++ {
644		for i := 0; i < Iterations; i++ {
645			stack.Push(callFrame{})
646		}
647		stack.SetSp(0)
648	}
649}
650
651type registryTestHandler int
652
653func (registryTestHandler) registryOverflow() {
654	panic("registry overflow")
655}
656
657// test pushing and popping from the registry
658func BenchmarkRegistryPushPopAutoGrow(t *testing.B) {
659	al := newAllocator(32)
660	sz := 256 * 20
661	reg := newRegistry(registryTestHandler(0), sz/2, 64, sz, al)
662	value := LString("test")
663
664	t.ResetTimer()
665
666	for j := 0; j < t.N; j++ {
667		for i := 0; i < sz; i++ {
668			reg.Push(value)
669		}
670		for i := 0; i < sz; i++ {
671			reg.Pop()
672		}
673	}
674}
675
676func BenchmarkRegistryPushPopFixed(t *testing.B) {
677	al := newAllocator(32)
678	sz := 256 * 20
679	reg := newRegistry(registryTestHandler(0), sz, 0, sz, al)
680	value := LString("test")
681
682	t.ResetTimer()
683
684	for j := 0; j < t.N; j++ {
685		for i := 0; i < sz; i++ {
686			reg.Push(value)
687		}
688		for i := 0; i < sz; i++ {
689			reg.Pop()
690		}
691	}
692}
693
694func BenchmarkRegistrySetTop(t *testing.B) {
695	al := newAllocator(32)
696	sz := 256 * 20
697	reg := newRegistry(registryTestHandler(0), sz, 32, sz*2, al)
698
699	t.ResetTimer()
700
701	for j := 0; j < t.N; j++ {
702		reg.SetTop(sz)
703		reg.SetTop(0)
704	}
705}
706