1# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
2
3import contextlib
4import struct
5
6import dns.exception
7import dns.name
8
9class Parser:
10    def __init__(self, wire, current=0):
11        self.wire = wire
12        self.current = 0
13        self.end = len(self.wire)
14        if current:
15            self.seek(current)
16        self.furthest = current
17
18    def remaining(self):
19        return self.end - self.current
20
21    def get_bytes(self, size):
22        if size > self.remaining():
23            raise dns.exception.FormError
24        output = self.wire[self.current:self.current + size]
25        self.current += size
26        self.furthest = max(self.furthest, self.current)
27        return output
28
29    def get_counted_bytes(self, length_size=1):
30        length = int.from_bytes(self.get_bytes(length_size), 'big')
31        return self.get_bytes(length)
32
33    def get_remaining(self):
34        return self.get_bytes(self.remaining())
35
36    def get_uint8(self):
37        return struct.unpack('!B', self.get_bytes(1))[0]
38
39    def get_uint16(self):
40        return struct.unpack('!H', self.get_bytes(2))[0]
41
42    def get_uint32(self):
43        return struct.unpack('!I', self.get_bytes(4))[0]
44
45    def get_uint48(self):
46        return int.from_bytes(self.get_bytes(6), 'big')
47
48    def get_struct(self, format):
49        return struct.unpack(format, self.get_bytes(struct.calcsize(format)))
50
51    def get_name(self, origin=None):
52        name = dns.name.from_wire_parser(self)
53        if origin:
54            name = name.relativize(origin)
55        return name
56
57    def seek(self, where):
58        # Note that seeking to the end is OK!  (If you try to read
59        # after such a seek, you'll get an exception as expected.)
60        if where < 0 or where > self.end:
61            raise dns.exception.FormError
62        self.current = where
63
64    @contextlib.contextmanager
65    def restrict_to(self, size):
66        if size > self.remaining():
67            raise dns.exception.FormError
68        saved_end = self.end
69        try:
70            self.end = self.current + size
71            yield
72            # We make this check here and not in the finally as we
73            # don't want to raise if we're already raising for some
74            # other reason.
75            if self.current != self.end:
76                raise dns.exception.FormError
77        finally:
78            self.end = saved_end
79
80    @contextlib.contextmanager
81    def restore_furthest(self):
82        try:
83            yield None
84        finally:
85            self.current = self.furthest
86