1package self_test
2
3import (
4	"bufio"
5	"bytes"
6	"crypto/rand"
7	"crypto/rsa"
8	"crypto/tls"
9	"crypto/x509"
10	"crypto/x509/pkix"
11	"flag"
12	"fmt"
13	"io"
14	"log"
15	"math/big"
16	mrand "math/rand"
17	"os"
18	"sync"
19	"testing"
20	"time"
21
22	"github.com/lucas-clemente/quic-go"
23	"github.com/lucas-clemente/quic-go/internal/utils"
24	"github.com/lucas-clemente/quic-go/logging"
25	"github.com/lucas-clemente/quic-go/metrics"
26	"github.com/lucas-clemente/quic-go/qlog"
27
28	. "github.com/onsi/ginkgo"
29	. "github.com/onsi/gomega"
30)
31
32const alpn = "quic-go integration tests"
33
34const (
35	dataLen     = 500 * 1024       // 500 KB
36	dataLenLong = 50 * 1024 * 1024 // 50 MB
37)
38
39var (
40	// PRData contains dataLen bytes of pseudo-random data.
41	PRData = GeneratePRData(dataLen)
42	// PRDataLong contains dataLenLong bytes of pseudo-random data.
43	PRDataLong = GeneratePRData(dataLenLong)
44)
45
46// See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
47func GeneratePRData(l int) []byte {
48	res := make([]byte, l)
49	seed := uint64(1)
50	for i := 0; i < l; i++ {
51		seed = seed * 48271 % 2147483647
52		res[i] = byte(seed)
53	}
54	return res
55}
56
57const logBufSize = 100 * 1 << 20 // initial size of the log buffer: 100 MB
58
59type syncedBuffer struct {
60	mutex sync.Mutex
61
62	*bytes.Buffer
63}
64
65func (b *syncedBuffer) Write(p []byte) (int, error) {
66	b.mutex.Lock()
67	n, err := b.Buffer.Write(p)
68	b.mutex.Unlock()
69	return n, err
70}
71
72func (b *syncedBuffer) Bytes() []byte {
73	b.mutex.Lock()
74	p := b.Buffer.Bytes()
75	b.mutex.Unlock()
76	return p
77}
78
79func (b *syncedBuffer) Reset() {
80	b.mutex.Lock()
81	b.Buffer.Reset()
82	b.mutex.Unlock()
83}
84
85var (
86	logFileName   string // the log file set in the ginkgo flags
87	logBufOnce    sync.Once
88	logBuf        *syncedBuffer
89	enableQlog    bool
90	enableMetrics bool
91
92	tlsConfig          *tls.Config
93	tlsConfigLongChain *tls.Config
94	tlsClientConfig    *tls.Config
95	tracer             logging.Tracer
96)
97
98// read the logfile command line flag
99// to set call ginkgo -- -logfile=log.txt
100func init() {
101	flag.StringVar(&logFileName, "logfile", "", "log file")
102	flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
103	// metrics won't be accessible anywhere, but it's useful to exercise the code
104	flag.BoolVar(&enableMetrics, "metrics", false, "enable metrics")
105}
106
107var _ = BeforeSuite(func() {
108	mrand.Seed(GinkgoRandomSeed())
109
110	ca, caPrivateKey, err := generateCA()
111	if err != nil {
112		panic(err)
113	}
114	leafCert, leafPrivateKey, err := generateLeafCert(ca, caPrivateKey)
115	if err != nil {
116		panic(err)
117	}
118	tlsConfig = &tls.Config{
119		Certificates: []tls.Certificate{{
120			Certificate: [][]byte{leafCert.Raw},
121			PrivateKey:  leafPrivateKey,
122		}},
123		NextProtos: []string{alpn},
124	}
125	tlsConfLongChain, err := generateTLSConfigWithLongCertChain(ca, caPrivateKey)
126	if err != nil {
127		panic(err)
128	}
129	tlsConfigLongChain = tlsConfLongChain
130
131	root := x509.NewCertPool()
132	root.AddCert(ca)
133	tlsClientConfig = &tls.Config{
134		RootCAs:    root,
135		NextProtos: []string{alpn},
136	}
137
138	var qlogTracer, metricsTracer logging.Tracer
139	if enableQlog {
140		qlogTracer = qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
141			role := "server"
142			if p == logging.PerspectiveClient {
143				role = "client"
144			}
145			filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role)
146			fmt.Fprintf(GinkgoWriter, "Creating %s.\n", filename)
147			f, err := os.Create(filename)
148			Expect(err).ToNot(HaveOccurred())
149			bw := bufio.NewWriter(f)
150			return utils.NewBufferedWriteCloser(bw, f)
151		})
152	}
153	if enableMetrics {
154		metricsTracer = metrics.NewTracer()
155	}
156
157	if enableQlog && enableMetrics {
158		tracer = logging.NewMultiplexedTracer(qlogTracer, metricsTracer)
159	} else if enableQlog {
160		tracer = qlogTracer
161	} else if enableMetrics {
162		tracer = metricsTracer
163	}
164})
165
166func generateCA() (*x509.Certificate, *rsa.PrivateKey, error) {
167	certTempl := &x509.Certificate{
168		SerialNumber:          big.NewInt(2019),
169		Subject:               pkix.Name{},
170		NotBefore:             time.Now(),
171		NotAfter:              time.Now().Add(24 * time.Hour),
172		IsCA:                  true,
173		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
174		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
175		BasicConstraintsValid: true,
176	}
177	caPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
178	if err != nil {
179		return nil, nil, err
180	}
181	caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey)
182	if err != nil {
183		return nil, nil, err
184	}
185	ca, err := x509.ParseCertificate(caBytes)
186	if err != nil {
187		return nil, nil, err
188	}
189	return ca, caPrivateKey, nil
190}
191
192func generateLeafCert(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, error) {
193	certTempl := &x509.Certificate{
194		SerialNumber: big.NewInt(1),
195		DNSNames:     []string{"localhost"},
196		NotBefore:    time.Now(),
197		NotAfter:     time.Now().Add(24 * time.Hour),
198		ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
199		KeyUsage:     x509.KeyUsageDigitalSignature,
200	}
201	privKey, err := rsa.GenerateKey(rand.Reader, 2048)
202	if err != nil {
203		return nil, nil, err
204	}
205	certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, &privKey.PublicKey, caPrivateKey)
206	if err != nil {
207		return nil, nil, err
208	}
209	cert, err := x509.ParseCertificate(certBytes)
210	if err != nil {
211		return nil, nil, err
212	}
213	return cert, privKey, nil
214}
215
216// getTLSConfigWithLongCertChain generates a tls.Config that uses a long certificate chain.
217// The Root CA used is the same as for the config returned from getTLSConfig().
218func generateTLSConfigWithLongCertChain(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*tls.Config, error) {
219	const chainLen = 7
220	certTempl := &x509.Certificate{
221		SerialNumber:          big.NewInt(2019),
222		Subject:               pkix.Name{},
223		NotBefore:             time.Now(),
224		NotAfter:              time.Now().Add(24 * time.Hour),
225		IsCA:                  true,
226		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
227		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
228		BasicConstraintsValid: true,
229	}
230
231	lastCA := ca
232	lastCAPrivKey := caPrivateKey
233	privKey, err := rsa.GenerateKey(rand.Reader, 2048)
234	if err != nil {
235		return nil, err
236	}
237	certs := make([]*x509.Certificate, chainLen)
238	for i := 0; i < chainLen; i++ {
239		caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, lastCA, &privKey.PublicKey, lastCAPrivKey)
240		if err != nil {
241			return nil, err
242		}
243		ca, err := x509.ParseCertificate(caBytes)
244		if err != nil {
245			return nil, err
246		}
247		certs[i] = ca
248		lastCA = ca
249		lastCAPrivKey = privKey
250	}
251	leafCert, leafPrivateKey, err := generateLeafCert(lastCA, lastCAPrivKey)
252	if err != nil {
253		return nil, err
254	}
255
256	rawCerts := make([][]byte, chainLen+1)
257	for i, cert := range certs {
258		rawCerts[chainLen-i] = cert.Raw
259	}
260	rawCerts[0] = leafCert.Raw
261
262	return &tls.Config{
263		Certificates: []tls.Certificate{{
264			Certificate: rawCerts,
265			PrivateKey:  leafPrivateKey,
266		}},
267		NextProtos: []string{alpn},
268	}, nil
269}
270
271func getTLSConfig() *tls.Config {
272	return tlsConfig.Clone()
273}
274
275func getTLSConfigWithLongCertChain() *tls.Config {
276	return tlsConfigLongChain.Clone()
277}
278
279func getTLSClientConfig() *tls.Config {
280	return tlsClientConfig.Clone()
281}
282
283func getQuicConfig(conf *quic.Config) *quic.Config {
284	if conf == nil {
285		conf = &quic.Config{}
286	} else {
287		conf = conf.Clone()
288	}
289	conf.Tracer = tracer
290	return conf
291}
292
293var _ = BeforeEach(func() {
294	log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
295
296	if debugLog() {
297		logBufOnce.Do(func() {
298			logBuf = &syncedBuffer{Buffer: bytes.NewBuffer(make([]byte, 0, logBufSize))}
299		})
300		utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
301		log.SetOutput(logBuf)
302	}
303})
304
305var _ = AfterEach(func() {
306	if debugLog() {
307		logFile, err := os.Create(logFileName)
308		Expect(err).ToNot(HaveOccurred())
309		logFile.Write(logBuf.Bytes())
310		logFile.Close()
311		logBuf.Reset()
312	}
313})
314
315// Debug says if this test is being logged
316func debugLog() bool {
317	return len(logFileName) > 0
318}
319
320func TestSelf(t *testing.T) {
321	RegisterFailHandler(Fail)
322	RunSpecs(t, "Self integration tests")
323}
324