1# -*- coding: utf-8 -*-
2import csv
3import json
4import sys
5import traceback
6
7import gevent
8import requests
9from gevent import pywsgi
10
11from locust import events, runners, stats, web
12from locust.main import parse_options
13from locust.runners import LocustRunner
14from six.moves import StringIO
15
16from .testcases import LocustTestCase
17
18
19class TestWebUI(LocustTestCase):
20    def setUp(self):
21        super(TestWebUI, self).setUp()
22
23        stats.global_stats.clear_all()
24        parser = parse_options()[0]
25        options = parser.parse_args([])[0]
26        runners.locust_runner = LocustRunner([], options)
27
28        web.request_stats.clear_cache()
29
30        self._web_ui_server = pywsgi.WSGIServer(('127.0.0.1', 0), web.app, log=None)
31        gevent.spawn(lambda: self._web_ui_server.serve_forever())
32        gevent.sleep(0.01)
33        self.web_port = self._web_ui_server.server_port
34
35    def tearDown(self):
36        super(TestWebUI, self).tearDown()
37        self._web_ui_server.stop()
38
39    def test_index(self):
40        self.assertEqual(200, requests.get("http://127.0.0.1:%i/" % self.web_port).status_code)
41
42    def test_stats_no_data(self):
43        self.assertEqual(200, requests.get("http://127.0.0.1:%i/stats/requests" % self.web_port).status_code)
44
45    def test_stats(self):
46        stats.global_stats.log_request("GET", "/test", 120, 5612)
47        response = requests.get("http://127.0.0.1:%i/stats/requests" % self.web_port)
48        self.assertEqual(200, response.status_code)
49
50        data = json.loads(response.text)
51        self.assertEqual(2, len(data["stats"])) # one entry plus Total
52        self.assertEqual("/test", data["stats"][0]["name"])
53        self.assertEqual("GET", data["stats"][0]["method"])
54        self.assertEqual(120, data["stats"][0]["avg_response_time"])
55
56        self.assertEqual("Total", data["stats"][1]["name"])
57        self.assertEqual(1, data["stats"][1]["num_requests"])
58        self.assertEqual(120, data["stats"][1]["avg_response_time"])
59
60    def test_stats_cache(self):
61        stats.global_stats.log_request("GET", "/test", 120, 5612)
62        response = requests.get("http://127.0.0.1:%i/stats/requests" % self.web_port)
63        self.assertEqual(200, response.status_code)
64        data = json.loads(response.text)
65        self.assertEqual(2, len(data["stats"])) # one entry plus Total
66
67        # add another entry
68        stats.global_stats.log_request("GET", "/test2", 120, 5612)
69        data = json.loads(requests.get("http://127.0.0.1:%i/stats/requests" % self.web_port).text)
70        self.assertEqual(2, len(data["stats"])) # old value should be cached now
71
72        web.request_stats.clear_cache()
73
74        data = json.loads(requests.get("http://127.0.0.1:%i/stats/requests" % self.web_port).text)
75        self.assertEqual(3, len(data["stats"])) # this should no longer be cached
76
77    def test_request_stats_csv(self):
78        stats.global_stats.log_request("GET", "/test2", 120, 5612)
79        response = requests.get("http://127.0.0.1:%i/stats/requests/csv" % self.web_port)
80        self.assertEqual(200, response.status_code)
81
82    def test_distribution_stats_csv(self):
83        for i in range(19):
84            stats.global_stats.log_request("GET", "/test2", 400, 5612)
85        stats.global_stats.log_request("GET", "/test2", 1200, 5612)
86        response = requests.get("http://127.0.0.1:%i/stats/distribution/csv" % self.web_port)
87        self.assertEqual(200, response.status_code)
88        rows = response.text.split("\n")
89        # check that /test2 is present in stats
90        row = rows[len(rows)-2].split(",")
91        self.assertEqual('"GET /test2"', row[0])
92        # check total row
93        total_cols = rows[len(rows)-1].split(",")
94        self.assertEqual('"Total"', total_cols[0])
95        # verify that the 95%, 98%, 99% and 100% percentiles are 1200
96        for value in total_cols[-4:]:
97            self.assertEqual('1200', value)
98
99    def test_request_stats_with_errors(self):
100        stats.global_stats.log_error("GET", "/", Exception("Error1337"))
101        response = requests.get("http://127.0.0.1:%i/stats/requests" % self.web_port)
102        self.assertEqual(200, response.status_code)
103        self.assertIn("Error1337", response.text)
104
105    def test_exceptions(self):
106        try:
107            raise Exception(u"A cool test exception")
108        except Exception as e:
109            tb = sys.exc_info()[2]
110            runners.locust_runner.log_exception("local", str(e), "".join(traceback.format_tb(tb)))
111            runners.locust_runner.log_exception("local", str(e), "".join(traceback.format_tb(tb)))
112
113        response = requests.get("http://127.0.0.1:%i/exceptions" % self.web_port)
114        self.assertEqual(200, response.status_code)
115        self.assertIn("A cool test exception", response.text)
116
117        response = requests.get("http://127.0.0.1:%i/stats/requests" % self.web_port)
118        self.assertEqual(200, response.status_code)
119
120    def test_exceptions_csv(self):
121        try:
122            raise Exception("Test exception")
123        except Exception as e:
124            tb = sys.exc_info()[2]
125            runners.locust_runner.log_exception("local", str(e), "".join(traceback.format_tb(tb)))
126            runners.locust_runner.log_exception("local", str(e), "".join(traceback.format_tb(tb)))
127
128        response = requests.get("http://127.0.0.1:%i/exceptions/csv" % self.web_port)
129        self.assertEqual(200, response.status_code)
130
131        reader = csv.reader(StringIO(response.text))
132        rows = []
133        for row in reader:
134            rows.append(row)
135
136        self.assertEqual(2, len(rows))
137        self.assertEqual("Test exception", rows[1][1])
138        self.assertEqual(2, int(rows[1][0]), "Exception count should be 2")
139