1// Copyright 2016, Joe Tsai. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE.md file.
4
5// +build cgo
6
7// Package lzma implements the LZMA2 compressed data format using C wrappers.
8package lzma
9
10/*
11#cgo LDFLAGS: -llzma
12
13#include <assert.h>
14#include <stdlib.h>
15#include "lzma.h"
16
17// zlState is a tuple of C allocated data structures.
18//
19// The liblzma documentation is not clear about whether the filters struct must
20// stay live past calls to lzma_raw_encoder and lzma_raw_decoder.
21// To be on the safe side, we allocate them and keep them around until the end.
22typedef struct {
23	lzma_stream stream;
24	lzma_filter filters[2];
25	lzma_options_lzma options;
26} zlState;
27
28zlState* zlDecCreate() {
29	zlState* state = calloc(1, sizeof(zlState));
30	state->filters[0].id = LZMA_FILTER_LZMA2;
31	state->filters[0].options = &state->options;
32	state->filters[1].id = LZMA_VLI_UNKNOWN;
33	state->options.dict_size = LZMA_DICT_SIZE_DEFAULT;
34
35	assert(lzma_raw_decoder(&state->stream, state->filters) == LZMA_OK);
36	return state;
37}
38
39zlState* zlEncCreate(int level) {
40	zlState* state = calloc(1, sizeof(zlState));
41	state->filters[0].id = LZMA_FILTER_LZMA2;
42	state->filters[0].options = &state->options;
43	state->filters[1].id = LZMA_VLI_UNKNOWN;
44
45	assert(!lzma_lzma_preset(&state->options, level));
46	assert(lzma_raw_encoder(&state->stream, state->filters) == LZMA_OK);
47	return state;
48}
49
50lzma_ret zlStream(
51	lzma_stream* strm, lzma_action action,
52	size_t* avail_in, uint8_t* next_in,
53	size_t* avail_out, uint8_t* next_out
54) {
55	strm->avail_in = *avail_in;
56	strm->avail_out = *avail_out;
57	strm->next_in = next_in;
58	strm->next_out = next_out;
59	lzma_ret ret = lzma_code(strm, action);
60	*avail_in = strm->avail_in;
61	*avail_out = strm->avail_out;
62	strm->next_in = NULL;
63	strm->next_out = NULL;
64	return ret;
65}
66
67void zlDestroy(zlState* state) {
68	lzma_end(&state->stream);
69	free(state);
70}
71*/
72import "C"
73
74import (
75	"errors"
76	"io"
77	"unsafe"
78)
79
80type reader struct {
81	r     io.Reader
82	err   error
83	state *C.zlState
84	buf   []byte
85	arr   [1 << 14]byte
86}
87
88func NewReader(r io.Reader) io.ReadCloser {
89	zr := &reader{r: r, state: C.zlDecCreate()}
90	if zr.state == nil {
91		panic("lzma: could not allocate decoder state")
92	}
93	return zr
94}
95
96func (zr *reader) Read(buf []byte) (int, error) {
97	if zr.state == nil {
98		return 0, io.ErrClosedPipe
99	}
100
101	var n int
102	for zr.err == nil && (len(buf) > 0 && n == 0) {
103		availIn, availOut, ptrIn, ptrOut := sizePtrs(zr.buf, buf)
104		ret := C.zlStream(&zr.state.stream, 0, &availIn, ptrIn, &availOut, ptrOut)
105		n += len(buf) - int(availOut)
106		buf = buf[len(buf)-int(availOut):]
107		zr.buf = zr.buf[len(zr.buf)-int(availIn):]
108
109		switch ret {
110		case C.LZMA_OK:
111			return n, nil
112		case C.LZMA_BUF_ERROR:
113			if len(zr.buf) == 0 {
114				n1, err := zr.r.Read(zr.arr[:])
115				if n1 > 0 {
116					zr.buf = zr.arr[:n1]
117				} else if err != nil {
118					if err == io.EOF {
119						err = io.ErrUnexpectedEOF
120					}
121					zr.err = err
122				}
123			}
124		case C.LZMA_STREAM_END:
125			return n, io.EOF
126		default:
127			zr.err = errors.New("lzma: corrupted input")
128		}
129	}
130	return n, zr.err
131}
132
133func (zr *reader) Close() error {
134	if zr.state != nil {
135		defer func() {
136			C.zlDestroy(zr.state)
137			zr.state = nil
138		}()
139	}
140	return zr.err
141}
142
143type writer struct {
144	w     io.Writer
145	err   error
146	state *C.zlState
147	buf   []byte
148	arr   [1 << 14]byte
149}
150
151func NewWriter(w io.Writer, level int) io.WriteCloser {
152	if level < 0 || level > 9 {
153		panic("lzma: invalid compression level")
154	}
155
156	zw := &writer{w: w, state: C.zlEncCreate(C.int(level))}
157	if zw.state == nil {
158		panic("lzma: could not allocate encoder state")
159	}
160	return zw
161}
162
163func (zw *writer) Write(buf []byte) (int, error) {
164	return zw.write(buf, C.LZMA_RUN)
165}
166
167func (zw *writer) write(buf []byte, op C.lzma_action) (int, error) {
168	if zw.state == nil {
169		return 0, io.ErrClosedPipe
170	}
171
172	var n int
173	flush := op != C.LZMA_RUN
174	for zw.err == nil && (len(buf) > 0 || flush) {
175		availIn, availOut, ptrIn, ptrOut := sizePtrs(buf, zw.arr[:])
176		ret := C.zlStream(&zw.state.stream, op, &availIn, ptrIn, &availOut, ptrOut)
177		n += len(buf) - int(availIn)
178		buf = buf[len(buf)-int(availIn):]
179		zw.buf = zw.arr[:len(zw.arr)-int(availOut)]
180
181		if len(zw.buf) > 0 {
182			if _, err := zw.w.Write(zw.buf); err != nil {
183				zw.err = err
184			}
185		}
186		switch ret {
187		case C.LZMA_OK, C.LZMA_BUF_ERROR:
188			continue // Do nothing
189		case C.LZMA_STREAM_END:
190			return n, zw.err
191		default:
192			zw.err = errors.New("lzma: compression error")
193		}
194	}
195	return n, zw.err
196}
197
198func (zw *writer) Close() error {
199	if zw.state != nil {
200		defer func() {
201			C.zlDestroy(zw.state)
202			zw.state = nil
203		}()
204		zw.write(nil, C.LZMA_FINISH)
205	}
206	return zw.err
207}
208
209func sizePtrs(in, out []byte) (sizeIn, sizeOut C.size_t, ptrIn, ptrOut *C.uint8_t) {
210	sizeIn = C.size_t(len(in))
211	sizeOut = C.size_t(len(out))
212	if len(in) > 0 {
213		ptrIn = (*C.uint8_t)(unsafe.Pointer(&in[0]))
214	}
215	if len(out) > 0 {
216		ptrOut = (*C.uint8_t)(unsafe.Pointer(&out[0]))
217	}
218	return
219}
220