1// This file is auto-generated by jwt/internal/cmd/gentoken/main.go. DO NOT EDIT
2
3package jwt
4
5import (
6	"bytes"
7	"context"
8	"sort"
9	"sync"
10	"time"
11
12	"github.com/lestrrat-go/iter/mapiter"
13	"github.com/lestrrat-go/jwx/internal/base64"
14	"github.com/lestrrat-go/jwx/internal/iter"
15	"github.com/lestrrat-go/jwx/internal/json"
16	"github.com/lestrrat-go/jwx/internal/pool"
17	"github.com/lestrrat-go/jwx/jwt/internal/types"
18	"github.com/pkg/errors"
19)
20
21const (
22	AudienceKey   = "aud"
23	ExpirationKey = "exp"
24	IssuedAtKey   = "iat"
25	IssuerKey     = "iss"
26	JwtIDKey      = "jti"
27	NotBeforeKey  = "nbf"
28	SubjectKey    = "sub"
29)
30
31// Token represents a generic JWT token.
32// which are type-aware (to an extent). Other claims may be accessed via the `Get`/`Set`
33// methods but their types are not taken into consideration at all. If you have non-standard
34// claims that you must frequently access, consider creating accessors functions
35// like the following
36//
37// func SetFoo(tok jwt.Token) error
38// func GetFoo(tok jwt.Token) (*Customtyp, error)
39//
40// Embedding jwt.Token into another struct is not recommended, because
41// jwt.Token needs to handle private claims, and this really does not
42// work well when it is embedded in other structure
43type Token interface {
44	Audience() []string
45	Expiration() time.Time
46	IssuedAt() time.Time
47	Issuer() string
48	JwtID() string
49	NotBefore() time.Time
50	Subject() string
51	PrivateClaims() map[string]interface{}
52	Get(string) (interface{}, bool)
53	Set(string, interface{}) error
54	Remove(string) error
55	Clone() (Token, error)
56	Iterate(context.Context) Iterator
57	Walk(context.Context, Visitor) error
58	AsMap(context.Context) (map[string]interface{}, error)
59}
60type stdToken struct {
61	mu            *sync.RWMutex
62	dc            DecodeCtx          // per-object context for decoding
63	audience      types.StringList   // https://tools.ietf.org/html/rfc7519#section-4.1.3
64	expiration    *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.4
65	issuedAt      *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.6
66	issuer        *string            // https://tools.ietf.org/html/rfc7519#section-4.1.1
67	jwtID         *string            // https://tools.ietf.org/html/rfc7519#section-4.1.7
68	notBefore     *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.5
69	subject       *string            // https://tools.ietf.org/html/rfc7519#section-4.1.2
70	privateClaims map[string]interface{}
71}
72
73// New creates a standard token, with minimal knowledge of
74// possible claims. Standard claims include"aud", "exp", "iat", "iss", "jti", "nbf" and "sub".
75// Convenience accessors are provided for these standard claims
76func New() Token {
77	return &stdToken{
78		mu:            &sync.RWMutex{},
79		privateClaims: make(map[string]interface{}),
80	}
81}
82
83func (t *stdToken) Get(name string) (interface{}, bool) {
84	t.mu.RLock()
85	defer t.mu.RUnlock()
86	switch name {
87	case AudienceKey:
88		if t.audience == nil {
89			return nil, false
90		}
91		v := t.audience.Get()
92		return v, true
93	case ExpirationKey:
94		if t.expiration == nil {
95			return nil, false
96		}
97		v := t.expiration.Get()
98		return v, true
99	case IssuedAtKey:
100		if t.issuedAt == nil {
101			return nil, false
102		}
103		v := t.issuedAt.Get()
104		return v, true
105	case IssuerKey:
106		if t.issuer == nil {
107			return nil, false
108		}
109		v := *(t.issuer)
110		return v, true
111	case JwtIDKey:
112		if t.jwtID == nil {
113			return nil, false
114		}
115		v := *(t.jwtID)
116		return v, true
117	case NotBeforeKey:
118		if t.notBefore == nil {
119			return nil, false
120		}
121		v := t.notBefore.Get()
122		return v, true
123	case SubjectKey:
124		if t.subject == nil {
125			return nil, false
126		}
127		v := *(t.subject)
128		return v, true
129	default:
130		v, ok := t.privateClaims[name]
131		return v, ok
132	}
133}
134
135func (t *stdToken) Remove(key string) error {
136	t.mu.Lock()
137	defer t.mu.Unlock()
138	switch key {
139	case AudienceKey:
140		t.audience = nil
141	case ExpirationKey:
142		t.expiration = nil
143	case IssuedAtKey:
144		t.issuedAt = nil
145	case IssuerKey:
146		t.issuer = nil
147	case JwtIDKey:
148		t.jwtID = nil
149	case NotBeforeKey:
150		t.notBefore = nil
151	case SubjectKey:
152		t.subject = nil
153	default:
154		delete(t.privateClaims, key)
155	}
156	return nil
157}
158
159func (t *stdToken) Set(name string, value interface{}) error {
160	t.mu.Lock()
161	defer t.mu.Unlock()
162	return t.setNoLock(name, value)
163}
164
165func (t *stdToken) DecodeCtx() DecodeCtx {
166	t.mu.RLock()
167	defer t.mu.RUnlock()
168	return t.dc
169}
170
171func (t *stdToken) SetDecodeCtx(v DecodeCtx) {
172	t.mu.Lock()
173	defer t.mu.Unlock()
174	t.dc = v
175}
176
177func (t *stdToken) setNoLock(name string, value interface{}) error {
178	switch name {
179	case AudienceKey:
180		var acceptor types.StringList
181		if err := acceptor.Accept(value); err != nil {
182			return errors.Wrapf(err, `invalid value for %s key`, AudienceKey)
183		}
184		t.audience = acceptor
185		return nil
186	case ExpirationKey:
187		var acceptor types.NumericDate
188		if err := acceptor.Accept(value); err != nil {
189			return errors.Wrapf(err, `invalid value for %s key`, ExpirationKey)
190		}
191		t.expiration = &acceptor
192		return nil
193	case IssuedAtKey:
194		var acceptor types.NumericDate
195		if err := acceptor.Accept(value); err != nil {
196			return errors.Wrapf(err, `invalid value for %s key`, IssuedAtKey)
197		}
198		t.issuedAt = &acceptor
199		return nil
200	case IssuerKey:
201		if v, ok := value.(string); ok {
202			t.issuer = &v
203			return nil
204		}
205		return errors.Errorf(`invalid value for %s key: %T`, IssuerKey, value)
206	case JwtIDKey:
207		if v, ok := value.(string); ok {
208			t.jwtID = &v
209			return nil
210		}
211		return errors.Errorf(`invalid value for %s key: %T`, JwtIDKey, value)
212	case NotBeforeKey:
213		var acceptor types.NumericDate
214		if err := acceptor.Accept(value); err != nil {
215			return errors.Wrapf(err, `invalid value for %s key`, NotBeforeKey)
216		}
217		t.notBefore = &acceptor
218		return nil
219	case SubjectKey:
220		if v, ok := value.(string); ok {
221			t.subject = &v
222			return nil
223		}
224		return errors.Errorf(`invalid value for %s key: %T`, SubjectKey, value)
225	default:
226		if t.privateClaims == nil {
227			t.privateClaims = map[string]interface{}{}
228		}
229		t.privateClaims[name] = value
230	}
231	return nil
232}
233
234func (t *stdToken) Audience() []string {
235	t.mu.RLock()
236	defer t.mu.RUnlock()
237	if t.audience != nil {
238		return t.audience.Get()
239	}
240	return nil
241}
242
243func (t *stdToken) Expiration() time.Time {
244	t.mu.RLock()
245	defer t.mu.RUnlock()
246	if t.expiration != nil {
247		return t.expiration.Get()
248	}
249	return time.Time{}
250}
251
252func (t *stdToken) IssuedAt() time.Time {
253	t.mu.RLock()
254	defer t.mu.RUnlock()
255	if t.issuedAt != nil {
256		return t.issuedAt.Get()
257	}
258	return time.Time{}
259}
260
261func (t *stdToken) Issuer() string {
262	t.mu.RLock()
263	defer t.mu.RUnlock()
264	if t.issuer != nil {
265		return *(t.issuer)
266	}
267	return ""
268}
269
270func (t *stdToken) JwtID() string {
271	t.mu.RLock()
272	defer t.mu.RUnlock()
273	if t.jwtID != nil {
274		return *(t.jwtID)
275	}
276	return ""
277}
278
279func (t *stdToken) NotBefore() time.Time {
280	t.mu.RLock()
281	defer t.mu.RUnlock()
282	if t.notBefore != nil {
283		return t.notBefore.Get()
284	}
285	return time.Time{}
286}
287
288func (t *stdToken) Subject() string {
289	t.mu.RLock()
290	defer t.mu.RUnlock()
291	if t.subject != nil {
292		return *(t.subject)
293	}
294	return ""
295}
296
297func (t *stdToken) PrivateClaims() map[string]interface{} {
298	t.mu.RLock()
299	defer t.mu.RUnlock()
300	return t.privateClaims
301}
302
303func (t *stdToken) makePairs() []*ClaimPair {
304	t.mu.RLock()
305	defer t.mu.RUnlock()
306
307	pairs := make([]*ClaimPair, 0, 7)
308	if t.audience != nil {
309		v := t.audience.Get()
310		pairs = append(pairs, &ClaimPair{Key: AudienceKey, Value: v})
311	}
312	if t.expiration != nil {
313		v := t.expiration.Get()
314		pairs = append(pairs, &ClaimPair{Key: ExpirationKey, Value: v})
315	}
316	if t.issuedAt != nil {
317		v := t.issuedAt.Get()
318		pairs = append(pairs, &ClaimPair{Key: IssuedAtKey, Value: v})
319	}
320	if t.issuer != nil {
321		v := *(t.issuer)
322		pairs = append(pairs, &ClaimPair{Key: IssuerKey, Value: v})
323	}
324	if t.jwtID != nil {
325		v := *(t.jwtID)
326		pairs = append(pairs, &ClaimPair{Key: JwtIDKey, Value: v})
327	}
328	if t.notBefore != nil {
329		v := t.notBefore.Get()
330		pairs = append(pairs, &ClaimPair{Key: NotBeforeKey, Value: v})
331	}
332	if t.subject != nil {
333		v := *(t.subject)
334		pairs = append(pairs, &ClaimPair{Key: SubjectKey, Value: v})
335	}
336	for k, v := range t.privateClaims {
337		pairs = append(pairs, &ClaimPair{Key: k, Value: v})
338	}
339	sort.Slice(pairs, func(i, j int) bool {
340		return pairs[i].Key.(string) < pairs[j].Key.(string)
341	})
342	return pairs
343}
344
345func (t *stdToken) UnmarshalJSON(buf []byte) error {
346	t.mu.Lock()
347	defer t.mu.Unlock()
348	t.audience = nil
349	t.expiration = nil
350	t.issuedAt = nil
351	t.issuer = nil
352	t.jwtID = nil
353	t.notBefore = nil
354	t.subject = nil
355	dec := json.NewDecoder(bytes.NewReader(buf))
356LOOP:
357	for {
358		tok, err := dec.Token()
359		if err != nil {
360			return errors.Wrap(err, `error reading token`)
361		}
362		switch tok := tok.(type) {
363		case json.Delim:
364			// Assuming we're doing everything correctly, we should ONLY
365			// get either '{' or '}' here.
366			if tok == '}' { // End of object
367				break LOOP
368			} else if tok != '{' {
369				return errors.Errorf(`expected '{', but got '%c'`, tok)
370			}
371		case string: // Objects can only have string keys
372			switch tok {
373			case AudienceKey:
374				var decoded types.StringList
375				if err := dec.Decode(&decoded); err != nil {
376					return errors.Wrapf(err, `failed to decode value for key %s`, AudienceKey)
377				}
378				t.audience = decoded
379			case ExpirationKey:
380				var decoded types.NumericDate
381				if err := dec.Decode(&decoded); err != nil {
382					return errors.Wrapf(err, `failed to decode value for key %s`, ExpirationKey)
383				}
384				t.expiration = &decoded
385			case IssuedAtKey:
386				var decoded types.NumericDate
387				if err := dec.Decode(&decoded); err != nil {
388					return errors.Wrapf(err, `failed to decode value for key %s`, IssuedAtKey)
389				}
390				t.issuedAt = &decoded
391			case IssuerKey:
392				if err := json.AssignNextStringToken(&t.issuer, dec); err != nil {
393					return errors.Wrapf(err, `failed to decode value for key %s`, IssuerKey)
394				}
395			case JwtIDKey:
396				if err := json.AssignNextStringToken(&t.jwtID, dec); err != nil {
397					return errors.Wrapf(err, `failed to decode value for key %s`, JwtIDKey)
398				}
399			case NotBeforeKey:
400				var decoded types.NumericDate
401				if err := dec.Decode(&decoded); err != nil {
402					return errors.Wrapf(err, `failed to decode value for key %s`, NotBeforeKey)
403				}
404				t.notBefore = &decoded
405			case SubjectKey:
406				if err := json.AssignNextStringToken(&t.subject, dec); err != nil {
407					return errors.Wrapf(err, `failed to decode value for key %s`, SubjectKey)
408				}
409			default:
410				if dc := t.dc; dc != nil {
411					if localReg := dc.Registry(); localReg != nil {
412						decoded, err := localReg.Decode(dec, tok)
413						if err == nil {
414							t.setNoLock(tok, decoded)
415							continue
416						}
417					}
418				}
419				decoded, err := registry.Decode(dec, tok)
420				if err == nil {
421					t.setNoLock(tok, decoded)
422					continue
423				}
424				return errors.Wrapf(err, `could not decode field %s`, tok)
425			}
426		default:
427			return errors.Errorf(`invalid token %T`, tok)
428		}
429	}
430	return nil
431}
432
433func (t stdToken) MarshalJSON() ([]byte, error) {
434	t.mu.RLock()
435	defer t.mu.RUnlock()
436	buf := pool.GetBytesBuffer()
437	defer pool.ReleaseBytesBuffer(buf)
438	buf.WriteByte('{')
439	enc := json.NewEncoder(buf)
440	for i, pair := range t.makePairs() {
441		f := pair.Key.(string)
442		if i > 0 {
443			buf.WriteByte(',')
444		}
445		buf.WriteRune('"')
446		buf.WriteString(f)
447		buf.WriteString(`":`)
448		switch f {
449		case AudienceKey:
450			if err := json.EncodeAudience(enc, pair.Value.([]string)); err != nil {
451				return nil, errors.Wrap(err, `failed to encode "aud"`)
452			}
453			continue
454		case ExpirationKey, IssuedAtKey, NotBeforeKey:
455			enc.Encode(pair.Value.(time.Time).Unix())
456			continue
457		}
458		switch v := pair.Value.(type) {
459		case []byte:
460			buf.WriteRune('"')
461			buf.WriteString(base64.EncodeToString(v))
462			buf.WriteRune('"')
463		default:
464			if err := enc.Encode(v); err != nil {
465				return nil, errors.Wrapf(err, `failed to marshal field %s`, f)
466			}
467			buf.Truncate(buf.Len() - 1)
468		}
469	}
470	buf.WriteByte('}')
471	ret := make([]byte, buf.Len())
472	copy(ret, buf.Bytes())
473	return ret, nil
474}
475
476func (t *stdToken) Iterate(ctx context.Context) Iterator {
477	pairs := t.makePairs()
478	ch := make(chan *ClaimPair, len(pairs))
479	go func(ctx context.Context, ch chan *ClaimPair, pairs []*ClaimPair) {
480		defer close(ch)
481		for _, pair := range pairs {
482			select {
483			case <-ctx.Done():
484				return
485			case ch <- pair:
486			}
487		}
488	}(ctx, ch, pairs)
489	return mapiter.New(ch)
490}
491
492func (t *stdToken) Walk(ctx context.Context, visitor Visitor) error {
493	return iter.WalkMap(ctx, t, visitor)
494}
495
496func (t *stdToken) AsMap(ctx context.Context) (map[string]interface{}, error) {
497	return iter.AsMap(ctx, t)
498}
499