1# -*- coding: utf-8 -*-
2
3from __future__ import absolute_import
4
5import copy
6import contextlib
7import threading
8import uuid
9
10ctx = threading.local()
11
12
13class VersionMixin(object):
14    """Mixin class to handle version compatibilities"""
15    DEFAULT_VERSION = 0  # support only request header
16    VERSION_SUPPORT_REQUEST_HEADER = 1  # add request header
17    VERSION_SUPPORT_RESPONSE_HEADER = 2  # add response header
18
19    CURRENT = VERSION_SUPPORT_RESPONSE_HEADER
20
21    def init_version_mixin(self):
22        self.current_version = self.DEFAULT_VERSION
23        self.is_upgraded = False
24
25    def check_version(self, feature_version):
26        return self.current_version >= feature_version
27
28    def upgrade_version(self, target_version):
29        self.is_upgraded = True
30        if VersionMixin.CURRENT >= target_version > self.current_version:
31            self.current_version = target_version
32
33
34class TrackerBase(object):
35    def __init__(self, client=None, server=None):
36        self.client = client
37        self.server = server
38
39    def handle(self, header):
40        ctx.header = header
41        ctx.counter = 0
42
43    def handle_response_header(self, response_header):
44        pass
45
46    def gen_header(self, header):
47        header.request_id = self.get_request_id()
48
49        if not hasattr(ctx, "counter"):
50            ctx.counter = 0
51
52        ctx.counter += 1
53
54        if hasattr(ctx, "header"):
55            header.seq = "{prev_seq}.{cur_counter}".format(
56                prev_seq=ctx.header.seq, cur_counter=ctx.counter)
57            header.meta = ctx.header.meta
58        else:
59            header.meta = {}
60            header.seq = str(ctx.counter)
61
62        if hasattr(ctx, "meta"):
63            header.meta.update(ctx.meta)
64
65    def gen_response_header(self, response_header):
66        if hasattr(ctx, "response_meta"):
67            response_header.meta = ctx.response_meta
68            del ctx.response_meta
69
70    def record(self, header, exception):
71        pass
72
73    @classmethod
74    @contextlib.contextmanager
75    def counter(cls, init=0):
76        """Context for manually setting counter of seq number.
77
78        :init: init value
79        """
80        if not hasattr(ctx, "counter"):
81            ctx.counter = 0
82
83        old = ctx.counter
84        ctx.counter = init
85
86        try:
87            yield
88        finally:
89            ctx.counter = old
90
91    @classmethod
92    @contextlib.contextmanager
93    def annotate(cls, **kwargs):
94        ctx.annotation = kwargs
95        try:
96            yield ctx.annotation
97        finally:
98            del ctx.annotation
99
100    @classmethod
101    @contextlib.contextmanager
102    def add_meta(cls, **kwds):
103        if hasattr(ctx, 'meta'):
104            old_dict = copy.copy(ctx.meta)
105            ctx.meta.update(kwds)
106            try:
107                yield ctx.meta
108            finally:
109                ctx.meta = old_dict
110        else:
111            ctx.meta = kwds
112            try:
113                yield ctx.meta
114            finally:
115                del ctx.meta
116
117    @classmethod
118    def add_response_meta(cls, **kwds):
119        if hasattr(ctx, 'response_meta'):
120            ctx.response_meta.update(kwds)
121
122        else:
123            ctx.response_meta = kwds
124
125        return ctx.response_meta
126
127    @property
128    def meta(self):
129        meta = ctx.header.meta if hasattr(ctx, "header") else {}
130        if hasattr(ctx, "meta"):
131            meta.update(ctx.meta)
132        return meta
133
134    @property
135    def annotation(self):
136        return ctx.annotation if hasattr(ctx, "annotation") else {}
137
138    def get_request_id(self):
139        if hasattr(ctx, "header"):
140            return ctx.header.request_id
141        return str(uuid.uuid4())
142
143    def init_handshake_info(self, handshake_obj):
144        pass
145
146    def handle_handshake_info(self, handshake_obj):
147        pass
148
149
150class ConsoleTracker(TrackerBase):
151    def record(self, header, exception):
152        print(header)
153