1/*
2   Copyright The ocicrypt Authors.
3
4   Licensed under the Apache License, Version 2.0 (the "License");
5   you may not use this file except in compliance with the License.
6   You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10   Unless required by applicable law or agreed to in writing, software
11   distributed under the License is distributed on an "AS IS" BASIS,
12   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   See the License for the specific language governing permissions and
14   limitations under the License.
15*/
16
17package utils
18
19import (
20	"io"
21)
22
23func min(a, b int) int {
24	if a < b {
25		return a
26	}
27	return b
28}
29
30// DelayedReader wraps a io.Reader and allows a client to use the Reader
31// interface. The DelayedReader holds back some buffer to the client
32// so that it can report any error that occurred on the Reader it wraps
33// early to the client while it may still have held some data back.
34type DelayedReader struct {
35	reader   io.Reader // Reader to Read() bytes from and delay them
36	err      error     // error that occurred on the reader
37	buffer   []byte    // delay buffer
38	bufbytes int       // number of bytes in the delay buffer to give to Read(); on '0' we return 'EOF' to caller
39	bufoff   int       // offset in the delay buffer to give to Read()
40}
41
42// NewDelayedReader wraps a io.Reader and allocates a delay buffer of bufsize bytes
43func NewDelayedReader(reader io.Reader, bufsize uint) io.Reader {
44	return &DelayedReader{
45		reader: reader,
46		buffer: make([]byte, bufsize),
47	}
48}
49
50// Read implements the io.Reader interface
51func (dr *DelayedReader) Read(p []byte) (int, error) {
52	if dr.err != nil && dr.err != io.EOF {
53		return 0, dr.err
54	}
55
56	// if we are completely drained, return io.EOF
57	if dr.err == io.EOF && dr.bufbytes == 0 {
58		return 0, io.EOF
59	}
60
61	// only at the beginning we fill our delay buffer in an extra step
62	if dr.bufbytes < len(dr.buffer) && dr.err == nil {
63		dr.bufbytes, dr.err = FillBuffer(dr.reader, dr.buffer)
64		if dr.err != nil && dr.err != io.EOF {
65			return 0, dr.err
66		}
67	}
68	// dr.err != nil means we have EOF and can drain the delay buffer
69	// otherwise we need to still read from the reader
70
71	var tmpbuf []byte
72	tmpbufbytes := 0
73	if dr.err == nil {
74		tmpbuf = make([]byte, len(p))
75		tmpbufbytes, dr.err = FillBuffer(dr.reader, tmpbuf)
76		if dr.err != nil && dr.err != io.EOF {
77			return 0, dr.err
78		}
79	}
80
81	// copy out of the delay buffer into 'p'
82	tocopy1 := min(len(p), dr.bufbytes)
83	c1 := copy(p[:tocopy1], dr.buffer[dr.bufoff:])
84	dr.bufoff += c1
85	dr.bufbytes -= c1
86
87	c2 := 0
88	// can p still hold more data?
89	if c1 < len(p) {
90		// copy out of the tmpbuf into 'p'
91		c2 = copy(p[tocopy1:], tmpbuf[:tmpbufbytes])
92	}
93
94	// if tmpbuf holds data we need to hold onto, copy them
95	// into the delay buffer
96	if tmpbufbytes-c2 > 0 {
97		// left-shift the delay buffer and append the tmpbuf's remaining data
98		dr.buffer = dr.buffer[dr.bufoff : dr.bufoff+dr.bufbytes]
99		dr.buffer = append(dr.buffer, tmpbuf[c2:tmpbufbytes]...)
100		dr.bufoff = 0
101		dr.bufbytes = len(dr.buffer)
102	}
103
104	var err error
105	if dr.bufbytes == 0 {
106		err = io.EOF
107	}
108	return c1 + c2, err
109}
110