1"""
2Tests for basic IPC stuff via Component class
3"""
4
5from __future__ import absolute_import, print_function, unicode_literals
6
7import logging
8import os
9import unittest
10from io import BytesIO
11
12import simplejson as json
13
14try:
15    from unittest.mock import patch
16except ImportError:
17    from mock import patch
18
19from pystorm import Component
20from pystorm.exceptions import StormWentAwayError
21
22
23log = logging.getLogger(__name__)
24
25
26class ComponentTests(unittest.TestCase):
27    conf = {"topology.message.timeout.secs": 3,
28            "topology.tick.tuple.freq.secs": 1,
29            "topology.debug": True,
30            "topology.name": "foo"}
31    context = {
32        "task->component": {
33            "1": "example-spout",
34            "2": "__acker",
35            "3": "example-bolt1",
36            "4": "example-bolt2"
37        },
38        "taskid": 3,
39        # Everything below this line is only available in Storm 0.11.0+
40        "componentid": "example-bolt1",
41        "stream->target->grouping": {
42            "default": {
43                "example-bolt2": {
44                    "type": "SHUFFLE"
45                }
46            }
47        },
48        "streams": ["default"],
49        "stream->outputfields": {"default": ["word"]},
50        "source->stream->grouping": {
51            "example-spout": {
52                "default": {
53                    "type": "FIELDS",
54                    "fields": ["word"]
55                }
56            }
57        },
58        "source->stream->fields": {
59            "example-spout": {
60                "default": ["sentence", "word", "number"]
61            }
62        }
63    }
64
65    def test_read_handshake(self):
66        handshake_dict = {"conf": self.conf,
67                          "pidDir": ".",
68                          "context": self.context}
69        pid_dir = handshake_dict['pidDir']
70        expected_conf = handshake_dict['conf']
71        expected_context = handshake_dict['context']
72        inputs = ["{}\n".format(json.dumps(handshake_dict)),
73                  "end\n"]
74        component = Component(input_stream=BytesIO(''.join(inputs).encode('utf-8')),
75                              output_stream=BytesIO())
76        given_conf, given_context = component.read_handshake()
77        pid_path = os.path.join(pid_dir, str(component.pid))
78        self.assertTrue(os.path.exists(pid_path))
79        os.remove(pid_path)
80        self.assertEqual(given_conf, expected_conf)
81        self.assertEqual(given_context, expected_context)
82        self.assertEqual(component.serializer.serialize_dict({"pid": component.pid}).encode('utf-8'),
83                         component.serializer.output_stream.buffer.getvalue())
84
85    def test_setup_component(self):
86        conf = self.conf
87        component = Component(input_stream=BytesIO(),
88                              output_stream=BytesIO())
89        component._setup_component(conf, self.context)
90        self.assertEqual(component.topology_name, conf['topology.name'])
91        self.assertEqual(component.task_id, self.context['taskid'])
92        self.assertEqual(component.component_name,
93                         self.context['task->component'][str(self.context['taskid'])])
94        self.assertEqual(component.storm_conf, conf)
95        self.assertEqual(component.context, self.context)
96
97    def test_read_message(self):
98        inputs = [# Task IDs
99                  '[12, 22, 24]\n', 'end\n',
100                  # Incoming Tuple for bolt
101                  ('{ "id": "-6955786537413359385", "comp": "1", "stream": "1"'
102                   ', "task": 9, "tuple": ["snow white and the seven dwarfs", '
103                   '"field2", 3]}\n'), 'end\n',
104                  # next command for spout
105                  '{"command": "next"}\n', 'end\n',
106                  # empty message, which should trigger sys.exit (end ignored)
107                  '', '']
108        outputs = [json.loads(msg) for msg in inputs[::2] if msg]
109        outputs.append('')
110        component = Component(input_stream=BytesIO(''.join(inputs).encode('utf-8')),
111                              output_stream=BytesIO())
112        for output in outputs:
113            log.info('Checking msg for %r', output)
114            if output:
115                msg = component.read_message()
116                self.assertEqual(output, msg)
117            else:
118                with self.assertRaises(StormWentAwayError):
119                    component.read_message()
120
121    def test_read_message_unicode(self):
122        inputs = [# Task IDs
123                  '[12, 22, 24]\n', 'end\n',
124                  # Incoming Tuple for bolt
125                  ('{ "id": "-6955786537413359385", "comp": "1", "stream": "1"'
126                   ', "task": 9, "tuple": ["snow white \uFFE6 the seven dwarfs"'
127                   ', "field2", 3]}\n'), 'end\n',
128                  # next command for spout
129                  '{"command": "next"}\n', 'end\n',
130                  # empty message, which should trigger sys.exit (end ignored)
131                  '', '']
132        outputs = [json.loads(msg) for msg in inputs[::2] if msg]
133        outputs.append('')
134        component = Component(input_stream=BytesIO(''.join(inputs).encode('utf8')),
135                              output_stream=BytesIO())
136        for output in outputs:
137            log.info('Checking msg for %r', output)
138            if output:
139                msg = component.read_message()
140                self.assertEqual(output, msg)
141            else:
142                with self.assertRaises(StormWentAwayError):
143                    component.read_message()
144
145    def test_read_split_message(self):
146        # Make sure we can read something that's broken up into many "lines"
147        inputs = ['{ "id": "-6955786537413359385", ',
148                  '"comp": "1", "stream": "1"\n',
149                  '\n',
150                  ', "task": 9, "tuple": ["snow white and the seven dwarfs", ',
151                  '"field2", 3]}\n',
152                  'end\n']
153        output = json.loads(''.join(inputs[:-1]))
154
155        component = Component(input_stream=BytesIO(''.join(inputs).encode('utf-8')),
156                              output_stream=BytesIO())
157        msg = component.read_message()
158        self.assertEqual(output, msg)
159
160    def test_read_command(self):
161        # Check that we properly queue task IDs and return only commands
162        inputs = [# Task IDs
163                  '[12, 22, 24]\n', 'end\n',
164                  # Incoming Tuple for bolt
165                  ('{ "id": "-6955786537413359385", "comp": "1", "stream": "1"'
166                   ', "task": 9, "tuple": ["snow white and the seven dwarfs", '
167                   '"field2", 3]}\n'), 'end\n',
168                  # next command for spout
169                  '{"command": "next"}\n', 'end\n']
170        outputs = [json.loads(msg) for msg in inputs[::2]]
171        component = Component(input_stream=BytesIO(''.join(inputs).encode('utf-8')),
172                              output_stream=BytesIO())
173
174        # Skip first output, because it's a task ID, and won't be returned by
175        # read_command
176        for output in outputs[1:]:
177            log.info('Checking msg for %r', output)
178            msg = component.read_command()
179            self.assertEqual(output, msg)
180        self.assertEqual(component._pending_task_ids.pop(), outputs[0])
181
182    def test_read_task_ids(self):
183        # Check that we properly queue commands and return only task IDs
184        inputs = [# Task IDs
185                  '[4, 8, 15]\n', 'end\n',
186                  # Incoming Tuple for bolt
187                  ('{ "id": "-6955786537413359385", "comp": "1", "stream": "1"'
188                   ', "task": 9, "tuple": ["snow white and the seven dwarfs", '
189                   '"field2", 3]}\n'), 'end\n',
190                  # next command for spout
191                  '{"command": "next"}\n', 'end\n',
192                  # Task IDs
193                  '[16, 23, 42]\n', 'end\n']
194        outputs = [json.loads(msg) for msg in inputs[::2]]
195        component = Component(input_stream=BytesIO(''.join(inputs).encode('utf-8')),
196                              output_stream=BytesIO())
197
198        # Skip middle outputs, because they're commands and won't be returned by
199        # read_task_ids
200        for output in (outputs[0], outputs[-1]):
201            log.info('Checking msg for %r', output)
202            msg = component.read_task_ids()
203            self.assertEqual(output, msg)
204        for output in outputs[1:-1]:
205            self.assertEqual(component._pending_commands.popleft(), output)
206
207    def test_send_message(self):
208        component = Component(input_stream=BytesIO(), output_stream=BytesIO())
209        inputs = [{"command": "emit", "id": 4, "stream": "", "task": 9,
210                   "tuple": ["field1", 2, 3]},
211                  {"command": "log", "msg": "I am a robot monkey."},
212                  {"command": "next"},
213                  {"command": "sync"}]
214        for cmd in inputs:
215            component.serializer.output_stream.close()
216            component.serializer.output_stream = component.serializer._wrap_stream(BytesIO())
217            component.send_message(cmd)
218            self.assertEqual(component.serializer.serialize_dict(cmd).encode('utf-8'),
219                             component.serializer.output_stream.buffer.getvalue())
220
221        # Check that we properly skip over invalid input
222        self.assertIsNone(component.send_message(['foo', 'bar']))
223
224    def test_send_message_unicode(self):
225        component = Component(input_stream=BytesIO(), output_stream=BytesIO())
226        inputs = [{"command": "emit", "id": 4, "stream": "", "task": 9,
227                   "tuple": ["field\uFFE6", 2, 3]},
228                  {"command": "log", "msg": "I am a robot monkey."},
229                  {"command": "next"},
230                  {"command": "sync"}]
231        for cmd in inputs:
232            component.serializer.output_stream.close()
233            component.serializer.output_stream = component.serializer._wrap_stream(BytesIO())
234            component.send_message(cmd)
235            self.assertEqual(component.serializer.serialize_dict(cmd).encode('utf-8'),
236                             component.serializer.output_stream.buffer.getvalue())
237
238        # Check that we properly skip over invalid input
239        self.assertIsNone(component.send_message(['foo', 'bar']))
240
241    @patch.object(Component, 'send_message', autospec=True)
242    def test_log(self, send_message_mock):
243        component = Component(input_stream=BytesIO(), output_stream=BytesIO())
244        inputs = [("I am a robot monkey.", None, 2),
245                  ("I am a monkey who learned to talk.", 'warning', 3)]
246        for msg, level, storm_level in inputs:
247            component.serializer.output_stream.close()
248            component.serializer.output_stream = component.serializer._wrap_stream(BytesIO())
249            component.log(msg, level=level)
250            send_message_mock.assert_called_with(component, {'command': 'log',
251                                                             'msg': msg,
252                                                             'level': storm_level})
253
254    def test_exit_on_exception_true(self):
255        handshake_dict = {"conf": self.conf,
256                          "pidDir": ".",
257                          "context": self.context}
258        inputs = ["{}\n".format(json.dumps(handshake_dict)),
259                  "end\n"]
260        component = Component(input_stream=BytesIO(''.join(inputs).encode('utf-8')),
261                              output_stream=BytesIO())
262        component.exit_on_exception = True
263        with self.assertRaises(SystemExit) as raises_fixture:
264            component.run()
265        assert raises_fixture.exception.code == 1
266
267    @patch.object(Component, '_run', autospec=True)
268    def test_exit_on_exception_false(self, _run_mock):
269        # Make sure _run raises an exception
270        def raiser(self): # lambdas can't raise
271            raise StormWentAwayError if _run_mock.called else NotImplementedError
272        _run_mock.side_effect = raiser
273
274        handshake_dict = {"conf": self.conf,
275                          "pidDir": ".",
276                          "context": self.context}
277        inputs = ["{}\n".format(json.dumps(handshake_dict)),
278                  "end\n"]
279        component = Component(input_stream=BytesIO(''.join(inputs).encode('utf-8')),
280                              output_stream=BytesIO())
281        component.exit_on_exception = False
282        with self.assertRaises(SystemExit) as raises_fixture:
283            component.run()
284        assert raises_fixture.exception.code == 2
285
286    @patch.object(Component, '_handle_run_exception', autospec=True)
287    @patch('pystorm.component.log', autospec=True)
288    def test_nested_exception(self, log_mock, _handle_run_exception_mock):
289        # Make sure self._handle_run_exception raises an exception
290        def raiser(self): # lambdas can't raise
291            raise Exception('Oops')
292
293        handshake_dict = {"conf": self.conf,
294                          "pidDir": ".",
295                          "context": self.context}
296        inputs = ["{}\n".format(json.dumps(handshake_dict)),
297                  "end\n"]
298        component = Component(input_stream=BytesIO(''.join(inputs).encode('utf-8')),
299                              output_stream=BytesIO())
300        component.exit_on_exception = True
301        _handle_run_exception_mock.side_effect = raiser
302
303        with self.assertRaises(SystemExit) as raises_fixture:
304            component.run()
305        assert log_mock.error.call_count == 2
306        assert raises_fixture.exception.code == 1
307
308
309    @patch.object(Component, '_handle_run_exception', autospec=True)
310    @patch('pystorm.component.log', autospec=True)
311    def test_nested_went_away_exception(self, log_mock, _handle_run_exception_mock):
312        # Make sure self._handle_run_exception raises an exception
313        def raiser(*args): # lambdas can't raise
314            raise StormWentAwayError
315
316        handshake_dict = {"conf": self.conf,
317                          "pidDir": ".",
318                          "context": self.context}
319        inputs = ["{}\n".format(json.dumps(handshake_dict)),
320                  "end\n"]
321        component = Component(input_stream=BytesIO(''.join(inputs).encode('utf-8')),
322                              output_stream=BytesIO())
323        component.exit_on_exception = True
324        _handle_run_exception_mock.side_effect = raiser
325
326        with self.assertRaises(SystemExit) as raises_fixture:
327            component.run()
328        assert log_mock.error.call_count == 1
329        assert log_mock.info.call_count == 1
330        assert raises_fixture.exception.code == 2
331
332
333if __name__ == '__main__':
334    unittest.main()
335