1// Copyright (c) 2019 Andreas Auernhammer. All rights reserved.
2// Use of this source code is governed by a license that can be
3// found in the LICENSE file.
4
5package sio
6
7import (
8	"crypto/cipher"
9	"encoding/binary"
10	"io"
11	"io/ioutil"
12	"math"
13	"sync"
14)
15
16// An EncReader encrypts and authenticates everything it reads
17// from an underlying io.Reader.
18type EncReader struct {
19	r       io.Reader
20	cipher  cipher.AEAD
21	bufSize int
22
23	seqNum         uint32
24	nonce          []byte
25	associatedData []byte
26
27	buffer           []byte
28	ciphertextBuffer []byte
29	offset           int
30
31	err               error
32	carry             byte
33	firstRead, closed bool
34}
35
36// Read behaves as specified by the io.Reader interface.
37// In particular, Read reads up to len(p) encrypted bytes
38// into p. It returns the number of bytes read (0 <= n <= len(p))
39// and any error encountered while reading from the underlying
40// io.Reader.
41//
42// When Read cannot encrypt any more bytes securely it returns
43// ErrExceeded.
44func (r *EncReader) Read(p []byte) (n int, err error) {
45	if r.err != nil {
46		return n, r.err
47	}
48	if r.firstRead {
49		r.firstRead = false
50		n, err = r.readFragment(p, 0)
51		if err != nil {
52			return n, err
53		}
54		p = p[n:]
55	}
56	if r.offset > 0 {
57		nn := copy(p, r.ciphertextBuffer[r.offset:])
58		n += nn
59		if nn == len(p) {
60			r.offset += nn
61			return n, nil
62		}
63		p = p[nn:]
64		r.offset = 0
65	}
66	if r.closed {
67		return n, io.EOF
68	}
69	nn, err := r.readFragment(p, 1)
70	return n + nn, err
71}
72
73// ReadByte behaves as specified by the io.ByteReader
74// interface. In particular, ReadByte returns the next
75// encrypted byte or any error encountered.
76//
77// When ReadByte cannot encrypt one more byte
78// securely it returns ErrExceeded.
79func (r *EncReader) ReadByte() (byte, error) {
80	if r.err != nil {
81		return 0, r.err
82	}
83	if r.firstRead {
84		r.firstRead = false
85		if _, err := r.readFragment(nil, 0); err != nil {
86			return 0, err
87		}
88		b := r.ciphertextBuffer[0]
89		r.offset = 1
90		return b, nil
91	}
92
93	if r.offset > 0 && r.offset < len(r.ciphertextBuffer) {
94		b := r.ciphertextBuffer[r.offset]
95		r.offset++
96		return b, nil
97	}
98	if r.closed {
99		return 0, io.EOF
100	}
101
102	r.offset = 0
103	if _, err := r.readFragment(nil, 1); err != nil {
104		return 0, err
105	}
106	b := r.ciphertextBuffer[0]
107	r.offset = 1
108	return b, nil
109}
110
111// WriteTo behaves as specified by the io.WriterTo
112// interface. In particular, WriteTo writes encrypted
113// data to w until either there's no more data to write,
114// an error occurs or no more data can be encrypted
115// securely. When WriteTo cannot encrypt more data
116// securely it returns ErrExceeded.
117func (r *EncReader) WriteTo(w io.Writer) (int64, error) {
118	var n int64
119	if r.firstRead {
120		r.firstRead = false
121		nn, err := r.readFragment(r.buffer, 0)
122		if err != nil && err != io.EOF {
123			return n, err
124		}
125		nn, err = writeTo(w, r.buffer[:nn])
126		if err != nil {
127			return n, err
128		}
129		n += int64(nn)
130		if r.closed {
131			return n, nil
132		}
133	}
134	if r.err != nil {
135		return n, r.err
136	}
137	if r.offset > 0 {
138		nn, err := writeTo(w, r.ciphertextBuffer[r.offset:])
139		if err != nil {
140			r.err = err
141			return n, err
142		}
143		r.offset = 0
144		n += int64(nn)
145	}
146	if r.closed {
147		return n, io.EOF
148	}
149	for {
150		nn, err := r.readFragment(r.buffer, 1)
151		if err != nil && err != io.EOF {
152			return n, err
153		}
154		nn, err = writeTo(w, r.buffer[:nn])
155		if err != nil {
156			r.err = err
157			return n, err
158		}
159		n += int64(nn)
160		if r.closed {
161			return n, nil
162		}
163	}
164}
165
166func (r *EncReader) readFragment(p []byte, firstReadOffset int) (int, error) {
167	if r.seqNum == 0 {
168		r.err = ErrExceeded
169		return 0, r.err
170	}
171	binary.LittleEndian.PutUint32(r.nonce[r.cipher.NonceSize()-4:], r.seqNum)
172	r.seqNum++
173
174	r.buffer[0] = r.carry
175	n, err := readFrom(r.r, r.buffer[firstReadOffset:1+r.bufSize])
176	switch {
177	default:
178		r.carry = r.buffer[r.bufSize]
179		if len(p) < r.bufSize+r.cipher.Overhead() {
180			r.ciphertextBuffer = r.cipher.Seal(r.buffer[:0], r.nonce, r.buffer[:r.bufSize], r.associatedData)
181			r.offset = copy(p, r.ciphertextBuffer)
182			return r.offset, nil
183		}
184		r.cipher.Seal(p[:0], r.nonce, r.buffer[:r.bufSize], r.associatedData)
185		return r.bufSize + r.cipher.Overhead(), nil
186	case err == io.EOF:
187		r.closed = true
188		r.associatedData[0] = 0x80
189		if len(p) < firstReadOffset+n+r.cipher.Overhead() {
190			r.ciphertextBuffer = r.cipher.Seal(r.buffer[:0], r.nonce, r.buffer[:firstReadOffset+n], r.associatedData)
191			r.offset = copy(p, r.ciphertextBuffer)
192			return r.offset, nil
193		}
194		r.cipher.Seal(p[:0], r.nonce, r.buffer[:firstReadOffset+n], r.associatedData)
195		return firstReadOffset + n + r.cipher.Overhead(), io.EOF
196	case err != nil:
197		r.err = err
198		return 0, r.err
199	}
200}
201
202// A DecReader decrypts and verifies everything it reads
203// from an underlying io.Reader. A DecReader never returns
204// invalid (i.e. not authentic) data.
205type DecReader struct {
206	r      io.Reader
207	cipher cipher.AEAD
208
209	bufSize        int
210	seqNum         uint32
211	nonce          []byte
212	associatedData []byte
213
214	buffer          []byte
215	plaintextBuffer []byte
216	offset          int
217
218	err               error
219	carry             byte
220	firstRead, closed bool
221}
222
223// Read behaves like specified by the io.Reader interface.
224// In particular, Read reads up to len(p) decrypted bytes
225// into p. It returns the number of bytes read (0 <= n <= len(p))
226// and any error encountered while reading from the underlying
227// io.Reader.
228//
229// When Read fails to decrypt some data returned by the underlying
230// io.Reader it returns NotAuthentic. This error indicates
231// that the encrypted data has been (maliciously) modified.
232//
233// When Read cannot decrypt more bytes securely it returns
234// ErrExceeded. However, this can only happen when the
235// underlying io.Reader returns valid but too many
236// encrypted bytes. Therefore, ErrExceeded indicates
237// a misbehaving producer of encrypted data.
238func (r *DecReader) Read(p []byte) (n int, err error) {
239	if r.err != nil {
240		return n, r.err
241	}
242	if r.firstRead {
243		r.firstRead = false
244		n, err = r.readFragment(p, 0)
245		if err != nil {
246			return n, err
247		}
248		p = p[n:]
249	}
250	if r.offset > 0 {
251		nn := copy(p, r.plaintextBuffer[r.offset:])
252		n += nn
253		if nn == len(p) {
254			r.offset += nn
255			return n, nil
256		}
257		p = p[nn:]
258		r.offset = 0
259	}
260	if r.closed {
261		return n, io.EOF
262	}
263	nn, err := r.readFragment(p, 1)
264	return n + nn, err
265}
266
267// ReadByte behaves as specified by the io.ByteReader
268// interface. In particular, ReadByte returns the next
269// decrypted byte or any error encountered.
270//
271// When ReadByte fails to decrypt the next byte returned by
272// the underlying io.Reader it returns NotAuthentic. This
273// error indicates that the encrypted byte has been
274// (maliciously) modified.
275//
276// When Read cannot decrypt one more byte securely it
277// returns ErrExceeded. However, this can only happen
278// when the underlying io.Reader returns valid but too
279// many encrypted bytes. Therefore, ErrExceeded indicates
280// a misbehaving producer of encrypted data.
281func (r *DecReader) ReadByte() (byte, error) {
282	if r.err != nil {
283		return 0, r.err
284	}
285	if r.firstRead {
286		r.firstRead = false
287		if _, err := r.readFragment(nil, 0); err != nil {
288			return 0, err
289		}
290		b := r.plaintextBuffer[0]
291		r.offset = 1
292		return b, nil
293	}
294	if r.offset > 0 && r.offset < len(r.plaintextBuffer) {
295		b := r.plaintextBuffer[r.offset]
296		r.offset++
297		return b, nil
298	}
299	if r.closed {
300		return 0, io.EOF
301	}
302
303	r.offset = 0
304	if _, err := r.readFragment(nil, 1); err != nil {
305		return 0, err
306	}
307	b := r.plaintextBuffer[0]
308	r.offset = 1
309	return b, nil
310}
311
312// WriteTo behaves as specified by the io.WriterTo
313// interface. In particular, WriteTo writes decrypted
314// data to w until either there's no more data to write,
315// an error occurs or the encrypted data is invalid.
316//
317// When WriteTo fails to decrypt some data it returns
318// NotAuthentic. This error indicates that the encrypted
319// bytes has been (maliciously) modified.
320//
321// When WriteTo cannot decrypt any more bytes securely it
322// returns ErrExceeded. However, this can only happen
323// when the underlying io.Reader returns valid but too
324// many encrypted bytes. Therefore, ErrExceeded indicates
325// a misbehaving producer of encrypted data.
326func (r *DecReader) WriteTo(w io.Writer) (int64, error) {
327	var n int64
328	if r.err != nil {
329		return n, r.err
330	}
331	if r.firstRead {
332		r.firstRead = false
333		nn, err := r.readFragment(r.buffer, 0)
334		if err != nil && err != io.EOF {
335			return n, err
336		}
337		nn, err = writeTo(w, r.buffer[:nn])
338		if err != nil {
339			return n, err
340		}
341		n += int64(nn)
342		if r.closed {
343			return n, nil
344		}
345	}
346	if r.offset > 0 {
347		nn, err := writeTo(w, r.plaintextBuffer[r.offset:])
348		if err != nil {
349			r.err = err
350			return n, err
351		}
352		r.offset = 0
353		n += int64(nn)
354	}
355	if r.closed {
356		return n, io.EOF
357	}
358	for {
359		nn, err := r.readFragment(r.buffer, 1)
360		if err != nil && err != io.EOF {
361			return n, err
362		}
363		nn, err = writeTo(w, r.buffer[:nn])
364		if err != nil {
365			r.err = err
366			return n, err
367		}
368		n += int64(nn)
369		if r.closed {
370			return n, nil
371		}
372	}
373}
374
375func (r *DecReader) readFragment(p []byte, firstReadOffset int) (int, error) {
376	if r.seqNum == 0 {
377		r.err = ErrExceeded
378		return 0, r.err
379	}
380	binary.LittleEndian.PutUint32(r.nonce[r.cipher.NonceSize()-4:], r.seqNum)
381	r.seqNum++
382
383	ciphertextLen := r.bufSize + r.cipher.Overhead()
384
385	r.buffer[0] = r.carry
386	n, err := readFrom(r.r, r.buffer[firstReadOffset:1+ciphertextLen])
387	switch {
388	default:
389		r.carry = r.buffer[ciphertextLen]
390		if len(p) < r.bufSize {
391			r.plaintextBuffer, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:ciphertextLen], r.associatedData)
392			if err != nil {
393				r.err = NotAuthentic
394				return 0, r.err
395			}
396			r.offset = copy(p, r.plaintextBuffer)
397			return r.offset, nil
398		}
399		if _, err = r.cipher.Open(p[:0], r.nonce, r.buffer[:ciphertextLen], r.associatedData); err != nil {
400			r.err = NotAuthentic
401			return 0, r.err
402		}
403		return r.bufSize, nil
404	case err == io.EOF:
405		if firstReadOffset+n < r.cipher.Overhead() {
406			r.err = NotAuthentic
407			return 0, r.err
408		}
409		r.closed = true
410		r.associatedData[0] = 0x80
411		if len(p) < firstReadOffset+n-r.cipher.Overhead() {
412			r.plaintextBuffer, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:firstReadOffset+n], r.associatedData)
413			if err != nil {
414				r.err = NotAuthentic
415				return 0, r.err
416			}
417			r.offset = copy(p, r.plaintextBuffer)
418			return r.offset, nil
419		}
420		if _, err = r.cipher.Open(p[:0], r.nonce, r.buffer[:firstReadOffset+n], r.associatedData); err != nil {
421			r.err = NotAuthentic
422			return 0, r.err
423
424		}
425		return firstReadOffset + n - r.cipher.Overhead(), io.EOF
426	case err != nil:
427		r.err = err
428		return 0, r.err
429	}
430}
431
432// A DecReaderAt decrypts and verifies everything it reads
433// from an underlying io.ReaderAt. A DecReaderAt never returns
434// invalid (i.e. not authentic) data.
435type DecReaderAt struct {
436	r      io.ReaderAt
437	cipher cipher.AEAD
438
439	bufPool sync.Pool
440	bufSize int
441
442	nonce          []byte
443	associatedData []byte
444}
445
446// ReadAt behaves like specified by the io.ReaderAt interface.
447// In particular, ReadAt reads len(p) decrypted bytes into p.
448// It returns the number of bytes read (0 <= n <= len(p))
449// and any error encountered while reading from the underlying
450// io.Reader. When ReadAt returns n < len(p), it returns a non-nil
451// error explaining why more bytes were not returned.
452//
453// When ReadAt fails to decrypt some data returned by the underlying
454// io.ReaderAt it returns NotAuthentic. This error indicates
455// that the encrypted data has been (maliciously) modified.
456//
457// When ReadAt cannot decrypt more bytes securely it returns
458// ErrExceeded. However, this can only happen when the
459// underlying io.ReaderAt returns valid but too many
460// encrypted bytes. Therefore, ErrExceeded indicates
461// a misbehaving producer of encrypted data.
462func (r *DecReaderAt) ReadAt(p []byte, offset int64) (int, error) {
463	if offset < 0 {
464		return 0, errorType("sio: DecReaderAt.ReadAt: offset is negative")
465	}
466
467	t := offset / int64(r.bufSize)
468	if t+1 > math.MaxUint32 {
469		return 0, ErrExceeded
470	}
471
472	buffer := r.bufPool.Get().(*[]byte)
473	defer r.bufPool.Put(buffer)
474
475	decReader := DecReader{
476		r:              &sectionReader{r: r.r, off: t * int64(r.bufSize+r.cipher.Overhead())},
477		cipher:         r.cipher,
478		bufSize:        r.bufSize,
479		nonce:          make([]byte, r.cipher.NonceSize()),
480		associatedData: make([]byte, 1+r.cipher.Overhead()),
481		seqNum:         1 + uint32(t),
482		buffer:         *buffer,
483		firstRead:      true,
484	}
485	copy(decReader.nonce, r.nonce)
486	copy(decReader.associatedData, r.associatedData)
487
488	if k := offset % int64(r.bufSize); k > 0 {
489		if _, err := io.CopyN(ioutil.Discard, &decReader, k); err != nil {
490			return 0, err
491		}
492	}
493	return readFrom(&decReader, p)
494}
495
496// Use a custom sectionReader since io.SectionReader
497// demands a read limit.
498
499type sectionReader struct {
500	r   io.ReaderAt
501	off int64
502	err error
503}
504
505func (r *sectionReader) Read(p []byte) (int, error) {
506	if r.err != nil {
507		return 0, r.err
508	}
509
510	var n int
511	n, r.err = r.r.ReadAt(p, r.off)
512	r.off += int64(n)
513	return n, r.err
514}
515