1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package zip
6
7import (
8	"compress/flate"
9	"errors"
10	"io"
11	"io/ioutil"
12	"sync"
13)
14
15// A Compressor returns a new compressing writer, writing to w.
16// The WriteCloser's Close method must be used to flush pending data to w.
17// The Compressor itself must be safe to invoke from multiple goroutines
18// simultaneously, but each returned writer will be used only by
19// one goroutine at a time.
20type Compressor func(w io.Writer) (io.WriteCloser, error)
21
22// A Decompressor returns a new decompressing reader, reading from r.
23// The ReadCloser's Close method must be used to release associated resources.
24// The Decompressor itself must be safe to invoke from multiple goroutines
25// simultaneously, but each returned reader will be used only by
26// one goroutine at a time.
27type Decompressor func(r io.Reader) io.ReadCloser
28
29var flateWriterPool sync.Pool
30
31func newFlateWriter(w io.Writer) io.WriteCloser {
32	fw, ok := flateWriterPool.Get().(*flate.Writer)
33	if ok {
34		fw.Reset(w)
35	} else {
36		fw, _ = flate.NewWriter(w, 5)
37	}
38	return &pooledFlateWriter{fw: fw}
39}
40
41type pooledFlateWriter struct {
42	mu sync.Mutex // guards Close and Write
43	fw *flate.Writer
44}
45
46func (w *pooledFlateWriter) Write(p []byte) (n int, err error) {
47	w.mu.Lock()
48	defer w.mu.Unlock()
49	if w.fw == nil {
50		return 0, errors.New("Write after Close")
51	}
52	return w.fw.Write(p)
53}
54
55func (w *pooledFlateWriter) Close() error {
56	w.mu.Lock()
57	defer w.mu.Unlock()
58	var err error
59	if w.fw != nil {
60		err = w.fw.Close()
61		flateWriterPool.Put(w.fw)
62		w.fw = nil
63	}
64	return err
65}
66
67var flateReaderPool sync.Pool
68
69func newFlateReader(r io.Reader) io.ReadCloser {
70	fr, ok := flateReaderPool.Get().(io.ReadCloser)
71	if ok {
72		fr.(flate.Resetter).Reset(r, nil)
73	} else {
74		fr = flate.NewReader(r)
75	}
76	return &pooledFlateReader{fr: fr}
77}
78
79type pooledFlateReader struct {
80	mu sync.Mutex // guards Close and Read
81	fr io.ReadCloser
82}
83
84func (r *pooledFlateReader) Read(p []byte) (n int, err error) {
85	r.mu.Lock()
86	defer r.mu.Unlock()
87	if r.fr == nil {
88		return 0, errors.New("Read after Close")
89	}
90	return r.fr.Read(p)
91}
92
93func (r *pooledFlateReader) Close() error {
94	r.mu.Lock()
95	defer r.mu.Unlock()
96	var err error
97	if r.fr != nil {
98		err = r.fr.Close()
99		flateReaderPool.Put(r.fr)
100		r.fr = nil
101	}
102	return err
103}
104
105var (
106	compressors   sync.Map // map[uint16]Compressor
107	decompressors sync.Map // map[uint16]Decompressor
108)
109
110func init() {
111	compressors.Store(Store, Compressor(func(w io.Writer) (io.WriteCloser, error) { return &nopCloser{w}, nil }))
112	compressors.Store(Deflate, Compressor(func(w io.Writer) (io.WriteCloser, error) { return newFlateWriter(w), nil }))
113
114	decompressors.Store(Store, Decompressor(ioutil.NopCloser))
115	decompressors.Store(Deflate, Decompressor(newFlateReader))
116}
117
118// RegisterDecompressor allows custom decompressors for a specified method ID.
119// The common methods Store and Deflate are built in.
120func RegisterDecompressor(method uint16, dcomp Decompressor) {
121	if _, dup := decompressors.LoadOrStore(method, dcomp); dup {
122		panic("decompressor already registered")
123	}
124}
125
126// RegisterCompressor registers custom compressors for a specified method ID.
127// The common methods Store and Deflate are built in.
128func RegisterCompressor(method uint16, comp Compressor) {
129	if _, dup := compressors.LoadOrStore(method, comp); dup {
130		panic("compressor already registered")
131	}
132}
133
134func compressor(method uint16) Compressor {
135	ci, ok := compressors.Load(method)
136	if !ok {
137		return nil
138	}
139	return ci.(Compressor)
140}
141
142func decompressor(method uint16) Decompressor {
143	di, ok := decompressors.Load(method)
144	if !ok {
145		return nil
146	}
147	return di.(Decompressor)
148}
149