1from __future__ import unicode_literals
2import pytest
3
4from redis import exceptions
5
6
7multiply_script = """
8local value = redis.call('GET', KEYS[1])
9value = tonumber(value)
10return value * ARGV[1]"""
11
12msgpack_hello_script = """
13local message = cmsgpack.unpack(ARGV[1])
14local name = message['name']
15return "hello " .. name
16"""
17msgpack_hello_script_broken = """
18local message = cmsgpack.unpack(ARGV[1])
19local names = message['name']
20return "hello " .. name
21"""
22
23
24class TestScripting(object):
25    @pytest.fixture(autouse=True)
26    def reset_scripts(self, r):
27        r.script_flush()
28
29    def test_eval(self, r):
30        r.set('a', 2)
31        # 2 * 3 == 6
32        assert r.eval(multiply_script, 1, 'a', 3) == 6
33
34    def test_evalsha(self, r):
35        r.set('a', 2)
36        sha = r.script_load(multiply_script)
37        # 2 * 3 == 6
38        assert r.evalsha(sha, 1, 'a', 3) == 6
39
40    def test_evalsha_script_not_loaded(self, r):
41        r.set('a', 2)
42        sha = r.script_load(multiply_script)
43        # remove the script from Redis's cache
44        r.script_flush()
45        with pytest.raises(exceptions.NoScriptError):
46            r.evalsha(sha, 1, 'a', 3)
47
48    def test_script_loading(self, r):
49        # get the sha, then clear the cache
50        sha = r.script_load(multiply_script)
51        r.script_flush()
52        assert r.script_exists(sha) == [False]
53        r.script_load(multiply_script)
54        assert r.script_exists(sha) == [True]
55
56    def test_script_object(self, r):
57        r.set('a', 2)
58        multiply = r.register_script(multiply_script)
59        precalculated_sha = multiply.sha
60        assert precalculated_sha
61        assert r.script_exists(multiply.sha) == [False]
62        # Test second evalsha block (after NoScriptError)
63        assert multiply(keys=['a'], args=[3]) == 6
64        # At this point, the script should be loaded
65        assert r.script_exists(multiply.sha) == [True]
66        # Test that the precalculated sha matches the one from redis
67        assert multiply.sha == precalculated_sha
68        # Test first evalsha block
69        assert multiply(keys=['a'], args=[3]) == 6
70
71    def test_script_object_in_pipeline(self, r):
72        multiply = r.register_script(multiply_script)
73        precalculated_sha = multiply.sha
74        assert precalculated_sha
75        pipe = r.pipeline()
76        pipe.set('a', 2)
77        pipe.get('a')
78        multiply(keys=['a'], args=[3], client=pipe)
79        assert r.script_exists(multiply.sha) == [False]
80        # [SET worked, GET 'a', result of multiple script]
81        assert pipe.execute() == [True, b'2', 6]
82        # The script should have been loaded by pipe.execute()
83        assert r.script_exists(multiply.sha) == [True]
84        # The precalculated sha should have been the correct one
85        assert multiply.sha == precalculated_sha
86
87        # purge the script from redis's cache and re-run the pipeline
88        # the multiply script should be reloaded by pipe.execute()
89        r.script_flush()
90        pipe = r.pipeline()
91        pipe.set('a', 2)
92        pipe.get('a')
93        multiply(keys=['a'], args=[3], client=pipe)
94        assert r.script_exists(multiply.sha) == [False]
95        # [SET worked, GET 'a', result of multiple script]
96        assert pipe.execute() == [True, b'2', 6]
97        assert r.script_exists(multiply.sha) == [True]
98
99    def test_eval_msgpack_pipeline_error_in_lua(self, r):
100        msgpack_hello = r.register_script(msgpack_hello_script)
101        assert msgpack_hello.sha
102
103        pipe = r.pipeline()
104
105        # avoiding a dependency to msgpack, this is the output of
106        # msgpack.dumps({"name": "joe"})
107        msgpack_message_1 = b'\x81\xa4name\xa3Joe'
108
109        msgpack_hello(args=[msgpack_message_1], client=pipe)
110
111        assert r.script_exists(msgpack_hello.sha) == [False]
112        assert pipe.execute()[0] == b'hello Joe'
113        assert r.script_exists(msgpack_hello.sha) == [True]
114
115        msgpack_hello_broken = r.register_script(msgpack_hello_script_broken)
116
117        msgpack_hello_broken(args=[msgpack_message_1], client=pipe)
118        with pytest.raises(exceptions.ResponseError) as excinfo:
119            pipe.execute()
120        assert excinfo.type == exceptions.ResponseError
121