1package quic
2
3import (
4	"crypto/hmac"
5	"crypto/rand"
6	"crypto/sha256"
7	"errors"
8	"hash"
9	"net"
10	"sync"
11	"time"
12
13	"github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/quic-go/internal/protocol"
14	"github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/quic-go/internal/utils"
15	"github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/quic-go/internal/wire"
16)
17
18// The packetHandlerMap stores packetHandlers, identified by connection ID.
19// It is used:
20// * by the server to store sessions
21// * when multiplexing outgoing connections to store clients
22type packetHandlerMap struct {
23	mutex sync.RWMutex
24
25	conn      net.PacketConn
26	connIDLen int
27
28	handlers    map[string] /* string(ConnectionID)*/ packetHandler
29	resetTokens map[[16]byte] /* stateless reset token */ packetHandler
30	server      unknownPacketHandler
31
32	listening chan struct{} // is closed when listen returns
33	closed    bool
34
35	deleteRetiredSessionsAfter time.Duration
36
37	statelessResetEnabled bool
38	statelessResetMutex   sync.Mutex
39	statelessResetHasher  hash.Hash
40
41	logger utils.Logger
42}
43
44var _ packetHandlerManager = &packetHandlerMap{}
45
46func newPacketHandlerMap(
47	conn net.PacketConn,
48	connIDLen int,
49	statelessResetKey []byte,
50	logger utils.Logger,
51) packetHandlerManager {
52	m := &packetHandlerMap{
53		conn:                       conn,
54		connIDLen:                  connIDLen,
55		listening:                  make(chan struct{}),
56		handlers:                   make(map[string]packetHandler),
57		resetTokens:                make(map[[16]byte]packetHandler),
58		deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout,
59		statelessResetEnabled:      len(statelessResetKey) > 0,
60		statelessResetHasher:       hmac.New(sha256.New, statelessResetKey),
61		logger:                     logger,
62	}
63	go m.listen()
64
65	if logger.Debug() {
66		go m.logUsage()
67	}
68
69	return m
70}
71
72func (h *packetHandlerMap) logUsage() {
73	ticker := time.NewTicker(2 * time.Second)
74	var printedZero bool
75	for {
76		select {
77		case <-h.listening:
78			return
79		case <-ticker.C:
80		}
81
82		h.mutex.Lock()
83		numHandlers := len(h.handlers)
84		numTokens := len(h.resetTokens)
85		h.mutex.Unlock()
86		// If the number tracked handlers and tokens is zero, only print it a single time.
87		hasZero := numHandlers == 0 && numTokens == 0
88		if !hasZero || (hasZero && !printedZero) {
89			h.logger.Debugf("Tracking %d connection IDs and %d reset tokens.\n", numHandlers, numTokens)
90			printedZero = false
91			if hasZero {
92				printedZero = true
93			}
94		}
95	}
96}
97
98func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) [16]byte {
99	h.mutex.Lock()
100	h.handlers[string(id)] = handler
101	h.mutex.Unlock()
102	return h.GetStatelessResetToken(id)
103}
104
105func (h *packetHandlerMap) AddIfNotTaken(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
106	sid := string(id)
107	h.mutex.Lock()
108	defer h.mutex.Unlock()
109
110	if _, ok := h.handlers[sid]; !ok {
111		h.handlers[sid] = handler
112		return true
113	}
114	return false
115}
116
117func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
118	h.mutex.Lock()
119	delete(h.handlers, string(id))
120	h.mutex.Unlock()
121}
122
123func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
124	time.AfterFunc(h.deleteRetiredSessionsAfter, func() {
125		h.mutex.Lock()
126		delete(h.handlers, string(id))
127		h.mutex.Unlock()
128	})
129}
130
131func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) {
132	h.mutex.Lock()
133	h.handlers[string(id)] = handler
134	h.mutex.Unlock()
135
136	time.AfterFunc(h.deleteRetiredSessionsAfter, func() {
137		h.mutex.Lock()
138		handler.Close()
139		delete(h.handlers, string(id))
140		h.mutex.Unlock()
141	})
142}
143
144func (h *packetHandlerMap) AddResetToken(token [16]byte, handler packetHandler) {
145	h.mutex.Lock()
146	h.resetTokens[token] = handler
147	h.mutex.Unlock()
148}
149
150func (h *packetHandlerMap) RemoveResetToken(token [16]byte) {
151	h.mutex.Lock()
152	delete(h.resetTokens, token)
153	h.mutex.Unlock()
154}
155
156func (h *packetHandlerMap) RetireResetToken(token [16]byte) {
157	time.AfterFunc(h.deleteRetiredSessionsAfter, func() {
158		h.mutex.Lock()
159		delete(h.resetTokens, token)
160		h.mutex.Unlock()
161	})
162}
163
164func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
165	h.mutex.Lock()
166	h.server = s
167	h.mutex.Unlock()
168}
169
170func (h *packetHandlerMap) CloseServer() {
171	h.mutex.Lock()
172	h.server = nil
173	var wg sync.WaitGroup
174	for _, handler := range h.handlers {
175		if handler.getPerspective() == protocol.PerspectiveServer {
176			wg.Add(1)
177			go func(handler packetHandler) {
178				// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
179				_ = handler.Close()
180				wg.Done()
181			}(handler)
182		}
183	}
184	h.mutex.Unlock()
185	wg.Wait()
186}
187
188// Close the underlying connection and wait until listen() has returned.
189func (h *packetHandlerMap) Close() error {
190	if err := h.conn.Close(); err != nil {
191		return err
192	}
193	<-h.listening // wait until listening returns
194	return nil
195}
196
197func (h *packetHandlerMap) close(e error) error {
198	h.mutex.Lock()
199	if h.closed {
200		h.mutex.Unlock()
201		return nil
202	}
203
204	var wg sync.WaitGroup
205	for _, handler := range h.handlers {
206		wg.Add(1)
207		go func(handler packetHandler) {
208			handler.destroy(e)
209			wg.Done()
210		}(handler)
211	}
212
213	// [Psiphon]
214	// Call h.server.setCloseError(e) outside of mutex to prevent deadlock
215	// See comment in psiphon/common/quic/gquic-go/packetHandlerMap.close
216
217	var server unknownPacketHandler
218	if h.server != nil {
219		server = h.server
220	}
221
222	h.mutex.Unlock()
223
224	if server != nil {
225		server.setCloseError(e)
226	}
227
228	h.mutex.Lock()
229	h.closed = true
230	h.mutex.Unlock()
231
232	wg.Wait()
233	return getMultiplexer().RemoveConn(h.conn)
234}
235
236func (h *packetHandlerMap) listen() {
237	defer close(h.listening)
238	for {
239		buffer := getPacketBuffer()
240		data := buffer.Slice
241		// The packet size should not exceed protocol.MaxReceivePacketSize bytes
242		// If it does, we only read a truncated packet, which will then end up undecryptable
243		n, addr, err := h.conn.ReadFrom(data)
244		if err != nil {
245			// [Psiphon]
246			// Do not unconditionally shutdown
247			if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() {
248				h.close(err)
249				return
250			}
251		}
252		h.handlePacket(addr, buffer, data[:n])
253	}
254}
255
256func (h *packetHandlerMap) handlePacket(
257	addr net.Addr,
258	buffer *packetBuffer,
259	data []byte,
260) {
261	connID, err := wire.ParseConnectionID(data, h.connIDLen)
262	if err != nil {
263		h.logger.Debugf("error parsing connection ID on packet from %s: %s", addr, err)
264		return
265	}
266	rcvTime := time.Now()
267
268	h.mutex.RLock()
269	defer h.mutex.RUnlock()
270
271	if isStatelessReset := h.maybeHandleStatelessReset(data); isStatelessReset {
272		return
273	}
274
275	handler, handlerFound := h.handlers[string(connID)]
276
277	p := &receivedPacket{
278		remoteAddr: addr,
279		rcvTime:    rcvTime,
280		buffer:     buffer,
281		data:       data,
282	}
283	if handlerFound { // existing session
284		handler.handlePacket(p)
285		return
286	}
287	if data[0]&0x80 == 0 {
288		go h.maybeSendStatelessReset(p, connID)
289		return
290	}
291	if h.server == nil { // no server set
292		h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
293		return
294	}
295	h.server.handlePacket(p)
296}
297
298func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
299	// stateless resets are always short header packets
300	if data[0]&0x80 != 0 {
301		return false
302	}
303	if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
304		return false
305	}
306
307	var token [16]byte
308	copy(token[:], data[len(data)-16:])
309	if sess, ok := h.resetTokens[token]; ok {
310		h.logger.Debugf("Received a stateless retry with token %#x. Closing session.", token)
311		go sess.destroy(errors.New("received a stateless reset"))
312		return true
313	}
314	return false
315}
316
317func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte {
318	var token [16]byte
319	if !h.statelessResetEnabled {
320		// Return a random stateless reset token.
321		// This token will be sent in the server's transport parameters.
322		// By using a random token, an off-path attacker won't be able to disrupt the connection.
323		rand.Read(token[:])
324		return token
325	}
326	h.statelessResetMutex.Lock()
327	h.statelessResetHasher.Write(connID.Bytes())
328	copy(token[:], h.statelessResetHasher.Sum(nil))
329	h.statelessResetHasher.Reset()
330	h.statelessResetMutex.Unlock()
331	return token
332}
333
334func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
335	defer p.buffer.Release()
336	if !h.statelessResetEnabled {
337		return
338	}
339	// Don't send a stateless reset in response to very small packets.
340	// This includes packets that could be stateless resets.
341	if len(p.data) <= protocol.MinStatelessResetSize {
342		return
343	}
344	token := h.GetStatelessResetToken(connID)
345	h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
346	data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
347	rand.Read(data)
348	data[0] = (data[0] & 0x7f) | 0x40
349	data = append(data, token[:]...)
350	if _, err := h.conn.WriteTo(data, p.remoteAddr); err != nil {
351		h.logger.Debugf("Error sending Stateless Reset: %s", err)
352	}
353}
354