1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssh
6
7import (
8	"encoding/binary"
9	"errors"
10	"fmt"
11	"io"
12	"log"
13	"sync"
14)
15
16const (
17	minPacketLength = 9
18	// channelMaxPacket contains the maximum number of bytes that will be
19	// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
20	// the minimum.
21	channelMaxPacket = 1 << 15
22	// We follow OpenSSH here.
23	channelWindowSize = 64 * channelMaxPacket
24)
25
26// NewChannel represents an incoming request to a channel. It must either be
27// accepted for use by calling Accept, or rejected by calling Reject.
28type NewChannel interface {
29	// Accept accepts the channel creation request. It returns the Channel
30	// and a Go channel containing SSH requests. The Go channel must be
31	// serviced otherwise the Channel will hang.
32	Accept() (Channel, <-chan *Request, error)
33
34	// Reject rejects the channel creation request. After calling
35	// this, no other methods on the Channel may be called.
36	Reject(reason RejectionReason, message string) error
37
38	// ChannelType returns the type of the channel, as supplied by the
39	// client.
40	ChannelType() string
41
42	// ExtraData returns the arbitrary payload for this channel, as supplied
43	// by the client. This data is specific to the channel type.
44	ExtraData() []byte
45}
46
47// A Channel is an ordered, reliable, flow-controlled, duplex stream
48// that is multiplexed over an SSH connection.
49type Channel interface {
50	// Read reads up to len(data) bytes from the channel.
51	Read(data []byte) (int, error)
52
53	// Write writes len(data) bytes to the channel.
54	Write(data []byte) (int, error)
55
56	// Close signals end of channel use. No data may be sent after this
57	// call.
58	Close() error
59
60	// CloseWrite signals the end of sending in-band
61	// data. Requests may still be sent, and the other side may
62	// still send data
63	CloseWrite() error
64
65	// SendRequest sends a channel request.  If wantReply is true,
66	// it will wait for a reply and return the result as a
67	// boolean, otherwise the return value will be false. Channel
68	// requests are out-of-band messages so they may be sent even
69	// if the data stream is closed or blocked by flow control.
70	// If the channel is closed before a reply is returned, io.EOF
71	// is returned.
72	SendRequest(name string, wantReply bool, payload []byte) (bool, error)
73
74	// Stderr returns an io.ReadWriter that writes to this channel
75	// with the extended data type set to stderr. Stderr may
76	// safely be read and written from a different goroutine than
77	// Read and Write respectively.
78	Stderr() io.ReadWriter
79}
80
81// Request is a request sent outside of the normal stream of
82// data. Requests can either be specific to an SSH channel, or they
83// can be global.
84type Request struct {
85	Type      string
86	WantReply bool
87	Payload   []byte
88
89	ch  *channel
90	mux *mux
91}
92
93// Reply sends a response to a request. It must be called for all requests
94// where WantReply is true and is a no-op otherwise. The payload argument is
95// ignored for replies to channel-specific requests.
96func (r *Request) Reply(ok bool, payload []byte) error {
97	if !r.WantReply {
98		return nil
99	}
100
101	if r.ch == nil {
102		return r.mux.ackRequest(ok, payload)
103	}
104
105	return r.ch.ackRequest(ok)
106}
107
108// RejectionReason is an enumeration used when rejecting channel creation
109// requests. See RFC 4254, section 5.1.
110type RejectionReason uint32
111
112const (
113	Prohibited RejectionReason = iota + 1
114	ConnectionFailed
115	UnknownChannelType
116	ResourceShortage
117)
118
119// String converts the rejection reason to human readable form.
120func (r RejectionReason) String() string {
121	switch r {
122	case Prohibited:
123		return "administratively prohibited"
124	case ConnectionFailed:
125		return "connect failed"
126	case UnknownChannelType:
127		return "unknown channel type"
128	case ResourceShortage:
129		return "resource shortage"
130	}
131	return fmt.Sprintf("unknown reason %d", int(r))
132}
133
134func min(a uint32, b int) uint32 {
135	if a < uint32(b) {
136		return a
137	}
138	return uint32(b)
139}
140
141type channelDirection uint8
142
143const (
144	channelInbound channelDirection = iota
145	channelOutbound
146)
147
148// channel is an implementation of the Channel interface that works
149// with the mux class.
150type channel struct {
151	// R/O after creation
152	chanType          string
153	extraData         []byte
154	localId, remoteId uint32
155
156	// maxIncomingPayload and maxRemotePayload are the maximum
157	// payload sizes of normal and extended data packets for
158	// receiving and sending, respectively. The wire packet will
159	// be 9 or 13 bytes larger (excluding encryption overhead).
160	maxIncomingPayload uint32
161	maxRemotePayload   uint32
162
163	mux *mux
164
165	// decided is set to true if an accept or reject message has been sent
166	// (for outbound channels) or received (for inbound channels).
167	decided bool
168
169	// direction contains either channelOutbound, for channels created
170	// locally, or channelInbound, for channels created by the peer.
171	direction channelDirection
172
173	// Pending internal channel messages.
174	msg chan interface{}
175
176	// Since requests have no ID, there can be only one request
177	// with WantReply=true outstanding.  This lock is held by a
178	// goroutine that has such an outgoing request pending.
179	sentRequestMu sync.Mutex
180
181	incomingRequests chan *Request
182
183	sentEOF bool
184
185	// thread-safe data
186	remoteWin  window
187	pending    *buffer
188	extPending *buffer
189
190	// windowMu protects myWindow, the flow-control window.
191	windowMu sync.Mutex
192	myWindow uint32
193
194	// writeMu serializes calls to mux.conn.writePacket() and
195	// protects sentClose and packetPool. This mutex must be
196	// different from windowMu, as writePacket can block if there
197	// is a key exchange pending.
198	writeMu   sync.Mutex
199	sentClose bool
200
201	// packetPool has a buffer for each extended channel ID to
202	// save allocations during writes.
203	packetPool map[uint32][]byte
204}
205
206// writePacket sends a packet. If the packet is a channel close, it updates
207// sentClose. This method takes the lock c.writeMu.
208func (ch *channel) writePacket(packet []byte) error {
209	ch.writeMu.Lock()
210	if ch.sentClose {
211		ch.writeMu.Unlock()
212		return io.EOF
213	}
214	ch.sentClose = (packet[0] == msgChannelClose)
215	err := ch.mux.conn.writePacket(packet)
216	ch.writeMu.Unlock()
217	return err
218}
219
220func (ch *channel) sendMessage(msg interface{}) error {
221	if debugMux {
222		log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
223	}
224
225	p := Marshal(msg)
226	binary.BigEndian.PutUint32(p[1:], ch.remoteId)
227	return ch.writePacket(p)
228}
229
230// WriteExtended writes data to a specific extended stream. These streams are
231// used, for example, for stderr.
232func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
233	if ch.sentEOF {
234		return 0, io.EOF
235	}
236	// 1 byte message type, 4 bytes remoteId, 4 bytes data length
237	opCode := byte(msgChannelData)
238	headerLength := uint32(9)
239	if extendedCode > 0 {
240		headerLength += 4
241		opCode = msgChannelExtendedData
242	}
243
244	ch.writeMu.Lock()
245	packet := ch.packetPool[extendedCode]
246	// We don't remove the buffer from packetPool, so
247	// WriteExtended calls from different goroutines will be
248	// flagged as errors by the race detector.
249	ch.writeMu.Unlock()
250
251	for len(data) > 0 {
252		space := min(ch.maxRemotePayload, len(data))
253		if space, err = ch.remoteWin.reserve(space); err != nil {
254			return n, err
255		}
256		if want := headerLength + space; uint32(cap(packet)) < want {
257			packet = make([]byte, want)
258		} else {
259			packet = packet[:want]
260		}
261
262		todo := data[:space]
263
264		packet[0] = opCode
265		binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
266		if extendedCode > 0 {
267			binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
268		}
269		binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
270		copy(packet[headerLength:], todo)
271		if err = ch.writePacket(packet); err != nil {
272			return n, err
273		}
274
275		n += len(todo)
276		data = data[len(todo):]
277	}
278
279	ch.writeMu.Lock()
280	ch.packetPool[extendedCode] = packet
281	ch.writeMu.Unlock()
282
283	return n, err
284}
285
286func (ch *channel) handleData(packet []byte) error {
287	headerLen := 9
288	isExtendedData := packet[0] == msgChannelExtendedData
289	if isExtendedData {
290		headerLen = 13
291	}
292	if len(packet) < headerLen {
293		// malformed data packet
294		return parseError(packet[0])
295	}
296
297	var extended uint32
298	if isExtendedData {
299		extended = binary.BigEndian.Uint32(packet[5:])
300	}
301
302	length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
303	if length == 0 {
304		return nil
305	}
306	if length > ch.maxIncomingPayload {
307		// TODO(hanwen): should send Disconnect?
308		return errors.New("ssh: incoming packet exceeds maximum payload size")
309	}
310
311	data := packet[headerLen:]
312	if length != uint32(len(data)) {
313		return errors.New("ssh: wrong packet length")
314	}
315
316	ch.windowMu.Lock()
317	if ch.myWindow < length {
318		ch.windowMu.Unlock()
319		// TODO(hanwen): should send Disconnect with reason?
320		return errors.New("ssh: remote side wrote too much")
321	}
322	ch.myWindow -= length
323	ch.windowMu.Unlock()
324
325	if extended == 1 {
326		ch.extPending.write(data)
327	} else if extended > 0 {
328		// discard other extended data.
329	} else {
330		ch.pending.write(data)
331	}
332	return nil
333}
334
335func (c *channel) adjustWindow(n uint32) error {
336	c.windowMu.Lock()
337	// Since myWindow is managed on our side, and can never exceed
338	// the initial window setting, we don't worry about overflow.
339	c.myWindow += uint32(n)
340	c.windowMu.Unlock()
341	return c.sendMessage(windowAdjustMsg{
342		AdditionalBytes: uint32(n),
343	})
344}
345
346func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
347	switch extended {
348	case 1:
349		n, err = c.extPending.Read(data)
350	case 0:
351		n, err = c.pending.Read(data)
352	default:
353		return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
354	}
355
356	if n > 0 {
357		err = c.adjustWindow(uint32(n))
358		// sendWindowAdjust can return io.EOF if the remote
359		// peer has closed the connection, however we want to
360		// defer forwarding io.EOF to the caller of Read until
361		// the buffer has been drained.
362		if n > 0 && err == io.EOF {
363			err = nil
364		}
365	}
366
367	return n, err
368}
369
370func (c *channel) close() {
371	c.pending.eof()
372	c.extPending.eof()
373	close(c.msg)
374	close(c.incomingRequests)
375	c.writeMu.Lock()
376	// This is not necessary for a normal channel teardown, but if
377	// there was another error, it is.
378	c.sentClose = true
379	c.writeMu.Unlock()
380	// Unblock writers.
381	c.remoteWin.close()
382}
383
384// responseMessageReceived is called when a success or failure message is
385// received on a channel to check that such a message is reasonable for the
386// given channel.
387func (ch *channel) responseMessageReceived() error {
388	if ch.direction == channelInbound {
389		return errors.New("ssh: channel response message received on inbound channel")
390	}
391	if ch.decided {
392		return errors.New("ssh: duplicate response received for channel")
393	}
394	ch.decided = true
395	return nil
396}
397
398func (ch *channel) handlePacket(packet []byte) error {
399	switch packet[0] {
400	case msgChannelData, msgChannelExtendedData:
401		return ch.handleData(packet)
402	case msgChannelClose:
403		ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
404		ch.mux.chanList.remove(ch.localId)
405		ch.close()
406		return nil
407	case msgChannelEOF:
408		// RFC 4254 is mute on how EOF affects dataExt messages but
409		// it is logical to signal EOF at the same time.
410		ch.extPending.eof()
411		ch.pending.eof()
412		return nil
413	}
414
415	decoded, err := decode(packet)
416	if err != nil {
417		return err
418	}
419
420	switch msg := decoded.(type) {
421	case *channelOpenFailureMsg:
422		if err := ch.responseMessageReceived(); err != nil {
423			return err
424		}
425		ch.mux.chanList.remove(msg.PeersID)
426		ch.msg <- msg
427	case *channelOpenConfirmMsg:
428		if err := ch.responseMessageReceived(); err != nil {
429			return err
430		}
431		if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
432			return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
433		}
434		ch.remoteId = msg.MyID
435		ch.maxRemotePayload = msg.MaxPacketSize
436		ch.remoteWin.add(msg.MyWindow)
437		ch.msg <- msg
438	case *windowAdjustMsg:
439		if !ch.remoteWin.add(msg.AdditionalBytes) {
440			return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
441		}
442	case *channelRequestMsg:
443		req := Request{
444			Type:      msg.Request,
445			WantReply: msg.WantReply,
446			Payload:   msg.RequestSpecificData,
447			ch:        ch,
448		}
449
450		ch.incomingRequests <- &req
451	default:
452		ch.msg <- msg
453	}
454	return nil
455}
456
457func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
458	ch := &channel{
459		remoteWin:        window{Cond: newCond()},
460		myWindow:         channelWindowSize,
461		pending:          newBuffer(),
462		extPending:       newBuffer(),
463		direction:        direction,
464		incomingRequests: make(chan *Request, chanSize),
465		msg:              make(chan interface{}, chanSize),
466		chanType:         chanType,
467		extraData:        extraData,
468		mux:              m,
469		packetPool:       make(map[uint32][]byte),
470	}
471	ch.localId = m.chanList.add(ch)
472	return ch
473}
474
475var errUndecided = errors.New("ssh: must Accept or Reject channel")
476var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
477
478type extChannel struct {
479	code uint32
480	ch   *channel
481}
482
483func (e *extChannel) Write(data []byte) (n int, err error) {
484	return e.ch.WriteExtended(data, e.code)
485}
486
487func (e *extChannel) Read(data []byte) (n int, err error) {
488	return e.ch.ReadExtended(data, e.code)
489}
490
491func (ch *channel) Accept() (Channel, <-chan *Request, error) {
492	if ch.decided {
493		return nil, nil, errDecidedAlready
494	}
495	ch.maxIncomingPayload = channelMaxPacket
496	confirm := channelOpenConfirmMsg{
497		PeersID:       ch.remoteId,
498		MyID:          ch.localId,
499		MyWindow:      ch.myWindow,
500		MaxPacketSize: ch.maxIncomingPayload,
501	}
502	ch.decided = true
503	if err := ch.sendMessage(confirm); err != nil {
504		return nil, nil, err
505	}
506
507	return ch, ch.incomingRequests, nil
508}
509
510func (ch *channel) Reject(reason RejectionReason, message string) error {
511	if ch.decided {
512		return errDecidedAlready
513	}
514	reject := channelOpenFailureMsg{
515		PeersID:  ch.remoteId,
516		Reason:   reason,
517		Message:  message,
518		Language: "en",
519	}
520	ch.decided = true
521	return ch.sendMessage(reject)
522}
523
524func (ch *channel) Read(data []byte) (int, error) {
525	if !ch.decided {
526		return 0, errUndecided
527	}
528	return ch.ReadExtended(data, 0)
529}
530
531func (ch *channel) Write(data []byte) (int, error) {
532	if !ch.decided {
533		return 0, errUndecided
534	}
535	return ch.WriteExtended(data, 0)
536}
537
538func (ch *channel) CloseWrite() error {
539	if !ch.decided {
540		return errUndecided
541	}
542	ch.sentEOF = true
543	return ch.sendMessage(channelEOFMsg{
544		PeersID: ch.remoteId})
545}
546
547func (ch *channel) Close() error {
548	if !ch.decided {
549		return errUndecided
550	}
551
552	return ch.sendMessage(channelCloseMsg{
553		PeersID: ch.remoteId})
554}
555
556// Extended returns an io.ReadWriter that sends and receives data on the given,
557// SSH extended stream. Such streams are used, for example, for stderr.
558func (ch *channel) Extended(code uint32) io.ReadWriter {
559	if !ch.decided {
560		return nil
561	}
562	return &extChannel{code, ch}
563}
564
565func (ch *channel) Stderr() io.ReadWriter {
566	return ch.Extended(1)
567}
568
569func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
570	if !ch.decided {
571		return false, errUndecided
572	}
573
574	if wantReply {
575		ch.sentRequestMu.Lock()
576		defer ch.sentRequestMu.Unlock()
577	}
578
579	msg := channelRequestMsg{
580		PeersID:             ch.remoteId,
581		Request:             name,
582		WantReply:           wantReply,
583		RequestSpecificData: payload,
584	}
585
586	if err := ch.sendMessage(msg); err != nil {
587		return false, err
588	}
589
590	if wantReply {
591		m, ok := (<-ch.msg)
592		if !ok {
593			return false, io.EOF
594		}
595		switch m.(type) {
596		case *channelRequestFailureMsg:
597			return false, nil
598		case *channelRequestSuccessMsg:
599			return true, nil
600		default:
601			return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
602		}
603	}
604
605	return false, nil
606}
607
608// ackRequest either sends an ack or nack to the channel request.
609func (ch *channel) ackRequest(ok bool) error {
610	if !ch.decided {
611		return errUndecided
612	}
613
614	var msg interface{}
615	if !ok {
616		msg = channelRequestFailureMsg{
617			PeersID: ch.remoteId,
618		}
619	} else {
620		msg = channelRequestSuccessMsg{
621			PeersID: ch.remoteId,
622		}
623	}
624	return ch.sendMessage(msg)
625}
626
627func (ch *channel) ChannelType() string {
628	return ch.chanType
629}
630
631func (ch *channel) ExtraData() []byte {
632	return ch.extraData
633}
634