1package stdcopy
2
3import (
4	"bytes"
5	"encoding/binary"
6	"errors"
7	"fmt"
8	"io"
9	"sync"
10)
11
12// StdType is the type of standard stream
13// a writer can multiplex to.
14type StdType byte
15
16const (
17	// Stdin represents standard input stream type.
18	Stdin StdType = iota
19	// Stdout represents standard output stream type.
20	Stdout
21	// Stderr represents standard error steam type.
22	Stderr
23	// Systemerr represents errors originating from the system that make it
24	// into the the multiplexed stream.
25	Systemerr
26
27	stdWriterPrefixLen = 8
28	stdWriterFdIndex   = 0
29	stdWriterSizeIndex = 4
30
31	startingBufLen = 32*1024 + stdWriterPrefixLen + 1
32)
33
34var bufPool = &sync.Pool{New: func() interface{} { return bytes.NewBuffer(nil) }}
35
36// stdWriter is wrapper of io.Writer with extra customized info.
37type stdWriter struct {
38	io.Writer
39	prefix byte
40}
41
42// Write sends the buffer to the underneath writer.
43// It inserts the prefix header before the buffer,
44// so stdcopy.StdCopy knows where to multiplex the output.
45// It makes stdWriter to implement io.Writer.
46func (w *stdWriter) Write(p []byte) (n int, err error) {
47	if w == nil || w.Writer == nil {
48		return 0, errors.New("Writer not instantiated")
49	}
50	if p == nil {
51		return 0, nil
52	}
53
54	header := [stdWriterPrefixLen]byte{stdWriterFdIndex: w.prefix}
55	binary.BigEndian.PutUint32(header[stdWriterSizeIndex:], uint32(len(p)))
56	buf := bufPool.Get().(*bytes.Buffer)
57	buf.Write(header[:])
58	buf.Write(p)
59
60	n, err = w.Writer.Write(buf.Bytes())
61	n -= stdWriterPrefixLen
62	if n < 0 {
63		n = 0
64	}
65
66	buf.Reset()
67	bufPool.Put(buf)
68	return
69}
70
71// NewStdWriter instantiates a new Writer.
72// Everything written to it will be encapsulated using a custom format,
73// and written to the underlying `w` stream.
74// This allows multiple write streams (e.g. stdout and stderr) to be muxed into a single connection.
75// `t` indicates the id of the stream to encapsulate.
76// It can be stdcopy.Stdin, stdcopy.Stdout, stdcopy.Stderr.
77func NewStdWriter(w io.Writer, t StdType) io.Writer {
78	return &stdWriter{
79		Writer: w,
80		prefix: byte(t),
81	}
82}
83
84// StdCopy is a modified version of io.Copy.
85//
86// StdCopy will demultiplex `src`, assuming that it contains two streams,
87// previously multiplexed together using a StdWriter instance.
88// As it reads from `src`, StdCopy will write to `dstout` and `dsterr`.
89//
90// StdCopy will read until it hits EOF on `src`. It will then return a nil error.
91// In other words: if `err` is non nil, it indicates a real underlying error.
92//
93// `written` will hold the total number of bytes written to `dstout` and `dsterr`.
94func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error) {
95	var (
96		buf       = make([]byte, startingBufLen)
97		bufLen    = len(buf)
98		nr, nw    int
99		er, ew    error
100		out       io.Writer
101		frameSize int
102	)
103
104	for {
105		// Make sure we have at least a full header
106		for nr < stdWriterPrefixLen {
107			var nr2 int
108			nr2, er = src.Read(buf[nr:])
109			nr += nr2
110			if er == io.EOF {
111				if nr < stdWriterPrefixLen {
112					return written, nil
113				}
114				break
115			}
116			if er != nil {
117				return 0, er
118			}
119		}
120
121		stream := StdType(buf[stdWriterFdIndex])
122		// Check the first byte to know where to write
123		switch stream {
124		case Stdin:
125			fallthrough
126		case Stdout:
127			// Write on stdout
128			out = dstout
129		case Stderr:
130			// Write on stderr
131			out = dsterr
132		case Systemerr:
133			// If we're on Systemerr, we won't write anywhere.
134			// NB: if this code changes later, make sure you don't try to write
135			// to outstream if Systemerr is the stream
136			out = nil
137		default:
138			return 0, fmt.Errorf("Unrecognized input header: %d", buf[stdWriterFdIndex])
139		}
140
141		// Retrieve the size of the frame
142		frameSize = int(binary.BigEndian.Uint32(buf[stdWriterSizeIndex : stdWriterSizeIndex+4]))
143
144		// Check if the buffer is big enough to read the frame.
145		// Extend it if necessary.
146		if frameSize+stdWriterPrefixLen > bufLen {
147			buf = append(buf, make([]byte, frameSize+stdWriterPrefixLen-bufLen+1)...)
148			bufLen = len(buf)
149		}
150
151		// While the amount of bytes read is less than the size of the frame + header, we keep reading
152		for nr < frameSize+stdWriterPrefixLen {
153			var nr2 int
154			nr2, er = src.Read(buf[nr:])
155			nr += nr2
156			if er == io.EOF {
157				if nr < frameSize+stdWriterPrefixLen {
158					return written, nil
159				}
160				break
161			}
162			if er != nil {
163				return 0, er
164			}
165		}
166
167		// we might have an error from the source mixed up in our multiplexed
168		// stream. if we do, return it.
169		if stream == Systemerr {
170			return written, fmt.Errorf("error from daemon in stream: %s", string(buf[stdWriterPrefixLen:frameSize+stdWriterPrefixLen]))
171		}
172
173		// Write the retrieved frame (without header)
174		nw, ew = out.Write(buf[stdWriterPrefixLen : frameSize+stdWriterPrefixLen])
175		if ew != nil {
176			return 0, ew
177		}
178
179		// If the frame has not been fully written: error
180		if nw != frameSize {
181			return 0, io.ErrShortWrite
182		}
183		written += int64(nw)
184
185		// Move the rest of the buffer to the beginning
186		copy(buf, buf[frameSize+stdWriterPrefixLen:])
187		// Move the index
188		nr -= frameSize + stdWriterPrefixLen
189	}
190}
191