1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssh
6
7import (
8	"encoding/binary"
9	"errors"
10	"fmt"
11	"io"
12	"log"
13	"sync"
14)
15
16const (
17	minPacketLength = 9
18	// channelMaxPacket contains the maximum number of bytes that will be
19	// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
20	// the minimum.
21	channelMaxPacket = 1 << 15
22	// We follow OpenSSH here.
23	channelWindowSize = 64 * channelMaxPacket
24)
25
26// NewChannel represents an incoming request to a channel. It must either be
27// accepted for use by calling Accept, or rejected by calling Reject.
28type NewChannel interface {
29	// Accept accepts the channel creation request. It returns the Channel
30	// and a Go channel containing SSH requests. The Go channel must be
31	// serviced otherwise the Channel will hang.
32	Accept() (Channel, <-chan *Request, error)
33
34	// Reject rejects the channel creation request. After calling
35	// this, no other methods on the Channel may be called.
36	Reject(reason RejectionReason, message string) error
37
38	// ChannelType returns the type of the channel, as supplied by the
39	// client.
40	ChannelType() string
41
42	// ExtraData returns the arbitrary payload for this channel, as supplied
43	// by the client. This data is specific to the channel type.
44	ExtraData() []byte
45}
46
47// A Channel is an ordered, reliable, flow-controlled, duplex stream
48// that is multiplexed over an SSH connection.
49type Channel interface {
50	// Read reads up to len(data) bytes from the channel.
51	Read(data []byte) (int, error)
52
53	// Write writes len(data) bytes to the channel.
54	Write(data []byte) (int, error)
55
56	// Close signals end of channel use. No data may be sent after this
57	// call.
58	Close() error
59
60	// CloseWrite signals the end of sending in-band
61	// data. Requests may still be sent, and the other side may
62	// still send data
63	CloseWrite() error
64
65	// SendRequest sends a channel request.  If wantReply is true,
66	// it will wait for a reply and return the result as a
67	// boolean, otherwise the return value will be false. Channel
68	// requests are out-of-band messages so they may be sent even
69	// if the data stream is closed or blocked by flow control.
70	SendRequest(name string, wantReply bool, payload []byte) (bool, error)
71
72	// Stderr returns an io.ReadWriter that writes to this channel
73	// with the extended data type set to stderr. Stderr may
74	// safely be read and written from a different goroutine than
75	// Read and Write respectively.
76	Stderr() io.ReadWriter
77}
78
79// Request is a request sent outside of the normal stream of
80// data. Requests can either be specific to an SSH channel, or they
81// can be global.
82type Request struct {
83	Type      string
84	WantReply bool
85	Payload   []byte
86
87	ch  *channel
88	mux *mux
89}
90
91// Reply sends a response to a request. It must be called for all requests
92// where WantReply is true and is a no-op otherwise. The payload argument is
93// ignored for replies to channel-specific requests.
94func (r *Request) Reply(ok bool, payload []byte) error {
95	if !r.WantReply {
96		return nil
97	}
98
99	if r.ch == nil {
100		return r.mux.ackRequest(ok, payload)
101	}
102
103	return r.ch.ackRequest(ok)
104}
105
106// RejectionReason is an enumeration used when rejecting channel creation
107// requests. See RFC 4254, section 5.1.
108type RejectionReason uint32
109
110const (
111	Prohibited RejectionReason = iota + 1
112	ConnectionFailed
113	UnknownChannelType
114	ResourceShortage
115)
116
117// String converts the rejection reason to human readable form.
118func (r RejectionReason) String() string {
119	switch r {
120	case Prohibited:
121		return "administratively prohibited"
122	case ConnectionFailed:
123		return "connect failed"
124	case UnknownChannelType:
125		return "unknown channel type"
126	case ResourceShortage:
127		return "resource shortage"
128	}
129	return fmt.Sprintf("unknown reason %d", int(r))
130}
131
132func min(a uint32, b int) uint32 {
133	if a < uint32(b) {
134		return a
135	}
136	return uint32(b)
137}
138
139type channelDirection uint8
140
141const (
142	channelInbound channelDirection = iota
143	channelOutbound
144)
145
146// channel is an implementation of the Channel interface that works
147// with the mux class.
148type channel struct {
149	// R/O after creation
150	chanType          string
151	extraData         []byte
152	localId, remoteId uint32
153
154	// maxIncomingPayload and maxRemotePayload are the maximum
155	// payload sizes of normal and extended data packets for
156	// receiving and sending, respectively. The wire packet will
157	// be 9 or 13 bytes larger (excluding encryption overhead).
158	maxIncomingPayload uint32
159	maxRemotePayload   uint32
160
161	mux *mux
162
163	// decided is set to true if an accept or reject message has been sent
164	// (for outbound channels) or received (for inbound channels).
165	decided bool
166
167	// direction contains either channelOutbound, for channels created
168	// locally, or channelInbound, for channels created by the peer.
169	direction channelDirection
170
171	// Pending internal channel messages.
172	msg chan interface{}
173
174	// Since requests have no ID, there can be only one request
175	// with WantReply=true outstanding.  This lock is held by a
176	// goroutine that has such an outgoing request pending.
177	sentRequestMu sync.Mutex
178
179	incomingRequests chan *Request
180
181	sentEOF bool
182
183	// thread-safe data
184	remoteWin  window
185	pending    *buffer
186	extPending *buffer
187
188	// windowMu protects myWindow, the flow-control window.
189	windowMu sync.Mutex
190	myWindow uint32
191
192	// writeMu serializes calls to mux.conn.writePacket() and
193	// protects sentClose and packetPool. This mutex must be
194	// different from windowMu, as writePacket can block if there
195	// is a key exchange pending.
196	writeMu   sync.Mutex
197	sentClose bool
198
199	// packetPool has a buffer for each extended channel ID to
200	// save allocations during writes.
201	packetPool map[uint32][]byte
202}
203
204// writePacket sends a packet. If the packet is a channel close, it updates
205// sentClose. This method takes the lock c.writeMu.
206func (c *channel) writePacket(packet []byte) error {
207	c.writeMu.Lock()
208	if c.sentClose {
209		c.writeMu.Unlock()
210		return io.EOF
211	}
212	c.sentClose = (packet[0] == msgChannelClose)
213	err := c.mux.conn.writePacket(packet)
214	c.writeMu.Unlock()
215	return err
216}
217
218func (c *channel) sendMessage(msg interface{}) error {
219	if debugMux {
220		log.Printf("send %d: %#v", c.mux.chanList.offset, msg)
221	}
222
223	p := Marshal(msg)
224	binary.BigEndian.PutUint32(p[1:], c.remoteId)
225	return c.writePacket(p)
226}
227
228// WriteExtended writes data to a specific extended stream. These streams are
229// used, for example, for stderr.
230func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
231	if c.sentEOF {
232		return 0, io.EOF
233	}
234	// 1 byte message type, 4 bytes remoteId, 4 bytes data length
235	opCode := byte(msgChannelData)
236	headerLength := uint32(9)
237	if extendedCode > 0 {
238		headerLength += 4
239		opCode = msgChannelExtendedData
240	}
241
242	c.writeMu.Lock()
243	packet := c.packetPool[extendedCode]
244	// We don't remove the buffer from packetPool, so
245	// WriteExtended calls from different goroutines will be
246	// flagged as errors by the race detector.
247	c.writeMu.Unlock()
248
249	for len(data) > 0 {
250		space := min(c.maxRemotePayload, len(data))
251		if space, err = c.remoteWin.reserve(space); err != nil {
252			return n, err
253		}
254		if want := headerLength + space; uint32(cap(packet)) < want {
255			packet = make([]byte, want)
256		} else {
257			packet = packet[:want]
258		}
259
260		todo := data[:space]
261
262		packet[0] = opCode
263		binary.BigEndian.PutUint32(packet[1:], c.remoteId)
264		if extendedCode > 0 {
265			binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
266		}
267		binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
268		copy(packet[headerLength:], todo)
269		if err = c.writePacket(packet); err != nil {
270			return n, err
271		}
272
273		n += len(todo)
274		data = data[len(todo):]
275	}
276
277	c.writeMu.Lock()
278	c.packetPool[extendedCode] = packet
279	c.writeMu.Unlock()
280
281	return n, err
282}
283
284func (c *channel) handleData(packet []byte) error {
285	headerLen := 9
286	isExtendedData := packet[0] == msgChannelExtendedData
287	if isExtendedData {
288		headerLen = 13
289	}
290	if len(packet) < headerLen {
291		// malformed data packet
292		return parseError(packet[0])
293	}
294
295	var extended uint32
296	if isExtendedData {
297		extended = binary.BigEndian.Uint32(packet[5:])
298	}
299
300	length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
301	if length == 0 {
302		return nil
303	}
304	if length > c.maxIncomingPayload {
305		// TODO(hanwen): should send Disconnect?
306		return errors.New("ssh: incoming packet exceeds maximum payload size")
307	}
308
309	data := packet[headerLen:]
310	if length != uint32(len(data)) {
311		return errors.New("ssh: wrong packet length")
312	}
313
314	c.windowMu.Lock()
315	if c.myWindow < length {
316		c.windowMu.Unlock()
317		// TODO(hanwen): should send Disconnect with reason?
318		return errors.New("ssh: remote side wrote too much")
319	}
320	c.myWindow -= length
321	c.windowMu.Unlock()
322
323	if extended == 1 {
324		c.extPending.write(data)
325	} else if extended > 0 {
326		// discard other extended data.
327	} else {
328		c.pending.write(data)
329	}
330	return nil
331}
332
333func (c *channel) adjustWindow(n uint32) error {
334	c.windowMu.Lock()
335	// Since myWindow is managed on our side, and can never exceed
336	// the initial window setting, we don't worry about overflow.
337	c.myWindow += uint32(n)
338	c.windowMu.Unlock()
339	return c.sendMessage(windowAdjustMsg{
340		AdditionalBytes: uint32(n),
341	})
342}
343
344func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
345	switch extended {
346	case 1:
347		n, err = c.extPending.Read(data)
348	case 0:
349		n, err = c.pending.Read(data)
350	default:
351		return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
352	}
353
354	if n > 0 {
355		err = c.adjustWindow(uint32(n))
356		// sendWindowAdjust can return io.EOF if the remote
357		// peer has closed the connection, however we want to
358		// defer forwarding io.EOF to the caller of Read until
359		// the buffer has been drained.
360		if n > 0 && err == io.EOF {
361			err = nil
362		}
363	}
364
365	return n, err
366}
367
368func (c *channel) close() {
369	c.pending.eof()
370	c.extPending.eof()
371	close(c.msg)
372	close(c.incomingRequests)
373	c.writeMu.Lock()
374	// This is not necesary for a normal channel teardown, but if
375	// there was another error, it is.
376	c.sentClose = true
377	c.writeMu.Unlock()
378	// Unblock writers.
379	c.remoteWin.close()
380}
381
382// responseMessageReceived is called when a success or failure message is
383// received on a channel to check that such a message is reasonable for the
384// given channel.
385func (c *channel) responseMessageReceived() error {
386	if c.direction == channelInbound {
387		return errors.New("ssh: channel response message received on inbound channel")
388	}
389	if c.decided {
390		return errors.New("ssh: duplicate response received for channel")
391	}
392	c.decided = true
393	return nil
394}
395
396func (c *channel) handlePacket(packet []byte) error {
397	switch packet[0] {
398	case msgChannelData, msgChannelExtendedData:
399		return c.handleData(packet)
400	case msgChannelClose:
401		c.sendMessage(channelCloseMsg{PeersId: c.remoteId})
402		c.mux.chanList.remove(c.localId)
403		c.close()
404		return nil
405	case msgChannelEOF:
406		// RFC 4254 is mute on how EOF affects dataExt messages but
407		// it is logical to signal EOF at the same time.
408		c.extPending.eof()
409		c.pending.eof()
410		return nil
411	}
412
413	decoded, err := decode(packet)
414	if err != nil {
415		return err
416	}
417
418	switch msg := decoded.(type) {
419	case *channelOpenFailureMsg:
420		if err := c.responseMessageReceived(); err != nil {
421			return err
422		}
423		c.mux.chanList.remove(msg.PeersId)
424		c.msg <- msg
425	case *channelOpenConfirmMsg:
426		if err := c.responseMessageReceived(); err != nil {
427			return err
428		}
429		if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
430			return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
431		}
432		c.remoteId = msg.MyId
433		c.maxRemotePayload = msg.MaxPacketSize
434		c.remoteWin.add(msg.MyWindow)
435		c.msg <- msg
436	case *windowAdjustMsg:
437		if !c.remoteWin.add(msg.AdditionalBytes) {
438			return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
439		}
440	case *channelRequestMsg:
441		req := Request{
442			Type:      msg.Request,
443			WantReply: msg.WantReply,
444			Payload:   msg.RequestSpecificData,
445			ch:        c,
446		}
447
448		c.incomingRequests <- &req
449	default:
450		c.msg <- msg
451	}
452	return nil
453}
454
455func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
456	ch := &channel{
457		remoteWin:        window{Cond: newCond()},
458		myWindow:         channelWindowSize,
459		pending:          newBuffer(),
460		extPending:       newBuffer(),
461		direction:        direction,
462		incomingRequests: make(chan *Request, 16),
463		msg:              make(chan interface{}, 16),
464		chanType:         chanType,
465		extraData:        extraData,
466		mux:              m,
467		packetPool:       make(map[uint32][]byte),
468	}
469	ch.localId = m.chanList.add(ch)
470	return ch
471}
472
473var errUndecided = errors.New("ssh: must Accept or Reject channel")
474var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
475
476type extChannel struct {
477	code uint32
478	ch   *channel
479}
480
481func (e *extChannel) Write(data []byte) (n int, err error) {
482	return e.ch.WriteExtended(data, e.code)
483}
484
485func (e *extChannel) Read(data []byte) (n int, err error) {
486	return e.ch.ReadExtended(data, e.code)
487}
488
489func (c *channel) Accept() (Channel, <-chan *Request, error) {
490	if c.decided {
491		return nil, nil, errDecidedAlready
492	}
493	c.maxIncomingPayload = channelMaxPacket
494	confirm := channelOpenConfirmMsg{
495		PeersId:       c.remoteId,
496		MyId:          c.localId,
497		MyWindow:      c.myWindow,
498		MaxPacketSize: c.maxIncomingPayload,
499	}
500	c.decided = true
501	if err := c.sendMessage(confirm); err != nil {
502		return nil, nil, err
503	}
504
505	return c, c.incomingRequests, nil
506}
507
508func (ch *channel) Reject(reason RejectionReason, message string) error {
509	if ch.decided {
510		return errDecidedAlready
511	}
512	reject := channelOpenFailureMsg{
513		PeersId:  ch.remoteId,
514		Reason:   reason,
515		Message:  message,
516		Language: "en",
517	}
518	ch.decided = true
519	return ch.sendMessage(reject)
520}
521
522func (ch *channel) Read(data []byte) (int, error) {
523	if !ch.decided {
524		return 0, errUndecided
525	}
526	return ch.ReadExtended(data, 0)
527}
528
529func (ch *channel) Write(data []byte) (int, error) {
530	if !ch.decided {
531		return 0, errUndecided
532	}
533	return ch.WriteExtended(data, 0)
534}
535
536func (ch *channel) CloseWrite() error {
537	if !ch.decided {
538		return errUndecided
539	}
540	ch.sentEOF = true
541	return ch.sendMessage(channelEOFMsg{
542		PeersId: ch.remoteId})
543}
544
545func (ch *channel) Close() error {
546	if !ch.decided {
547		return errUndecided
548	}
549
550	return ch.sendMessage(channelCloseMsg{
551		PeersId: ch.remoteId})
552}
553
554// Extended returns an io.ReadWriter that sends and receives data on the given,
555// SSH extended stream. Such streams are used, for example, for stderr.
556func (ch *channel) Extended(code uint32) io.ReadWriter {
557	if !ch.decided {
558		return nil
559	}
560	return &extChannel{code, ch}
561}
562
563func (ch *channel) Stderr() io.ReadWriter {
564	return ch.Extended(1)
565}
566
567func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
568	if !ch.decided {
569		return false, errUndecided
570	}
571
572	if wantReply {
573		ch.sentRequestMu.Lock()
574		defer ch.sentRequestMu.Unlock()
575	}
576
577	msg := channelRequestMsg{
578		PeersId:             ch.remoteId,
579		Request:             name,
580		WantReply:           wantReply,
581		RequestSpecificData: payload,
582	}
583
584	if err := ch.sendMessage(msg); err != nil {
585		return false, err
586	}
587
588	if wantReply {
589		m, ok := (<-ch.msg)
590		if !ok {
591			return false, io.EOF
592		}
593		switch m.(type) {
594		case *channelRequestFailureMsg:
595			return false, nil
596		case *channelRequestSuccessMsg:
597			return true, nil
598		default:
599			return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
600		}
601	}
602
603	return false, nil
604}
605
606// ackRequest either sends an ack or nack to the channel request.
607func (ch *channel) ackRequest(ok bool) error {
608	if !ch.decided {
609		return errUndecided
610	}
611
612	var msg interface{}
613	if !ok {
614		msg = channelRequestFailureMsg{
615			PeersId: ch.remoteId,
616		}
617	} else {
618		msg = channelRequestSuccessMsg{
619			PeersId: ch.remoteId,
620		}
621	}
622	return ch.sendMessage(msg)
623}
624
625func (ch *channel) ChannelType() string {
626	return ch.chanType
627}
628
629func (ch *channel) ExtraData() []byte {
630	return ch.extraData
631}
632