1from .error import AMFError
2from .packet import Packet
3from .types import AMF0String, AMF0Value, U8, U16BE, U32BE
4
5
6class AMFHeader(Packet):
7    exception = AMFError
8
9    def __init__(self, name, value, must_understand=False):
10        self.name = name
11        self.value = value
12        self.must_understand = must_understand
13
14    @property
15    def size(self):
16        size = 4+1
17        size += AMF0String.size(self.name)
18        size += AMF0Value.size(self.value)
19
20        return size
21
22    def _serialize(self, packet):
23        packet += AMF0String(self.name)
24        packet += U8(int(self.must_understand))
25        packet += U32BE(self.size)
26        packet += AMF0Value(self.value)
27
28    @classmethod
29    def _deserialize(cls, io):
30        name = AMF0String.read(io)
31        must_understand = bool(U8.read(io))
32        length = U32BE.read(io)
33        value = AMF0Value.read(io)
34
35        return cls(name, value, must_understand)
36
37
38class AMFMessage(Packet):
39    exception = AMFError
40
41    def __init__(self, target_uri, response_uri, value):
42        self.target_uri = target_uri
43        self.response_uri = response_uri
44        self.value = value
45
46    @property
47    def size(self):
48        size = 4
49        size += AMF0String.size(self.target_uri)
50        size += AMF0String.size(self.response_uri)
51        size += AMF0Value.size(self.value)
52
53        return size
54
55    def _serialize(self, packet):
56        packet += AMF0String(self.target_uri)
57        packet += AMF0String(self.response_uri)
58        packet += U32BE(self.size)
59        packet += AMF0Value.pack(self.value)
60
61    @classmethod
62    def _deserialize(cls, io):
63        target_uri = AMF0String.read(io)
64        response_uri = AMF0String.read(io)
65        length = U32BE.read(io)
66        value = AMF0Value.read(io)
67
68        return cls(target_uri, response_uri, value)
69
70
71class AMFPacket(Packet):
72    exception = AMFError
73
74    def __init__(self, version, headers=None, messages=None):
75        if headers is None:
76            headers = []
77
78        if messages is None:
79            messages = []
80
81        self.version = version
82        self.headers = headers
83        self.messages = messages
84
85    @property
86    def size(self):
87        size = 2+2+2
88
89        for header in self.headers:
90            size += header.size
91
92        for message in self.messages:
93            size += message.size
94
95        return size
96
97    def _serialize(self, packet):
98        packet += U16BE(self.version)
99        packet += U16BE(len(self.headers))
100
101        for header in self.headers:
102            header.serialize(packet)
103
104        packet += U16BE(len(self.messages))
105        for message in self.messages:
106            message.serialize(packet)
107
108    @classmethod
109    def _deserialize(cls, io):
110        version = U16BE.read(io)
111
112        if not version in (0, 3):
113            raise AMFError("AMF version must be 0 or 3")
114
115        headers = []
116        header_count = U16BE.read(io)
117
118        for i in range(header_count):
119            header = AMFHeader.deserialize(io)
120            headers.append(header)
121
122        messages = []
123        message_count = U16BE.read(io)
124        for i in range(message_count):
125            message = AMFMessage.deserialize(io)
126            messages.append(message)
127
128        return cls(version, headers, messages)
129
130__all__ = ["AMFPacket", "AMFHeader", "AMFMessage"]
131