1/*
2NNCP -- Node to Node copy, utilities for store-and-forward data exchange
3Copyright (C) 2016-2021 Sergey Matveev <stargrave@stargrave.org>
4
5This program is free software: you can redistribute it and/or modify
6it under the terms of the GNU General Public License as published by
7the Free Software Foundation, version 3 of the License.
8
9This program is distributed in the hope that it will be useful,
10but WITHOUT ANY WARRANTY; without even the implied warranty of
11MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12GNU General Public License for more details.
13
14You should have received a copy of the GNU General Public License
15along with this program.  If not, see <http://www.gnu.org/licenses/>.
16*/
17
18package nncp
19
20import (
21	"bytes"
22	"crypto/cipher"
23	"crypto/rand"
24	"errors"
25	"io"
26
27	xdr "github.com/davecgh/go-xdr/xdr2"
28	"golang.org/x/crypto/chacha20poly1305"
29	"golang.org/x/crypto/curve25519"
30	"golang.org/x/crypto/ed25519"
31	"golang.org/x/crypto/nacl/box"
32	"golang.org/x/crypto/poly1305"
33	"lukechampine.com/blake3"
34)
35
36type PktType uint8
37
38const (
39	EncBlkSize = 128 * (1 << 10)
40
41	PktTypeFile    PktType = iota
42	PktTypeFreq    PktType = iota
43	PktTypeExec    PktType = iota
44	PktTypeTrns    PktType = iota
45	PktTypeExecFat PktType = iota
46	PktTypeArea    PktType = iota
47
48	MaxPathSize = 1<<8 - 1
49
50	NNCPBundlePrefix = "NNCP"
51)
52
53var (
54	BadPktType error = errors.New("Unknown packet type")
55
56	DeriveKeyFullCtx = string(MagicNNCPEv6.B[:]) + " FULL"
57	DeriveKeySizeCtx = string(MagicNNCPEv6.B[:]) + " SIZE"
58	DeriveKeyPadCtx  = string(MagicNNCPEv6.B[:]) + " PAD"
59
60	PktOverhead     int64
61	PktEncOverhead  int64
62	PktSizeOverhead int64
63
64	TooBig = errors.New("Too big than allowed")
65)
66
67type Pkt struct {
68	Magic   [8]byte
69	Type    PktType
70	Nice    uint8
71	PathLen uint8
72	Path    [MaxPathSize]byte
73}
74
75type PktTbs struct {
76	Magic     [8]byte
77	Nice      uint8
78	Sender    *NodeId
79	Recipient *NodeId
80	ExchPub   [32]byte
81}
82
83type PktEnc struct {
84	Magic     [8]byte
85	Nice      uint8
86	Sender    *NodeId
87	Recipient *NodeId
88	ExchPub   [32]byte
89	Sign      [ed25519.SignatureSize]byte
90}
91
92type PktSize struct {
93	Payload uint64
94	Pad     uint64
95}
96
97func NewPkt(typ PktType, nice uint8, path []byte) (*Pkt, error) {
98	if len(path) > MaxPathSize {
99		return nil, errors.New("Too long path")
100	}
101	pkt := Pkt{
102		Magic:   MagicNNCPPv3.B,
103		Type:    typ,
104		Nice:    nice,
105		PathLen: uint8(len(path)),
106	}
107	copy(pkt.Path[:], path)
108	return &pkt, nil
109}
110
111func init() {
112	var buf bytes.Buffer
113	pkt := Pkt{Type: PktTypeFile}
114	n, err := xdr.Marshal(&buf, pkt)
115	if err != nil {
116		panic(err)
117	}
118	PktOverhead = int64(n)
119	buf.Reset()
120
121	dummyId, err := NodeIdFromString(DummyB32Id)
122	if err != nil {
123		panic(err)
124	}
125	pktEnc := PktEnc{
126		Magic:     MagicNNCPEv6.B,
127		Sender:    dummyId,
128		Recipient: dummyId,
129	}
130	n, err = xdr.Marshal(&buf, pktEnc)
131	if err != nil {
132		panic(err)
133	}
134	PktEncOverhead = int64(n)
135	buf.Reset()
136
137	size := PktSize{}
138	n, err = xdr.Marshal(&buf, size)
139	if err != nil {
140		panic(err)
141	}
142	PktSizeOverhead = int64(n)
143}
144
145func ctrIncr(b []byte) {
146	for i := len(b) - 1; i >= 0; i-- {
147		b[i]++
148		if b[i] != 0 {
149			return
150		}
151	}
152	panic("counter overflow")
153}
154
155func TbsPrepare(our *NodeOur, their *Node, pktEnc *PktEnc) []byte {
156	tbs := PktTbs{
157		Magic:     MagicNNCPEv6.B,
158		Nice:      pktEnc.Nice,
159		Sender:    their.Id,
160		Recipient: our.Id,
161		ExchPub:   pktEnc.ExchPub,
162	}
163	var tbsBuf bytes.Buffer
164	if _, err := xdr.Marshal(&tbsBuf, &tbs); err != nil {
165		panic(err)
166	}
167	return tbsBuf.Bytes()
168}
169
170func TbsVerify(our *NodeOur, their *Node, pktEnc *PktEnc) ([]byte, bool, error) {
171	tbs := TbsPrepare(our, their, pktEnc)
172	return tbs, ed25519.Verify(their.SignPub, tbs, pktEnc.Sign[:]), nil
173}
174
175func sizeWithTags(size int64) (fullSize int64) {
176	size += PktSizeOverhead
177	fullSize = size + (size/EncBlkSize)*poly1305.TagSize
178	if size%EncBlkSize != 0 {
179		fullSize += poly1305.TagSize
180	}
181	return
182}
183
184func sizePadCalc(sizePayload, minSize int64, wrappers int) (sizePad int64) {
185	expectedSize := sizePayload - PktOverhead
186	for i := 0; i < wrappers; i++ {
187		expectedSize = PktEncOverhead + sizeWithTags(PktOverhead+expectedSize)
188	}
189	sizePad = minSize - expectedSize
190	if sizePad < 0 {
191		sizePad = 0
192	}
193	return
194}
195
196func PktEncWrite(
197	our *NodeOur, their *Node,
198	pkt *Pkt, nice uint8,
199	minSize, maxSize int64, wrappers int,
200	r io.Reader, w io.Writer,
201) (pktEncRaw []byte, size int64, err error) {
202	pub, prv, err := box.GenerateKey(rand.Reader)
203	if err != nil {
204		return nil, 0, err
205	}
206
207	var buf bytes.Buffer
208	_, err = xdr.Marshal(&buf, pkt)
209	if err != nil {
210		return
211	}
212	pktRaw := make([]byte, buf.Len())
213	copy(pktRaw, buf.Bytes())
214	buf.Reset()
215
216	tbs := PktTbs{
217		Magic:     MagicNNCPEv6.B,
218		Nice:      nice,
219		Sender:    our.Id,
220		Recipient: their.Id,
221		ExchPub:   *pub,
222	}
223	_, err = xdr.Marshal(&buf, &tbs)
224	if err != nil {
225		return
226	}
227	signature := new([ed25519.SignatureSize]byte)
228	copy(signature[:], ed25519.Sign(our.SignPrv, buf.Bytes()))
229	ad := blake3.Sum256(buf.Bytes())
230	buf.Reset()
231
232	pktEnc := PktEnc{
233		Magic:     MagicNNCPEv6.B,
234		Nice:      nice,
235		Sender:    our.Id,
236		Recipient: their.Id,
237		ExchPub:   *pub,
238		Sign:      *signature,
239	}
240	_, err = xdr.Marshal(&buf, &pktEnc)
241	if err != nil {
242		return
243	}
244	pktEncRaw = make([]byte, buf.Len())
245	copy(pktEncRaw, buf.Bytes())
246	buf.Reset()
247	_, err = w.Write(pktEncRaw)
248	if err != nil {
249		return
250	}
251
252	sharedKey := new([32]byte)
253	curve25519.ScalarMult(sharedKey, prv, their.ExchPub)
254	keyFull := make([]byte, chacha20poly1305.KeySize)
255	keySize := make([]byte, chacha20poly1305.KeySize)
256	blake3.DeriveKey(keyFull, DeriveKeyFullCtx, sharedKey[:])
257	blake3.DeriveKey(keySize, DeriveKeySizeCtx, sharedKey[:])
258	aeadFull, err := chacha20poly1305.New(keyFull)
259	if err != nil {
260		return
261	}
262	aeadSize, err := chacha20poly1305.New(keySize)
263	if err != nil {
264		return
265	}
266	nonce := make([]byte, aeadFull.NonceSize())
267
268	data := make([]byte, EncBlkSize, EncBlkSize+aeadFull.Overhead())
269	mr := io.MultiReader(bytes.NewReader(pktRaw), r)
270	var sizePayload int64
271	var n int
272	var ct []byte
273	for {
274		n, err = io.ReadFull(mr, data)
275		sizePayload += int64(n)
276		if sizePayload > maxSize {
277			err = TooBig
278			return
279		}
280		if err == nil {
281			ct = aeadFull.Seal(data[:0], nonce, data[:n], ad[:])
282			_, err = w.Write(ct)
283			if err != nil {
284				return
285			}
286			ctrIncr(nonce)
287			continue
288		}
289		if !(err == io.EOF || err == io.ErrUnexpectedEOF) {
290			return
291		}
292		break
293	}
294
295	sizePad := sizePadCalc(sizePayload, minSize, wrappers)
296	_, err = xdr.Marshal(&buf, &PktSize{uint64(sizePayload), uint64(sizePad)})
297	if err != nil {
298		return
299	}
300
301	var aeadLast cipher.AEAD
302	if n+int(PktSizeOverhead) > EncBlkSize {
303		left := make([]byte, (n+int(PktSizeOverhead))-EncBlkSize)
304		copy(left, data[n-len(left):])
305		copy(data[PktSizeOverhead:], data[:n-len(left)])
306		copy(data[:PktSizeOverhead], buf.Bytes())
307		ct = aeadSize.Seal(data[:0], nonce, data[:EncBlkSize], ad[:])
308		_, err = w.Write(ct)
309		if err != nil {
310			return
311		}
312		ctrIncr(nonce)
313		copy(data, left)
314		n = len(left)
315		aeadLast = aeadFull
316	} else {
317		copy(data[PktSizeOverhead:], data[:n])
318		copy(data[:PktSizeOverhead], buf.Bytes())
319		n += int(PktSizeOverhead)
320		aeadLast = aeadSize
321	}
322
323	var sizeBlockPadded int
324	var sizePadLeft int64
325	if sizePad > EncBlkSize-int64(n) {
326		sizeBlockPadded = EncBlkSize
327		sizePadLeft = sizePad - (EncBlkSize - int64(n))
328	} else {
329		sizeBlockPadded = n + int(sizePad)
330		sizePadLeft = 0
331	}
332	for i := n; i < sizeBlockPadded; i++ {
333		data[i] = 0
334	}
335	ct = aeadLast.Seal(data[:0], nonce, data[:sizeBlockPadded], ad[:])
336	_, err = w.Write(ct)
337	if err != nil {
338		return
339	}
340
341	size = sizePayload
342	if sizePadLeft > 0 {
343		keyPad := make([]byte, chacha20poly1305.KeySize)
344		blake3.DeriveKey(keyPad, DeriveKeyPadCtx, sharedKey[:])
345		_, err = io.CopyN(w, blake3.New(32, keyPad).XOF(), sizePadLeft)
346	}
347	return
348}
349
350func PktEncRead(
351	our *NodeOur, nodes map[NodeId]*Node,
352	r io.Reader, w io.Writer,
353	signatureVerify bool,
354	sharedKeyCached []byte,
355) (sharedKey []byte, their *Node, size int64, err error) {
356	var pktEnc PktEnc
357	_, err = xdr.Unmarshal(r, &pktEnc)
358	if err != nil {
359		return
360	}
361	switch pktEnc.Magic {
362	case MagicNNCPEv1.B:
363		err = MagicNNCPEv1.TooOld()
364	case MagicNNCPEv2.B:
365		err = MagicNNCPEv2.TooOld()
366	case MagicNNCPEv3.B:
367		err = MagicNNCPEv3.TooOld()
368	case MagicNNCPEv4.B:
369		err = MagicNNCPEv4.TooOld()
370	case MagicNNCPEv5.B:
371		err = MagicNNCPEv5.TooOld()
372	case MagicNNCPEv6.B:
373	default:
374		err = BadMagic
375	}
376	if err != nil {
377		return
378	}
379	if *pktEnc.Recipient != *our.Id {
380		err = errors.New("Invalid recipient")
381		return
382	}
383
384	var tbsRaw []byte
385	if signatureVerify {
386		their = nodes[*pktEnc.Sender]
387		if their == nil {
388			err = errors.New("Unknown sender")
389			return
390		}
391		var verified bool
392		tbsRaw, verified, err = TbsVerify(our, their, &pktEnc)
393		if err != nil {
394			return
395		}
396		if !verified {
397			err = errors.New("Invalid signature")
398			return
399		}
400	} else {
401		tbsRaw = TbsPrepare(our, &Node{Id: pktEnc.Sender}, &pktEnc)
402	}
403	ad := blake3.Sum256(tbsRaw)
404	if sharedKeyCached == nil {
405		key := new([32]byte)
406		curve25519.ScalarMult(key, our.ExchPrv, &pktEnc.ExchPub)
407		sharedKey = key[:]
408	} else {
409		sharedKey = sharedKeyCached
410	}
411
412	keyFull := make([]byte, chacha20poly1305.KeySize)
413	keySize := make([]byte, chacha20poly1305.KeySize)
414	blake3.DeriveKey(keyFull, DeriveKeyFullCtx, sharedKey[:])
415	blake3.DeriveKey(keySize, DeriveKeySizeCtx, sharedKey[:])
416	aeadFull, err := chacha20poly1305.New(keyFull)
417	if err != nil {
418		return
419	}
420	aeadSize, err := chacha20poly1305.New(keySize)
421	if err != nil {
422		return
423	}
424	nonce := make([]byte, aeadFull.NonceSize())
425
426	ct := make([]byte, EncBlkSize+aeadFull.Overhead())
427	pt := make([]byte, EncBlkSize)
428	var n int
429FullRead:
430	for {
431		n, err = io.ReadFull(r, ct)
432		switch err {
433		case nil:
434			pt, err = aeadFull.Open(pt[:0], nonce, ct, ad[:])
435			if err != nil {
436				break FullRead
437			}
438			size += EncBlkSize
439			_, err = w.Write(pt)
440			if err != nil {
441				return
442			}
443			ctrIncr(nonce)
444			continue
445		case io.ErrUnexpectedEOF:
446			break FullRead
447		default:
448			return
449		}
450	}
451
452	pt, err = aeadSize.Open(pt[:0], nonce, ct[:n], ad[:])
453	if err != nil {
454		return
455	}
456	var pktSize PktSize
457	_, err = xdr.Unmarshal(bytes.NewReader(pt), &pktSize)
458	if err != nil {
459		return
460	}
461	pt = pt[PktSizeOverhead:]
462
463	left := int64(pktSize.Payload) - size
464	for left > int64(len(pt)) {
465		size += int64(len(pt))
466		left -= int64(len(pt))
467		_, err = w.Write(pt)
468		if err != nil {
469			return
470		}
471		n, err = io.ReadFull(r, ct)
472		if err != nil && err != io.ErrUnexpectedEOF {
473			return
474		}
475		ctrIncr(nonce)
476		pt, err = aeadFull.Open(pt[:0], nonce, ct[:n], ad[:])
477		if err != nil {
478			return
479		}
480	}
481	size += left
482	_, err = w.Write(pt[:left])
483	if err != nil {
484		return
485	}
486	pt = pt[left:]
487
488	if pktSize.Pad < uint64(len(pt)) {
489		err = errors.New("unexpected pad")
490		return
491	}
492	for i := 0; i < len(pt); i++ {
493		if pt[i] != 0 {
494			err = errors.New("non-zero pad byte")
495			return
496		}
497	}
498	sizePad := int64(pktSize.Pad) - int64(len(pt))
499	if sizePad == 0 {
500		return
501	}
502
503	keyPad := make([]byte, chacha20poly1305.KeySize)
504	blake3.DeriveKey(keyPad, DeriveKeyPadCtx, sharedKey[:])
505	xof := blake3.New(32, keyPad).XOF()
506	pt = make([]byte, len(ct))
507	for sizePad > 0 {
508		n, err = io.ReadFull(r, ct)
509		if err != nil && err != io.ErrUnexpectedEOF {
510			return
511		}
512		_, err = io.ReadFull(xof, pt[:n])
513		if err != nil {
514			panic(err)
515		}
516		if bytes.Compare(ct[:n], pt[:n]) != 0 {
517			err = errors.New("wrong pad value")
518			return
519		}
520		sizePad -= int64(n)
521	}
522	if sizePad < 0 {
523		err = errors.New("excess pad")
524	}
525	return
526}
527