1import sys
2import struct
3import itertools
4
5assert sys.version_info[:2] >= (3,0), "This is Python 3 code"
6
7def nbits(n):
8    # Mimic mp_get_nbits for ordinary Python integers.
9    assert 0 <= n
10    smax = next(s for s in itertools.count() if (n >> (1 << s)) == 0)
11    toret = 0
12    for shift in reversed([1 << s for s in range(smax)]):
13        if n >> shift != 0:
14            n >>= shift
15            toret += shift
16    assert n <= 1
17    if n == 1:
18        toret += 1
19    return toret
20
21def ssh_byte(n):
22    return struct.pack("B", n)
23
24def ssh_uint32(n):
25    return struct.pack(">L", n)
26
27def ssh_string(s):
28    return ssh_uint32(len(s)) + s
29
30def ssh1_mpint(x):
31    bits = nbits(x)
32    bytevals = [0xFF & (x >> (8*n)) for n in range((bits-1)//8, -1, -1)]
33    return struct.pack(">H" + "B" * len(bytevals), bits, *bytevals)
34
35def ssh2_mpint(x):
36    bytevals = [0xFF & (x >> (8*n)) for n in range(nbits(x)//8, -1, -1)]
37    return struct.pack(">L" + "B" * len(bytevals), len(bytevals), *bytevals)
38
39def decoder(fn):
40    def decode(s, return_rest = False):
41        item, length_consumed = fn(s)
42        if return_rest:
43            return item, s[length_consumed:]
44        else:
45            return item
46    return decode
47
48@decoder
49def ssh_decode_byte(s):
50    return struct.unpack_from("B", s, 0)[0], 1
51
52@decoder
53def ssh_decode_uint32(s):
54    return struct.unpack_from(">L", s, 0)[0], 4
55
56@decoder
57def ssh_decode_string(s):
58    length = ssh_decode_uint32(s)
59    assert length + 4 <= len(s)
60    return s[4:length+4], length+4
61
62@decoder
63def ssh1_get_mpint(s): # returns it unconsumed, still in wire encoding
64    nbits = struct.unpack_from(">H", s, 0)[0]
65    nbytes = (nbits + 7) // 8
66    assert nbytes + 2 <= len(s)
67    return s[:nbytes+2], nbytes+2
68
69@decoder
70def ssh1_decode_mpint(s):
71    nbits = struct.unpack_from(">H", s, 0)[0]
72    nbytes = (nbits + 7) // 8
73    assert nbytes + 2 <= len(s)
74    data = s[2:nbytes+2]
75    v = 0
76    for b in struct.unpack("B" * len(data), data):
77        v = (v << 8) | b
78    return v, nbytes+2
79
80AGENT_MAX_MSGLEN = 262144
81
82SSH1_AGENTC_REQUEST_RSA_IDENTITIES = 1
83SSH1_AGENT_RSA_IDENTITIES_ANSWER = 2
84SSH1_AGENTC_RSA_CHALLENGE = 3
85SSH1_AGENT_RSA_RESPONSE = 4
86SSH1_AGENTC_ADD_RSA_IDENTITY = 7
87SSH1_AGENTC_REMOVE_RSA_IDENTITY = 8
88SSH1_AGENTC_REMOVE_ALL_RSA_IDENTITIES = 9
89SSH_AGENT_FAILURE = 5
90SSH_AGENT_SUCCESS = 6
91SSH2_AGENTC_REQUEST_IDENTITIES = 11
92SSH2_AGENT_IDENTITIES_ANSWER = 12
93SSH2_AGENTC_SIGN_REQUEST = 13
94SSH2_AGENT_SIGN_RESPONSE = 14
95SSH2_AGENTC_ADD_IDENTITY = 17
96SSH2_AGENTC_REMOVE_IDENTITY = 18
97SSH2_AGENTC_REMOVE_ALL_IDENTITIES = 19
98SSH2_AGENTC_EXTENSION = 27
99
100SSH_AGENT_RSA_SHA2_256 = 2
101SSH_AGENT_RSA_SHA2_512 = 4
102