1package handshake 2 3import ( 4 "crypto/aes" 5 "crypto/cipher" 6 "crypto/sha256" 7 "fmt" 8 "io" 9 10 "golang.org/x/crypto/hkdf" 11) 12 13// TokenProtector is used to create and verify a token 14type tokenProtector interface { 15 // NewToken creates a new token 16 NewToken([]byte) ([]byte, error) 17 // DecodeToken decodes a token 18 DecodeToken([]byte) ([]byte, error) 19} 20 21const ( 22 tokenSecretSize = 32 23 tokenNonceSize = 32 24) 25 26// tokenProtector is used to create and verify a token 27type tokenProtectorImpl struct { 28 rand io.Reader 29 secret []byte 30} 31 32// newTokenProtector creates a source for source address tokens 33func newTokenProtector(rand io.Reader) (tokenProtector, error) { 34 secret := make([]byte, tokenSecretSize) 35 if _, err := rand.Read(secret); err != nil { 36 return nil, err 37 } 38 return &tokenProtectorImpl{ 39 rand: rand, 40 secret: secret, 41 }, nil 42} 43 44// NewToken encodes data into a new token. 45func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) { 46 nonce := make([]byte, tokenNonceSize) 47 if _, err := s.rand.Read(nonce); err != nil { 48 return nil, err 49 } 50 aead, aeadNonce, err := s.createAEAD(nonce) 51 if err != nil { 52 return nil, err 53 } 54 return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil 55} 56 57// DecodeToken decodes a token. 58func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) { 59 if len(p) < tokenNonceSize { 60 return nil, fmt.Errorf("token too short: %d", len(p)) 61 } 62 nonce := p[:tokenNonceSize] 63 aead, aeadNonce, err := s.createAEAD(nonce) 64 if err != nil { 65 return nil, err 66 } 67 return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil) 68} 69 70func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { 71 h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source")) 72 key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 73 if _, err := io.ReadFull(h, key); err != nil { 74 return nil, nil, err 75 } 76 aeadNonce := make([]byte, 12) 77 if _, err := io.ReadFull(h, aeadNonce); err != nil { 78 return nil, nil, err 79 } 80 c, err := aes.NewCipher(key) 81 if err != nil { 82 return nil, nil, err 83 } 84 aead, err := cipher.NewGCM(c) 85 if err != nil { 86 return nil, nil, err 87 } 88 return aead, aeadNonce, nil 89} 90