1package quic
2
3import (
4	"bytes"
5	"fmt"
6	"net"
7	"sync"
8
9	"github.com/lucas-clemente/quic-go/internal/utils"
10	"github.com/lucas-clemente/quic-go/logging"
11)
12
13var (
14	connMuxerOnce sync.Once
15	connMuxer     multiplexer
16)
17
18type indexableConn interface {
19	LocalAddr() net.Addr
20}
21
22type multiplexer interface {
23	AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer) (packetHandlerManager, error)
24	RemoveConn(indexableConn) error
25}
26
27type connManager struct {
28	connIDLen         int
29	statelessResetKey []byte
30	tracer            logging.Tracer
31	manager           packetHandlerManager
32}
33
34// The connMultiplexer listens on multiple net.PacketConns and dispatches
35// incoming packets to the session handler.
36type connMultiplexer struct {
37	mutex sync.Mutex
38
39	conns                   map[string] /* LocalAddr().String() */ connManager
40	newPacketHandlerManager func(net.PacketConn, int, []byte, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests
41
42	logger utils.Logger
43}
44
45var _ multiplexer = &connMultiplexer{}
46
47func getMultiplexer() multiplexer {
48	connMuxerOnce.Do(func() {
49		connMuxer = &connMultiplexer{
50			conns:                   make(map[string]connManager),
51			logger:                  utils.DefaultLogger.WithPrefix("muxer"),
52			newPacketHandlerManager: newPacketHandlerMap,
53		}
54	})
55	return connMuxer
56}
57
58func (m *connMultiplexer) AddConn(
59	c net.PacketConn,
60	connIDLen int,
61	statelessResetKey []byte,
62	tracer logging.Tracer,
63) (packetHandlerManager, error) {
64	m.mutex.Lock()
65	defer m.mutex.Unlock()
66
67	addr := c.LocalAddr()
68	connIndex := addr.Network() + " " + addr.String()
69	p, ok := m.conns[connIndex]
70	if !ok {
71		manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger)
72		if err != nil {
73			return nil, err
74		}
75		p = connManager{
76			connIDLen:         connIDLen,
77			statelessResetKey: statelessResetKey,
78			manager:           manager,
79			tracer:            tracer,
80		}
81		m.conns[connIndex] = p
82	} else {
83		if p.connIDLen != connIDLen {
84			return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen)
85		}
86		if statelessResetKey != nil && !bytes.Equal(p.statelessResetKey, statelessResetKey) {
87			return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn")
88		}
89		if tracer != p.tracer {
90			return nil, fmt.Errorf("cannot use different tracers on the same packet conn")
91		}
92	}
93	return p.manager, nil
94}
95
96func (m *connMultiplexer) RemoveConn(c indexableConn) error {
97	m.mutex.Lock()
98	defer m.mutex.Unlock()
99
100	connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String()
101	if _, ok := m.conns[connIndex]; !ok {
102		return fmt.Errorf("cannote remove connection, connection is unknown")
103	}
104
105	delete(m.conns, connIndex)
106	return nil
107}
108