1/* 2 * Copyright 2011-2012 Branimir Karadzic. All rights reserved. 3 * 4 * Redistribution and use in source and binary forms, with or without modification, 5 * are permitted provided that the following conditions are met: 6 * 7 * 1. Redistributions of source code must retain the above copyright notice, this 8 * list of conditions and the following disclaimer. 9 * 10 * 2. Redistributions in binary form must reproduce the above copyright notice, 11 * this list of conditions and the following disclaimer in the documentation 12 * and/or other materials provided with the distribution. 13 * 14 * THIS SOFTWARE IS PROVIDED BY COPYRIGHT HOLDER ``AS IS'' AND ANY EXPRESS OR 15 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 16 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT 17 * SHALL COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 18 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 19 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 20 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 21 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 22 * OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 23 * THE POSSIBILITY OF SUCH DAMAGE. 24 */ 25 26package lz4 27 28import ( 29 "encoding/binary" 30 "errors" 31 "io" 32) 33 34var ( 35 // ErrCorrupt indicates the input was corrupt 36 ErrCorrupt = errors.New("corrupt input") 37) 38 39const ( 40 mlBits = 4 41 mlMask = (1 << mlBits) - 1 42 runBits = 8 - mlBits 43 runMask = (1 << runBits) - 1 44) 45 46type decoder struct { 47 src []byte 48 dst []byte 49 spos uint32 50 dpos uint32 51 ref uint32 52} 53 54func (d *decoder) readByte() (uint8, error) { 55 if int(d.spos) == len(d.src) { 56 return 0, io.EOF 57 } 58 b := d.src[d.spos] 59 d.spos++ 60 return b, nil 61} 62 63func (d *decoder) getLen() (uint32, error) { 64 65 length := uint32(0) 66 ln, err := d.readByte() 67 if err != nil { 68 return 0, ErrCorrupt 69 } 70 for ln == 255 { 71 length += 255 72 ln, err = d.readByte() 73 if err != nil { 74 return 0, ErrCorrupt 75 } 76 } 77 length += uint32(ln) 78 79 return length, nil 80} 81 82func (d *decoder) cp(length, decr uint32) { 83 84 if int(d.ref+length) < int(d.dpos) { 85 copy(d.dst[d.dpos:], d.dst[d.ref:d.ref+length]) 86 } else { 87 for ii := uint32(0); ii < length; ii++ { 88 d.dst[d.dpos+ii] = d.dst[d.ref+ii] 89 } 90 } 91 d.dpos += length 92 d.ref += length - decr 93} 94 95func (d *decoder) finish(err error) error { 96 if err == io.EOF { 97 return nil 98 } 99 100 return err 101} 102 103// Decode returns the decoded form of src. The returned slice may be a 104// subslice of dst if it was large enough to hold the entire decoded block. 105func Decode(dst, src []byte) ([]byte, error) { 106 107 if len(src) < 4 { 108 return nil, ErrCorrupt 109 } 110 111 uncompressedLen := binary.LittleEndian.Uint32(src) 112 113 if uncompressedLen == 0 { 114 return nil, nil 115 } 116 117 if uncompressedLen > MaxInputSize { 118 return nil, ErrTooLarge 119 } 120 121 if dst == nil || len(dst) < int(uncompressedLen) { 122 dst = make([]byte, uncompressedLen) 123 } 124 125 d := decoder{src: src, dst: dst[:uncompressedLen], spos: 4} 126 127 decr := []uint32{0, 3, 2, 3} 128 129 for { 130 code, err := d.readByte() 131 if err != nil { 132 return d.dst, d.finish(err) 133 } 134 135 length := uint32(code >> mlBits) 136 if length == runMask { 137 ln, err := d.getLen() 138 if err != nil { 139 return nil, ErrCorrupt 140 } 141 length += ln 142 } 143 144 if int(d.spos+length) > len(d.src) || int(d.dpos+length) > len(d.dst) { 145 return nil, ErrCorrupt 146 } 147 148 for ii := uint32(0); ii < length; ii++ { 149 d.dst[d.dpos+ii] = d.src[d.spos+ii] 150 } 151 152 d.spos += length 153 d.dpos += length 154 155 if int(d.spos) == len(d.src) { 156 return d.dst, nil 157 } 158 159 if int(d.spos+2) >= len(d.src) { 160 return nil, ErrCorrupt 161 } 162 163 back := uint32(d.src[d.spos]) | uint32(d.src[d.spos+1])<<8 164 165 if back > d.dpos { 166 return nil, ErrCorrupt 167 } 168 169 d.spos += 2 170 d.ref = d.dpos - back 171 172 length = uint32(code & mlMask) 173 if length == mlMask { 174 ln, err := d.getLen() 175 if err != nil { 176 return nil, ErrCorrupt 177 } 178 length += ln 179 } 180 181 literal := d.dpos - d.ref 182 183 if literal < 4 { 184 if int(d.dpos+4) > len(d.dst) { 185 return nil, ErrCorrupt 186 } 187 188 d.cp(4, decr[literal]) 189 } else { 190 length += 4 191 } 192 193 if d.dpos+length > uncompressedLen { 194 return nil, ErrCorrupt 195 } 196 197 d.cp(length, 0) 198 } 199} 200