1import unittest
2
3import gevent
4from gevent import sleep
5from gevent.queue import Queue
6
7import mock
8from locust import events
9from locust.core import Locust, TaskSet, task
10from locust.exception import LocustError
11from locust.rpc import Message
12from locust.runners import LocalLocustRunner, MasterLocustRunner, SlaveNode, STATE_INIT, STATE_HATCHING, STATE_RUNNING, STATE_MISSING
13from locust.stats import global_stats, RequestStats
14from locust.test.testcases import LocustTestCase
15
16def mocked_rpc_server():
17    class MockedRpcServer(object):
18        queue = Queue()
19        outbox = []
20
21        def __init__(self, host, port):
22            pass
23
24        @classmethod
25        def mocked_send(cls, message):
26            cls.queue.put(message.serialize())
27            sleep(0)
28
29        def recv(self):
30            results = self.queue.get()
31            return Message.unserialize(results)
32
33        def send(self, message):
34            self.outbox.append(message.serialize())
35
36        def send_to_client(self, message):
37            self.outbox.append([message.node_id, message.serialize()])
38
39        def recv_from_client(self):
40            results = self.queue.get()
41            msg = Message.unserialize(results)
42            return msg.node_id, msg
43
44    return MockedRpcServer
45
46class mocked_options(object):
47    def __init__(self):
48        self.hatch_rate = 5
49        self.num_clients = 5
50        self.host = '/'
51        self.master_host = 'localhost'
52        self.master_port = 5557
53        self.master_bind_host = '*'
54        self.master_bind_port = 5557
55        self.heartbeat_liveness = 3
56        self.heartbeat_interval = 0.01
57
58    def reset_stats(self):
59        pass
60
61class TestMasterRunner(LocustTestCase):
62    def setUp(self):
63        global_stats.reset_all()
64        self._slave_report_event_handlers = [h for h in events.slave_report._handlers]
65        self.options = mocked_options()
66
67
68    def tearDown(self):
69        events.slave_report._handlers = self._slave_report_event_handlers
70
71    def test_slave_connect(self):
72        class MyTestLocust(Locust):
73            pass
74
75        with mock.patch("locust.rpc.rpc.Server", mocked_rpc_server()) as server:
76            master = MasterLocustRunner(MyTestLocust, self.options)
77            server.mocked_send(Message("client_ready", None, "zeh_fake_client1"))
78            self.assertEqual(1, len(master.clients))
79            self.assertTrue("zeh_fake_client1" in master.clients, "Could not find fake client in master instance's clients dict")
80            server.mocked_send(Message("client_ready", None, "zeh_fake_client2"))
81            server.mocked_send(Message("client_ready", None, "zeh_fake_client3"))
82            server.mocked_send(Message("client_ready", None, "zeh_fake_client4"))
83            self.assertEqual(4, len(master.clients))
84
85            server.mocked_send(Message("quit", None, "zeh_fake_client3"))
86            self.assertEqual(3, len(master.clients))
87
88    def test_slave_stats_report_median(self):
89        class MyTestLocust(Locust):
90            pass
91
92        with mock.patch("locust.rpc.rpc.Server", mocked_rpc_server()) as server:
93            master = MasterLocustRunner(MyTestLocust, self.options)
94            server.mocked_send(Message("client_ready", None, "fake_client"))
95
96            master.stats.get("/", "GET").log(100, 23455)
97            master.stats.get("/", "GET").log(800, 23455)
98            master.stats.get("/", "GET").log(700, 23455)
99
100            data = {"user_count":1}
101            events.report_to_master.fire(client_id="fake_client", data=data)
102            master.stats.clear_all()
103
104            server.mocked_send(Message("stats", data, "fake_client"))
105            s = master.stats.get("/", "GET")
106            self.assertEqual(700, s.median_response_time)
107
108    def test_master_marks_downed_slaves_as_missing(self):
109        class MyTestLocust(Locust):
110            pass
111
112        with mock.patch("locust.rpc.rpc.Server", mocked_rpc_server()) as server:
113            master = MasterLocustRunner(MyTestLocust, self.options)
114            server.mocked_send(Message("client_ready", None, "fake_client"))
115            sleep(0.1)
116            # print(master.clients['fake_client'].__dict__)
117            assert master.clients['fake_client'].state == STATE_MISSING
118
119    def test_master_total_stats(self):
120        class MyTestLocust(Locust):
121            pass
122
123        with mock.patch("locust.rpc.rpc.Server", mocked_rpc_server()) as server:
124            master = MasterLocustRunner(MyTestLocust, self.options)
125            server.mocked_send(Message("client_ready", None, "fake_client"))
126            stats = RequestStats()
127            stats.log_request("GET", "/1", 100, 3546)
128            stats.log_request("GET", "/1", 800, 56743)
129            stats2 = RequestStats()
130            stats2.log_request("GET", "/2", 700, 2201)
131            server.mocked_send(Message("stats", {
132                "stats":stats.serialize_stats(),
133                "stats_total": stats.total.serialize(),
134                "errors":stats.serialize_errors(),
135                "user_count": 1,
136            }, "fake_client"))
137            server.mocked_send(Message("stats", {
138                "stats":stats2.serialize_stats(),
139                "stats_total": stats2.total.serialize(),
140                "errors":stats2.serialize_errors(),
141                "user_count": 2,
142            }, "fake_client"))
143            self.assertEqual(700, master.stats.total.median_response_time)
144
145    def test_master_current_response_times(self):
146        class MyTestLocust(Locust):
147            pass
148
149        start_time = 1
150        with mock.patch("time.time") as mocked_time:
151            mocked_time.return_value = start_time
152            global_stats.reset_all()
153            with mock.patch("locust.rpc.rpc.Server", mocked_rpc_server()) as server:
154                master = MasterLocustRunner(MyTestLocust, self.options)
155                mocked_time.return_value += 1
156                server.mocked_send(Message("client_ready", None, "fake_client"))
157                stats = RequestStats()
158                stats.log_request("GET", "/1", 100, 3546)
159                stats.log_request("GET", "/1", 800, 56743)
160                server.mocked_send(Message("stats", {
161                    "stats":stats.serialize_stats(),
162                    "stats_total": stats.total.get_stripped_report(),
163                    "errors":stats.serialize_errors(),
164                    "user_count": 1,
165                }, "fake_client"))
166                mocked_time.return_value += 1
167                stats2 = RequestStats()
168                stats2.log_request("GET", "/2", 400, 2201)
169                server.mocked_send(Message("stats", {
170                    "stats":stats2.serialize_stats(),
171                    "stats_total": stats2.total.get_stripped_report(),
172                    "errors":stats2.serialize_errors(),
173                    "user_count": 2,
174                }, "fake_client"))
175                mocked_time.return_value += 4
176                self.assertEqual(400, master.stats.total.get_current_response_time_percentile(0.5))
177                self.assertEqual(800, master.stats.total.get_current_response_time_percentile(0.95))
178
179                # let 10 second pass, do some more requests, send it to the master and make
180                # sure the current response time percentiles only accounts for these new requests
181                mocked_time.return_value += 10
182                stats.log_request("GET", "/1", 20, 1)
183                stats.log_request("GET", "/1", 30, 1)
184                stats.log_request("GET", "/1", 3000, 1)
185                server.mocked_send(Message("stats", {
186                    "stats":stats.serialize_stats(),
187                    "stats_total": stats.total.get_stripped_report(),
188                    "errors":stats.serialize_errors(),
189                    "user_count": 2,
190                }, "fake_client"))
191                self.assertEqual(30, master.stats.total.get_current_response_time_percentile(0.5))
192                self.assertEqual(3000, master.stats.total.get_current_response_time_percentile(0.95))
193
194    def test_sends_hatch_data_to_ready_running_hatching_slaves(self):
195        '''Sends hatch job to running, ready, or hatching slaves'''
196        class MyTestLocust(Locust):
197            pass
198
199        with mock.patch("locust.rpc.rpc.Server", mocked_rpc_server()) as server:
200            master = MasterLocustRunner(MyTestLocust, self.options)
201            master.clients[1] = SlaveNode(1)
202            master.clients[2] = SlaveNode(2)
203            master.clients[3] = SlaveNode(3)
204            master.clients[1].state = STATE_INIT
205            master.clients[2].state = STATE_HATCHING
206            master.clients[3].state = STATE_RUNNING
207            master.start_hatching(5,5)
208
209            self.assertEqual(3, len(server.outbox))
210
211    def test_spawn_zero_locusts(self):
212        class MyTaskSet(TaskSet):
213            @task
214            def my_task(self):
215                pass
216
217        class MyTestLocust(Locust):
218            task_set = MyTaskSet
219            min_wait = 100
220            max_wait = 100
221
222        runner = LocalLocustRunner([MyTestLocust], self.options)
223
224        timeout = gevent.Timeout(2.0)
225        timeout.start()
226
227        try:
228            runner.start_hatching(0, 1, wait=True)
229            runner.greenlet.join()
230        except gevent.Timeout:
231            self.fail("Got Timeout exception. A locust seems to have been spawned, even though 0 was specified.")
232        finally:
233            timeout.cancel()
234
235    def test_spawn_uneven_locusts(self):
236        """
237        Tests that we can accurately spawn a certain number of locusts, even if it's not an
238        even number of the connected slaves
239        """
240        class MyTestLocust(Locust):
241            pass
242
243        with mock.patch("locust.rpc.rpc.Server", mocked_rpc_server()) as server:
244            master = MasterLocustRunner(MyTestLocust, self.options)
245            for i in range(5):
246                server.mocked_send(Message("client_ready", None, "fake_client%i" % i))
247
248            master.start_hatching(7, 7)
249            self.assertEqual(5, len(server.outbox))
250
251            num_clients = 0
252            for _, msg in server.outbox:
253                num_clients += Message.unserialize(msg).data["num_clients"]
254
255            self.assertEqual(7, num_clients, "Total number of locusts that would have been spawned is not 7")
256
257    def test_spawn_fewer_locusts_than_slaves(self):
258        class MyTestLocust(Locust):
259            pass
260
261        with mock.patch("locust.rpc.rpc.Server", mocked_rpc_server()) as server:
262            master = MasterLocustRunner(MyTestLocust, self.options)
263            for i in range(5):
264                server.mocked_send(Message("client_ready", None, "fake_client%i" % i))
265
266            master.start_hatching(2, 2)
267            self.assertEqual(5, len(server.outbox))
268
269            num_clients = 0
270            for _, msg in server.outbox:
271                num_clients += Message.unserialize(msg).data["num_clients"]
272
273            self.assertEqual(2, num_clients, "Total number of locusts that would have been spawned is not 2")
274
275    def test_exception_in_task(self):
276        class HeyAnException(Exception):
277            pass
278
279        class MyLocust(Locust):
280            class task_set(TaskSet):
281                @task
282                def will_error(self):
283                    raise HeyAnException(":(")
284
285        runner = LocalLocustRunner([MyLocust], self.options)
286
287        l = MyLocust()
288        l._catch_exceptions = False
289
290        self.assertRaises(HeyAnException, l.run)
291        self.assertRaises(HeyAnException, l.run)
292        self.assertEqual(1, len(runner.exceptions))
293
294        hash_key, exception = runner.exceptions.popitem()
295        self.assertTrue("traceback" in exception)
296        self.assertTrue("HeyAnException" in exception["traceback"])
297        self.assertEqual(2, exception["count"])
298
299    def test_exception_is_catched(self):
300        """ Test that exceptions are stored, and execution continues """
301        class HeyAnException(Exception):
302            pass
303
304        class MyTaskSet(TaskSet):
305            def __init__(self, *a, **kw):
306                super(MyTaskSet, self).__init__(*a, **kw)
307                self._task_queue = [
308                    {"callable":self.will_error, "args":[], "kwargs":{}},
309                    {"callable":self.will_stop, "args":[], "kwargs":{}},
310                ]
311
312            @task(1)
313            def will_error(self):
314                raise HeyAnException(":(")
315
316            @task(1)
317            def will_stop(self):
318                self.interrupt()
319
320        class MyLocust(Locust):
321            min_wait = 10
322            max_wait = 10
323            task_set = MyTaskSet
324
325        runner = LocalLocustRunner([MyLocust], self.options)
326        l = MyLocust()
327
328        # supress stderr
329        with mock.patch("sys.stderr") as mocked:
330            l.task_set._task_queue = [l.task_set.will_error, l.task_set.will_stop]
331            self.assertRaises(LocustError, l.run) # make sure HeyAnException isn't raised
332            l.task_set._task_queue = [l.task_set.will_error, l.task_set.will_stop]
333            self.assertRaises(LocustError, l.run) # make sure HeyAnException isn't raised
334        self.assertEqual(2, len(mocked.method_calls))
335
336        # make sure exception was stored
337        self.assertEqual(1, len(runner.exceptions))
338        hash_key, exception = runner.exceptions.popitem()
339        self.assertTrue("traceback" in exception)
340        self.assertTrue("HeyAnException" in exception["traceback"])
341        self.assertEqual(2, exception["count"])
342
343
344class TestMessageSerializing(unittest.TestCase):
345    def test_message_serialize(self):
346        msg = Message("client_ready", None, "my_id")
347        rebuilt = Message.unserialize(msg.serialize())
348        self.assertEqual(msg.type, rebuilt.type)
349        self.assertEqual(msg.data, rebuilt.data)
350        self.assertEqual(msg.node_id, rebuilt.node_id)
351