1// Copyright 2016, Joe Tsai. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE.md file.
4
5// +build cgo
6
7// Package zstd implements the Zstandard compressed data format using C wrappers.
8package zstd
9
10/*
11// This relies upon the shared library built from github.com/facebook/zstd.
12//
13// The steps to build and install the shared library is as follows:
14//	curl -L https://github.com/facebook/zstd/archive/v1.3.2.tar.gz | tar -zxv
15//	cd zstd-1.3.2
16//	sudo make install
17
18#cgo LDFLAGS: -lzstd
19
20#include <stdlib.h>
21#include <stdint.h>
22#include "zstd.h"
23
24ZSTD_DStream* zsDecCreate() {
25	ZSTD_DStream* state = ZSTD_createDStream();
26	ZSTD_initDStream(state);
27	return state;
28}
29
30size_t zsDecStream(
31	ZSTD_DStream* state,
32	size_t* avail_in, uint8_t* next_in,
33	size_t* avail_out, uint8_t* next_out
34) {
35	ZSTD_inBuffer in = {next_in, *avail_in, 0};
36	ZSTD_outBuffer out = {next_out, *avail_out, 0};
37	size_t ret = ZSTD_decompressStream(state, &out, &in);
38	*avail_in = in.size - in.pos;
39	*avail_out = out.size - out.pos;
40	in.src = NULL;
41	out.dst = NULL;
42	return ret;
43}
44
45void zsDecDestroy(ZSTD_DStream* state) {
46	ZSTD_freeDStream(state);
47}
48
49ZSTD_CStream* zsEncCreate(int level) {
50	ZSTD_CStream* state = ZSTD_createCStream();
51	ZSTD_initCStream(state, level);
52	return state;
53}
54
55size_t zsEncStream(
56	ZSTD_CStream* state, int finish,
57	size_t* avail_in, uint8_t* next_in,
58	size_t* avail_out, uint8_t* next_out
59) {
60	ZSTD_inBuffer in = {next_in, *avail_in, 0};
61	ZSTD_outBuffer out = {next_out, *avail_out, 0};
62	size_t ret = finish ?
63		ZSTD_endStream(state, &out) : ZSTD_compressStream(state, &out, &in);
64	*avail_in = in.size - in.pos;
65	*avail_out = out.size - out.pos;
66	in.src = NULL;
67	out.dst = NULL;
68	return ret;
69}
70
71void zsEncDestroy(ZSTD_CStream* state) {
72	ZSTD_freeCStream(state);
73}
74*/
75import "C"
76
77import (
78	"errors"
79	"io"
80	"unsafe"
81)
82
83type reader struct {
84	r     io.Reader
85	err   error
86	state *C.ZSTD_DStream
87	buf   []byte
88	arr   [1 << 14]byte
89}
90
91func NewReader(r io.Reader) io.ReadCloser {
92	zr := &reader{r: r, state: C.zsDecCreate()}
93	if zr.state == nil {
94		panic("zstd: could not allocate decoder state")
95	}
96	return zr
97}
98
99func (zr *reader) Read(buf []byte) (int, error) {
100	if zr.state == nil {
101		return 0, io.ErrClosedPipe
102	}
103
104	var n int
105	for zr.err == nil && (len(buf) > 0 && n == 0) {
106		availIn, availOut, ptrIn, ptrOut := sizePtrs(zr.buf, buf)
107		ret := C.zsDecStream(zr.state, &availIn, ptrIn, &availOut, ptrOut)
108		n += len(buf) - int(availOut)
109		buf = buf[len(buf)-int(availOut):]
110		zr.buf = zr.buf[len(zr.buf)-int(availIn):]
111
112		switch {
113		case C.ZSTD_isError(ret) > 0:
114			zr.err = errors.New("zstd: corrupted input")
115		case ret == 0:
116			return n, io.EOF
117		case n > 0:
118			return n, nil
119		case len(zr.buf) == 0 && n == 0:
120			n1, err := zr.r.Read(zr.arr[:])
121			if n1 > 0 {
122				zr.buf = zr.arr[:n1]
123			} else if err != nil {
124				if err == io.EOF {
125					err = io.ErrUnexpectedEOF
126				}
127				zr.err = err
128			}
129		}
130	}
131	return n, zr.err
132}
133
134func (zr *reader) Close() error {
135	if zr.state != nil {
136		defer func() {
137			C.zsDecDestroy(zr.state)
138			zr.state = nil
139		}()
140	}
141	return zr.err
142}
143
144type writer struct {
145	w     io.Writer
146	err   error
147	state *C.ZSTD_CStream
148	buf   []byte
149	arr   [1 << 14]byte
150}
151
152func NewWriter(w io.Writer, level int) io.WriteCloser {
153	if level < 1 || level > 22 {
154		panic("zstd: invalid compression level")
155	}
156
157	zw := &writer{w: w, state: C.zsEncCreate(C.int(level))}
158	if zw.state == nil {
159		panic("zstd: could not allocate encoder state")
160	}
161	return zw
162}
163
164func (zw *writer) Write(buf []byte) (int, error) {
165	return zw.write(buf, 0)
166}
167
168func (zw *writer) write(buf []byte, finish C.int) (int, error) {
169	if zw.state == nil {
170		return 0, io.ErrClosedPipe
171	}
172
173	var n int
174	for zw.err == nil && (len(buf) > 0 || finish > 0) {
175		availIn, availOut, ptrIn, ptrOut := sizePtrs(buf, zw.arr[:])
176		ret := C.zsEncStream(zw.state, finish, &availIn, ptrIn, &availOut, ptrOut)
177		n += len(buf) - int(availIn)
178		buf = buf[len(buf)-int(availIn):]
179		zw.buf = zw.arr[:len(zw.arr)-int(availOut)]
180
181		if len(zw.buf) > 0 {
182			if _, err := zw.w.Write(zw.buf); err != nil {
183				zw.err = err
184			}
185		}
186		switch {
187		case C.ZSTD_isError(ret) > 0:
188			zw.err = errors.New("zstd: compression error")
189		case len(buf) == 0 && len(zw.buf) == 0:
190			return n, zw.err
191		case ret == 0 && finish > 0:
192			return n, zw.err
193		}
194	}
195	return n, zw.err
196}
197
198func (zw *writer) Close() error {
199	if zw.state != nil {
200		defer func() {
201			C.zsEncDestroy(zw.state)
202			zw.state = nil
203		}()
204		zw.write(nil, 1)
205	}
206	return zw.err
207}
208
209func sizePtrs(in, out []byte) (sizeIn, sizeOut C.size_t, ptrIn, ptrOut *C.uint8_t) {
210	sizeIn = C.size_t(len(in))
211	sizeOut = C.size_t(len(out))
212	if len(in) > 0 {
213		ptrIn = (*C.uint8_t)(unsafe.Pointer(&in[0]))
214	}
215	if len(out) > 0 {
216		ptrOut = (*C.uint8_t)(unsafe.Pointer(&out[0]))
217	}
218	return
219}
220