1// Copyright 2016, Joe Tsai. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE.md file.
4
5package xflate
6
7import (
8	"bufio"
9	"compress/flate"
10	"io"
11)
12
13// TODO(dsnet): The standard library's version of flate.Reader and flate.Writer
14// do not track the input and output offsets. When we eventually switch over
15// to using the DEFLATE implementation in this repository, we can delete these.
16
17// countReader is a trivial io.Reader that counts the number of bytes read.
18type countReader struct {
19	R io.Reader
20	N int64
21}
22
23func (cr *countReader) Read(buf []byte) (int, error) {
24	n, err := cr.R.Read(buf)
25	cr.N += int64(n)
26	return n, err
27}
28
29// flateReader is a trivial wrapper around flate.Reader keeps tracks of offsets.
30type flateReader struct {
31	InputOffset  int64 // Total number of bytes read from underlying io.Reader
32	OutputOffset int64 // Total number of bytes emitted from Read
33
34	zr io.ReadCloser
35	br *bufio.Reader
36	cr countReader
37}
38
39func newFlateReader(rd io.Reader) (*flateReader, error) {
40	fr := new(flateReader)
41	fr.cr = countReader{R: rd}
42	fr.br = bufio.NewReader(&fr.cr)
43	fr.zr = flate.NewReader(fr.br)
44	return fr, nil
45}
46
47func (fr *flateReader) Reset(rd io.Reader) {
48	*fr = flateReader{zr: fr.zr, br: fr.br}
49	fr.cr = countReader{R: rd}
50	fr.br.Reset(&fr.cr)
51	fr.zr.(flate.Resetter).Reset(fr.br, nil)
52}
53
54func (fr *flateReader) Read(buf []byte) (int, error) {
55	offset := fr.cr.N - int64(fr.br.Buffered())
56	n, err := fr.zr.Read(buf)
57	fr.InputOffset += (fr.cr.N - int64(fr.br.Buffered())) - offset
58	fr.OutputOffset += int64(n)
59	return n, errWrap(err)
60}
61
62// countWriter is a trivial io.Writer that counts the number of bytes written.
63type countWriter struct {
64	W io.Writer
65	N int64
66}
67
68func (cw *countWriter) Write(buf []byte) (int, error) {
69	n, err := cw.W.Write(buf)
70	cw.N += int64(n)
71	return n, err
72}
73
74// flateWriter is a trivial wrapper around flate.Writer keeps tracks of offsets.
75type flateWriter struct {
76	InputOffset  int64 // Total number of bytes issued to Write
77	OutputOffset int64 // Total number of bytes written to underlying io.Writer
78
79	zw *flate.Writer
80	cw countWriter
81}
82
83func newFlateWriter(wr io.Writer, lvl int) (*flateWriter, error) {
84	var err error
85	fw := new(flateWriter)
86	switch lvl {
87	case 0:
88		lvl = flate.DefaultCompression
89	case -1:
90		lvl = flate.NoCompression
91	}
92	fw.cw = countWriter{W: wr}
93	fw.zw, err = flate.NewWriter(&fw.cw, lvl)
94	return fw, errWrap(err)
95}
96
97func (fw *flateWriter) Reset(wr io.Writer) {
98	*fw = flateWriter{zw: fw.zw}
99	fw.cw = countWriter{W: wr}
100	fw.zw.Reset(&fw.cw)
101}
102
103func (fw *flateWriter) Write(buf []byte) (int, error) {
104	offset := fw.cw.N
105	n, err := fw.zw.Write(buf)
106	fw.OutputOffset += fw.cw.N - offset
107	fw.InputOffset += int64(n)
108	return n, errWrap(err)
109}
110
111func (fw *flateWriter) Flush() error {
112	offset := fw.cw.N
113	err := fw.zw.Flush()
114	fw.OutputOffset += fw.cw.N - offset
115	return errWrap(err)
116}
117