1/*
2 * Copyright 2011-2012 Branimir Karadzic. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification,
5 * are permitted provided that the following conditions are met:
6 *
7 *    1. Redistributions of source code must retain the above copyright notice, this
8 *       list of conditions and the following disclaimer.
9 *
10 *    2. Redistributions in binary form must reproduce the above copyright notice,
11 *       this list of conditions and the following disclaimer in the documentation
12 *       and/or other materials provided with the distribution.
13 *
14 * THIS SOFTWARE IS PROVIDED BY COPYRIGHT HOLDER ``AS IS'' AND ANY EXPRESS OR
15 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
16 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
17 * SHALL COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
18 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
19 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
20 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
21 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
22 * OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23 * THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26package lz4
27
28import (
29	"encoding/binary"
30	"errors"
31	"io"
32)
33
34var (
35	// ErrCorrupt indicates the input was corrupt
36	ErrCorrupt = errors.New("corrupt input")
37)
38
39const (
40	mlBits  = 4
41	mlMask  = (1 << mlBits) - 1
42	runBits = 8 - mlBits
43	runMask = (1 << runBits) - 1
44)
45
46type decoder struct {
47	src  []byte
48	dst  []byte
49	spos uint32
50	dpos uint32
51	ref  uint32
52}
53
54func (d *decoder) readByte() (uint8, error) {
55	if int(d.spos) == len(d.src) {
56		return 0, io.EOF
57	}
58	b := d.src[d.spos]
59	d.spos++
60	return b, nil
61}
62
63func (d *decoder) getLen() (uint32, error) {
64
65	length := uint32(0)
66	ln, err := d.readByte()
67	if err != nil {
68		return 0, ErrCorrupt
69	}
70	for ln == 255 {
71		length += 255
72		ln, err = d.readByte()
73		if err != nil {
74			return 0, ErrCorrupt
75		}
76	}
77	length += uint32(ln)
78
79	return length, nil
80}
81
82func (d *decoder) cp(length, decr uint32) {
83
84	if int(d.ref+length) < int(d.dpos) {
85		copy(d.dst[d.dpos:], d.dst[d.ref:d.ref+length])
86	} else {
87		for ii := uint32(0); ii < length; ii++ {
88			d.dst[d.dpos+ii] = d.dst[d.ref+ii]
89		}
90	}
91	d.dpos += length
92	d.ref += length - decr
93}
94
95func (d *decoder) finish(err error) error {
96	if err == io.EOF {
97		return nil
98	}
99
100	return err
101}
102
103// Decode returns the decoded form of src.  The returned slice may be a
104// subslice of dst if it was large enough to hold the entire decoded block.
105func Decode(dst, src []byte) ([]byte, error) {
106
107	if len(src) < 4 {
108		return nil, ErrCorrupt
109	}
110
111	uncompressedLen := binary.LittleEndian.Uint32(src)
112
113	if uncompressedLen == 0 {
114		return nil, nil
115	}
116
117	if uncompressedLen > MaxInputSize {
118		return nil, ErrTooLarge
119	}
120
121	if dst == nil || len(dst) < int(uncompressedLen) {
122		dst = make([]byte, uncompressedLen)
123	}
124
125	d := decoder{src: src, dst: dst[:uncompressedLen], spos: 4}
126
127	decr := []uint32{0, 3, 2, 3}
128
129	for {
130		code, err := d.readByte()
131		if err != nil {
132			return d.dst, d.finish(err)
133		}
134
135		length := uint32(code >> mlBits)
136		if length == runMask {
137			ln, err := d.getLen()
138			if err != nil {
139				return nil, ErrCorrupt
140			}
141			length += ln
142		}
143
144		if int(d.spos+length) > len(d.src) || int(d.dpos+length) > len(d.dst) {
145			return nil, ErrCorrupt
146		}
147
148		for ii := uint32(0); ii < length; ii++ {
149			d.dst[d.dpos+ii] = d.src[d.spos+ii]
150		}
151
152		d.spos += length
153		d.dpos += length
154
155		if int(d.spos) == len(d.src) {
156			return d.dst, nil
157		}
158
159		if int(d.spos+2) >= len(d.src) {
160			return nil, ErrCorrupt
161		}
162
163		back := uint32(d.src[d.spos]) | uint32(d.src[d.spos+1])<<8
164
165		if back > d.dpos {
166			return nil, ErrCorrupt
167		}
168
169		d.spos += 2
170		d.ref = d.dpos - back
171
172		length = uint32(code & mlMask)
173		if length == mlMask {
174			ln, err := d.getLen()
175			if err != nil {
176				return nil, ErrCorrupt
177			}
178			length += ln
179		}
180
181		literal := d.dpos - d.ref
182
183		if literal < 4 {
184			if int(d.dpos+4) > len(d.dst) {
185				return nil, ErrCorrupt
186			}
187
188			d.cp(4, decr[literal])
189		} else {
190			length += 4
191		}
192
193		if d.dpos+length > uncompressedLen {
194			return nil, ErrCorrupt
195		}
196
197		d.cp(length, 0)
198	}
199}
200