1package jwt 2 3import ( 4 "crypto/rsa" 5 "encoding/base64" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "sync" 10 "time" 11 12 "github.com/cristalhq/jwt/v3" 13) 14 15type ConnectToken struct { 16 // UserID tells library an ID of connecting user. 17 UserID string 18 // ExpireAt allows to set time in future when connection must be validated. 19 // Validation can be server-side or client-side using Refresh handler. 20 ExpireAt int64 21 // Info contains additional information about connection. It will be 22 // included into Join/Leave messages, into Presence information, also 23 // info becomes a part of published message if it was published from 24 // client directly. In some cases having additional info can be an 25 // overhead – but you are simply free to not use it. 26 Info []byte 27 // Channels slice contains channels to subscribe connection to on server-side. 28 Channels []string 29} 30 31type SubscribeToken struct { 32 // Client is a unique client ID string set to each connection on server. 33 // Will be compared with actual client ID. 34 Client string 35 // Channel client wants to subscribe. Will be compared with channel in 36 // subscribe command. 37 Channel string 38 // ExpireAt allows to set time in future when connection must be validated. 39 // Validation can be server-side or client-side using SubRefresh handler. 40 ExpireAt int64 41 // Info contains additional information about connection in channel. 42 // It will be included into Join/Leave messages, into Presence information, 43 // also channel info becomes a part of published message if it was published 44 // from subscribed client directly. 45 Info []byte 46 // ExpireTokenOnly used to indicate that library must only check token 47 // expiration but not turn on Subscription expiration checks on server side. 48 // This allows to implement one-time subscription tokens. 49 ExpireTokenOnly bool 50} 51 52type TokenVerifierConfig struct { 53 // HMACSecretKey is a secret key used to validate connection and subscription 54 // tokens generated using HMAC. Zero value means that HMAC tokens won't be allowed. 55 HMACSecretKey string 56 // RSAPublicKey is a public key used to validate connection and subscription 57 // tokens generated using RSA. Zero value means that RSA tokens won't be allowed. 58 RSAPublicKey *rsa.PublicKey 59} 60 61func NewTokenVerifier(config TokenVerifierConfig) *TokenVerifier { 62 verifier := &TokenVerifier{} 63 algorithms, err := newAlgorithms(config.HMACSecretKey, config.RSAPublicKey) 64 if err != nil { 65 panic(err) 66 } 67 verifier.algorithms = algorithms 68 return verifier 69} 70 71type TokenVerifier struct { 72 mu sync.RWMutex 73 algorithms *algorithms 74} 75 76var ( 77 ErrTokenExpired = errors.New("token expired") 78 errUnsupportedAlgorithm = errors.New("unsupported JWT algorithm") 79 errDisabledAlgorithm = errors.New("disabled JWT algorithm") 80) 81 82type connectTokenClaims struct { 83 Info json.RawMessage `json:"info,omitempty"` 84 Base64Info string `json:"b64info,omitempty"` 85 Channels []string `json:"channels,omitempty"` 86 jwt.StandardClaims 87} 88 89type subscribeTokenClaims struct { 90 Client string `json:"client,omitempty"` 91 Channel string `json:"channel,omitempty"` 92 Info json.RawMessage `json:"info,omitempty"` 93 Base64Info string `json:"b64info,omitempty"` 94 ExpireTokenOnly bool `json:"eto,omitempty"` 95 jwt.StandardClaims 96} 97 98type algorithms struct { 99 HS256 jwt.Verifier 100 HS384 jwt.Verifier 101 HS512 jwt.Verifier 102 RS256 jwt.Verifier 103 RS384 jwt.Verifier 104 RS512 jwt.Verifier 105} 106 107func newAlgorithms(tokenHMACSecretKey string, pubKey *rsa.PublicKey) (*algorithms, error) { 108 alg := &algorithms{} 109 110 // HMAC SHA. 111 if tokenHMACSecretKey != "" { 112 verifierHS256, err := jwt.NewVerifierHS(jwt.HS256, []byte(tokenHMACSecretKey)) 113 if err != nil { 114 return nil, err 115 } 116 verifierHS384, err := jwt.NewVerifierHS(jwt.HS384, []byte(tokenHMACSecretKey)) 117 if err != nil { 118 return nil, err 119 } 120 verifierHS512, err := jwt.NewVerifierHS(jwt.HS512, []byte(tokenHMACSecretKey)) 121 if err != nil { 122 return nil, err 123 } 124 alg.HS256 = verifierHS256 125 alg.HS384 = verifierHS384 126 alg.HS512 = verifierHS512 127 } 128 129 // RSA. 130 if pubKey != nil { 131 verifierRS256, err := jwt.NewVerifierRS(jwt.RS256, pubKey) 132 if err != nil { 133 return nil, err 134 } 135 verifierRS384, err := jwt.NewVerifierRS(jwt.RS384, pubKey) 136 if err != nil { 137 return nil, err 138 } 139 verifierRS512, err := jwt.NewVerifierRS(jwt.RS512, pubKey) 140 if err != nil { 141 return nil, err 142 } 143 alg.RS256 = verifierRS256 144 alg.RS384 = verifierRS384 145 alg.RS512 = verifierRS512 146 } 147 148 return alg, nil 149} 150 151func (s *algorithms) verify(token *jwt.Token) error { 152 var verifier jwt.Verifier 153 switch token.Header().Algorithm { 154 case jwt.HS256: 155 verifier = s.HS256 156 case jwt.HS384: 157 verifier = s.HS384 158 case jwt.HS512: 159 verifier = s.HS512 160 case jwt.RS256: 161 verifier = s.RS256 162 case jwt.RS384: 163 verifier = s.RS384 164 case jwt.RS512: 165 verifier = s.RS512 166 default: 167 return fmt.Errorf("%w: %s", errUnsupportedAlgorithm, string(token.Header().Algorithm)) 168 } 169 if verifier == nil { 170 return fmt.Errorf("%w: %s", errDisabledAlgorithm, string(token.Header().Algorithm)) 171 } 172 return verifier.Verify(token.Payload(), token.Signature()) 173} 174 175func (verifier *TokenVerifier) verifySignature(token *jwt.Token) error { 176 verifier.mu.RLock() 177 defer verifier.mu.RUnlock() 178 return verifier.algorithms.verify(token) 179} 180 181func (verifier *TokenVerifier) VerifyConnectToken(t string) (ConnectToken, error) { 182 token, err := jwt.Parse([]byte(t)) 183 if err != nil { 184 return ConnectToken{}, err 185 } 186 187 err = verifier.verifySignature(token) 188 if err != nil { 189 return ConnectToken{}, err 190 } 191 192 claims := &connectTokenClaims{} 193 err = json.Unmarshal(token.RawClaims(), claims) 194 if err != nil { 195 return ConnectToken{}, err 196 } 197 198 now := time.Now() 199 if !claims.IsValidExpiresAt(now) || !claims.IsValidNotBefore(now) { 200 return ConnectToken{}, ErrTokenExpired 201 } 202 203 ct := ConnectToken{ 204 UserID: claims.StandardClaims.Subject, 205 Info: claims.Info, 206 Channels: claims.Channels, 207 } 208 if claims.ExpiresAt != nil { 209 ct.ExpireAt = claims.ExpiresAt.Unix() 210 } 211 if claims.Base64Info != "" { 212 byteInfo, err := base64.StdEncoding.DecodeString(claims.Base64Info) 213 if err != nil { 214 return ConnectToken{}, err 215 } 216 ct.Info = byteInfo 217 } 218 return ct, nil 219} 220 221func (verifier *TokenVerifier) VerifySubscribeToken(t string) (SubscribeToken, error) { 222 token, err := jwt.Parse([]byte(t)) 223 if err != nil { 224 return SubscribeToken{}, err 225 } 226 227 err = verifier.verifySignature(token) 228 if err != nil { 229 return SubscribeToken{}, err 230 } 231 232 claims := &subscribeTokenClaims{} 233 err = json.Unmarshal(token.RawClaims(), claims) 234 if err != nil { 235 return SubscribeToken{}, err 236 } 237 238 now := time.Now() 239 if !claims.IsValidExpiresAt(now) || !claims.IsValidNotBefore(now) { 240 return SubscribeToken{}, ErrTokenExpired 241 } 242 243 st := SubscribeToken{ 244 Client: claims.Client, 245 Info: claims.Info, 246 Channel: claims.Channel, 247 ExpireTokenOnly: claims.ExpireTokenOnly, 248 } 249 if claims.ExpiresAt != nil { 250 st.ExpireAt = claims.ExpiresAt.Unix() 251 } 252 if claims.Base64Info != "" { 253 byteInfo, err := base64.StdEncoding.DecodeString(claims.Base64Info) 254 if err != nil { 255 return SubscribeToken{}, err 256 } 257 st.Info = byteInfo 258 } 259 return st, nil 260} 261 262func (verifier *TokenVerifier) Reload(config TokenVerifierConfig) error { 263 verifier.mu.Lock() 264 defer verifier.mu.Unlock() 265 alg, err := newAlgorithms(config.HMACSecretKey, config.RSAPublicKey) 266 if err != nil { 267 return err 268 } 269 verifier.algorithms = alg 270 return nil 271} 272