1package fastzip
2
3import (
4	"io"
5	"sync"
6
7	stdflate "compress/flate"
8
9	"github.com/klauspost/compress/flate"
10)
11
12type flater interface {
13	Close() error
14	Flush() error
15	Reset(dst io.Writer)
16	Write(data []byte) (n int, err error)
17}
18
19func newFlateReaderPool(newReaderFn func(w io.Reader) io.ReadCloser) *sync.Pool {
20	pool := &sync.Pool{}
21	pool.New = func() interface{} {
22		return &flateReader{pool, newReaderFn(nil)}
23	}
24	return pool
25}
26
27type flateReader struct {
28	pool *sync.Pool
29	io.ReadCloser
30}
31
32func (fr *flateReader) Reset(r io.Reader) {
33	fr.ReadCloser.(flate.Resetter).Reset(r, nil)
34}
35
36func (fr *flateReader) Close() error {
37	err := fr.ReadCloser.Close()
38	fr.pool.Put(fr)
39	return err
40}
41
42// FlateDecompressor returns a pooled performant zip.Decompressor.
43func FlateDecompressor() func(r io.Reader) io.ReadCloser {
44	pool := newFlateReaderPool(flate.NewReader)
45
46	return func(r io.Reader) io.ReadCloser {
47		fr := pool.Get().(*flateReader)
48		fr.Reset(r)
49		return fr
50	}
51}
52
53// StdFlateDecompressor returns a pooled standard library zip.Decompressor.
54func StdFlateDecompressor() func(r io.Reader) io.ReadCloser {
55	pool := newFlateReaderPool(stdflate.NewReader)
56
57	return func(r io.Reader) io.ReadCloser {
58		fr := pool.Get().(*flateReader)
59		fr.Reset(r)
60		return fr
61	}
62}
63
64func newFlateWriterPool(level int, newWriterFn func(w io.Writer, level int) (flater, error)) *sync.Pool {
65	pool := &sync.Pool{}
66	pool.New = func() interface{} {
67		fw, err := newWriterFn(nil, level)
68		if err != nil {
69			panic(err)
70		}
71
72		return &flateWriter{pool, fw}
73	}
74	return pool
75}
76
77type flateWriter struct {
78	pool *sync.Pool
79	flater
80}
81
82func (fw *flateWriter) Reset(w io.Writer) {
83	fw.flater.Reset(w)
84}
85
86func (fw *flateWriter) Close() error {
87	err := fw.flater.Close()
88	fw.pool.Put(fw)
89	return err
90}
91
92// FlateCompressor returns a pooled performant zip.Compressor configured to a
93// specified compression level. Invalid flate levels will panic.
94func FlateCompressor(level int) func(w io.Writer) (io.WriteCloser, error) {
95	pool := newFlateWriterPool(level, func(w io.Writer, level int) (flater, error) {
96		return flate.NewWriter(w, level)
97	})
98
99	return func(w io.Writer) (io.WriteCloser, error) {
100		fw := pool.Get().(*flateWriter)
101		fw.Reset(w)
102		return fw, nil
103	}
104}
105
106// StdFlateCompressor returns a pooled standard library zip.Compressor
107// configured to a specified compression level. Invalid flate levels will
108// panic.
109func StdFlateCompressor(level int) func(w io.Writer) (io.WriteCloser, error) {
110	pool := newFlateWriterPool(level, func(w io.Writer, level int) (flater, error) {
111		return stdflate.NewWriter(w, level)
112	})
113
114	return func(w io.Writer) (io.WriteCloser, error) {
115		fw := pool.Get().(*flateWriter)
116		fw.Reset(w)
117		return fw, nil
118	}
119}
120