1// Copyright 2019+ Klaus Post. All rights reserved. 2// License information can be found in the LICENSE file. 3// Based on work by Yann Collet, released under BSD License. 4 5package zstd 6 7import ( 8 "fmt" 9 "io" 10 "io/ioutil" 11) 12 13type byteBuffer interface { 14 // Read up to 8 bytes. 15 // Returns nil if no more input is available. 16 readSmall(n int) []byte 17 18 // Read >8 bytes. 19 // MAY use the destination slice. 20 readBig(n int, dst []byte) ([]byte, error) 21 22 // Read a single byte. 23 readByte() (byte, error) 24 25 // Skip n bytes. 26 skipN(n int) error 27} 28 29// in-memory buffer 30type byteBuf []byte 31 32func (b *byteBuf) readSmall(n int) []byte { 33 if debugAsserts && n > 8 { 34 panic(fmt.Errorf("small read > 8 (%d). use readBig", n)) 35 } 36 bb := *b 37 if len(bb) < n { 38 return nil 39 } 40 r := bb[:n] 41 *b = bb[n:] 42 return r 43} 44 45func (b *byteBuf) readBig(n int, dst []byte) ([]byte, error) { 46 bb := *b 47 if len(bb) < n { 48 return nil, io.ErrUnexpectedEOF 49 } 50 r := bb[:n] 51 *b = bb[n:] 52 return r, nil 53} 54 55func (b *byteBuf) remain() []byte { 56 return *b 57} 58 59func (b *byteBuf) readByte() (byte, error) { 60 bb := *b 61 if len(bb) < 1 { 62 return 0, nil 63 } 64 r := bb[0] 65 *b = bb[1:] 66 return r, nil 67} 68 69func (b *byteBuf) skipN(n int) error { 70 bb := *b 71 if len(bb) < n { 72 return io.ErrUnexpectedEOF 73 } 74 *b = bb[n:] 75 return nil 76} 77 78// wrapper around a reader. 79type readerWrapper struct { 80 r io.Reader 81 tmp [8]byte 82} 83 84func (r *readerWrapper) readSmall(n int) []byte { 85 if debugAsserts && n > 8 { 86 panic(fmt.Errorf("small read > 8 (%d). use readBig", n)) 87 } 88 n2, err := io.ReadFull(r.r, r.tmp[:n]) 89 // We only really care about the actual bytes read. 90 if n2 != n { 91 if debug { 92 println("readSmall: got", n2, "want", n, "err", err) 93 } 94 return nil 95 } 96 return r.tmp[:n] 97} 98 99func (r *readerWrapper) readBig(n int, dst []byte) ([]byte, error) { 100 if cap(dst) < n { 101 dst = make([]byte, n) 102 } 103 n2, err := io.ReadFull(r.r, dst[:n]) 104 if err == io.EOF && n > 0 { 105 err = io.ErrUnexpectedEOF 106 } 107 return dst[:n2], err 108} 109 110func (r *readerWrapper) readByte() (byte, error) { 111 n2, err := r.r.Read(r.tmp[:1]) 112 if err != nil { 113 return 0, err 114 } 115 if n2 != 1 { 116 return 0, io.ErrUnexpectedEOF 117 } 118 return r.tmp[0], nil 119} 120 121func (r *readerWrapper) skipN(n int) error { 122 n2, err := io.CopyN(ioutil.Discard, r.r, int64(n)) 123 if n2 != int64(n) { 124 err = io.ErrUnexpectedEOF 125 } 126 return err 127} 128