1package quic
2
3import (
4	"fmt"
5	"net"
6	"reflect"
7	"time"
8
9	mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging"
10	"github.com/lucas-clemente/quic-go/internal/protocol"
11
12	. "github.com/onsi/ginkgo"
13	. "github.com/onsi/gomega"
14)
15
16var _ = Describe("Config", func() {
17	Context("validating", func() {
18		It("validates a nil config", func() {
19			Expect(validateConfig(nil)).To(Succeed())
20		})
21
22		It("validates a config with normal values", func() {
23			Expect(validateConfig(populateServerConfig(&Config{}))).To(Succeed())
24		})
25
26		It("errors on too large values for MaxIncomingStreams", func() {
27			Expect(validateConfig(&Config{MaxIncomingStreams: 1<<60 + 1})).To(MatchError("invalid value for Config.MaxIncomingStreams"))
28		})
29
30		It("errors on too large values for MaxIncomingUniStreams", func() {
31			Expect(validateConfig(&Config{MaxIncomingUniStreams: 1<<60 + 1})).To(MatchError("invalid value for Config.MaxIncomingUniStreams"))
32		})
33	})
34
35	configWithNonZeroNonFunctionFields := func() *Config {
36		c := &Config{}
37		v := reflect.ValueOf(c).Elem()
38
39		typ := v.Type()
40		for i := 0; i < typ.NumField(); i++ {
41			f := v.Field(i)
42			if !f.CanSet() {
43				// unexported field; not cloned.
44				continue
45			}
46
47			switch fn := typ.Field(i).Name; fn {
48			case "AcceptToken", "GetLogWriter":
49				// Can't compare functions.
50			case "Versions":
51				f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
52			case "ConnectionIDLength":
53				f.Set(reflect.ValueOf(8))
54			case "HandshakeIdleTimeout":
55				f.Set(reflect.ValueOf(time.Second))
56			case "MaxIdleTimeout":
57				f.Set(reflect.ValueOf(time.Hour))
58			case "TokenStore":
59				f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3)))
60			case "InitialStreamReceiveWindow":
61				f.Set(reflect.ValueOf(uint64(1234)))
62			case "MaxStreamReceiveWindow":
63				f.Set(reflect.ValueOf(uint64(9)))
64			case "InitialConnectionReceiveWindow":
65				f.Set(reflect.ValueOf(uint64(4321)))
66			case "MaxConnectionReceiveWindow":
67				f.Set(reflect.ValueOf(uint64(10)))
68			case "MaxIncomingStreams":
69				f.Set(reflect.ValueOf(int64(11)))
70			case "MaxIncomingUniStreams":
71				f.Set(reflect.ValueOf(int64(12)))
72			case "StatelessResetKey":
73				f.Set(reflect.ValueOf([]byte{1, 2, 3, 4}))
74			case "KeepAlive":
75				f.Set(reflect.ValueOf(true))
76			case "EnableDatagrams":
77				f.Set(reflect.ValueOf(true))
78			case "DisableVersionNegotiationPackets":
79				f.Set(reflect.ValueOf(true))
80			case "DisablePathMTUDiscovery":
81				f.Set(reflect.ValueOf(true))
82			case "Tracer":
83				f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl)))
84			default:
85				Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn))
86			}
87		}
88		return c
89	}
90
91	It("uses 10s handshake timeout for short handshake idle timeouts", func() {
92		c := &Config{HandshakeIdleTimeout: time.Second}
93		Expect(c.handshakeTimeout()).To(Equal(protocol.DefaultHandshakeTimeout))
94	})
95
96	It("uses twice the handshake idle timeouts for the handshake timeout, for long handshake idle timeouts", func() {
97		c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2}
98		Expect(c.handshakeTimeout()).To(Equal(11 * time.Second))
99	})
100
101	Context("cloning", func() {
102		It("clones function fields", func() {
103			var calledAcceptToken bool
104			c1 := &Config{
105				AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
106			}
107			c2 := c1.Clone()
108			c2.AcceptToken(&net.UDPAddr{}, &Token{})
109			Expect(calledAcceptToken).To(BeTrue())
110		})
111
112		It("clones non-function fields", func() {
113			c := configWithNonZeroNonFunctionFields()
114			Expect(c.Clone()).To(Equal(c))
115		})
116
117		It("returns a copy", func() {
118			c1 := &Config{
119				MaxIncomingStreams: 100,
120				AcceptToken:        func(_ net.Addr, _ *Token) bool { return true },
121			}
122			c2 := c1.Clone()
123			c2.MaxIncomingStreams = 200
124			c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
125
126			Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
127			Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue())
128		})
129	})
130
131	Context("populating", func() {
132		It("populates function fields", func() {
133			var calledAcceptToken bool
134			c1 := &Config{
135				AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
136			}
137			c2 := populateConfig(c1)
138			c2.AcceptToken(&net.UDPAddr{}, &Token{})
139			Expect(calledAcceptToken).To(BeTrue())
140		})
141
142		It("copies non-function fields", func() {
143			c := configWithNonZeroNonFunctionFields()
144			Expect(populateConfig(c)).To(Equal(c))
145		})
146
147		It("populates empty fields with default values", func() {
148			c := populateConfig(&Config{})
149			Expect(c.Versions).To(Equal(protocol.SupportedVersions))
150			Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
151			Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData))
152			Expect(c.MaxStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow))
153			Expect(c.InitialConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxData))
154			Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow))
155			Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams))
156			Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams))
157			Expect(c.DisableVersionNegotiationPackets).To(BeFalse())
158			Expect(c.DisablePathMTUDiscovery).To(BeFalse())
159		})
160
161		It("populates empty fields with default values, for the server", func() {
162			c := populateServerConfig(&Config{})
163			Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
164			Expect(c.AcceptToken).ToNot(BeNil())
165		})
166
167		It("sets a default connection ID length if we didn't create the conn, for the client", func() {
168			c := populateClientConfig(&Config{}, false)
169			Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
170		})
171
172		It("doesn't set a default connection ID length if we created the conn, for the client", func() {
173			c := populateClientConfig(&Config{}, true)
174			Expect(c.ConnectionIDLength).To(BeZero())
175		})
176	})
177})
178