1package sftp
2
3import (
4	"encoding"
5	"sort"
6	"sync"
7)
8
9// The goal of the packetManager is to keep the outgoing packets in the same
10// order as the incoming as is requires by section 7 of the RFC.
11
12type packetManager struct {
13	requests    chan orderedPacket
14	responses   chan orderedPacket
15	fini        chan struct{}
16	incoming    orderedPackets
17	outgoing    orderedPackets
18	sender      packetSender // connection object
19	working     *sync.WaitGroup
20	packetCount uint32
21}
22
23type packetSender interface {
24	sendPacket(encoding.BinaryMarshaler) error
25}
26
27func newPktMgr(sender packetSender) *packetManager {
28	s := &packetManager{
29		requests:  make(chan orderedPacket, SftpServerWorkerCount),
30		responses: make(chan orderedPacket, SftpServerWorkerCount),
31		fini:      make(chan struct{}),
32		incoming:  make([]orderedPacket, 0, SftpServerWorkerCount),
33		outgoing:  make([]orderedPacket, 0, SftpServerWorkerCount),
34		sender:    sender,
35		working:   &sync.WaitGroup{},
36	}
37	go s.controller()
38	return s
39}
40
41//// packet ordering
42func (s *packetManager) newOrderId() uint32 {
43	s.packetCount++
44	return s.packetCount
45}
46
47type orderedRequest struct {
48	requestPacket
49	orderid uint32
50}
51
52func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest {
53	return orderedRequest{requestPacket: p, orderid: s.newOrderId()}
54}
55func (p orderedRequest) orderId() uint32       { return p.orderid }
56func (p orderedRequest) setOrderId(oid uint32) { p.orderid = oid }
57
58type orderedResponse struct {
59	responsePacket
60	orderid uint32
61}
62
63func (s *packetManager) newOrderedResponse(p responsePacket, id uint32,
64) orderedResponse {
65	return orderedResponse{responsePacket: p, orderid: id}
66}
67func (p orderedResponse) orderId() uint32       { return p.orderid }
68func (p orderedResponse) setOrderId(oid uint32) { p.orderid = oid }
69
70type orderedPacket interface {
71	id() uint32
72	orderId() uint32
73}
74type orderedPackets []orderedPacket
75
76func (o orderedPackets) Sort() {
77	sort.Slice(o, func(i, j int) bool {
78		return o[i].orderId() < o[j].orderId()
79	})
80}
81
82//// packet registry
83// register incoming packets to be handled
84func (s *packetManager) incomingPacket(pkt orderedRequest) {
85	s.working.Add(1)
86	s.requests <- pkt
87}
88
89// register outgoing packets as being ready
90func (s *packetManager) readyPacket(pkt orderedResponse) {
91	s.responses <- pkt
92	s.working.Done()
93}
94
95// shut down packetManager controller
96func (s *packetManager) close() {
97	// pause until current packets are processed
98	s.working.Wait()
99	close(s.fini)
100}
101
102// Passed a worker function, returns a channel for incoming packets.
103// Keep process packet responses in the order they are received while
104// maximizing throughput of file transfers.
105func (s *packetManager) workerChan(runWorker func(chan orderedRequest),
106) chan orderedRequest {
107
108	// multiple workers for faster read/writes
109	rwChan := make(chan orderedRequest, SftpServerWorkerCount)
110	for i := 0; i < SftpServerWorkerCount; i++ {
111		runWorker(rwChan)
112	}
113
114	// single worker to enforce sequential processing of everything else
115	cmdChan := make(chan orderedRequest)
116	runWorker(cmdChan)
117
118	pktChan := make(chan orderedRequest, SftpServerWorkerCount)
119	go func() {
120		for pkt := range pktChan {
121			switch pkt.requestPacket.(type) {
122			case *sshFxpReadPacket, *sshFxpWritePacket:
123				s.incomingPacket(pkt)
124				rwChan <- pkt
125				continue
126			case *sshFxpClosePacket:
127				// wait for reads/writes to finish when file is closed
128				// incomingPacket() call must occur after this
129				s.working.Wait()
130			}
131			s.incomingPacket(pkt)
132			// all non-RW use sequential cmdChan
133			cmdChan <- pkt
134		}
135		close(rwChan)
136		close(cmdChan)
137		s.close()
138	}()
139
140	return pktChan
141}
142
143// process packets
144func (s *packetManager) controller() {
145	for {
146		select {
147		case pkt := <-s.requests:
148			debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderId())
149			s.incoming = append(s.incoming, pkt)
150			s.incoming.Sort()
151		case pkt := <-s.responses:
152			debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderId())
153			s.outgoing = append(s.outgoing, pkt)
154			s.outgoing.Sort()
155		case <-s.fini:
156			return
157		}
158		s.maybeSendPackets()
159	}
160}
161
162// send as many packets as are ready
163func (s *packetManager) maybeSendPackets() {
164	for {
165		if len(s.outgoing) == 0 || len(s.incoming) == 0 {
166			debug("break! -- outgoing: %v; incoming: %v",
167				len(s.outgoing), len(s.incoming))
168			break
169		}
170		out := s.outgoing[0]
171		in := s.incoming[0]
172		// debug("incoming: %v", ids(s.incoming))
173		// debug("outgoing: %v", ids(s.outgoing))
174		if in.orderId() == out.orderId() {
175			debug("Sending packet: %v", out.id())
176			s.sender.sendPacket(out.(encoding.BinaryMarshaler))
177			// pop off heads
178			copy(s.incoming, s.incoming[1:])            // shift left
179			s.incoming[len(s.incoming)-1] = nil         // clear last
180			s.incoming = s.incoming[:len(s.incoming)-1] // remove last
181			copy(s.outgoing, s.outgoing[1:])            // shift left
182			s.outgoing[len(s.outgoing)-1] = nil         // clear last
183			s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
184		} else {
185			break
186		}
187	}
188}
189
190// func oids(o []orderedPacket) []uint32 {
191// 	res := make([]uint32, 0, len(o))
192// 	for _, v := range o {
193// 		res = append(res, v.orderId())
194// 	}
195// 	return res
196// }
197// func ids(o []orderedPacket) []uint32 {
198// 	res := make([]uint32, 0, len(o))
199// 	for _, v := range o {
200// 		res = append(res, v.id())
201// 	}
202// 	return res
203// }
204