1from circuits import Component, handler
2from circuits.core import Value
3from circuits.net.events import write
4from .utils import dump_event, load_value, load_event, dump_value
5
6
7DELIMITER = b'~~~'
8
9
10class Protocol(Component):
11    __buffer = b''
12    __nid = 0
13    __events = {}
14
15    def init(self, sock=None, server=None, **kwargs):
16        self.__server = server
17        self.__sock = sock
18        self.__receive_event_firewall = kwargs.get('receive_event_firewall',
19                                                   None)
20        self.__send_event_firewall = kwargs.get('send_event_firewall', None)
21
22    def add_buffer(self, data=''):
23        if data:
24            self.__buffer += data
25
26        packets = self.__buffer.split(DELIMITER)
27        self.__buffer = b''
28
29        for packet in packets:
30            try:
31                self.__process_packet(packet)
32            except ValueError:
33                self.__buffer = packet
34
35    @handler(channel='node_result', priority=100)
36    def result_handler(self, event, *args, **kwargs):
37        if event.name.endswith('_success'):
38            source_event = args[0]
39
40            if getattr(args[0], 'node_call_id', False) is not False:
41                self.send_result(source_event.node_call_id, source_event.value)
42
43    def send(self, event):
44        if self.__send_event_firewall and \
45                not self.__send_event_firewall(event, self.__sock):
46            yield Value(event, self)
47
48        else:
49            id = self.__nid
50            self.__nid += 1
51
52            packet = dump_event(event, id).encode('utf-8') + DELIMITER
53            self.__send(packet)
54
55            if not getattr(event, 'node_without_result', False):
56                self.__events[id] = event
57                while not hasattr(self.__events[id], 'remote_finish'):
58                    yield
59
60                del (self.__events[id])
61                yield event.value
62
63    def send_result(self, id, value):
64        value.node_call_id = id
65        value.node_sock = self.__sock
66        packet = dump_value(value).encode('utf-8') + DELIMITER
67        self.__send(packet)
68
69    def __send(self, packet):
70        if self.__server is not None:
71            self.fire(write(self.__sock, packet))
72        else:
73            self.fire(write(packet))
74
75    def __process_packet(self, packet):
76        packet = packet.decode('utf-8')
77
78        if '"value":' in packet:
79            self.__process_packet_value(packet)
80
81        else:
82            self.__process_packet_call(packet)
83
84    def __process_packet_call(self, packet):
85        event, id = load_event(packet)
86
87        if self.__receive_event_firewall and \
88                not self.__receive_event_firewall(event, self.__sock):
89            self.send_result(id, Value(event, self))
90        else:
91            event.success = True  # fire %s_success event
92            event.success_channels = ('node_result',)
93            event.node_call_id = id
94            event.node_sock = self.__sock
95
96            # convert byte to str
97            event.args = [arg.decode('utf-8') if isinstance(arg, bytes) else
98                          arg for arg in event.args]
99
100            for i in event.kwargs:
101                v = event.kwargs[i]
102                index = i.decode('utf-8') if isinstance(i, bytes) else i
103                value = v.decode('utf-8') if isinstance(v, bytes) else v
104
105                del (event.kwargs[i])
106                event.kwargs[index] = value
107
108            self.fire(event, *event.channels)
109
110    def __process_packet_value(self, packet):
111        value, id, error, meta = load_value(packet)
112
113        if id in self.__events:
114            # convert byte to str
115            value = value.decode(
116                'utf-8') if isinstance(value, bytes) else value
117            error = error.decode(
118                'utf-8') if isinstance(error, bytes) else error
119
120            if not hasattr(self.__events[id], 'value') \
121                    or not self.__events[id].value:
122                self.__events[id].value = Value(self.__events[id], self)
123
124            # save result
125            self.__events[id].value.setValue(value)
126            self.__events[id].errors = error
127            self.__events[id].remote_finish = True
128
129            for k, v in dict(meta).items():
130                setattr(self.__events[id], k, v)
131