1package transfer
2
3import (
4	"crypto/md5"
5	"crypto/rand"
6	"encoding/base64"
7	"encoding/hex"
8	"errors"
9	"fmt"
10	"net"
11	"strings"
12
13	"github.com/colinmarc/hdfs/v2/internal/sasl"
14)
15
16const (
17	saslIntegrityPrefixLength = 4
18	macDataLen                = 4
19	macHMACLen                = 10
20	macMsgTypeLen             = 2
21	macSeqNumLen              = 4
22)
23
24var macMsgType = [2]byte{0x00, 0x01}
25
26type digestMD5Conn interface {
27	net.Conn
28	decode(input []byte) ([]byte, error)
29}
30
31// digestMD5Handshake represents the negotiation state in a token-digestmd5
32// authentication flow.
33type digestMD5Handshake struct {
34	authID   []byte
35	passwd   string
36	hostname string
37	service  string
38
39	token *sasl.Challenge
40
41	cnonce string
42	cipher string
43}
44
45// challengeStep1 implements step one of RFC 2831.
46func (d *digestMD5Handshake) challengeStep1(challenge []byte) ([]byte, error) {
47	var err error
48	d.token, err = sasl.ParseChallenge(challenge)
49	if err != nil {
50		return nil, err
51	}
52
53	d.cnonce, err = genCnonce()
54	if err != nil {
55		return nil, err
56	}
57
58	d.cipher = chooseCipher(d.token.Cipher)
59	rspdigest := d.compute(true)
60
61	ret := fmt.Sprintf(`username="%s", realm="%s", nonce="%s", cnonce="%s", nc=%08x, qop=%s, digest-uri="%s/%s", response=%s, charset=utf-8`,
62		d.authID, d.token.Realm, d.token.Nonce, d.cnonce, 1, d.token.Qop[0], d.service, d.hostname, rspdigest)
63
64	if d.cipher != "" {
65		ret += ", cipher=" + d.cipher
66	}
67
68	return []byte(ret), nil
69}
70
71// challengeStep2 implements step two of RFC 2831.
72func (d *digestMD5Handshake) challengeStep2(challenge []byte) error {
73	rspauth := strings.Split(string(challenge), "=")
74
75	if rspauth[0] != "rspauth" {
76		return fmt.Errorf("rspauth not in '%s'", string(challenge))
77	}
78
79	if rspauth[1] != d.compute(false) {
80		return errors.New("rspauth did not match digest")
81	}
82
83	return nil
84}
85
86// compute implements the computation of md5 digest authentication per RFC 2831.
87// The response value computation is defined as:
88//
89//     HEX(KD(HEX(H(A1)),
90//       { nonce-value, ":", nc-value, ":", cnonce-value, ":", qop-value,
91//         ":", HEX(H(A2)) }))
92//     A1 = { H({ username-value, ":", realm-value, ":", passwd }),
93//            ":", nonce-value, ":", cnonce-value }
94//
95//   If "qop" is "auth":
96//
97//		 A2 = { "AUTHENTICATE:", digest-uri-value }
98//
99//   If "qop" is "auth-int" or "auth-conf":
100//
101//       A2 = { "AUTHENTICATE:", digest-uri-value,
102//              ":00000000000000000000000000000000" }
103//
104//   Where:
105//
106//     - { a, b, ... } is the concatenation of the octet strings a, b, ...
107//     - H(s) is the 16 octet MD5 Hash [RFC1321] of the octet string s
108//     - KD(k, s) is H({k, ":", s})
109//     - HEX(n) is the representation of the 16 octet MD5 hash n as a string of
110//       32 hex digits (with alphabetic characters in lower case)
111func (d *digestMD5Handshake) compute(initial bool) string {
112	x := hex.EncodeToString(h(d.a1()))
113	y := strings.Join([]string{
114		d.token.Nonce,
115		fmt.Sprintf("%08x", 1),
116		d.cnonce,
117		d.token.Qop[0],
118		hex.EncodeToString(h(d.a2(initial))),
119	}, ":")
120	return hex.EncodeToString(kd(x, y))
121}
122
123func (d *digestMD5Handshake) a1() string {
124	x := h(strings.Join([]string{string(d.authID), d.token.Realm, d.passwd}, ":"))
125	return strings.Join([]string{string(x[:]), d.token.Nonce, d.cnonce}, ":")
126
127}
128
129func (d *digestMD5Handshake) a2(initial bool) string {
130	digestURI := d.service + "/" + d.hostname
131	var a2 string
132
133	// When validating the server's response-auth, we need to leave out the
134	// 'AUTHENTICATE:' prefix.
135	if initial {
136		a2 = strings.Join([]string{"AUTHENTICATE", digestURI}, ":")
137	} else {
138		a2 = ":" + digestURI
139	}
140
141	if d.token.Qop[0] == sasl.QopPrivacy || d.token.Qop[0] == sasl.QopIntegrity {
142		a2 = a2 + ":00000000000000000000000000000000"
143	}
144
145	return a2
146}
147
148// Defined this way for testing.
149var genCnonce = func() (string, error) {
150	ret := make([]byte, 12)
151	if _, err := rand.Read(ret); err != nil {
152		return "", err
153	}
154	return base64.StdEncoding.EncodeToString(ret), nil
155}
156
157func h(s string) []byte {
158	hash := md5.Sum([]byte(s))
159	return hash[:]
160}
161
162func kd(k, s string) []byte {
163	return h(k + ":" + s)
164}
165
166func generateIntegrityKeys(a1 string) ([]byte, []byte) {
167	clientIntMagicStr := []byte("Digest session key to client-to-server signing key magic constant")
168	serverIntMagicStr := []byte("Digest session key to server-to-client signing key magic constant")
169
170	sum := h(a1)
171	kic := md5.Sum(append(sum[:], clientIntMagicStr...))
172	kis := md5.Sum(append(sum[:], serverIntMagicStr...))
173
174	return kic[:], kis[:]
175}
176
177func generatePrivacyKeys(a1 string, cipher string) ([]byte, []byte) {
178	sum := h(a1)
179	var n int
180	switch cipher {
181	case "rc4-40":
182		n = 5
183	case "rc4-56":
184		n = 7
185	default:
186		n = md5.Size
187	}
188
189	kcc := md5.Sum(append(sum[:n],
190		[]byte("Digest H(A1) to client-to-server sealing key magic constant")...))
191	kcs := md5.Sum(append(sum[:n],
192		[]byte("Digest H(A1) to server-to-client sealing key magic constant")...))
193
194	return kcc[:], kcs[:]
195}
196
197func chooseCipher(options []string) string {
198	s := make(map[string]bool)
199	for _, c := range options {
200		s[c] = true
201	}
202
203	// TODO: Support 3DES
204
205	switch {
206	case s["rc4"]:
207		return "rc4"
208	case s["rc4-56"]:
209		return "rc4-56"
210	case s["rc4-40"]:
211		return "rc4-40"
212	default:
213		return ""
214	}
215}
216
217func lenEncodeBytes(seqnum int) (out [4]byte) {
218	out[0] = byte((seqnum >> 24) & 0xFF)
219	out[1] = byte((seqnum >> 16) & 0xFF)
220	out[2] = byte((seqnum >> 8) & 0xFF)
221	out[3] = byte(seqnum & 0xFF)
222	return
223}
224