1package sftp 2 3import ( 4 "encoding" 5 "fmt" 6 "io" 7 "sync" 8) 9 10// conn implements a bidirectional channel on which client and server 11// connections are multiplexed. 12type conn struct { 13 io.Reader 14 io.WriteCloser 15 // this is the same allocator used in packet manager 16 alloc *allocator 17 sync.Mutex // used to serialise writes to sendPacket 18} 19 20// the orderID is used in server mode if the allocator is enabled. 21// For the client mode just pass 0 22func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) { 23 return recvPacket(c, c.alloc, orderID) 24} 25 26func (c *conn) sendPacket(m encoding.BinaryMarshaler) error { 27 c.Lock() 28 defer c.Unlock() 29 30 return sendPacket(c, m) 31} 32 33func (c *conn) Close() error { 34 c.Lock() 35 defer c.Unlock() 36 return c.WriteCloser.Close() 37} 38 39type clientConn struct { 40 conn 41 wg sync.WaitGroup 42 43 sync.Mutex // protects inflight 44 inflight map[uint32]chan<- result // outstanding requests 45 46 closed chan struct{} 47 err error 48} 49 50// Wait blocks until the conn has shut down, and return the error 51// causing the shutdown. It can be called concurrently from multiple 52// goroutines. 53func (c *clientConn) Wait() error { 54 <-c.closed 55 return c.err 56} 57 58// Close closes the SFTP session. 59func (c *clientConn) Close() error { 60 defer c.wg.Wait() 61 return c.conn.Close() 62} 63 64func (c *clientConn) loop() { 65 defer c.wg.Done() 66 err := c.recv() 67 if err != nil { 68 c.broadcastErr(err) 69 } 70} 71 72// recv continuously reads from the server and forwards responses to the 73// appropriate channel. 74func (c *clientConn) recv() error { 75 defer c.conn.Close() 76 77 for { 78 typ, data, err := c.recvPacket(0) 79 if err != nil { 80 return err 81 } 82 sid, _, err := unmarshalUint32Safe(data) 83 if err != nil { 84 return err 85 } 86 87 ch, ok := c.getChannel(sid) 88 if !ok { 89 // This is an unexpected occurrence. Send the error 90 // back to all listeners so that they terminate 91 // gracefully. 92 return fmt.Errorf("sid not found: %d", sid) 93 } 94 95 ch <- result{typ: typ, data: data} 96 } 97} 98 99func (c *clientConn) putChannel(ch chan<- result, sid uint32) bool { 100 c.Lock() 101 defer c.Unlock() 102 103 select { 104 case <-c.closed: 105 // already closed with broadcastErr, return error on chan. 106 ch <- result{err: ErrSSHFxConnectionLost} 107 return false 108 default: 109 } 110 111 c.inflight[sid] = ch 112 return true 113} 114 115func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) { 116 c.Lock() 117 defer c.Unlock() 118 119 ch, ok := c.inflight[sid] 120 delete(c.inflight, sid) 121 122 return ch, ok 123} 124 125// result captures the result of receiving the a packet from the server 126type result struct { 127 typ byte 128 data []byte 129 err error 130} 131 132type idmarshaler interface { 133 id() uint32 134 encoding.BinaryMarshaler 135} 136 137func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) { 138 if cap(ch) < 1 { 139 ch = make(chan result, 1) 140 } 141 142 c.dispatchRequest(ch, p) 143 s := <-ch 144 return s.typ, s.data, s.err 145} 146 147// dispatchRequest should ideally only be called by race-detection tests outside of this file, 148// where you have to ensure two packets are in flight sequentially after each other. 149func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) { 150 sid := p.id() 151 152 if !c.putChannel(ch, sid) { 153 // already closed. 154 return 155 } 156 157 if err := c.conn.sendPacket(p); err != nil { 158 if ch, ok := c.getChannel(sid); ok { 159 ch <- result{err: err} 160 } 161 } 162} 163 164// broadcastErr sends an error to all goroutines waiting for a response. 165func (c *clientConn) broadcastErr(err error) { 166 c.Lock() 167 defer c.Unlock() 168 169 bcastRes := result{err: ErrSSHFxConnectionLost} 170 for sid, ch := range c.inflight { 171 ch <- bcastRes 172 173 // Replace the chan in inflight, 174 // we have hijacked this chan, 175 // and this guarantees always-only-once sending. 176 c.inflight[sid] = make(chan<- result, 1) 177 } 178 179 c.err = err 180 close(c.closed) 181} 182 183type serverConn struct { 184 conn 185} 186 187func (s *serverConn) sendError(id uint32, err error) error { 188 return s.sendPacket(statusFromError(id, err)) 189} 190