1package jwt
2
3import (
4	"crypto/rsa"
5	"encoding/base64"
6	"encoding/json"
7	"errors"
8	"fmt"
9	"sync"
10	"time"
11
12	"github.com/cristalhq/jwt/v3"
13)
14
15type ConnectToken struct {
16	// UserID tells library an ID of connecting user.
17	UserID string
18	// ExpireAt allows to set time in future when connection must be validated.
19	// Validation can be server-side or client-side using Refresh handler.
20	ExpireAt int64
21	// Info contains additional information about connection. It will be
22	// included into Join/Leave messages, into Presence information, also
23	// info becomes a part of published message if it was published from
24	// client directly. In some cases having additional info can be an
25	// overhead – but you are simply free to not use it.
26	Info []byte
27	// Channels slice contains channels to subscribe connection to on server-side.
28	Channels []string
29}
30
31type SubscribeToken struct {
32	// Client is a unique client ID string set to each connection on server.
33	// Will be compared with actual client ID.
34	Client string
35	// Channel client wants to subscribe. Will be compared with channel in
36	// subscribe command.
37	Channel string
38	// ExpireAt allows to set time in future when connection must be validated.
39	// Validation can be server-side or client-side using SubRefresh handler.
40	ExpireAt int64
41	// Info contains additional information about connection in channel.
42	// It will be included into Join/Leave messages, into Presence information,
43	// also channel info becomes a part of published message if it was published
44	// from subscribed client directly.
45	Info []byte
46	// ExpireTokenOnly used to indicate that library must only check token
47	// expiration but not turn on Subscription expiration checks on server side.
48	// This allows to implement one-time subscription tokens.
49	ExpireTokenOnly bool
50}
51
52type TokenVerifierConfig struct {
53	// HMACSecretKey is a secret key used to validate connection and subscription
54	// tokens generated using HMAC. Zero value means that HMAC tokens won't be allowed.
55	HMACSecretKey string
56	// RSAPublicKey is a public key used to validate connection and subscription
57	// tokens generated using RSA. Zero value means that RSA tokens won't be allowed.
58	RSAPublicKey *rsa.PublicKey
59}
60
61func NewTokenVerifier(config TokenVerifierConfig) *TokenVerifier {
62	verifier := &TokenVerifier{}
63	algorithms, err := newAlgorithms(config.HMACSecretKey, config.RSAPublicKey)
64	if err != nil {
65		panic(err)
66	}
67	verifier.algorithms = algorithms
68	return verifier
69}
70
71type TokenVerifier struct {
72	mu         sync.RWMutex
73	algorithms *algorithms
74}
75
76var (
77	ErrTokenExpired         = errors.New("token expired")
78	errUnsupportedAlgorithm = errors.New("unsupported JWT algorithm")
79	errDisabledAlgorithm    = errors.New("disabled JWT algorithm")
80)
81
82type connectTokenClaims struct {
83	Info       json.RawMessage `json:"info,omitempty"`
84	Base64Info string          `json:"b64info,omitempty"`
85	Channels   []string        `json:"channels,omitempty"`
86	jwt.StandardClaims
87}
88
89type subscribeTokenClaims struct {
90	Client          string          `json:"client,omitempty"`
91	Channel         string          `json:"channel,omitempty"`
92	Info            json.RawMessage `json:"info,omitempty"`
93	Base64Info      string          `json:"b64info,omitempty"`
94	ExpireTokenOnly bool            `json:"eto,omitempty"`
95	jwt.StandardClaims
96}
97
98type algorithms struct {
99	HS256 jwt.Verifier
100	HS384 jwt.Verifier
101	HS512 jwt.Verifier
102	RS256 jwt.Verifier
103	RS384 jwt.Verifier
104	RS512 jwt.Verifier
105}
106
107func newAlgorithms(tokenHMACSecretKey string, pubKey *rsa.PublicKey) (*algorithms, error) {
108	alg := &algorithms{}
109
110	// HMAC SHA.
111	if tokenHMACSecretKey != "" {
112		verifierHS256, err := jwt.NewVerifierHS(jwt.HS256, []byte(tokenHMACSecretKey))
113		if err != nil {
114			return nil, err
115		}
116		verifierHS384, err := jwt.NewVerifierHS(jwt.HS384, []byte(tokenHMACSecretKey))
117		if err != nil {
118			return nil, err
119		}
120		verifierHS512, err := jwt.NewVerifierHS(jwt.HS512, []byte(tokenHMACSecretKey))
121		if err != nil {
122			return nil, err
123		}
124		alg.HS256 = verifierHS256
125		alg.HS384 = verifierHS384
126		alg.HS512 = verifierHS512
127	}
128
129	// RSA.
130	if pubKey != nil {
131		verifierRS256, err := jwt.NewVerifierRS(jwt.RS256, pubKey)
132		if err != nil {
133			return nil, err
134		}
135		verifierRS384, err := jwt.NewVerifierRS(jwt.RS384, pubKey)
136		if err != nil {
137			return nil, err
138		}
139		verifierRS512, err := jwt.NewVerifierRS(jwt.RS512, pubKey)
140		if err != nil {
141			return nil, err
142		}
143		alg.RS256 = verifierRS256
144		alg.RS384 = verifierRS384
145		alg.RS512 = verifierRS512
146	}
147
148	return alg, nil
149}
150
151func (s *algorithms) verify(token *jwt.Token) error {
152	var verifier jwt.Verifier
153	switch token.Header().Algorithm {
154	case jwt.HS256:
155		verifier = s.HS256
156	case jwt.HS384:
157		verifier = s.HS384
158	case jwt.HS512:
159		verifier = s.HS512
160	case jwt.RS256:
161		verifier = s.RS256
162	case jwt.RS384:
163		verifier = s.RS384
164	case jwt.RS512:
165		verifier = s.RS512
166	default:
167		return fmt.Errorf("%w: %s", errUnsupportedAlgorithm, string(token.Header().Algorithm))
168	}
169	if verifier == nil {
170		return fmt.Errorf("%w: %s", errDisabledAlgorithm, string(token.Header().Algorithm))
171	}
172	return verifier.Verify(token.Payload(), token.Signature())
173}
174
175func (verifier *TokenVerifier) verifySignature(token *jwt.Token) error {
176	verifier.mu.RLock()
177	defer verifier.mu.RUnlock()
178	return verifier.algorithms.verify(token)
179}
180
181func (verifier *TokenVerifier) VerifyConnectToken(t string) (ConnectToken, error) {
182	token, err := jwt.Parse([]byte(t))
183	if err != nil {
184		return ConnectToken{}, err
185	}
186
187	err = verifier.verifySignature(token)
188	if err != nil {
189		return ConnectToken{}, err
190	}
191
192	claims := &connectTokenClaims{}
193	err = json.Unmarshal(token.RawClaims(), claims)
194	if err != nil {
195		return ConnectToken{}, err
196	}
197
198	now := time.Now()
199	if !claims.IsValidExpiresAt(now) || !claims.IsValidNotBefore(now) {
200		return ConnectToken{}, ErrTokenExpired
201	}
202
203	ct := ConnectToken{
204		UserID:   claims.StandardClaims.Subject,
205		Info:     claims.Info,
206		Channels: claims.Channels,
207	}
208	if claims.ExpiresAt != nil {
209		ct.ExpireAt = claims.ExpiresAt.Unix()
210	}
211	if claims.Base64Info != "" {
212		byteInfo, err := base64.StdEncoding.DecodeString(claims.Base64Info)
213		if err != nil {
214			return ConnectToken{}, err
215		}
216		ct.Info = byteInfo
217	}
218	return ct, nil
219}
220
221func (verifier *TokenVerifier) VerifySubscribeToken(t string) (SubscribeToken, error) {
222	token, err := jwt.Parse([]byte(t))
223	if err != nil {
224		return SubscribeToken{}, err
225	}
226
227	err = verifier.verifySignature(token)
228	if err != nil {
229		return SubscribeToken{}, err
230	}
231
232	claims := &subscribeTokenClaims{}
233	err = json.Unmarshal(token.RawClaims(), claims)
234	if err != nil {
235		return SubscribeToken{}, err
236	}
237
238	now := time.Now()
239	if !claims.IsValidExpiresAt(now) || !claims.IsValidNotBefore(now) {
240		return SubscribeToken{}, ErrTokenExpired
241	}
242
243	st := SubscribeToken{
244		Client:          claims.Client,
245		Info:            claims.Info,
246		Channel:         claims.Channel,
247		ExpireTokenOnly: claims.ExpireTokenOnly,
248	}
249	if claims.ExpiresAt != nil {
250		st.ExpireAt = claims.ExpiresAt.Unix()
251	}
252	if claims.Base64Info != "" {
253		byteInfo, err := base64.StdEncoding.DecodeString(claims.Base64Info)
254		if err != nil {
255			return SubscribeToken{}, err
256		}
257		st.Info = byteInfo
258	}
259	return st, nil
260}
261
262func (verifier *TokenVerifier) Reload(config TokenVerifierConfig) error {
263	verifier.mu.Lock()
264	defer verifier.mu.Unlock()
265	alg, err := newAlgorithms(config.HMACSecretKey, config.RSAPublicKey)
266	if err != nil {
267		return err
268	}
269	verifier.algorithms = alg
270	return nil
271}
272