1package handshake
2
3import (
4	"crypto/aes"
5	"crypto/cipher"
6	"crypto/sha256"
7	"fmt"
8	"io"
9
10	"golang.org/x/crypto/hkdf"
11)
12
13// TokenProtector is used to create and verify a token
14type tokenProtector interface {
15	// NewToken creates a new token
16	NewToken([]byte) ([]byte, error)
17	// DecodeToken decodes a token
18	DecodeToken([]byte) ([]byte, error)
19}
20
21const (
22	tokenSecretSize = 32
23	tokenNonceSize  = 32
24)
25
26// tokenProtector is used to create and verify a token
27type tokenProtectorImpl struct {
28	rand   io.Reader
29	secret []byte
30}
31
32// newTokenProtector creates a source for source address tokens
33func newTokenProtector(rand io.Reader) (tokenProtector, error) {
34	secret := make([]byte, tokenSecretSize)
35	if _, err := rand.Read(secret); err != nil {
36		return nil, err
37	}
38	return &tokenProtectorImpl{
39		rand:   rand,
40		secret: secret,
41	}, nil
42}
43
44// NewToken encodes data into a new token.
45func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) {
46	nonce := make([]byte, tokenNonceSize)
47	if _, err := s.rand.Read(nonce); err != nil {
48		return nil, err
49	}
50	aead, aeadNonce, err := s.createAEAD(nonce)
51	if err != nil {
52		return nil, err
53	}
54	return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil
55}
56
57// DecodeToken decodes a token.
58func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
59	if len(p) < tokenNonceSize {
60		return nil, fmt.Errorf("token too short: %d", len(p))
61	}
62	nonce := p[:tokenNonceSize]
63	aead, aeadNonce, err := s.createAEAD(nonce)
64	if err != nil {
65		return nil, err
66	}
67	return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil)
68}
69
70func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
71	h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source"))
72	key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
73	if _, err := io.ReadFull(h, key); err != nil {
74		return nil, nil, err
75	}
76	aeadNonce := make([]byte, 12)
77	if _, err := io.ReadFull(h, aeadNonce); err != nil {
78		return nil, nil, err
79	}
80	c, err := aes.NewCipher(key)
81	if err != nil {
82		return nil, nil, err
83	}
84	aead, err := cipher.NewGCM(c)
85	if err != nil {
86		return nil, nil, err
87	}
88	return aead, aeadNonce, nil
89}
90