1from functools import wraps
2
3from pyaxo import a2b
4
5from . import elements
6from . import errors
7
8
9IV_LEN = 8
10KEY_LEN = 32
11ENC_KEY_LEN = 72
12HASH_LEN = 32
13
14LINESEP = '\n'
15
16
17def raise_malformed(f):
18    @wraps(f)
19    def try_building(data):
20        try:
21            return f(data)
22        except (AssertionError, IndexError, TypeError):
23            packet_type = f.func_name.split('_')[1]
24            e = errors.MalformedPacketError(packet_type)
25            indexed_lines = ['[{}]: {}'.format(index, line)
26                             for index, line in enumerate(data.splitlines())]
27            e.message = LINESEP.join([e.message] + indexed_lines)
28            raise e
29    return try_building
30
31
32def check_iv(packet):
33    assert len(a2b(packet.iv)) == IV_LEN
34    assert len(a2b(packet.iv_hash)) == HASH_LEN
35
36
37def check_payload(packet):
38    assert len(a2b(packet.payload_hash)) == HASH_LEN
39    a2b(packet.payload)
40
41
42@raise_malformed
43def build_intro_packet(data):
44    lines = data.splitlines()
45    packet = IntroductionPacket(iv=lines[0],
46                                iv_hash=lines[1],
47                                data=data)
48
49    check_iv(packet)
50
51    return packet
52
53
54@raise_malformed
55def build_regular_packet(data):
56    packet = RegularPacket(*data.splitlines())
57
58    check_payload(packet)
59    assert not len(a2b(packet.handshake_key))
60
61    return packet
62
63
64@raise_malformed
65def build_reply_packet(data):
66    packet = RegularPacket(*data.splitlines())
67
68    check_payload(packet)
69    assert len(a2b(packet.handshake_key)) == ENC_KEY_LEN
70
71    return packet
72
73
74@raise_malformed
75def build_request_packet(data):
76    packet = RequestPacket(*data.splitlines())
77
78    assert len(a2b(packet.handshake_packet_hash)) == HASH_LEN
79    assert len(a2b(packet.request_key)) == KEY_LEN
80    a2b(packet.handshake_packet)
81
82    return packet
83
84
85@raise_malformed
86def build_handshake_packet(data):
87    packet = HandshakePacket(*data.splitlines())
88
89    assert len(a2b(packet.identity_key)) == KEY_LEN
90    assert len(a2b(packet.handshake_key)) == KEY_LEN
91    assert len(a2b(packet.ratchet_key)) == KEY_LEN
92
93    return packet
94
95
96@raise_malformed
97def build_element_packet(data):
98    lines = data.splitlines()
99    return ElementPacket(type_=lines[0],
100                         id_=lines[1],
101                         part_num=lines[2],
102                         part_len=lines[3],
103                         payload=LINESEP.join(lines[4:]))
104
105
106class IntroductionPacket:
107    def __init__(self, iv, iv_hash, data):
108        self.iv = iv
109        self.iv_hash = iv_hash
110        self.data = data
111
112    def __str__(self):
113        return self.data
114
115
116class RegularPacket:
117    def __init__(self, iv, iv_hash, payload_hash, handshake_key, payload):
118        self.iv = iv
119        self.iv_hash = iv_hash
120        self.payload_hash = payload_hash
121        self.handshake_key = handshake_key
122        self.payload = payload
123
124    def __str__(self):
125        return LINESEP.join([self.iv,
126                             self.iv_hash,
127                             self.payload_hash,
128                             self.handshake_key,
129                             self.payload])
130
131
132class RequestPacket:
133    def __init__(self, iv, iv_hash, handshake_packet_hash, request_key,
134                 handshake_packet):
135        self.iv = iv
136        self.iv_hash = iv_hash
137        self.handshake_packet_hash = handshake_packet_hash
138        self.request_key = request_key
139        self.handshake_packet = handshake_packet
140
141    def __str__(self):
142        return LINESEP.join([self.iv,
143                             self.iv_hash,
144                             self.handshake_packet_hash,
145                             self.request_key,
146                             str(self.handshake_packet)])
147
148
149class HandshakePacket:
150    def __init__(self, identity, identity_key, handshake_key, ratchet_key):
151        self.identity = identity
152        self.identity_key = identity_key
153        self.handshake_key = handshake_key
154        self.ratchet_key = ratchet_key
155
156    def __str__(self):
157        return LINESEP.join([self.identity,
158                             self.identity_key,
159                             self.handshake_key,
160                             self.ratchet_key])
161
162
163class ElementPacket:
164    def __init__(self, type_, payload, id_=None, part_num=1, part_len=1):
165        self.type_ = type_
166        self.payload = payload
167        if not id_:
168            id_ = elements.get_random_id()
169        self.id_ = id_
170        self.part_num = int(part_num)
171        self.part_len = int(part_len)
172
173    def __str__(self):
174        return LINESEP.join([self.type_,
175                             self.id_,
176                             str(self.part_num),
177                             str(self.part_len),
178                             self.payload])
179