1# -*- coding: utf-8 -*-
2# flake8: disable E501
3
4'''
5Tests for contrib/tracking...
6
7Compatibility Tests Table
8
9============= ==================================== ======================================== ========================================
10client/server native                               v2                                       v3
11============= ==================================== ======================================== ========================================
12native        N/A                                  test_native_client_tracked_server_v2     test_native_client_tracked_server_v3
13v2            test_tracked_client_v2_native_server test_tracked_client_v2_tracked_server_v2 test_tracked_client_v2_tracked_server_v3
14v3            test_tracked_client_v3_native_server test_tracked_client_v3_tracked_server_v2 regular tests
15============= ==================================== ======================================== ========================================
16'''  # noqa
17
18from __future__ import absolute_import
19
20import contextlib
21import multiprocessing
22import os
23import pickle
24import random
25import socket
26import tempfile
27import time
28
29import thriftpy2
30
31try:
32    import dbm
33except ImportError:
34    import dbm.ndbm as dbm
35
36import pytest
37
38from thriftpy2.contrib.tracking import TTrackedProcessor, TTrackedClient, \
39    TrackerBase, track_thrift
40from thriftpy2.contrib.tracking.tracker import ctx
41
42from thriftpy2.thrift import TProcessorFactory, TClient, TProcessor
43from thriftpy2.server import TThreadedServer
44from thriftpy2.transport import TServerSocket, TBufferedTransportFactory, \
45    TTransportException, TSocket
46from thriftpy2.protocol import TBinaryProtocolFactory
47from compatible.version_2.tracking import (
48    TTrackedProcessor as TTrackedProcessorV2,
49    TTrackedClient as TTrackedClientV2,
50    TrackerBase as TrackerBaseV2,
51)
52
53try:
54    from pytest_cov.embed import cleanup_on_sigterm
55except ImportError:
56    pass
57else:
58    cleanup_on_sigterm()
59
60addressbook = thriftpy2.load(os.path.join(os.path.dirname(__file__),
61                                          "addressbook.thrift"))
62_, db_file = tempfile.mkstemp()
63
64
65def _get_port():
66    while True:
67        port = 20000 + random.randint(1, 9999)
68        for i in range(5):
69            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
70            result = sock.connect_ex(('127.0.0.1', port))
71            if result == 0:
72                continue
73        else:
74            return port
75
76
77PORT = _get_port()
78
79
80class SampleTracker(TrackerBase):
81    def record(self, header, exception):
82        db = dbm.open(db_file, 'w')
83        key = "%s:%s" % (header.request_id, header.seq)
84        db[key.encode("ascii")] = pickle.dumps(header.__dict__)
85        db.close()
86
87    def handle_response_header(self, response_header):
88        self.response_header = response_header
89
90    def get_response_header(self):
91        return getattr(self, 'response_header', None)
92
93
94class Tracker_V2(TrackerBaseV2):
95    def record(self, header, exception):
96        db = dbm.open(db_file, 'w')
97        key = "%s:%s" % (header.request_id, header.seq)
98        db[key.encode("ascii")] = pickle.dumps(header.__dict__)
99        db.close()
100
101
102tracker = SampleTracker("test_client", "test_server")
103tracker_v2 = Tracker_V2("test_client", "test_server")
104
105
106class Dispatcher(object):
107    def __init__(self):
108        self.ab = addressbook.AddressBook()
109        self.ab.people = {}
110
111    def ping(self):
112        test_response = {'ping': 'pong'}
113        TrackerBase.add_response_meta(**test_response)
114        return True
115
116    def hello(self, name):
117        # Add specially constructed response header for assertion.
118        test_response = {name: name}
119        TrackerBase.add_response_meta(**test_response)
120        return "hello %s" % name
121
122    def sleep(self, ms):
123        return True
124
125    def remove(self, name):
126        person = addressbook.Person(name="mary")
127        with client(port=PORT) as c:
128            c.add(person)
129
130        return True
131
132    def get_phonenumbers(self, name, count):
133        return [addressbook.PhoneNumber(number="sdaf"),
134                addressbook.PhoneNumber(number='saf')]
135
136    def add(self, person):
137        with client(port=PORT + 1) as c:
138            c.get_phonenumbers("jane", 1)
139
140        with client(port=PORT + 1) as c:
141            c.ping()
142        return True
143
144    def get(self, name):
145        if not name:
146            # undeclared exception
147            raise ValueError('name cannot be empty')
148        raise addressbook.PersonNotExistsError()
149
150
151class TSampleServer(TThreadedServer):
152    def __init__(self, processor_factory, trans, trans_factory, prot_factory):
153        self.daemon = False
154        self.processor_factory = processor_factory
155        self.trans = trans
156
157        self.itrans_factory = self.otrans_factory = trans_factory
158        self.iprot_factory = self.oprot_factory = prot_factory
159        self.closed = False
160
161    def handle(self, client):
162        processor = self.processor_factory.get_processor()
163        itrans = self.itrans_factory.get_transport(client)
164        otrans = self.otrans_factory.get_transport(client)
165        iprot = self.iprot_factory.get_protocol(itrans)
166        oprot = self.oprot_factory.get_protocol(otrans)
167        try:
168            while True:
169                processor.process(iprot, oprot)
170        except TTransportException:
171            pass
172        except Exception:
173            raise
174        finally:
175            itrans.close()
176            otrans.close()
177
178
179def gen_server(port, tracker=tracker, processor=TTrackedProcessor):
180    args = [processor, addressbook.AddressBookService, Dispatcher()]
181    if tracker:
182        args.insert(1, tracker)
183    processor = TProcessorFactory(*args)
184    server_socket = TServerSocket(host="localhost", port=port)
185    server = TSampleServer(processor, server_socket,
186                           prot_factory=TBinaryProtocolFactory(),
187                           trans_factory=TBufferedTransportFactory())
188    ps = multiprocessing.Process(target=server.serve)
189    ps.start()
190    return ps, server
191
192
193@pytest.fixture(scope="module")
194def server(request):
195    ps, ser = gen_server(PORT)
196    time.sleep(0.15)
197
198    def fin():
199        if ps.is_alive():
200            ps.terminate()
201            ps.join()
202
203    request.addfinalizer(fin)
204    return ser
205
206
207@pytest.fixture(scope="module")
208def server1(request):
209    ps, ser = gen_server(PORT + 1)
210    time.sleep(0.15)
211
212    def fin():
213        if ps.is_alive():
214            ps.terminate()
215            ps.join()
216
217    request.addfinalizer(fin)
218    return ser
219
220
221@pytest.fixture(scope="module")
222def server2(request):
223    ps, ser = gen_server(PORT + 2)
224    time.sleep(0.15)
225
226    def fin():
227        if ps.is_alive():
228            ps.terminate()
229            ps.join()
230
231    request.addfinalizer(fin)
232    return ser
233
234
235@pytest.fixture(scope="module")
236def native_server(request):
237    ps, ser = gen_server(PORT + 3, tracker=None, processor=TProcessor)
238    time.sleep(0.15)
239
240    def fin():
241        if ps.is_alive():
242            ps.terminate()
243            ps.join()
244
245    request.addfinalizer(fin)
246    return ser
247
248
249@pytest.fixture(scope="module")
250def tracked_server_v2(request):
251    ps, ser = gen_server(PORT + 4, tracker=tracker_v2,
252                         processor=TTrackedProcessorV2)
253    time.sleep(0.15)
254
255    def fin():
256        if ps.is_alive():
257            ps.terminate()
258            ps.join()
259
260    request.addfinalizer(fin)
261    return ser
262
263
264@contextlib.contextmanager
265def client(client_class=TTrackedClient, port=PORT):
266    socket = TSocket("localhost", port)
267
268    try:
269        trans = TBufferedTransportFactory().get_transport(socket)
270        proto = TBinaryProtocolFactory().get_protocol(trans)
271        trans.open()
272        args = [addressbook.AddressBookService, proto]
273        if client_class.__name__ == TTrackedClient.__name__:
274            args.insert(0, SampleTracker("test_client", "test_server"))
275        yield client_class(*args)
276    finally:
277        trans.close()
278
279
280@pytest.fixture
281def dbm_db(request):
282    db = dbm.open(db_file, 'n')
283    db.close()
284
285    def fin():
286        try:
287            os.remove(db_file)
288        except OSError:
289            pass
290
291    request.addfinalizer(fin)
292
293
294@pytest.fixture
295def tracker_ctx(request):
296    def fin():
297        if hasattr(ctx, "header"):
298            del ctx.header
299        if hasattr(ctx, "counter"):
300            del ctx.counter
301        if hasattr(ctx, "response_header"):
302            del ctx.response_header
303
304    request.addfinalizer(fin)
305
306
307def test_negotiation(server):
308    with client() as c:
309        assert c.is_upgraded is True
310
311
312def test_response_tracker(server, dbm_db, tracker_ctx):
313    with client() as c:
314        c.hello('you')
315        assert c.tracker.response_header.meta == {'you': 'you'}
316        c.hello('me')
317        assert c.tracker.response_header.meta == {'me': 'me'}
318
319
320def test_tracker(server, dbm_db, tracker_ctx):
321    with client() as c:
322        c.hello('you')
323        assert c.tracker.response_header.meta == {'you': 'you'}
324
325    time.sleep(0.2)
326
327    db = dbm.open(db_file, 'r')
328    headers = list(db.keys())
329    assert len(headers) == 1
330
331    request_id = headers[0]
332    data = pickle.loads(db[request_id])
333
334    assert "start" in data and "end" in data
335    data.pop("start")
336    data.pop("end")
337    assert data == {
338        "request_id": request_id.decode("ascii").split(':')[0],
339        "seq": '1',
340        "client": "test_client",
341        "server": "test_server",
342        "api": "hello",
343        "status": True,
344        "annotation": {},
345        "meta": {},
346    }
347
348
349def test_tracker_chain(server, server1, server2, dbm_db, tracker_ctx):
350    test_meta = {'test': 'test_meta'}
351    with client() as c:
352        with SampleTracker.add_meta(**test_meta):
353            c.remove("jane")
354            c.hello("yes")
355
356    time.sleep(0.2)
357
358    db = dbm.open(db_file, 'r')
359    headers = list(db.keys())
360    assert len(headers) == 5
361
362    headers = [pickle.loads(db[i]) for i in headers]
363    headers.sort(key=lambda x: x["seq"])
364
365    assert len(set([i["request_id"] for i in headers])) == 2
366
367    seqs = [i["seq"] for i in headers]
368    metas = [i["meta"] for i in headers]
369    assert seqs == ['1', '1.1', '1.1.1', '1.1.2', '2']
370    assert metas == [test_meta] * 5
371
372
373def test_exception(server, dbm_db, tracker_ctx):
374    with pytest.raises(addressbook.PersonNotExistsError):
375        with client() as c:
376            c.get("jane")
377
378    db = dbm.open(db_file, 'r')
379    headers = list(db.keys())
380    assert len(headers) == 1
381
382    header = pickle.loads(db[headers[0]])
383    assert header["status"] is False
384
385
386def test_undeclared_exception(server, dbm_db, tracker_ctx):
387    with pytest.raises(TTransportException):
388        with client() as c:
389            c.get('')
390
391
392def test_request_id_func():
393    ctx.__dict__.clear()
394
395    header = track_thrift.RequestHeader()
396    header.request_id = "hello"
397    header.seq = 0
398
399    tracker = TrackerBase()
400    tracker.handle(header)
401
402    header2 = track_thrift.RequestHeader()
403    tracker.gen_header(header2)
404    assert header2.request_id == "hello"
405
406
407def test_annotation(server, dbm_db, tracker_ctx):
408    with client() as c:
409        with SampleTracker.annotate(ann="value"):
410            c.ping()
411
412        with SampleTracker.annotate() as ann:
413            ann.update({"sig": "c.hello()", "user_id": "125"})
414            c.hello('you')
415
416    time.sleep(0.2)
417
418    db = dbm.open(db_file, 'r')
419    headers = list(db.keys())
420
421    data = [pickle.loads(db[i]) for i in headers]
422    data.sort(key=lambda x: x["seq"])
423
424    assert data[0]["annotation"] == {"ann": "value"} and \
425        data[1]["annotation"] == {"sig": "c.hello()", "user_id": "125"}
426
427
428def test_counter(server, dbm_db, tracker_ctx):
429    with client() as c:
430        c.get_phonenumbers("hello", 1)
431
432        with SampleTracker.counter():
433            c.ping()
434            c.hello("counter")
435
436        c.sleep(8)
437
438    time.sleep(0.2)
439
440    db = dbm.open(db_file, 'r')
441    headers = list(db.keys())
442
443    data = [pickle.loads(db[i]) for i in headers]
444    data.sort(key=lambda x: x["api"])
445    get, hello, ping, sleep = data
446
447    assert get["api"] == "get_phonenumbers" and get["seq"] == '1'
448    assert ping["api"] == "ping" and ping["seq"] == '1'
449    assert hello["api"] == "hello" and hello["seq"] == '2'
450    assert sleep["api"] == "sleep" and sleep["seq"] == '2'
451
452
453def test_native_client_tracked_server_v3(server):
454    with client(TClient) as c:
455        c.ping()
456        c.hello("world")
457
458
459def test_native_client_tracked_server_v2(tracked_server_v2):
460    with client(TClient, port=PORT + 4) as c:
461        c.ping()
462        c.hello("world")
463
464
465def test_tracked_client_v2_native_server(native_server):
466    with client(TTrackedClientV2, PORT + 3) as c:
467        assert c._upgraded is False
468        c.ping()
469        c.hello("cat")
470        a = c.get_phonenumbers("hello", 54)
471        assert len(a) == 2
472        assert a[0].number == 'sdaf' and a[1].number == 'saf'
473
474
475def test_tracked_client_v2_tracked_server_v2(
476        tracked_server_v2, dbm_db, tracker_ctx):
477    with client(TTrackedClientV2, PORT + 4) as c:
478        assert c._upgraded is True
479
480        c.ping()
481        time.sleep(0.2)
482
483        db = dbm.open(db_file, 'r')
484        headers = list(db.keys())
485        assert len(headers) == 1
486
487        request_id = headers[0]
488        data = pickle.loads(db[request_id])
489
490        assert "start" in data and "end" in data
491        data.pop("start")
492        data.pop("end")
493        assert data == {
494            "request_id": request_id.decode("ascii").split(':')[0],
495            "seq": '1',
496            "client": "test_client",
497            "server": "test_server",
498            "api": "ping",
499            "status": True,
500            "annotation": {},
501            "meta": {},
502        }
503
504
505def test_tracked_client_v2_tracked_server_v3(server, dbm_db, tracker_ctx):
506    with client(TTrackedClientV2) as c:
507        assert c._upgraded is True
508
509        c.ping()
510        time.sleep(0.2)
511
512        db = dbm.open(db_file, 'r')
513        headers = list(db.keys())
514        assert len(headers) == 1
515
516        request_id = headers[0]
517        data = pickle.loads(db[request_id])
518
519        assert "start" in data and "end" in data
520        data.pop("start")
521        data.pop("end")
522        assert data == {
523            "request_id": request_id.decode("ascii").split(':')[0],
524            "seq": '1',
525            "client": "test_client",
526            "server": "test_server",
527            "api": "ping",
528            "status": True,
529            "annotation": {},
530            "meta": {},
531        }
532
533        assert not hasattr(ctx, 'response_header')
534
535
536def test_tracked_client_v3_native_server(native_server):
537    with client(port=PORT + 3) as c:
538        assert c.is_upgraded is False
539        c.ping()
540        assert not hasattr(ctx, "response_header")
541
542        c.hello("cat")
543        a = c.get_phonenumbers("hello", 54)
544        assert len(a) == 2
545        assert a[0].number == 'sdaf' and a[1].number == 'saf'
546
547
548def test_tracked_client_v3_tracked_server_v2(
549        tracked_server_v2, dbm_db, tracker_ctx):
550    with client(port=PORT + 4) as c:
551        assert c.is_upgraded is True
552
553        c.ping()
554        assert not hasattr(ctx, "response_header")
555        assert c.tracker.get_response_header() is None
556
557        time.sleep(0.2)
558
559        db = dbm.open(db_file, 'r')
560        headers = list(db.keys())
561        assert len(headers) == 1
562
563        request_id = headers[0]
564        data = pickle.loads(db[request_id])
565
566        assert "start" in data and "end" in data
567        data.pop("start")
568        data.pop("end")
569        assert data == {
570            "request_id": request_id.decode("ascii").split(':')[0],
571            "seq": '1',
572            "client": "test_client",
573            "server": "test_server",
574            "api": "ping",
575            "status": True,
576            "annotation": {},
577            "meta": {},
578        }
579