1import unittest
2
3import grpc
4
5from _service import Service, ErroringHandler, ExceptionErroringHandler
6from _tracer import Tracer, SpanRelationship
7from grpc_opentracing import open_tracing_client_interceptor, open_tracing_server_interceptor
8import opentracing
9
10
11class OpenTracingTest(unittest.TestCase):
12    """Test that tracers create the correct spans when RPC calls are invoked."""
13
14    def setUp(self):
15        self._tracer = Tracer()
16        self._service = Service([open_tracing_client_interceptor(self._tracer)],
17                                [open_tracing_server_interceptor(self._tracer)])
18
19    def testUnaryUnaryOpenTracing(self):
20        multi_callable = self._service.unary_unary_multi_callable
21        request = b'\x01'
22        expected_response = self._service.handler.handle_unary_unary(request,
23                                                                     None)
24        response = multi_callable(request)
25
26        self.assertEqual(response, expected_response)
27
28        span0 = self._tracer.get_span(0)
29        self.assertIsNotNone(span0)
30        self.assertEqual(span0.get_tag('span.kind'), 'client')
31
32        span1 = self._tracer.get_span(1)
33        self.assertIsNotNone(span1)
34        self.assertEqual(span1.get_tag('span.kind'), 'server')
35
36        self.assertEqual(
37            self._tracer.get_relationship(0, 1),
38            opentracing.ReferenceType.CHILD_OF)
39
40    def testUnaryUnaryOpenTracingFuture(self):
41        multi_callable = self._service.unary_unary_multi_callable
42        request = b'\x01'
43        expected_response = self._service.handler.handle_unary_unary(request,
44                                                                     None)
45        future = multi_callable.future(request)
46        response = future.result()
47
48        self.assertEqual(response, expected_response)
49
50        span0 = self._tracer.get_span(0)
51        self.assertIsNotNone(span0)
52        self.assertEqual(span0.get_tag('span.kind'), 'client')
53
54        span1 = self._tracer.get_span(1)
55        self.assertIsNotNone(span1)
56        self.assertEqual(span1.get_tag('span.kind'), 'server')
57
58        self.assertEqual(
59            self._tracer.get_relationship(0, 1),
60            opentracing.ReferenceType.CHILD_OF)
61
62    def testUnaryUnaryOpenTracingWithCall(self):
63        multi_callable = self._service.unary_unary_multi_callable
64        request = b'\x01'
65        expected_response = self._service.handler.handle_unary_unary(request,
66                                                                     None)
67        response, call = multi_callable.with_call(request)
68
69        self.assertEqual(response, expected_response)
70        self.assertIs(grpc.StatusCode.OK, call.code())
71
72        span0 = self._tracer.get_span(0)
73        self.assertIsNotNone(span0)
74        self.assertEqual(span0.get_tag('span.kind'), 'client')
75
76        span1 = self._tracer.get_span(1)
77        self.assertIsNotNone(span1)
78        self.assertEqual(span1.get_tag('span.kind'), 'server')
79
80        self.assertEqual(
81            self._tracer.get_relationship(0, 1),
82            opentracing.ReferenceType.CHILD_OF)
83
84    def testUnaryStreamOpenTracing(self):
85        multi_callable = self._service.unary_stream_multi_callable
86        request = b'\x01'
87        expected_response = self._service.handler.handle_unary_stream(request,
88                                                                      None)
89        response = multi_callable(request)
90
91        self.assertEqual(list(response), list(expected_response))
92
93        span0 = self._tracer.get_span(0)
94        self.assertIsNotNone(span0)
95        self.assertEqual(span0.get_tag('span.kind'), 'client')
96
97        span1 = self._tracer.get_span(1)
98        self.assertIsNotNone(span1)
99        self.assertEqual(span1.get_tag('span.kind'), 'server')
100
101        self.assertEqual(
102            self._tracer.get_relationship(0, 1),
103            opentracing.ReferenceType.CHILD_OF)
104
105    def testStreamUnaryOpenTracing(self):
106        multi_callable = self._service.stream_unary_multi_callable
107        requests = [b'\x01', b'\x02']
108        expected_response = self._service.handler.handle_stream_unary(
109            iter(requests), None)
110        response = multi_callable(iter(requests))
111
112        self.assertEqual(response, expected_response)
113
114        span0 = self._tracer.get_span(0)
115        self.assertIsNotNone(span0)
116        self.assertEqual(span0.get_tag('span.kind'), 'client')
117
118        span1 = self._tracer.get_span(1)
119        self.assertIsNotNone(span1)
120        self.assertEqual(span1.get_tag('span.kind'), 'server')
121
122        self.assertEqual(
123            self._tracer.get_relationship(0, 1),
124            opentracing.ReferenceType.CHILD_OF)
125
126    def testStreamUnaryOpenTracingWithCall(self):
127        multi_callable = self._service.stream_unary_multi_callable
128        requests = [b'\x01', b'\x02']
129        expected_response = self._service.handler.handle_stream_unary(
130            iter(requests), None)
131        response, call = multi_callable.with_call(iter(requests))
132
133        self.assertEqual(response, expected_response)
134        self.assertIs(grpc.StatusCode.OK, call.code())
135
136        span0 = self._tracer.get_span(0)
137        self.assertIsNotNone(span0)
138        self.assertEqual(span0.get_tag('span.kind'), 'client')
139
140        span1 = self._tracer.get_span(1)
141        self.assertIsNotNone(span1)
142        self.assertEqual(span1.get_tag('span.kind'), 'server')
143
144        self.assertEqual(
145            self._tracer.get_relationship(0, 1),
146            opentracing.ReferenceType.CHILD_OF)
147
148    def testStreamUnaryOpenTracingFuture(self):
149        multi_callable = self._service.stream_unary_multi_callable
150        requests = [b'\x01', b'\x02']
151        expected_response = self._service.handler.handle_stream_unary(
152            iter(requests), None)
153        result = multi_callable.future(iter(requests))
154        response = result.result()
155
156        self.assertEqual(response, expected_response)
157
158        span0 = self._tracer.get_span(0)
159        self.assertIsNotNone(span0)
160        self.assertEqual(span0.get_tag('span.kind'), 'client')
161
162        span1 = self._tracer.get_span(1)
163        self.assertIsNotNone(span1)
164        self.assertEqual(span1.get_tag('span.kind'), 'server')
165
166        self.assertEqual(
167            self._tracer.get_relationship(0, 1),
168            opentracing.ReferenceType.CHILD_OF)
169
170    def testStreamStreamOpenTracing(self):
171        multi_callable = self._service.stream_stream_multi_callable
172        requests = [b'\x01', b'\x02']
173        expected_response = self._service.handler.handle_stream_stream(
174            iter(requests), None)
175        response = multi_callable(iter(requests))
176
177        self.assertEqual(list(response), list(expected_response))
178
179        span0 = self._tracer.get_span(0)
180        self.assertIsNotNone(span0)
181        self.assertEqual(span0.get_tag('span.kind'), 'client')
182
183        span1 = self._tracer.get_span(1)
184        self.assertIsNotNone(span1)
185        self.assertEqual(span1.get_tag('span.kind'), 'server')
186
187        self.assertEqual(
188            self._tracer.get_relationship(0, 1),
189            opentracing.ReferenceType.CHILD_OF)
190
191
192class OpenTracingInteroperabilityClientTest(unittest.TestCase):
193    """Test that a traced client can interoperate with a non-trace server."""
194
195    def setUp(self):
196        self._tracer = Tracer()
197        self._service = Service([open_tracing_client_interceptor(self._tracer)],
198                                [])
199
200    def testUnaryUnaryOpenTracing(self):
201        multi_callable = self._service.unary_unary_multi_callable
202        request = b'\x01'
203        expected_response = self._service.handler.handle_unary_unary(request,
204                                                                     None)
205        response = multi_callable(request)
206
207        self.assertEqual(response, expected_response)
208
209        span0 = self._tracer.get_span(0)
210        self.assertIsNotNone(span0)
211        self.assertEqual(span0.get_tag('span.kind'), 'client')
212
213        span1 = self._tracer.get_span(1)
214        self.assertIsNone(span1)
215
216    def testUnaryUnaryOpenTracingWithCall(self):
217        multi_callable = self._service.unary_unary_multi_callable
218        request = b'\x01'
219        expected_response = self._service.handler.handle_unary_unary(request,
220                                                                     None)
221        response, call = multi_callable.with_call(request)
222
223        self.assertEqual(response, expected_response)
224        self.assertIs(grpc.StatusCode.OK, call.code())
225
226        span0 = self._tracer.get_span(0)
227        self.assertIsNotNone(span0)
228        self.assertEqual(span0.get_tag('span.kind'), 'client')
229
230        span1 = self._tracer.get_span(1)
231        self.assertIsNone(span1)
232
233    def testUnaryStreamOpenTracing(self):
234        multi_callable = self._service.unary_stream_multi_callable
235        request = b'\x01'
236        expected_response = self._service.handler.handle_unary_stream(request,
237                                                                      None)
238        response = multi_callable(request)
239
240        self.assertEqual(list(response), list(expected_response))
241
242        span0 = self._tracer.get_span(0)
243        self.assertIsNotNone(span0)
244        self.assertEqual(span0.get_tag('span.kind'), 'client')
245
246        span1 = self._tracer.get_span(1)
247        self.assertIsNone(span1)
248
249    def testStreamUnaryOpenTracing(self):
250        multi_callable = self._service.stream_unary_multi_callable
251        requests = [b'\x01', b'\x02']
252        expected_response = self._service.handler.handle_stream_unary(
253            iter(requests), None)
254        response = multi_callable(iter(requests))
255
256        self.assertEqual(response, expected_response)
257
258        span0 = self._tracer.get_span(0)
259        self.assertIsNotNone(span0)
260        self.assertEqual(span0.get_tag('span.kind'), 'client')
261
262        span1 = self._tracer.get_span(1)
263        self.assertIsNone(span1)
264
265    def testStreamUnaryOpenTracingWithCall(self):
266        multi_callable = self._service.stream_unary_multi_callable
267        requests = [b'\x01', b'\x02']
268        expected_response = self._service.handler.handle_stream_unary(
269            iter(requests), None)
270        response, call = multi_callable.with_call(iter(requests))
271
272        self.assertEqual(response, expected_response)
273        self.assertIs(grpc.StatusCode.OK, call.code())
274
275        span0 = self._tracer.get_span(0)
276        self.assertIsNotNone(span0)
277        self.assertEqual(span0.get_tag('span.kind'), 'client')
278
279        span1 = self._tracer.get_span(1)
280        self.assertIsNone(span1)
281
282    def testStreamStreamOpenTracing(self):
283        multi_callable = self._service.stream_stream_multi_callable
284        requests = [b'\x01', b'\x02']
285        expected_response = self._service.handler.handle_stream_stream(
286            iter(requests), None)
287        response = multi_callable(iter(requests))
288
289        self.assertEqual(list(response), list(expected_response))
290
291        span0 = self._tracer.get_span(0)
292        self.assertIsNotNone(span0)
293        self.assertEqual(span0.get_tag('span.kind'), 'client')
294
295        span1 = self._tracer.get_span(1)
296        self.assertIsNone(span1)
297
298
299class OpenTracingMetadataTest(unittest.TestCase):
300    """Test that open-tracing doesn't interfere with passing metadata through the
301    RPC.
302  """
303
304    def setUp(self):
305        self._tracer = Tracer()
306        self._service = Service([open_tracing_client_interceptor(self._tracer)],
307                                [open_tracing_server_interceptor(self._tracer)])
308
309    def testInvocationMetadata(self):
310        multi_callable = self._service.unary_unary_multi_callable
311        request = b'\x01'
312        multi_callable(request, None, (('abc', '123'),))
313        self.assertIn(('abc', '123'), self._service.handler.invocation_metadata)
314
315    def testTrailingMetadata(self):
316        self._service.handler.trailing_metadata = (('abc', '123'),)
317        multi_callable = self._service.unary_unary_multi_callable
318        request = b'\x01'
319        future = multi_callable.future(request, None, (('abc', '123'),))
320        self.assertIn(('abc', '123'), future.trailing_metadata())
321
322
323class OpenTracingInteroperabilityServerTest(unittest.TestCase):
324    """Test that a traced server can interoperate with a non-trace client."""
325
326    def setUp(self):
327        self._tracer = Tracer()
328        self._service = Service([],
329                                [open_tracing_server_interceptor(self._tracer)])
330
331    def testUnaryUnaryOpenTracing(self):
332        multi_callable = self._service.unary_unary_multi_callable
333        request = b'\x01'
334        expected_response = self._service.handler.handle_unary_unary(request,
335                                                                     None)
336        response = multi_callable(request)
337
338        self.assertEqual(response, expected_response)
339
340        span0 = self._tracer.get_span(0)
341        self.assertIsNotNone(span0)
342        self.assertEqual(span0.get_tag('span.kind'), 'server')
343
344        span1 = self._tracer.get_span(1)
345        self.assertIsNone(span1)
346
347    def testUnaryUnaryOpenTracingWithCall(self):
348        multi_callable = self._service.unary_unary_multi_callable
349        request = b'\x01'
350        expected_response = self._service.handler.handle_unary_unary(request,
351                                                                     None)
352        response, call = multi_callable.with_call(request)
353
354        self.assertEqual(response, expected_response)
355        self.assertIs(grpc.StatusCode.OK, call.code())
356
357        span0 = self._tracer.get_span(0)
358        self.assertIsNotNone(span0)
359        self.assertEqual(span0.get_tag('span.kind'), 'server')
360
361        span1 = self._tracer.get_span(1)
362        self.assertIsNone(span1)
363
364    def testUnaryStreamOpenTracing(self):
365        multi_callable = self._service.unary_stream_multi_callable
366        request = b'\x01'
367        expected_response = self._service.handler.handle_unary_stream(request,
368                                                                      None)
369        response = multi_callable(request)
370
371        self.assertEqual(list(response), list(expected_response))
372
373        span0 = self._tracer.get_span(0)
374        self.assertIsNotNone(span0)
375        self.assertEqual(span0.get_tag('span.kind'), 'server')
376
377        span1 = self._tracer.get_span(1)
378        self.assertIsNone(span1)
379
380    def testStreamUnaryOpenTracing(self):
381        multi_callable = self._service.stream_unary_multi_callable
382        requests = [b'\x01', b'\x02']
383        expected_response = self._service.handler.handle_stream_unary(
384            iter(requests), None)
385        response = multi_callable(iter(requests))
386
387        self.assertEqual(response, expected_response)
388
389        span0 = self._tracer.get_span(0)
390        self.assertIsNotNone(span0)
391        self.assertEqual(span0.get_tag('span.kind'), 'server')
392
393        span1 = self._tracer.get_span(1)
394        self.assertIsNone(span1)
395
396    def testStreamUnaryOpenTracingWithCall(self):
397        multi_callable = self._service.stream_unary_multi_callable
398        requests = [b'\x01', b'\x02']
399        expected_response = self._service.handler.handle_stream_unary(
400            iter(requests), None)
401        response, call = multi_callable.with_call(iter(requests))
402
403        self.assertEqual(response, expected_response)
404        self.assertIs(grpc.StatusCode.OK, call.code())
405
406        span0 = self._tracer.get_span(0)
407        self.assertIsNotNone(span0)
408        self.assertEqual(span0.get_tag('span.kind'), 'server')
409
410        span1 = self._tracer.get_span(1)
411        self.assertIsNone(span1)
412
413    def testStreamStreamOpenTracing(self):
414        multi_callable = self._service.stream_stream_multi_callable
415        requests = [b'\x01', b'\x02']
416        expected_response = self._service.handler.handle_stream_stream(
417            iter(requests), None)
418        response = multi_callable(iter(requests))
419
420        self.assertEqual(list(response), list(expected_response))
421
422        span0 = self._tracer.get_span(0)
423        self.assertIsNotNone(span0)
424        self.assertEqual(span0.get_tag('span.kind'), 'server')
425
426        span1 = self._tracer.get_span(1)
427        self.assertIsNone(span1)
428
429
430class OpenTracingErroringTest(unittest.TestCase):
431    """Test that tracer spans set the error tag when erroring RPC are invoked."""
432
433    def setUp(self):
434        self._tracer = Tracer()
435        self._service = Service([open_tracing_client_interceptor(self._tracer)],
436                                [open_tracing_server_interceptor(self._tracer)],
437                                ErroringHandler())
438
439    def testUnaryUnaryOpenTracing(self):
440        multi_callable = self._service.unary_unary_multi_callable
441        request = b'\x01'
442        self.assertRaises(grpc.RpcError, multi_callable, request)
443
444        span0 = self._tracer.get_span(0)
445        self.assertIsNotNone(span0)
446        self.assertTrue(span0.get_tag('error'))
447
448        span1 = self._tracer.get_span(1)
449        self.assertIsNotNone(span1)
450        self.assertTrue(span1.get_tag('error'))
451
452    def testUnaryUnaryOpenTracingWithCall(self):
453        multi_callable = self._service.unary_unary_multi_callable
454        request = b'\x01'
455        self.assertRaises(grpc.RpcError, multi_callable.with_call, request)
456
457        span0 = self._tracer.get_span(0)
458        self.assertIsNotNone(span0)
459        self.assertTrue(span0.get_tag('error'))
460
461        span1 = self._tracer.get_span(1)
462        self.assertIsNotNone(span1)
463        self.assertTrue(span1.get_tag('error'))
464
465    def testUnaryStreamOpenTracing(self):
466        multi_callable = self._service.unary_stream_multi_callable
467        request = b'\x01'
468        response = multi_callable(request)
469        self.assertRaises(grpc.RpcError, list, response)
470
471        span0 = self._tracer.get_span(0)
472        self.assertIsNotNone(span0)
473        self.assertTrue(span0.get_tag('error'))
474
475        span1 = self._tracer.get_span(1)
476        self.assertIsNotNone(span1)
477        self.assertTrue(span1.get_tag('error'))
478
479    def testStreamUnaryOpenTracing(self):
480        multi_callable = self._service.stream_unary_multi_callable
481        requests = [b'\x01', b'\x02']
482        self.assertRaises(grpc.RpcError, multi_callable, iter(requests))
483
484        span0 = self._tracer.get_span(0)
485        self.assertIsNotNone(span0)
486        self.assertTrue(span0.get_tag('error'))
487
488        span1 = self._tracer.get_span(1)
489        self.assertIsNotNone(span1)
490        self.assertTrue(span1.get_tag('error'))
491
492    def testStreamUnaryOpenTracingWithCall(self):
493        multi_callable = self._service.stream_unary_multi_callable
494        requests = [b'\x01', b'\x02']
495        self.assertRaises(grpc.RpcError, multi_callable.with_call,
496                          iter(requests))
497
498        span0 = self._tracer.get_span(0)
499        self.assertIsNotNone(span0)
500        self.assertTrue(span0.get_tag('error'))
501
502        span1 = self._tracer.get_span(1)
503        self.assertIsNotNone(span1)
504        self.assertTrue(span1.get_tag('error'))
505
506    def testStreamStreamOpenTracing(self):
507        multi_callable = self._service.stream_stream_multi_callable
508        requests = [b'\x01', b'\x02']
509        response = multi_callable(iter(requests))
510        self.assertRaises(grpc.RpcError, list, response)
511
512        span0 = self._tracer.get_span(0)
513        self.assertIsNotNone(span0)
514        self.assertTrue(span0.get_tag('error'))
515
516        span1 = self._tracer.get_span(1)
517        self.assertIsNotNone(span1)
518        self.assertTrue(span1.get_tag('error'))
519
520
521class OpenTracingExceptionErroringTest(unittest.TestCase):
522    """Test that tracer spans set the error tag when exception erroring RPC are
523    invoked.
524  """
525
526    def setUp(self):
527        self._tracer = Tracer()
528        self._service = Service([open_tracing_client_interceptor(self._tracer)],
529                                [open_tracing_server_interceptor(self._tracer)],
530                                ExceptionErroringHandler())
531
532    def testUnaryUnaryOpenTracing(self):
533        multi_callable = self._service.unary_unary_multi_callable
534        request = b'\x01'
535        self.assertRaises(grpc.RpcError, multi_callable, request)
536
537        span0 = self._tracer.get_span(0)
538        self.assertIsNotNone(span0)
539        self.assertTrue(span0.get_tag('error'))
540
541        span1 = self._tracer.get_span(1)
542        self.assertIsNotNone(span1)
543        self.assertTrue(span1.get_tag('error'))
544
545    def testUnaryUnaryOpenTracingWithCall(self):
546        multi_callable = self._service.unary_unary_multi_callable
547        request = b'\x01'
548        self.assertRaises(grpc.RpcError, multi_callable.with_call, request)
549
550        span0 = self._tracer.get_span(0)
551        self.assertIsNotNone(span0)
552        self.assertTrue(span0.get_tag('error'))
553
554        span1 = self._tracer.get_span(1)
555        self.assertIsNotNone(span1)
556        self.assertTrue(span1.get_tag('error'))
557
558    def testUnaryStreamOpenTracing(self):
559        multi_callable = self._service.unary_stream_multi_callable
560        request = b'\x01'
561        response = multi_callable(request)
562        self.assertRaises(grpc.RpcError, list, response)
563
564        span0 = self._tracer.get_span(0)
565        self.assertIsNotNone(span0)
566        self.assertTrue(span0.get_tag('error'))
567
568        span1 = self._tracer.get_span(1)
569        self.assertIsNotNone(span1)
570        self.assertTrue(span1.get_tag('error'))
571
572    def testStreamUnaryOpenTracing(self):
573        multi_callable = self._service.stream_unary_multi_callable
574        requests = [b'\x01', b'\x02']
575        self.assertRaises(grpc.RpcError, multi_callable, iter(requests))
576
577        span0 = self._tracer.get_span(0)
578        self.assertIsNotNone(span0)
579        self.assertTrue(span0.get_tag('error'))
580
581        span1 = self._tracer.get_span(1)
582        self.assertIsNotNone(span1)
583        self.assertTrue(span1.get_tag('error'))
584
585    def testStreamUnaryOpenTracingWithCall(self):
586        multi_callable = self._service.stream_unary_multi_callable
587        requests = [b'\x01', b'\x02']
588        self.assertRaises(grpc.RpcError, multi_callable.with_call,
589                          iter(requests))
590
591        span0 = self._tracer.get_span(0)
592        self.assertIsNotNone(span0)
593        self.assertTrue(span0.get_tag('error'))
594
595        span1 = self._tracer.get_span(1)
596        self.assertIsNotNone(span1)
597        self.assertTrue(span1.get_tag('error'))
598
599    def testStreamStreamOpenTracing(self):
600        multi_callable = self._service.stream_stream_multi_callable
601        requests = [b'\x01', b'\x02']
602        response = multi_callable(iter(requests))
603        self.assertRaises(grpc.RpcError, list, response)
604
605        span0 = self._tracer.get_span(0)
606        self.assertIsNotNone(span0)
607        self.assertTrue(span0.get_tag('error'))
608
609        span1 = self._tracer.get_span(1)
610        self.assertIsNotNone(span1)
611        self.assertTrue(span1.get_tag('error'))
612