1package sftp 2 3import ( 4 "encoding" 5 "sort" 6 "sync" 7) 8 9// The goal of the packetManager is to keep the outgoing packets in the same 10// order as the incoming as is requires by section 7 of the RFC. 11 12type packetManager struct { 13 requests chan orderedPacket 14 responses chan orderedPacket 15 fini chan struct{} 16 incoming orderedPackets 17 outgoing orderedPackets 18 sender packetSender // connection object 19 working *sync.WaitGroup 20 packetCount uint32 21} 22 23type packetSender interface { 24 sendPacket(encoding.BinaryMarshaler) error 25} 26 27func newPktMgr(sender packetSender) *packetManager { 28 s := &packetManager{ 29 requests: make(chan orderedPacket, SftpServerWorkerCount), 30 responses: make(chan orderedPacket, SftpServerWorkerCount), 31 fini: make(chan struct{}), 32 incoming: make([]orderedPacket, 0, SftpServerWorkerCount), 33 outgoing: make([]orderedPacket, 0, SftpServerWorkerCount), 34 sender: sender, 35 working: &sync.WaitGroup{}, 36 } 37 go s.controller() 38 return s 39} 40 41//// packet ordering 42func (s *packetManager) newOrderId() uint32 { 43 s.packetCount++ 44 return s.packetCount 45} 46 47type orderedRequest struct { 48 requestPacket 49 orderid uint32 50} 51 52func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest { 53 return orderedRequest{requestPacket: p, orderid: s.newOrderId()} 54} 55func (p orderedRequest) orderId() uint32 { return p.orderid } 56func (p orderedRequest) setOrderId(oid uint32) { p.orderid = oid } 57 58type orderedResponse struct { 59 responsePacket 60 orderid uint32 61} 62 63func (s *packetManager) newOrderedResponse(p responsePacket, id uint32, 64) orderedResponse { 65 return orderedResponse{responsePacket: p, orderid: id} 66} 67func (p orderedResponse) orderId() uint32 { return p.orderid } 68func (p orderedResponse) setOrderId(oid uint32) { p.orderid = oid } 69 70type orderedPacket interface { 71 id() uint32 72 orderId() uint32 73} 74type orderedPackets []orderedPacket 75 76func (o orderedPackets) Sort() { 77 sort.Slice(o, func(i, j int) bool { 78 return o[i].orderId() < o[j].orderId() 79 }) 80} 81 82//// packet registry 83// register incoming packets to be handled 84func (s *packetManager) incomingPacket(pkt orderedRequest) { 85 s.working.Add(1) 86 s.requests <- pkt 87} 88 89// register outgoing packets as being ready 90func (s *packetManager) readyPacket(pkt orderedResponse) { 91 s.responses <- pkt 92 s.working.Done() 93} 94 95// shut down packetManager controller 96func (s *packetManager) close() { 97 // pause until current packets are processed 98 s.working.Wait() 99 close(s.fini) 100} 101 102// Passed a worker function, returns a channel for incoming packets. 103// Keep process packet responses in the order they are received while 104// maximizing throughput of file transfers. 105func (s *packetManager) workerChan(runWorker func(chan orderedRequest), 106) chan orderedRequest { 107 108 // multiple workers for faster read/writes 109 rwChan := make(chan orderedRequest, SftpServerWorkerCount) 110 for i := 0; i < SftpServerWorkerCount; i++ { 111 runWorker(rwChan) 112 } 113 114 // single worker to enforce sequential processing of everything else 115 cmdChan := make(chan orderedRequest) 116 runWorker(cmdChan) 117 118 pktChan := make(chan orderedRequest, SftpServerWorkerCount) 119 go func() { 120 for pkt := range pktChan { 121 switch pkt.requestPacket.(type) { 122 case *sshFxpReadPacket, *sshFxpWritePacket: 123 s.incomingPacket(pkt) 124 rwChan <- pkt 125 continue 126 case *sshFxpClosePacket: 127 // wait for reads/writes to finish when file is closed 128 // incomingPacket() call must occur after this 129 s.working.Wait() 130 } 131 s.incomingPacket(pkt) 132 // all non-RW use sequential cmdChan 133 cmdChan <- pkt 134 } 135 close(rwChan) 136 close(cmdChan) 137 s.close() 138 }() 139 140 return pktChan 141} 142 143// process packets 144func (s *packetManager) controller() { 145 for { 146 select { 147 case pkt := <-s.requests: 148 debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderId()) 149 s.incoming = append(s.incoming, pkt) 150 s.incoming.Sort() 151 case pkt := <-s.responses: 152 debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderId()) 153 s.outgoing = append(s.outgoing, pkt) 154 s.outgoing.Sort() 155 case <-s.fini: 156 return 157 } 158 s.maybeSendPackets() 159 } 160} 161 162// send as many packets as are ready 163func (s *packetManager) maybeSendPackets() { 164 for { 165 if len(s.outgoing) == 0 || len(s.incoming) == 0 { 166 debug("break! -- outgoing: %v; incoming: %v", 167 len(s.outgoing), len(s.incoming)) 168 break 169 } 170 out := s.outgoing[0] 171 in := s.incoming[0] 172 // debug("incoming: %v", ids(s.incoming)) 173 // debug("outgoing: %v", ids(s.outgoing)) 174 if in.orderId() == out.orderId() { 175 debug("Sending packet: %v", out.id()) 176 s.sender.sendPacket(out.(encoding.BinaryMarshaler)) 177 // pop off heads 178 copy(s.incoming, s.incoming[1:]) // shift left 179 s.incoming[len(s.incoming)-1] = nil // clear last 180 s.incoming = s.incoming[:len(s.incoming)-1] // remove last 181 copy(s.outgoing, s.outgoing[1:]) // shift left 182 s.outgoing[len(s.outgoing)-1] = nil // clear last 183 s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last 184 } else { 185 break 186 } 187 } 188} 189 190// func oids(o []orderedPacket) []uint32 { 191// res := make([]uint32, 0, len(o)) 192// for _, v := range o { 193// res = append(res, v.orderId()) 194// } 195// return res 196// } 197// func ids(o []orderedPacket) []uint32 { 198// res := make([]uint32, 0, len(o)) 199// for _, v := range o { 200// res = append(res, v.id()) 201// } 202// return res 203// } 204