1// Copyright 2016, Joe Tsai. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE.md file. 4 5// +build cgo 6 7// Package zstd implements the Zstandard compressed data format using C wrappers. 8package zstd 9 10/* 11// This relies upon the shared library built from github.com/facebook/zstd. 12// 13// The steps to build and install the shared library is as follows: 14// curl -L https://github.com/facebook/zstd/archive/v1.3.2.tar.gz | tar -zxv 15// cd zstd-1.3.2 16// sudo make install 17 18#cgo LDFLAGS: -lzstd 19 20#include <stdlib.h> 21#include <stdint.h> 22#include "zstd.h" 23 24ZSTD_DStream* zsDecCreate() { 25 ZSTD_DStream* state = ZSTD_createDStream(); 26 ZSTD_initDStream(state); 27 return state; 28} 29 30size_t zsDecStream( 31 ZSTD_DStream* state, 32 size_t* avail_in, uint8_t* next_in, 33 size_t* avail_out, uint8_t* next_out 34) { 35 ZSTD_inBuffer in = {next_in, *avail_in, 0}; 36 ZSTD_outBuffer out = {next_out, *avail_out, 0}; 37 size_t ret = ZSTD_decompressStream(state, &out, &in); 38 *avail_in = in.size - in.pos; 39 *avail_out = out.size - out.pos; 40 in.src = NULL; 41 out.dst = NULL; 42 return ret; 43} 44 45void zsDecDestroy(ZSTD_DStream* state) { 46 ZSTD_freeDStream(state); 47} 48 49ZSTD_CStream* zsEncCreate(int level) { 50 ZSTD_CStream* state = ZSTD_createCStream(); 51 ZSTD_initCStream(state, level); 52 return state; 53} 54 55size_t zsEncStream( 56 ZSTD_CStream* state, int finish, 57 size_t* avail_in, uint8_t* next_in, 58 size_t* avail_out, uint8_t* next_out 59) { 60 ZSTD_inBuffer in = {next_in, *avail_in, 0}; 61 ZSTD_outBuffer out = {next_out, *avail_out, 0}; 62 size_t ret = finish ? 63 ZSTD_endStream(state, &out) : ZSTD_compressStream(state, &out, &in); 64 *avail_in = in.size - in.pos; 65 *avail_out = out.size - out.pos; 66 in.src = NULL; 67 out.dst = NULL; 68 return ret; 69} 70 71void zsEncDestroy(ZSTD_CStream* state) { 72 ZSTD_freeCStream(state); 73} 74*/ 75import "C" 76 77import ( 78 "errors" 79 "io" 80 "unsafe" 81) 82 83type reader struct { 84 r io.Reader 85 err error 86 state *C.ZSTD_DStream 87 buf []byte 88 arr [1 << 14]byte 89} 90 91func NewReader(r io.Reader) io.ReadCloser { 92 zr := &reader{r: r, state: C.zsDecCreate()} 93 if zr.state == nil { 94 panic("zstd: could not allocate decoder state") 95 } 96 return zr 97} 98 99func (zr *reader) Read(buf []byte) (int, error) { 100 if zr.state == nil { 101 return 0, io.ErrClosedPipe 102 } 103 104 var n int 105 for zr.err == nil && (len(buf) > 0 && n == 0) { 106 availIn, availOut, ptrIn, ptrOut := sizePtrs(zr.buf, buf) 107 ret := C.zsDecStream(zr.state, &availIn, ptrIn, &availOut, ptrOut) 108 n += len(buf) - int(availOut) 109 buf = buf[len(buf)-int(availOut):] 110 zr.buf = zr.buf[len(zr.buf)-int(availIn):] 111 112 switch { 113 case C.ZSTD_isError(ret) > 0: 114 zr.err = errors.New("zstd: corrupted input") 115 case ret == 0: 116 return n, io.EOF 117 case n > 0: 118 return n, nil 119 case len(zr.buf) == 0 && n == 0: 120 n1, err := zr.r.Read(zr.arr[:]) 121 if n1 > 0 { 122 zr.buf = zr.arr[:n1] 123 } else if err != nil { 124 if err == io.EOF { 125 err = io.ErrUnexpectedEOF 126 } 127 zr.err = err 128 } 129 } 130 } 131 return n, zr.err 132} 133 134func (zr *reader) Close() error { 135 if zr.state != nil { 136 defer func() { 137 C.zsDecDestroy(zr.state) 138 zr.state = nil 139 }() 140 } 141 return zr.err 142} 143 144type writer struct { 145 w io.Writer 146 err error 147 state *C.ZSTD_CStream 148 buf []byte 149 arr [1 << 14]byte 150} 151 152func NewWriter(w io.Writer, level int) io.WriteCloser { 153 if level < 1 || level > 22 { 154 panic("zstd: invalid compression level") 155 } 156 157 zw := &writer{w: w, state: C.zsEncCreate(C.int(level))} 158 if zw.state == nil { 159 panic("zstd: could not allocate encoder state") 160 } 161 return zw 162} 163 164func (zw *writer) Write(buf []byte) (int, error) { 165 return zw.write(buf, 0) 166} 167 168func (zw *writer) write(buf []byte, finish C.int) (int, error) { 169 if zw.state == nil { 170 return 0, io.ErrClosedPipe 171 } 172 173 var n int 174 for zw.err == nil && (len(buf) > 0 || finish > 0) { 175 availIn, availOut, ptrIn, ptrOut := sizePtrs(buf, zw.arr[:]) 176 ret := C.zsEncStream(zw.state, finish, &availIn, ptrIn, &availOut, ptrOut) 177 n += len(buf) - int(availIn) 178 buf = buf[len(buf)-int(availIn):] 179 zw.buf = zw.arr[:len(zw.arr)-int(availOut)] 180 181 if len(zw.buf) > 0 { 182 if _, err := zw.w.Write(zw.buf); err != nil { 183 zw.err = err 184 } 185 } 186 switch { 187 case C.ZSTD_isError(ret) > 0: 188 zw.err = errors.New("zstd: compression error") 189 case len(buf) == 0 && len(zw.buf) == 0: 190 return n, zw.err 191 case ret == 0 && finish > 0: 192 return n, zw.err 193 } 194 } 195 return n, zw.err 196} 197 198func (zw *writer) Close() error { 199 if zw.state != nil { 200 defer func() { 201 C.zsEncDestroy(zw.state) 202 zw.state = nil 203 }() 204 zw.write(nil, 1) 205 } 206 return zw.err 207} 208 209func sizePtrs(in, out []byte) (sizeIn, sizeOut C.size_t, ptrIn, ptrOut *C.uint8_t) { 210 sizeIn = C.size_t(len(in)) 211 sizeOut = C.size_t(len(out)) 212 if len(in) > 0 { 213 ptrIn = (*C.uint8_t)(unsafe.Pointer(&in[0])) 214 } 215 if len(out) > 0 { 216 ptrOut = (*C.uint8_t)(unsafe.Pointer(&out[0])) 217 } 218 return 219} 220