1package sftp 2 3import ( 4 "context" 5 "encoding" 6 "io" 7 "path" 8 "path/filepath" 9 "strconv" 10 "sync" 11 "syscall" 12 13 "github.com/pkg/errors" 14) 15 16var maxTxPacket uint32 = 1 << 15 17 18// Handlers contains the 4 SFTP server request handlers. 19type Handlers struct { 20 FileGet FileReader 21 FilePut FileWriter 22 FileCmd FileCmder 23 FileList FileLister 24} 25 26// RequestServer abstracts the sftp protocol with an http request-like protocol 27type RequestServer struct { 28 *serverConn 29 Handlers Handlers 30 pktMgr *packetManager 31 openRequests map[string]*Request 32 openRequestLock sync.RWMutex 33 handleCount int 34} 35 36// NewRequestServer creates/allocates/returns new RequestServer. 37// Normally there there will be one server per user-session. 38func NewRequestServer(rwc io.ReadWriteCloser, h Handlers) *RequestServer { 39 svrConn := &serverConn{ 40 conn: conn{ 41 Reader: rwc, 42 WriteCloser: rwc, 43 }, 44 } 45 return &RequestServer{ 46 serverConn: svrConn, 47 Handlers: h, 48 pktMgr: newPktMgr(svrConn), 49 openRequests: make(map[string]*Request), 50 } 51} 52 53// New Open packet/Request 54func (rs *RequestServer) nextRequest(r *Request) string { 55 rs.openRequestLock.Lock() 56 defer rs.openRequestLock.Unlock() 57 rs.handleCount++ 58 handle := strconv.Itoa(rs.handleCount) 59 rs.openRequests[handle] = r 60 return handle 61} 62 63// Returns Request from openRequests, bool is false if it is missing 64// If the method is different, save/return a new Request w/ that Method. 65// 66// The Requests in openRequests work essentially as open file descriptors that 67// you can do different things with. What you are doing with it are denoted by 68// the first packet of that type (read/write/etc). We create a new Request when 69// it changes to set the request.Method attribute in a thread safe way. 70func (rs *RequestServer) getRequest(handle, method string) (*Request, bool) { 71 rs.openRequestLock.RLock() 72 r, ok := rs.openRequests[handle] 73 rs.openRequestLock.RUnlock() 74 if !ok || r.Method == method { 75 return r, ok 76 } 77 // if we make it here we need to replace the request 78 rs.openRequestLock.Lock() 79 defer rs.openRequestLock.Unlock() 80 r, ok = rs.openRequests[handle] 81 if !ok || r.Method == method { // re-check needed b/c lock race 82 return r, ok 83 } 84 r = r.copy() 85 r.Method = method 86 rs.openRequests[handle] = r 87 return r, ok 88} 89 90func (rs *RequestServer) closeRequest(handle string) error { 91 rs.openRequestLock.Lock() 92 defer rs.openRequestLock.Unlock() 93 if r, ok := rs.openRequests[handle]; ok { 94 delete(rs.openRequests, handle) 95 return r.close() 96 } 97 return syscall.EBADF 98} 99 100// Close the read/write/closer to trigger exiting the main server loop 101func (rs *RequestServer) Close() error { return rs.conn.Close() } 102 103// Serve requests for user session 104func (rs *RequestServer) Serve() error { 105 ctx, cancel := context.WithCancel(context.Background()) 106 defer cancel() 107 var wg sync.WaitGroup 108 runWorker := func(ch requestChan) { 109 wg.Add(1) 110 go func() { 111 defer wg.Done() 112 if err := rs.packetWorker(ctx, ch); err != nil { 113 rs.conn.Close() // shuts down recvPacket 114 } 115 }() 116 } 117 pktChan := rs.pktMgr.workerChan(runWorker) 118 119 var err error 120 var pkt requestPacket 121 var pktType uint8 122 var pktBytes []byte 123 for { 124 pktType, pktBytes, err = rs.recvPacket() 125 if err != nil { 126 break 127 } 128 129 pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) 130 if err != nil { 131 switch errors.Cause(err) { 132 case errUnknownExtendedPacket: 133 if err := rs.serverConn.sendError(pkt, ErrSshFxOpUnsupported); err != nil { 134 debug("failed to send err packet: %v", err) 135 rs.conn.Close() // shuts down recvPacket 136 break 137 } 138 default: 139 debug("makePacket err: %v", err) 140 rs.conn.Close() // shuts down recvPacket 141 break 142 } 143 } 144 145 pktChan <- pkt 146 } 147 148 close(pktChan) // shuts down sftpServerWorkers 149 wg.Wait() // wait for all workers to exit 150 151 // make sure all open requests are properly closed 152 // (eg. possible on dropped connections, client crashes, etc.) 153 for handle, req := range rs.openRequests { 154 delete(rs.openRequests, handle) 155 req.close() 156 } 157 158 return err 159} 160 161func (rs *RequestServer) packetWorker( 162 ctx context.Context, pktChan chan requestPacket, 163) error { 164 for pkt := range pktChan { 165 var rpkt responsePacket 166 switch pkt := pkt.(type) { 167 case *sshFxInitPacket: 168 rpkt = sshFxVersionPacket{sftpProtocolVersion, nil} 169 case *sshFxpClosePacket: 170 handle := pkt.getHandle() 171 rpkt = statusFromError(pkt, rs.closeRequest(handle)) 172 case *sshFxpRealpathPacket: 173 rpkt = cleanPacketPath(pkt) 174 case *sshFxpOpendirPacket: 175 request := requestFromPacket(ctx, pkt) 176 handle := rs.nextRequest(request) 177 rpkt = sshFxpHandlePacket{pkt.id(), handle} 178 case *sshFxpOpenPacket: 179 request := requestFromPacket(ctx, pkt) 180 handle := rs.nextRequest(request) 181 rpkt = sshFxpHandlePacket{pkt.id(), handle} 182 if pkt.hasPflags(ssh_FXF_CREAT) { 183 if p := request.call(rs.Handlers, pkt); !statusOk(p) { 184 rpkt = p // if error in write, return it 185 } 186 } 187 case hasHandle: 188 handle := pkt.getHandle() 189 request, ok := rs.getRequest(handle, requestMethod(pkt)) 190 if !ok { 191 rpkt = statusFromError(pkt, syscall.EBADF) 192 } else { 193 rpkt = request.call(rs.Handlers, pkt) 194 } 195 case hasPath: 196 request := requestFromPacket(ctx, pkt) 197 rpkt = request.call(rs.Handlers, pkt) 198 request.close() 199 default: 200 return errors.Errorf("unexpected packet type %T", pkt) 201 } 202 203 err := rs.sendPacket(rpkt) 204 if err != nil { 205 return err 206 } 207 } 208 return nil 209} 210 211// True is responsePacket is an OK status packet 212func statusOk(rpkt responsePacket) bool { 213 p, ok := rpkt.(sshFxpStatusPacket) 214 return ok && p.StatusError.Code == ssh_FX_OK 215} 216 217// clean and return name packet for file 218func cleanPacketPath(pkt *sshFxpRealpathPacket) responsePacket { 219 path := cleanPath(pkt.getPath()) 220 return &sshFxpNamePacket{ 221 ID: pkt.id(), 222 NameAttrs: []sshFxpNameAttr{{ 223 Name: path, 224 LongName: path, 225 Attrs: emptyFileStat, 226 }}, 227 } 228} 229 230// Makes sure we have a clean POSIX (/) absolute path to work with 231func cleanPath(p string) string { 232 p = filepath.ToSlash(p) 233 if !filepath.IsAbs(p) { 234 p = "/" + p 235 } 236 return path.Clean(p) 237} 238 239// Wrap underlying connection methods to use packetManager 240func (rs *RequestServer) sendPacket(m encoding.BinaryMarshaler) error { 241 if pkt, ok := m.(responsePacket); ok { 242 rs.pktMgr.readyPacket(pkt) 243 } else { 244 return errors.Errorf("unexpected packet type %T", m) 245 } 246 return nil 247} 248