1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package zip
6
7import (
8	"errors"
9	"io"
10	"io/ioutil"
11	"sync"
12
13	"github.com/klauspost/compress/flate"
14)
15
16// A Compressor returns a compressing writer, writing to the
17// provided writer. On Close, any pending data should be flushed.
18type Compressor func(io.Writer) (io.WriteCloser, error)
19
20// Decompressor is a function that wraps a Reader with a decompressing Reader.
21// The decompressed ReadCloser is returned to callers who open files from
22// within the archive.  These callers are responsible for closing this reader
23// when they're finished reading.
24type Decompressor func(io.Reader) io.ReadCloser
25
26var flateWriterPool sync.Pool
27
28func newFlateWriter(w io.Writer) io.WriteCloser {
29	fw, ok := flateWriterPool.Get().(*flate.Writer)
30	if ok {
31		fw.Reset(w)
32	} else {
33		fw, _ = flate.NewWriter(w, 5)
34	}
35	return &pooledFlateWriter{fw: fw}
36}
37
38type pooledFlateWriter struct {
39	mu sync.Mutex // guards Close and Write
40	fw *flate.Writer
41}
42
43func (w *pooledFlateWriter) Write(p []byte) (n int, err error) {
44	w.mu.Lock()
45	defer w.mu.Unlock()
46	if w.fw == nil {
47		return 0, errors.New("Write after Close")
48	}
49	return w.fw.Write(p)
50}
51
52func (w *pooledFlateWriter) Close() error {
53	w.mu.Lock()
54	defer w.mu.Unlock()
55	var err error
56	if w.fw != nil {
57		err = w.fw.Close()
58		flateWriterPool.Put(w.fw)
59		w.fw = nil
60	}
61	return err
62}
63
64var (
65	mu sync.RWMutex // guards compressor and decompressor maps
66
67	compressors = map[uint16]Compressor{
68		Store:   func(w io.Writer) (io.WriteCloser, error) { return &nopCloser{w}, nil },
69		Deflate: func(w io.Writer) (io.WriteCloser, error) { return newFlateWriter(w), nil },
70	}
71
72	decompressors = map[uint16]Decompressor{
73		Store:   ioutil.NopCloser,
74		Deflate: flate.NewReader,
75	}
76)
77
78// RegisterDecompressor allows custom decompressors for a specified method ID.
79func RegisterDecompressor(method uint16, d Decompressor) {
80	mu.Lock()
81	defer mu.Unlock()
82
83	if _, ok := decompressors[method]; ok {
84		panic("decompressor already registered")
85	}
86	decompressors[method] = d
87}
88
89// RegisterCompressor registers custom compressors for a specified method ID.
90// The common methods Store and Deflate are built in.
91func RegisterCompressor(method uint16, comp Compressor) {
92	mu.Lock()
93	defer mu.Unlock()
94
95	if _, ok := compressors[method]; ok {
96		panic("compressor already registered")
97	}
98	compressors[method] = comp
99}
100
101func compressor(method uint16) Compressor {
102	mu.RLock()
103	defer mu.RUnlock()
104	return compressors[method]
105}
106
107func decompressor(method uint16) Decompressor {
108	mu.RLock()
109	defer mu.RUnlock()
110	return decompressors[method]
111}
112