1// Copyright (C) 2019 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package streams
5
6import (
7	"errors"
8	"io"
9
10	"github.com/zeebo/errs"
11)
12
13// PeekThresholdReader allows a check to see if the size of a given reader
14// exceeds the maximum inline segment size or not.
15type PeekThresholdReader struct {
16	r              io.Reader
17	thresholdBuf   []byte
18	thresholdErr   error
19	isLargerCalled bool
20	readCalled     bool
21}
22
23// NewPeekThresholdReader creates a new instance of PeekThresholdReader.
24func NewPeekThresholdReader(r io.Reader) (pt *PeekThresholdReader) {
25	return &PeekThresholdReader{r: r}
26}
27
28// Read initially reads bytes from the internal buffer, then continues
29// reading from the wrapped data reader. The number of bytes read `n`
30// is returned.
31func (pt *PeekThresholdReader) Read(p []byte) (n int, err error) {
32	pt.readCalled = true
33
34	if len(pt.thresholdBuf) > 0 || pt.thresholdErr != nil {
35		n = copy(p, pt.thresholdBuf)
36		pt.thresholdBuf = pt.thresholdBuf[n:]
37		if len(pt.thresholdBuf) == 0 {
38			err := pt.thresholdErr
39			pt.thresholdErr = nil
40			return n, err
41		}
42		return n, nil
43	}
44
45	return pt.r.Read(p)
46}
47
48// IsLargerThan returns a bool to determine whether a reader's size
49// is larger than the given threshold or not.
50func (pt *PeekThresholdReader) IsLargerThan(thresholdSize int) (bool, error) {
51	if pt.isLargerCalled {
52		return false, errs.New("IsLargerThan can't be called more than once")
53	}
54	if pt.readCalled {
55		return false, errs.New("IsLargerThan can't be called after Read has been called")
56	}
57	pt.isLargerCalled = true
58	buf := make([]byte, thresholdSize+1)
59	n, err := io.ReadFull(pt.r, buf)
60	pt.thresholdBuf = buf[:n]
61	pt.thresholdErr = err
62	if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
63		if errors.Is(err, io.ErrUnexpectedEOF) {
64			pt.thresholdErr = io.EOF
65		}
66		return false, nil
67	}
68	if err != nil {
69		return false, err
70	}
71	return true, nil
72}
73