1
2import struct, logging, random
3from .nmb_constants import *
4from .nmb_structs import *
5from .utils import encode_name
6
7class NMBSession:
8
9    log = logging.getLogger('NMB.NMBSession')
10
11    def __init__(self, my_name, remote_name, host_type = TYPE_SERVER, is_direct_tcp = False):
12        self.my_name = my_name.upper()
13        self.remote_name = remote_name.upper()
14        self.host_type = host_type
15        self.data_buf = b''
16
17        if is_direct_tcp:
18            self.data_nmb = DirectTCPSessionMessage()
19            self.sendNMBPacket = self._sendNMBPacket_DirectTCP
20        else:
21            self.data_nmb = NMBSessionMessage()
22            self.sendNMBPacket = self._sendNMBPacket_NetBIOS
23
24    #
25    # Overridden Methods
26    #
27
28    def write(self, data):
29        raise NotImplementedError
30
31    def onNMBSessionMessage(self, flags, data):
32        pass
33
34    def onNMBSessionOK(self):
35        pass
36
37    def onNMBSessionFailed(self):
38        pass
39
40    #
41    # Public Methods
42    #
43
44    def feedData(self, data):
45        self.data_buf = self.data_buf + data
46
47        offset = 0
48        while True:
49            length = self.data_nmb.decode(self.data_buf, offset)
50            if length == 0:
51                break
52            elif length > 0:
53                offset += length
54                self._processNMBSessionPacket(self.data_nmb)
55            else:
56                raise NMBError
57
58        if offset > 0:
59            self.data_buf = self.data_buf[offset:]
60
61    def sendNMBMessage(self, data):
62        self.sendNMBPacket(SESSION_MESSAGE, data)
63
64    def requestNMBSession(self):
65        my_name_encoded = encode_name(self.my_name, TYPE_WORKSTATION)
66        remote_name_encoded = encode_name(self.remote_name, self.host_type)
67        self.sendNMBPacket(SESSION_REQUEST, remote_name_encoded + my_name_encoded)
68
69    #
70    # Protected Methods
71    #
72
73    def _processNMBSessionPacket(self, packet):
74        if packet.type == SESSION_MESSAGE:
75            self.onNMBSessionMessage(packet.flags, packet.data)
76        elif packet.type == POSITIVE_SESSION_RESPONSE:
77            self.onNMBSessionOK()
78        elif packet.type == NEGATIVE_SESSION_RESPONSE:
79            self.onNMBSessionFailed()
80        elif packet.type == SESSION_KEEPALIVE:
81            # Discard keepalive packets - [RFC1002]: 5.2.2.1
82            pass
83        else:
84            self.log.warning('Unrecognized NMB session type: 0x%02x', packet.type)
85
86    def _sendNMBPacket_NetBIOS(self, packet_type, data):
87        length = len(data)
88        assert length <= 0x01FFFF
89        flags = 0
90        if length > 0xFFFF:
91            flags |= 0x01
92            length &= 0xFFFF
93        self.write(struct.pack('>BBH', packet_type, flags, length) + data)
94
95    def _sendNMBPacket_DirectTCP(self, packet_type, data):
96        length = len(data)
97        assert length <= 0x00FFFFFF
98        self.write(struct.pack('>I', length) + data)
99
100
101class NBNS:
102
103    log = logging.getLogger('NMB.NBNS')
104
105    HEADER_STRUCT_FORMAT = '>HHHHHH'
106    HEADER_STRUCT_SIZE = struct.calcsize(HEADER_STRUCT_FORMAT)
107
108    def write(self, data, ip, port):
109        raise NotImplementedError
110
111    def decodePacket(self, data):
112        if len(data) < self.HEADER_STRUCT_SIZE:
113            raise Exception
114
115        trn_id, code, question_count, answer_count, authority_count, additional_count = \
116            struct.unpack(self.HEADER_STRUCT_FORMAT, data[:self.HEADER_STRUCT_SIZE])
117
118        is_response = bool((code >> 15) & 0x01)
119        opcode = (code >> 11) & 0x0F
120        flags = (code >> 4) & 0x7F
121        rcode = code & 0x0F
122
123        if opcode == 0x0000 and is_response:
124            name_len = data[self.HEADER_STRUCT_SIZE]
125            # Constant 2 for the padding bytes before/after the Name and constant 8 for the Type,
126            # Class and TTL fields in the Answer section after the Name:
127            offset = self.HEADER_STRUCT_SIZE + 2 + name_len + 8
128            record_count = (struct.unpack('>H', data[offset:offset+2])[0]) // 6
129
130            offset += 4  # Constant 4 for the Data Length and Flags field
131            ret = []
132            for i in range(0, record_count):
133                ret.append('%d.%d.%d.%d' % struct.unpack('4B', (data[offset:offset + 4])))
134                offset += 6
135            return trn_id, ret
136        else:
137            return trn_id, None
138
139    def prepareNameQuery(self, trn_id, name, is_broadcast = True):
140        header = struct.pack(self.HEADER_STRUCT_FORMAT,
141                             trn_id, (is_broadcast and 0x0110) or 0x0100, 1, 0, 0, 0)
142        payload = encode_name(name, 0x20) + b'\x00\x20\x00\x01'
143
144        return header + payload
145
146    #
147    # Contributed by Jason Anderson
148    #
149    def decodeIPQueryPacket(self, data):
150        if len(data) < self.HEADER_STRUCT_SIZE:
151            raise Exception
152
153        trn_id, code, question_count, answer_count, authority_count, additional_count = struct.unpack(self.HEADER_STRUCT_FORMAT, data[:self.HEADER_STRUCT_SIZE])
154
155        is_response = bool((code >> 15) & 0x01)
156        opcode = (code >> 11) & 0x0F
157        flags = (code >> 4) & 0x7F
158        rcode = code & 0x0F
159        numnames = data[self.HEADER_STRUCT_SIZE + 44]
160
161        if numnames > 0:
162            ret = [ ]
163            offset = self.HEADER_STRUCT_SIZE + 45
164
165            for i in range(0, numnames):
166                mynme = data[offset:offset + 15]
167                mynme = mynme.strip()
168                ret.append(( str(mynme, 'ascii'), data[offset+15] ))
169                offset += 18
170
171            return trn_id, ret
172        else:
173            return trn_id, None
174
175    #
176    # Contributed by Jason Anderson
177    #
178    def prepareNetNameQuery(self, trn_id, is_broadcast = True):
179        header = struct.pack(self.HEADER_STRUCT_FORMAT,
180                             trn_id, (is_broadcast and 0x0010) or 0x0000, 1, 0, 0, 0)
181        payload = encode_name('*', 0) + b'\x00\x21\x00\x01'
182
183        return header + payload
184