1package filexfer
2
3import (
4	"encoding/binary"
5	"errors"
6)
7
8// Various encoding errors.
9var (
10	ErrShortPacket = errors.New("packet too short")
11	ErrLongPacket  = errors.New("packet too long")
12)
13
14// Buffer wraps up the various encoding details of the SSH format.
15//
16// Data types are encoded as per section 4 from https://tools.ietf.org/html/draft-ietf-secsh-architecture-09#page-8
17type Buffer struct {
18	b   []byte
19	off int
20}
21
22// NewBuffer creates and initializes a new buffer using buf as its initial contents.
23// The new buffer takes ownership of buf, and the caller should not use buf after this call.
24//
25// In most cases, new(Buffer) (or just declaring a Buffer variable) is sufficient to initialize a Buffer.
26func NewBuffer(buf []byte) *Buffer {
27	return &Buffer{
28		b: buf,
29	}
30}
31
32// NewMarshalBuffer creates a new Buffer ready to start marshaling a Packet into.
33// It preallocates enough space for uint32(length), uint8(type), uint32(request-id) and size more bytes.
34func NewMarshalBuffer(size int) *Buffer {
35	return NewBuffer(make([]byte, 4+1+4+size))
36}
37
38// Bytes returns a slice of length b.Len() holding the unconsumed bytes in the Buffer.
39// The slice is valid for use only until the next buffer modification
40// (that is, only until the next call to an Append or Consume method).
41func (b *Buffer) Bytes() []byte {
42	return b.b[b.off:]
43}
44
45// Len returns the number of unconsumed bytes in the buffer.
46func (b *Buffer) Len() int { return len(b.b) - b.off }
47
48// Cap returns the capacity of the buffer’s underlying byte slice,
49// that is, the total space allocated for the buffer’s data.
50func (b *Buffer) Cap() int { return cap(b.b) }
51
52// Reset resets the buffer to be empty, but it retains the underlying storage for use by future Appends.
53func (b *Buffer) Reset() {
54	b.b = b.b[:0]
55	b.off = 0
56}
57
58// StartPacket resets and initializes the buffer to be ready to start marshaling a packet into.
59// It truncates the buffer, reserves space for uint32(length), then appends the given packetType and requestID.
60func (b *Buffer) StartPacket(packetType PacketType, requestID uint32) {
61	b.b, b.off = append(b.b[:0], make([]byte, 4)...), 0
62
63	b.AppendUint8(uint8(packetType))
64	b.AppendUint32(requestID)
65}
66
67// Packet finalizes the packet started from StartPacket.
68// It is expected that this will end the ownership of the underlying byte-slice,
69// and so the returned byte-slices may be reused the same as any other byte-slice,
70// the caller should not use this buffer after this call.
71//
72// It writes the packet body length into the first four bytes of the buffer in network byte order (big endian).
73// The packet body length is the length of this buffer less the 4-byte length itself, plus the length of payload.
74//
75// It is assumed that no Consume methods have been called on this buffer,
76// and so it returns the whole underlying slice.
77func (b *Buffer) Packet(payload []byte) (header, payloadPassThru []byte, err error) {
78	b.PutLength(len(b.b) - 4 + len(payload))
79
80	return b.b, payload, nil
81}
82
83// ConsumeUint8 consumes a single byte from the buffer.
84// If the buffer does not have enough data, it will return ErrShortPacket.
85func (b *Buffer) ConsumeUint8() (uint8, error) {
86	if b.Len() < 1 {
87		return 0, ErrShortPacket
88	}
89
90	var v uint8
91	v, b.off = b.b[b.off], b.off+1
92	return v, nil
93}
94
95// AppendUint8 appends a single byte into the buffer.
96func (b *Buffer) AppendUint8(v uint8) {
97	b.b = append(b.b, v)
98}
99
100// ConsumeBool consumes a single byte from the buffer, and returns true if that byte is non-zero.
101// If the buffer does not have enough data, it will return ErrShortPacket.
102func (b *Buffer) ConsumeBool() (bool, error) {
103	v, err := b.ConsumeUint8()
104	if err != nil {
105		return false, err
106	}
107
108	return v != 0, nil
109}
110
111// AppendBool appends a single bool into the buffer.
112// It encodes it as a single byte, with false as 0, and true as 1.
113func (b *Buffer) AppendBool(v bool) {
114	if v {
115		b.AppendUint8(1)
116	} else {
117		b.AppendUint8(0)
118	}
119}
120
121// ConsumeUint16 consumes a single uint16 from the buffer, in network byte order (big-endian).
122// If the buffer does not have enough data, it will return ErrShortPacket.
123func (b *Buffer) ConsumeUint16() (uint16, error) {
124	if b.Len() < 2 {
125		return 0, ErrShortPacket
126	}
127
128	v := binary.BigEndian.Uint16(b.b[b.off:])
129	b.off += 2
130	return v, nil
131}
132
133// AppendUint16 appends single uint16 into the buffer, in network byte order (big-endian).
134func (b *Buffer) AppendUint16(v uint16) {
135	b.b = append(b.b,
136		byte(v>>8),
137		byte(v>>0),
138	)
139}
140
141// unmarshalUint32 is used internally to read the packet length.
142// It is unsafe, and so not exported.
143// Even within this package, its use should be avoided.
144func unmarshalUint32(b []byte) uint32 {
145	return binary.BigEndian.Uint32(b[:4])
146}
147
148// ConsumeUint32 consumes a single uint32 from the buffer, in network byte order (big-endian).
149// If the buffer does not have enough data, it will return ErrShortPacket.
150func (b *Buffer) ConsumeUint32() (uint32, error) {
151	if b.Len() < 4 {
152		return 0, ErrShortPacket
153	}
154
155	v := binary.BigEndian.Uint32(b.b[b.off:])
156	b.off += 4
157	return v, nil
158}
159
160// AppendUint32 appends a single uint32 into the buffer, in network byte order (big-endian).
161func (b *Buffer) AppendUint32(v uint32) {
162	b.b = append(b.b,
163		byte(v>>24),
164		byte(v>>16),
165		byte(v>>8),
166		byte(v>>0),
167	)
168}
169
170// ConsumeUint64 consumes a single uint64 from the buffer, in network byte order (big-endian).
171// If the buffer does not have enough data, it will return ErrShortPacket.
172func (b *Buffer) ConsumeUint64() (uint64, error) {
173	if b.Len() < 8 {
174		return 0, ErrShortPacket
175	}
176
177	v := binary.BigEndian.Uint64(b.b[b.off:])
178	b.off += 8
179	return v, nil
180}
181
182// AppendUint64 appends a single uint64 into the buffer, in network byte order (big-endian).
183func (b *Buffer) AppendUint64(v uint64) {
184	b.b = append(b.b,
185		byte(v>>56),
186		byte(v>>48),
187		byte(v>>40),
188		byte(v>>32),
189		byte(v>>24),
190		byte(v>>16),
191		byte(v>>8),
192		byte(v>>0),
193	)
194}
195
196// ConsumeInt64 consumes a single int64 from the buffer, in network byte order (big-endian) with two’s complement.
197// If the buffer does not have enough data, it will return ErrShortPacket.
198func (b *Buffer) ConsumeInt64() (int64, error) {
199	u, err := b.ConsumeUint64()
200	if err != nil {
201		return 0, err
202	}
203
204	return int64(u), err
205}
206
207// AppendInt64 appends a single int64 into the buffer, in network byte order (big-endian) with two’s complement.
208func (b *Buffer) AppendInt64(v int64) {
209	b.AppendUint64(uint64(v))
210}
211
212// ConsumeByteSlice consumes a single string of raw binary data from the buffer.
213// A string is a uint32 length, followed by that number of raw bytes.
214// If the buffer does not have enough data, or defines a length larger than available, it will return ErrShortPacket.
215//
216// The returned slice aliases the buffer contents, and is valid only as long as the buffer is not reused
217// (that is, only until the next call to Reset, PutLength, StartPacket, or UnmarshalBinary).
218//
219// In no case will any Consume calls return overlapping slice aliases,
220// and Append calls are guaranteed to not disturb this slice alias.
221func (b *Buffer) ConsumeByteSlice() ([]byte, error) {
222	length, err := b.ConsumeUint32()
223	if err != nil {
224		return nil, err
225	}
226
227	if b.Len() < int(length) {
228		return nil, ErrShortPacket
229	}
230
231	v := b.b[b.off:]
232	if len(v) > int(length) {
233		v = v[:length:length]
234	}
235	b.off += int(length)
236	return v, nil
237}
238
239// AppendByteSlice appends a single string of raw binary data into the buffer.
240// A string is a uint32 length, followed by that number of raw bytes.
241func (b *Buffer) AppendByteSlice(v []byte) {
242	b.AppendUint32(uint32(len(v)))
243	b.b = append(b.b, v...)
244}
245
246// ConsumeString consumes a single string of binary data from the buffer.
247// A string is a uint32 length, followed by that number of raw bytes.
248// If the buffer does not have enough data, or defines a length larger than available, it will return ErrShortPacket.
249//
250// NOTE: Go implicitly assumes that strings contain UTF-8 encoded data.
251// All caveats on using arbitrary binary data in Go strings applies.
252func (b *Buffer) ConsumeString() (string, error) {
253	v, err := b.ConsumeByteSlice()
254	if err != nil {
255		return "", err
256	}
257
258	return string(v), nil
259}
260
261// AppendString appends a single string of binary data into the buffer.
262// A string is a uint32 length, followed by that number of raw bytes.
263func (b *Buffer) AppendString(v string) {
264	b.AppendByteSlice([]byte(v))
265}
266
267// PutLength writes the given size into the first four bytes of the buffer in network byte order (big endian).
268func (b *Buffer) PutLength(size int) {
269	if len(b.b) < 4 {
270		b.b = append(b.b, make([]byte, 4-len(b.b))...)
271	}
272
273	binary.BigEndian.PutUint32(b.b, uint32(size))
274}
275
276// MarshalBinary returns a clone of the full internal buffer.
277func (b *Buffer) MarshalBinary() ([]byte, error) {
278	clone := make([]byte, len(b.b))
279	n := copy(clone, b.b)
280	return clone[:n], nil
281}
282
283// UnmarshalBinary sets the internal buffer of b to be a clone of data, and zeros the internal offset.
284func (b *Buffer) UnmarshalBinary(data []byte) error {
285	if grow := len(data) - len(b.b); grow > 0 {
286		b.b = append(b.b, make([]byte, grow)...)
287	}
288
289	n := copy(b.b, data)
290	b.b = b.b[:n]
291	b.off = 0
292	return nil
293}
294