1# -*- coding: utf-8 -*- 2# flake8: disable E501 3 4''' 5Tests for contrib/tracking... 6 7Compatibility Tests Table 8 9============= ==================================== ======================================== ======================================== 10client/server native v2 v3 11============= ==================================== ======================================== ======================================== 12native N/A test_native_client_tracked_server_v2 test_native_client_tracked_server_v3 13v2 test_tracked_client_v2_native_server test_tracked_client_v2_tracked_server_v2 test_tracked_client_v2_tracked_server_v3 14v3 test_tracked_client_v3_native_server test_tracked_client_v3_tracked_server_v2 regular tests 15============= ==================================== ======================================== ======================================== 16''' # noqa 17 18from __future__ import absolute_import 19 20import contextlib 21import multiprocessing 22import os 23import pickle 24import random 25import socket 26import tempfile 27import time 28 29import thriftpy2 30 31try: 32 import dbm 33except ImportError: 34 import dbm.ndbm as dbm 35 36import pytest 37 38from thriftpy2.contrib.tracking import TTrackedProcessor, TTrackedClient, \ 39 TrackerBase, track_thrift 40from thriftpy2.contrib.tracking.tracker import ctx 41 42from thriftpy2.thrift import TProcessorFactory, TClient, TProcessor 43from thriftpy2.server import TThreadedServer 44from thriftpy2.transport import TServerSocket, TBufferedTransportFactory, \ 45 TTransportException, TSocket 46from thriftpy2.protocol import TBinaryProtocolFactory 47from compatible.version_2.tracking import ( 48 TTrackedProcessor as TTrackedProcessorV2, 49 TTrackedClient as TTrackedClientV2, 50 TrackerBase as TrackerBaseV2, 51) 52 53try: 54 from pytest_cov.embed import cleanup_on_sigterm 55except ImportError: 56 pass 57else: 58 cleanup_on_sigterm() 59 60addressbook = thriftpy2.load(os.path.join(os.path.dirname(__file__), 61 "addressbook.thrift")) 62_, db_file = tempfile.mkstemp() 63 64 65def _get_port(): 66 while True: 67 port = 20000 + random.randint(1, 9999) 68 for i in range(5): 69 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 70 result = sock.connect_ex(('127.0.0.1', port)) 71 if result == 0: 72 continue 73 else: 74 return port 75 76 77PORT = _get_port() 78 79 80class SampleTracker(TrackerBase): 81 def record(self, header, exception): 82 db = dbm.open(db_file, 'w') 83 key = "%s:%s" % (header.request_id, header.seq) 84 db[key.encode("ascii")] = pickle.dumps(header.__dict__) 85 db.close() 86 87 def handle_response_header(self, response_header): 88 self.response_header = response_header 89 90 def get_response_header(self): 91 return getattr(self, 'response_header', None) 92 93 94class Tracker_V2(TrackerBaseV2): 95 def record(self, header, exception): 96 db = dbm.open(db_file, 'w') 97 key = "%s:%s" % (header.request_id, header.seq) 98 db[key.encode("ascii")] = pickle.dumps(header.__dict__) 99 db.close() 100 101 102tracker = SampleTracker("test_client", "test_server") 103tracker_v2 = Tracker_V2("test_client", "test_server") 104 105 106class Dispatcher(object): 107 def __init__(self): 108 self.ab = addressbook.AddressBook() 109 self.ab.people = {} 110 111 def ping(self): 112 test_response = {'ping': 'pong'} 113 TrackerBase.add_response_meta(**test_response) 114 return True 115 116 def hello(self, name): 117 # Add specially constructed response header for assertion. 118 test_response = {name: name} 119 TrackerBase.add_response_meta(**test_response) 120 return "hello %s" % name 121 122 def sleep(self, ms): 123 return True 124 125 def remove(self, name): 126 person = addressbook.Person(name="mary") 127 with client(port=PORT) as c: 128 c.add(person) 129 130 return True 131 132 def get_phonenumbers(self, name, count): 133 return [addressbook.PhoneNumber(number="sdaf"), 134 addressbook.PhoneNumber(number='saf')] 135 136 def add(self, person): 137 with client(port=PORT + 1) as c: 138 c.get_phonenumbers("jane", 1) 139 140 with client(port=PORT + 1) as c: 141 c.ping() 142 return True 143 144 def get(self, name): 145 if not name: 146 # undeclared exception 147 raise ValueError('name cannot be empty') 148 raise addressbook.PersonNotExistsError() 149 150 151class TSampleServer(TThreadedServer): 152 def __init__(self, processor_factory, trans, trans_factory, prot_factory): 153 self.daemon = False 154 self.processor_factory = processor_factory 155 self.trans = trans 156 157 self.itrans_factory = self.otrans_factory = trans_factory 158 self.iprot_factory = self.oprot_factory = prot_factory 159 self.closed = False 160 161 def handle(self, client): 162 processor = self.processor_factory.get_processor() 163 itrans = self.itrans_factory.get_transport(client) 164 otrans = self.otrans_factory.get_transport(client) 165 iprot = self.iprot_factory.get_protocol(itrans) 166 oprot = self.oprot_factory.get_protocol(otrans) 167 try: 168 while True: 169 processor.process(iprot, oprot) 170 except TTransportException: 171 pass 172 except Exception: 173 raise 174 finally: 175 itrans.close() 176 otrans.close() 177 178 179def gen_server(port, tracker=tracker, processor=TTrackedProcessor): 180 args = [processor, addressbook.AddressBookService, Dispatcher()] 181 if tracker: 182 args.insert(1, tracker) 183 processor = TProcessorFactory(*args) 184 server_socket = TServerSocket(host="localhost", port=port) 185 server = TSampleServer(processor, server_socket, 186 prot_factory=TBinaryProtocolFactory(), 187 trans_factory=TBufferedTransportFactory()) 188 ps = multiprocessing.Process(target=server.serve) 189 ps.start() 190 return ps, server 191 192 193@pytest.fixture(scope="module") 194def server(request): 195 ps, ser = gen_server(PORT) 196 time.sleep(0.15) 197 198 def fin(): 199 if ps.is_alive(): 200 ps.terminate() 201 ps.join() 202 203 request.addfinalizer(fin) 204 return ser 205 206 207@pytest.fixture(scope="module") 208def server1(request): 209 ps, ser = gen_server(PORT + 1) 210 time.sleep(0.15) 211 212 def fin(): 213 if ps.is_alive(): 214 ps.terminate() 215 ps.join() 216 217 request.addfinalizer(fin) 218 return ser 219 220 221@pytest.fixture(scope="module") 222def server2(request): 223 ps, ser = gen_server(PORT + 2) 224 time.sleep(0.15) 225 226 def fin(): 227 if ps.is_alive(): 228 ps.terminate() 229 ps.join() 230 231 request.addfinalizer(fin) 232 return ser 233 234 235@pytest.fixture(scope="module") 236def native_server(request): 237 ps, ser = gen_server(PORT + 3, tracker=None, processor=TProcessor) 238 time.sleep(0.15) 239 240 def fin(): 241 if ps.is_alive(): 242 ps.terminate() 243 ps.join() 244 245 request.addfinalizer(fin) 246 return ser 247 248 249@pytest.fixture(scope="module") 250def tracked_server_v2(request): 251 ps, ser = gen_server(PORT + 4, tracker=tracker_v2, 252 processor=TTrackedProcessorV2) 253 time.sleep(0.15) 254 255 def fin(): 256 if ps.is_alive(): 257 ps.terminate() 258 ps.join() 259 260 request.addfinalizer(fin) 261 return ser 262 263 264@contextlib.contextmanager 265def client(client_class=TTrackedClient, port=PORT): 266 socket = TSocket("localhost", port) 267 268 try: 269 trans = TBufferedTransportFactory().get_transport(socket) 270 proto = TBinaryProtocolFactory().get_protocol(trans) 271 trans.open() 272 args = [addressbook.AddressBookService, proto] 273 if client_class.__name__ == TTrackedClient.__name__: 274 args.insert(0, SampleTracker("test_client", "test_server")) 275 yield client_class(*args) 276 finally: 277 trans.close() 278 279 280@pytest.fixture 281def dbm_db(request): 282 db = dbm.open(db_file, 'n') 283 db.close() 284 285 def fin(): 286 try: 287 os.remove(db_file) 288 except OSError: 289 pass 290 291 request.addfinalizer(fin) 292 293 294@pytest.fixture 295def tracker_ctx(request): 296 def fin(): 297 if hasattr(ctx, "header"): 298 del ctx.header 299 if hasattr(ctx, "counter"): 300 del ctx.counter 301 if hasattr(ctx, "response_header"): 302 del ctx.response_header 303 304 request.addfinalizer(fin) 305 306 307def test_negotiation(server): 308 with client() as c: 309 assert c.is_upgraded is True 310 311 312def test_response_tracker(server, dbm_db, tracker_ctx): 313 with client() as c: 314 c.hello('you') 315 assert c.tracker.response_header.meta == {'you': 'you'} 316 c.hello('me') 317 assert c.tracker.response_header.meta == {'me': 'me'} 318 319 320def test_tracker(server, dbm_db, tracker_ctx): 321 with client() as c: 322 c.hello('you') 323 assert c.tracker.response_header.meta == {'you': 'you'} 324 325 time.sleep(0.2) 326 327 db = dbm.open(db_file, 'r') 328 headers = list(db.keys()) 329 assert len(headers) == 1 330 331 request_id = headers[0] 332 data = pickle.loads(db[request_id]) 333 334 assert "start" in data and "end" in data 335 data.pop("start") 336 data.pop("end") 337 assert data == { 338 "request_id": request_id.decode("ascii").split(':')[0], 339 "seq": '1', 340 "client": "test_client", 341 "server": "test_server", 342 "api": "hello", 343 "status": True, 344 "annotation": {}, 345 "meta": {}, 346 } 347 348 349def test_tracker_chain(server, server1, server2, dbm_db, tracker_ctx): 350 test_meta = {'test': 'test_meta'} 351 with client() as c: 352 with SampleTracker.add_meta(**test_meta): 353 c.remove("jane") 354 c.hello("yes") 355 356 time.sleep(0.2) 357 358 db = dbm.open(db_file, 'r') 359 headers = list(db.keys()) 360 assert len(headers) == 5 361 362 headers = [pickle.loads(db[i]) for i in headers] 363 headers.sort(key=lambda x: x["seq"]) 364 365 assert len(set([i["request_id"] for i in headers])) == 2 366 367 seqs = [i["seq"] for i in headers] 368 metas = [i["meta"] for i in headers] 369 assert seqs == ['1', '1.1', '1.1.1', '1.1.2', '2'] 370 assert metas == [test_meta] * 5 371 372 373def test_exception(server, dbm_db, tracker_ctx): 374 with pytest.raises(addressbook.PersonNotExistsError): 375 with client() as c: 376 c.get("jane") 377 378 db = dbm.open(db_file, 'r') 379 headers = list(db.keys()) 380 assert len(headers) == 1 381 382 header = pickle.loads(db[headers[0]]) 383 assert header["status"] is False 384 385 386def test_undeclared_exception(server, dbm_db, tracker_ctx): 387 with pytest.raises(TTransportException): 388 with client() as c: 389 c.get('') 390 391 392def test_request_id_func(): 393 ctx.__dict__.clear() 394 395 header = track_thrift.RequestHeader() 396 header.request_id = "hello" 397 header.seq = 0 398 399 tracker = TrackerBase() 400 tracker.handle(header) 401 402 header2 = track_thrift.RequestHeader() 403 tracker.gen_header(header2) 404 assert header2.request_id == "hello" 405 406 407def test_annotation(server, dbm_db, tracker_ctx): 408 with client() as c: 409 with SampleTracker.annotate(ann="value"): 410 c.ping() 411 412 with SampleTracker.annotate() as ann: 413 ann.update({"sig": "c.hello()", "user_id": "125"}) 414 c.hello('you') 415 416 time.sleep(0.2) 417 418 db = dbm.open(db_file, 'r') 419 headers = list(db.keys()) 420 421 data = [pickle.loads(db[i]) for i in headers] 422 data.sort(key=lambda x: x["seq"]) 423 424 assert data[0]["annotation"] == {"ann": "value"} and \ 425 data[1]["annotation"] == {"sig": "c.hello()", "user_id": "125"} 426 427 428def test_counter(server, dbm_db, tracker_ctx): 429 with client() as c: 430 c.get_phonenumbers("hello", 1) 431 432 with SampleTracker.counter(): 433 c.ping() 434 c.hello("counter") 435 436 c.sleep(8) 437 438 time.sleep(0.2) 439 440 db = dbm.open(db_file, 'r') 441 headers = list(db.keys()) 442 443 data = [pickle.loads(db[i]) for i in headers] 444 data.sort(key=lambda x: x["api"]) 445 get, hello, ping, sleep = data 446 447 assert get["api"] == "get_phonenumbers" and get["seq"] == '1' 448 assert ping["api"] == "ping" and ping["seq"] == '1' 449 assert hello["api"] == "hello" and hello["seq"] == '2' 450 assert sleep["api"] == "sleep" and sleep["seq"] == '2' 451 452 453def test_native_client_tracked_server_v3(server): 454 with client(TClient) as c: 455 c.ping() 456 c.hello("world") 457 458 459def test_native_client_tracked_server_v2(tracked_server_v2): 460 with client(TClient, port=PORT + 4) as c: 461 c.ping() 462 c.hello("world") 463 464 465def test_tracked_client_v2_native_server(native_server): 466 with client(TTrackedClientV2, PORT + 3) as c: 467 assert c._upgraded is False 468 c.ping() 469 c.hello("cat") 470 a = c.get_phonenumbers("hello", 54) 471 assert len(a) == 2 472 assert a[0].number == 'sdaf' and a[1].number == 'saf' 473 474 475def test_tracked_client_v2_tracked_server_v2( 476 tracked_server_v2, dbm_db, tracker_ctx): 477 with client(TTrackedClientV2, PORT + 4) as c: 478 assert c._upgraded is True 479 480 c.ping() 481 time.sleep(0.2) 482 483 db = dbm.open(db_file, 'r') 484 headers = list(db.keys()) 485 assert len(headers) == 1 486 487 request_id = headers[0] 488 data = pickle.loads(db[request_id]) 489 490 assert "start" in data and "end" in data 491 data.pop("start") 492 data.pop("end") 493 assert data == { 494 "request_id": request_id.decode("ascii").split(':')[0], 495 "seq": '1', 496 "client": "test_client", 497 "server": "test_server", 498 "api": "ping", 499 "status": True, 500 "annotation": {}, 501 "meta": {}, 502 } 503 504 505def test_tracked_client_v2_tracked_server_v3(server, dbm_db, tracker_ctx): 506 with client(TTrackedClientV2) as c: 507 assert c._upgraded is True 508 509 c.ping() 510 time.sleep(0.2) 511 512 db = dbm.open(db_file, 'r') 513 headers = list(db.keys()) 514 assert len(headers) == 1 515 516 request_id = headers[0] 517 data = pickle.loads(db[request_id]) 518 519 assert "start" in data and "end" in data 520 data.pop("start") 521 data.pop("end") 522 assert data == { 523 "request_id": request_id.decode("ascii").split(':')[0], 524 "seq": '1', 525 "client": "test_client", 526 "server": "test_server", 527 "api": "ping", 528 "status": True, 529 "annotation": {}, 530 "meta": {}, 531 } 532 533 assert not hasattr(ctx, 'response_header') 534 535 536def test_tracked_client_v3_native_server(native_server): 537 with client(port=PORT + 3) as c: 538 assert c.is_upgraded is False 539 c.ping() 540 assert not hasattr(ctx, "response_header") 541 542 c.hello("cat") 543 a = c.get_phonenumbers("hello", 54) 544 assert len(a) == 2 545 assert a[0].number == 'sdaf' and a[1].number == 'saf' 546 547 548def test_tracked_client_v3_tracked_server_v2( 549 tracked_server_v2, dbm_db, tracker_ctx): 550 with client(port=PORT + 4) as c: 551 assert c.is_upgraded is True 552 553 c.ping() 554 assert not hasattr(ctx, "response_header") 555 assert c.tracker.get_response_header() is None 556 557 time.sleep(0.2) 558 559 db = dbm.open(db_file, 'r') 560 headers = list(db.keys()) 561 assert len(headers) == 1 562 563 request_id = headers[0] 564 data = pickle.loads(db[request_id]) 565 566 assert "start" in data and "end" in data 567 data.pop("start") 568 data.pop("end") 569 assert data == { 570 "request_id": request_id.decode("ascii").split(':')[0], 571 "seq": '1', 572 "client": "test_client", 573 "server": "test_server", 574 "api": "ping", 575 "status": True, 576 "annotation": {}, 577 "meta": {}, 578 } 579