1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"fmt"
9	"io"
10	"io/ioutil"
11)
12
13type byteBuffer interface {
14	// Read up to 8 bytes.
15	// Returns nil if no more input is available.
16	readSmall(n int) []byte
17
18	// Read >8 bytes.
19	// MAY use the destination slice.
20	readBig(n int, dst []byte) ([]byte, error)
21
22	// Read a single byte.
23	readByte() (byte, error)
24
25	// Skip n bytes.
26	skipN(n int) error
27}
28
29// in-memory buffer
30type byteBuf []byte
31
32func (b *byteBuf) readSmall(n int) []byte {
33	if debugAsserts && n > 8 {
34		panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
35	}
36	bb := *b
37	if len(bb) < n {
38		return nil
39	}
40	r := bb[:n]
41	*b = bb[n:]
42	return r
43}
44
45func (b *byteBuf) readBig(n int, dst []byte) ([]byte, error) {
46	bb := *b
47	if len(bb) < n {
48		return nil, io.ErrUnexpectedEOF
49	}
50	r := bb[:n]
51	*b = bb[n:]
52	return r, nil
53}
54
55func (b *byteBuf) remain() []byte {
56	return *b
57}
58
59func (b *byteBuf) readByte() (byte, error) {
60	bb := *b
61	if len(bb) < 1 {
62		return 0, nil
63	}
64	r := bb[0]
65	*b = bb[1:]
66	return r, nil
67}
68
69func (b *byteBuf) skipN(n int) error {
70	bb := *b
71	if len(bb) < n {
72		return io.ErrUnexpectedEOF
73	}
74	*b = bb[n:]
75	return nil
76}
77
78// wrapper around a reader.
79type readerWrapper struct {
80	r   io.Reader
81	tmp [8]byte
82}
83
84func (r *readerWrapper) readSmall(n int) []byte {
85	if debugAsserts && n > 8 {
86		panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
87	}
88	n2, err := io.ReadFull(r.r, r.tmp[:n])
89	// We only really care about the actual bytes read.
90	if n2 != n {
91		if debug {
92			println("readSmall: got", n2, "want", n, "err", err)
93		}
94		return nil
95	}
96	return r.tmp[:n]
97}
98
99func (r *readerWrapper) readBig(n int, dst []byte) ([]byte, error) {
100	if cap(dst) < n {
101		dst = make([]byte, n)
102	}
103	n2, err := io.ReadFull(r.r, dst[:n])
104	if err == io.EOF && n > 0 {
105		err = io.ErrUnexpectedEOF
106	}
107	return dst[:n2], err
108}
109
110func (r *readerWrapper) readByte() (byte, error) {
111	n2, err := r.r.Read(r.tmp[:1])
112	if err != nil {
113		return 0, err
114	}
115	if n2 != 1 {
116		return 0, io.ErrUnexpectedEOF
117	}
118	return r.tmp[0], nil
119}
120
121func (r *readerWrapper) skipN(n int) error {
122	n2, err := io.CopyN(ioutil.Discard, r.r, int64(n))
123	if n2 != int64(n) {
124		err = io.ErrUnexpectedEOF
125	}
126	return err
127}
128