1"""
2Extract server ssh public key from key exchange
3"""
4
5import dshell.core
6from dshell.output.alertout import AlertOutput
7import struct
8import base64
9import hashlib
10
11
12class DshellPlugin(dshell.core.ConnectionPlugin):
13
14    def __init__(self):
15        super().__init__(
16            name="ssh-pubkey",
17            author="amm",
18            description="Extract server ssh public key from key exchange",
19            bpf="tcp port 22",
20            output=AlertOutput(label=__name__)
21        )
22
23    def connection_handler(self, conn):
24
25        server_banner = ''
26        sc_blob_count = 0
27        cs_blob_count = 0
28
29        info = {}
30
31        for blob in conn.blobs:
32
33            #
34            # CS Blobs: Only interest is a client banner
35            #
36            if blob.direction == 'cs':
37                cs_blob_count += 1
38                if cs_blob_count > 1:
39                    continue
40                else:
41                    blob.reassemble(allow_overlap=True, allow_padding=True)
42                    if not blob.data:
43                        continue
44                    info['clientbanner'] = blob.data.split(b'\x0d')[0].rstrip()
45                    if not info['clientbanner'].startswith(b'SSH'):
46                        return conn  # NOT AN SSH CONNECTION
47                    try:
48                        info['clientbanner'] = info['clientbanner'].decode(
49                            'utf-8')
50                    except UnicodeDecodeError:
51                        return conn
52                    continue
53
54            #
55            # SC Blobs: Banner and public key
56            #
57            sc_blob_count += 1
58            blob.reassemble(allow_overlap=True, allow_padding=True)
59            if not blob.data:
60                continue
61            d = blob.data
62
63            # Server Banner
64            if sc_blob_count == 1:
65                info['serverbanner'] = d.split(b'\x0d')[0].rstrip()
66                if not info['serverbanner'].startswith(b'SSH'):
67                    return conn  # NOT AN SSH CONNECTION
68                try:
69                    info['serverbanner'] = info['serverbanner'].decode('utf-8')
70                except UnicodeDecodeError:
71                    pass
72                continue
73
74            # Key Exchange Packet/Messages
75            mlist = messagefactory(d)
76            stop_blobs = False
77            for m in mlist:
78                if m.message_code == 31 or m.message_code == 33:
79                    info['host_pubkey'] = m.host_pub_key
80                    stop_blobs = True
81                    break
82            if stop_blobs:
83                break
84
85        #print(repr(info))
86
87        if 'host_pubkey' in info:
88            # Calculate key fingerprints
89            info['host_fingerprints'] = {}
90            for hash_scheme in ("md5", "sha1", "sha256"):
91                hashfunction = eval("hashlib."+hash_scheme)
92                thisfp = key_fingerprint(info['host_pubkey'], hashfunction)
93                info['host_fingerprints'][hash_scheme] = ':'.join(
94                    ['%02x' % b for b in thisfp])
95
96            msg = "%s" % (info['host_pubkey'])
97            self.write(msg, **info, **conn.info())
98            return conn
99
100
101def messagefactory(data):
102
103    datalen = len(data)
104    offset = 0
105    msglist = []
106    while offset < datalen:
107        try:
108            msg = sshmessage(data[offset:])
109        except ValueError:
110            return msglist
111        msglist.append(msg)
112        offset += msg.packet_len + 4
113
114    return msglist
115
116
117class sshmessage:
118
119    def __init__(self, rawdata):
120        self.__parse_raw(rawdata)
121
122    def __parse_raw(self, data):
123        datalen = len(data)
124        if datalen < 6:
125            raise ValueError
126
127        (self.packet_len, self.padding_len,
128         self.message_code) = struct.unpack(">IBB", data[0:6])
129        if datalen < self.packet_len + 4:
130            raise ValueError
131        self.body = data[6:4+self.packet_len]
132
133        # ECDH Kex Reply
134        if self.message_code == 31 or self.message_code == 33:
135            host_key_len = struct.unpack(">I", self.body[0:4])[0]
136            full_key_net = self.body[4:4+host_key_len]
137            key_type_name_len = struct.unpack(">I", full_key_net[0:4])[0]
138            key_type_name = full_key_net[4:4+key_type_name_len]
139            key_data = full_key_net[4+key_type_name_len:]
140            if key_type_name_len > 50:
141                # something went wrong
142                # this probably isn't a code 31
143                self.message_code = 0
144            else:
145                self.host_pub_key = "%s %s" % (key_type_name.decode(
146                    'utf-8'), base64.b64encode(full_key_net).decode('utf-8'))
147
148
149def key_fingerprint(ssh_pubkey, hashfunction=hashlib.sha256):
150
151    # Treat as bytes, not string
152    if type(ssh_pubkey) == str:
153        ssh_pubkey = ssh_pubkey.encode('utf-8')
154
155    # Strip space from end
156    ssh_pubkey = ssh_pubkey.rstrip(b"\r\n\0 ")
157
158    # Only look at first line
159    ssh_pubkey = ssh_pubkey.split(b"\n")[0]
160    # If two spaces, look at middle segment
161    if ssh_pubkey.count(b" ") >= 1:
162        ssh_pubkey = ssh_pubkey.split(b" ")[1]
163
164    # Try to decode key as base64
165    try:
166        keybin = base64.b64decode(ssh_pubkey)
167    except:
168        sys.stderr.write("Invalid key value:\n")
169        sys.stderr.write("  \"%s\":\n" % ssh_pubkey)
170        return None
171
172    # Fingerprint
173    return hashfunction(keybin).digest()
174
175
176if __name__ == "__main__":
177    print(DshellPlugin())
178