1// Copyright 2016 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package ioutil
16
17import (
18	"io"
19)
20
21var defaultBufferBytes = 128 * 1024
22
23// PageWriter implements the io.Writer interface so that writes will
24// either be in page chunks or from flushing.
25type PageWriter struct {
26	w io.Writer
27	// pageOffset tracks the page offset of the base of the buffer
28	pageOffset int
29	// pageBytes is the number of bytes per page
30	pageBytes int
31	// bufferedBytes counts the number of bytes pending for write in the buffer
32	bufferedBytes int
33	// buf holds the write buffer
34	buf []byte
35	// bufWatermarkBytes is the number of bytes the buffer can hold before it needs
36	// to be flushed. It is less than len(buf) so there is space for slack writes
37	// to bring the writer to page alignment.
38	bufWatermarkBytes int
39}
40
41// NewPageWriter creates a new PageWriter. pageBytes is the number of bytes
42// to write per page. pageOffset is the starting offset of io.Writer.
43func NewPageWriter(w io.Writer, pageBytes, pageOffset int) *PageWriter {
44	return &PageWriter{
45		w:                 w,
46		pageOffset:        pageOffset,
47		pageBytes:         pageBytes,
48		buf:               make([]byte, defaultBufferBytes+pageBytes),
49		bufWatermarkBytes: defaultBufferBytes,
50	}
51}
52
53func (pw *PageWriter) Write(p []byte) (n int, err error) {
54	if len(p)+pw.bufferedBytes <= pw.bufWatermarkBytes {
55		// no overflow
56		copy(pw.buf[pw.bufferedBytes:], p)
57		pw.bufferedBytes += len(p)
58		return len(p), nil
59	}
60	// complete the slack page in the buffer if unaligned
61	slack := pw.pageBytes - ((pw.pageOffset + pw.bufferedBytes) % pw.pageBytes)
62	if slack != pw.pageBytes {
63		partial := slack > len(p)
64		if partial {
65			// not enough data to complete the slack page
66			slack = len(p)
67		}
68		// special case: writing to slack page in buffer
69		copy(pw.buf[pw.bufferedBytes:], p[:slack])
70		pw.bufferedBytes += slack
71		n = slack
72		p = p[slack:]
73		if partial {
74			// avoid forcing an unaligned flush
75			return n, nil
76		}
77	}
78	// buffer contents are now page-aligned; clear out
79	if err = pw.Flush(); err != nil {
80		return n, err
81	}
82	// directly write all complete pages without copying
83	if len(p) > pw.pageBytes {
84		pages := len(p) / pw.pageBytes
85		c, werr := pw.w.Write(p[:pages*pw.pageBytes])
86		n += c
87		if werr != nil {
88			return n, werr
89		}
90		p = p[pages*pw.pageBytes:]
91	}
92	// write remaining tail to buffer
93	c, werr := pw.Write(p)
94	n += c
95	return n, werr
96}
97
98func (pw *PageWriter) Flush() error {
99	if pw.bufferedBytes == 0 {
100		return nil
101	}
102	_, err := pw.w.Write(pw.buf[:pw.bufferedBytes])
103	pw.pageOffset = (pw.pageOffset + pw.bufferedBytes) % pw.pageBytes
104	pw.bufferedBytes = 0
105	return err
106}
107