1package jws
2
3import (
4	"bytes"
5	"encoding/json"
6	"net/http"
7	"strings"
8
9	"github.com/briankassouf/jose"
10	"github.com/briankassouf/jose/crypto"
11)
12
13// JWS implements a JWS per RFC 7515.
14type JWS interface {
15	// Payload Returns the payload.
16	Payload() interface{}
17
18	// SetPayload sets the payload with the given value.
19	SetPayload(p interface{})
20
21	// Protected returns the JWS' Protected Header.
22	Protected() jose.Protected
23
24	// ProtectedAt returns the JWS' Protected Header.
25	// i represents the index of the Protected Header.
26	ProtectedAt(i int) jose.Protected
27
28	// Header returns the JWS' unprotected Header.
29	Header() jose.Header
30
31	// HeaderAt returns the JWS' unprotected Header.
32	// i represents the index of the unprotected Header.
33	HeaderAt(i int) jose.Header
34
35	// Verify validates the current JWS' signature as-is. Refer to
36	// ValidateMulti for more information.
37	Verify(key interface{}, method crypto.SigningMethod) error
38
39	// ValidateMulti validates the current JWS' signature as-is. Since it's
40	// meant to be called after parsing a stream of bytes into a JWS, it
41	// shouldn't do any internal parsing like the Sign, Flat, Compact, or
42	// General methods do.
43	VerifyMulti(keys []interface{}, methods []crypto.SigningMethod, o *SigningOpts) error
44
45	// VerifyCallback validates the current JWS' signature as-is. It
46	// accepts a callback function that can be used to access header
47	// parameters to lookup needed information. For example, looking
48	// up the "kid" parameter.
49	// The return slice must be a slice of keys used in the verification
50	// of the JWS.
51	VerifyCallback(fn VerifyCallback, methods []crypto.SigningMethod, o *SigningOpts) error
52
53	// General serializes the JWS into its "general" form per
54	// https://tools.ietf.org/html/rfc7515#section-7.2.1
55	General(keys ...interface{}) ([]byte, error)
56
57	// Flat serializes the JWS to its "flattened" form per
58	// https://tools.ietf.org/html/rfc7515#section-7.2.2
59	Flat(key interface{}) ([]byte, error)
60
61	// Compact serializes the JWS into its "compact" form per
62	// https://tools.ietf.org/html/rfc7515#section-7.1
63	Compact(key interface{}) ([]byte, error)
64
65	// IsJWT returns true if the JWS is a JWT.
66	IsJWT() bool
67}
68
69// jws represents a specific jws.
70type jws struct {
71	payload *payload
72	plcache rawBase64
73	clean   bool
74
75	sb []sigHead
76
77	isJWT bool
78}
79
80// Payload returns the jws' payload.
81func (j *jws) Payload() interface{} {
82	return j.payload.v
83}
84
85// SetPayload sets the jws' raw, unexported payload.
86func (j *jws) SetPayload(val interface{}) {
87	j.payload.v = val
88}
89
90// Protected returns the JWS' Protected Header.
91func (j *jws) Protected() jose.Protected {
92	return j.sb[0].protected
93}
94
95// Protected returns the JWS' Protected Header.
96// i represents the index of the Protected Header.
97// Left empty, it defaults to 0.
98func (j *jws) ProtectedAt(i int) jose.Protected {
99	return j.sb[i].protected
100}
101
102// Header returns the JWS' unprotected Header.
103func (j *jws) Header() jose.Header {
104	return j.sb[0].unprotected
105}
106
107// HeaderAt returns the JWS' unprotected Header.
108// |i| is the index of the unprotected Header.
109func (j *jws) HeaderAt(i int) jose.Header {
110	return j.sb[i].unprotected
111}
112
113// sigHead represents the 'signatures' member of the jws' "general"
114// serialization form per
115// https://tools.ietf.org/html/rfc7515#section-7.2.1
116//
117// It's embedded inside the "flat" structure in order to properly
118// create the "flat" jws.
119type sigHead struct {
120	Protected   rawBase64        `json:"protected,omitempty"`
121	Unprotected rawBase64        `json:"header,omitempty"`
122	Signature   crypto.Signature `json:"signature"`
123
124	protected   jose.Protected
125	unprotected jose.Header
126	clean       bool
127
128	method crypto.SigningMethod
129}
130
131func (s *sigHead) unmarshal() error {
132	if err := s.protected.UnmarshalJSON(s.Protected); err != nil {
133		return err
134	}
135	return s.unprotected.UnmarshalJSON(s.Unprotected)
136}
137
138// New creates a JWS with the provided crypto.SigningMethods.
139func New(content interface{}, methods ...crypto.SigningMethod) JWS {
140	sb := make([]sigHead, len(methods))
141	for i := range methods {
142		sb[i] = sigHead{
143			protected: jose.Protected{
144				"alg": methods[i].Alg(),
145			},
146			unprotected: jose.Header{},
147			method:      methods[i],
148		}
149	}
150	return &jws{
151		payload: &payload{v: content},
152		sb:      sb,
153	}
154}
155
156func (s *sigHead) assignMethod(p jose.Protected) error {
157	alg, ok := p.Get("alg").(string)
158	if !ok {
159		return ErrNoAlgorithm
160	}
161
162	sm := GetSigningMethod(alg)
163	if sm == nil {
164		return ErrNoAlgorithm
165	}
166	s.method = sm
167	return nil
168}
169
170type generic struct {
171	Payload rawBase64 `json:"payload"`
172	sigHead
173	Signatures []sigHead `json:"signatures,omitempty"`
174}
175
176// Parse parses any of the three serialized jws forms into a physical
177// jws per https://tools.ietf.org/html/rfc7515#section-5.2
178//
179// It accepts a json.Unmarshaler in order to properly parse
180// the payload. In order to keep the caller from having to do extra
181// parsing of the payload, a json.Unmarshaler can be passed
182// which will be then to unmarshal the payload however the caller
183// wishes. Do note that if json.Unmarshal returns an error the
184// original payload will be used as if no json.Unmarshaler was
185// passed.
186//
187// Internally, Parse applies some heuristics and then calls either
188// ParseGeneral, ParseFlat, or ParseCompact.
189// It should only be called if, for whatever reason, you do not
190// know which form the serialized JWT is in.
191//
192// It cannot parse a JWT.
193func Parse(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
194	// Try and unmarshal into a generic struct that'll
195	// hopefully hold either of the two JSON serialization
196	// formats.
197	var g generic
198
199	// Not valid JSON. Let's try compact.
200	if err := json.Unmarshal(encoded, &g); err != nil {
201		return ParseCompact(encoded, u...)
202	}
203
204	if g.Signatures == nil {
205		return g.parseFlat(u...)
206	}
207	return g.parseGeneral(u...)
208}
209
210// ParseGeneral parses a jws serialized into its "general" form per
211// https://tools.ietf.org/html/rfc7515#section-7.2.1
212// into a physical jws per
213// https://tools.ietf.org/html/rfc7515#section-5.2
214//
215// For information on the json.Unmarshaler parameter, see Parse.
216func ParseGeneral(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
217	var g generic
218	if err := json.Unmarshal(encoded, &g); err != nil {
219		return nil, err
220	}
221	return g.parseGeneral(u...)
222}
223
224func (g *generic) parseGeneral(u ...json.Unmarshaler) (JWS, error) {
225
226	var p payload
227	if len(u) > 0 {
228		p.u = u[0]
229	}
230
231	if err := p.UnmarshalJSON(g.Payload); err != nil {
232		return nil, err
233	}
234
235	for i := range g.Signatures {
236		if err := g.Signatures[i].unmarshal(); err != nil {
237			return nil, err
238		}
239		if err := checkHeaders(jose.Header(g.Signatures[i].protected), g.Signatures[i].unprotected); err != nil {
240			return nil, err
241		}
242
243		if err := g.Signatures[i].assignMethod(g.Signatures[i].protected); err != nil {
244			return nil, err
245		}
246	}
247
248	g.clean = len(g.Signatures) != 0
249
250	return &jws{
251		payload: &p,
252		plcache: g.Payload,
253		clean:   true,
254		sb:      g.Signatures,
255	}, nil
256}
257
258// ParseFlat parses a jws serialized into its "flat" form per
259// https://tools.ietf.org/html/rfc7515#section-7.2.2
260// into a physical jws per
261// https://tools.ietf.org/html/rfc7515#section-5.2
262//
263// For information on the json.Unmarshaler parameter, see Parse.
264func ParseFlat(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
265	var g generic
266	if err := json.Unmarshal(encoded, &g); err != nil {
267		return nil, err
268	}
269	return g.parseFlat(u...)
270}
271
272func (g *generic) parseFlat(u ...json.Unmarshaler) (JWS, error) {
273
274	var p payload
275	if len(u) > 0 {
276		p.u = u[0]
277	}
278
279	if err := p.UnmarshalJSON(g.Payload); err != nil {
280		return nil, err
281	}
282
283	if err := g.sigHead.unmarshal(); err != nil {
284		return nil, err
285	}
286	g.sigHead.clean = true
287
288	if err := checkHeaders(jose.Header(g.sigHead.protected), g.sigHead.unprotected); err != nil {
289		return nil, err
290	}
291
292	if err := g.sigHead.assignMethod(g.sigHead.protected); err != nil {
293		return nil, err
294	}
295
296	return &jws{
297		payload: &p,
298		plcache: g.Payload,
299		clean:   true,
300		sb:      []sigHead{g.sigHead},
301	}, nil
302}
303
304// ParseCompact parses a jws serialized into its "compact" form per
305// https://tools.ietf.org/html/rfc7515#section-7.1
306// into a physical jws per
307// https://tools.ietf.org/html/rfc7515#section-5.2
308//
309// For information on the json.Unmarshaler parameter, see Parse.
310func ParseCompact(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
311	return parseCompact(encoded, false, u...)
312}
313
314func parseCompact(encoded []byte, jwt bool, u ...json.Unmarshaler) (*jws, error) {
315
316	// This section loosely follows
317	// https://tools.ietf.org/html/rfc7519#section-7.2
318	// because it's used to parse _both_ jws and JWTs.
319
320	parts := bytes.Split(encoded, []byte{'.'})
321	if len(parts) != 3 {
322		return nil, ErrNotCompact
323	}
324
325	var p jose.Protected
326	if err := p.UnmarshalJSON(parts[0]); err != nil {
327		return nil, err
328	}
329
330	s := sigHead{
331		Protected: parts[0],
332		protected: p,
333		Signature: parts[2],
334		clean:     true,
335	}
336
337	if err := s.assignMethod(p); err != nil {
338		return nil, err
339	}
340
341	var pl payload
342	if len(u) > 0 {
343		pl.u = u[0]
344	}
345
346	j := jws{
347		payload: &pl,
348		plcache: parts[1],
349		sb:      []sigHead{s},
350		isJWT:   jwt,
351	}
352
353	if err := j.payload.UnmarshalJSON(parts[1]); err != nil {
354		return nil, err
355	}
356
357	j.clean = true
358
359	if err := j.sb[0].Signature.UnmarshalJSON(parts[2]); err != nil {
360		return nil, err
361	}
362
363	// https://tools.ietf.org/html/rfc7519#section-7.2.8
364	cty, ok := p.Get("cty").(string)
365	if ok && cty == "JWT" {
366		return &j, ErrHoldsJWE
367	}
368	return &j, nil
369}
370
371var (
372	// JWSFormKey is the form "key" which should be used inside
373	// ParseFromRequest if the request is a multipart.Form.
374	JWSFormKey = "access_token"
375
376	// MaxMemory is maximum amount of memory which should be used
377	// inside ParseFromRequest while parsing the multipart.Form
378	// if the request is a multipart.Form.
379	MaxMemory int64 = 10e6
380)
381
382// Format specifies which "format" the JWS is in -- Flat, General,
383// or compact. Additionally, constants for JWT/Unknown are added.
384type Format uint8
385
386const (
387	// Unknown format.
388	Unknown Format = iota
389
390	// Flat format.
391	Flat
392
393	// General format.
394	General
395
396	// Compact format.
397	Compact
398)
399
400var parseJumpTable = [...]func([]byte, ...json.Unmarshaler) (JWS, error){
401	Unknown:  Parse,
402	Flat:     ParseFlat,
403	General:  ParseGeneral,
404	Compact:  ParseCompact,
405	1<<8 - 1: Parse, // Max uint8.
406}
407
408func init() {
409	for i := range parseJumpTable {
410		if parseJumpTable[i] == nil {
411			parseJumpTable[i] = Parse
412		}
413	}
414}
415
416func fromHeader(req *http.Request) ([]byte, bool) {
417	if ah := req.Header.Get("Authorization"); len(ah) > 7 && strings.EqualFold(ah[0:7], "BEARER ") {
418		return []byte(ah[7:]), true
419	}
420	return nil, false
421}
422
423func fromForm(req *http.Request) ([]byte, bool) {
424	if err := req.ParseMultipartForm(MaxMemory); err != nil {
425		return nil, false
426	}
427	if tokStr := req.Form.Get(JWSFormKey); tokStr != "" {
428		return []byte(tokStr), true
429	}
430	return nil, false
431}
432
433// ParseFromHeader tries to find the JWS in an http.Request header.
434func ParseFromHeader(req *http.Request, format Format, u ...json.Unmarshaler) (JWS, error) {
435	if b, ok := fromHeader(req); ok {
436		return parseJumpTable[format](b, u...)
437	}
438	return nil, ErrNoTokenInRequest
439}
440
441// ParseFromForm tries to find the JWS in an http.Request form request.
442func ParseFromForm(req *http.Request, format Format, u ...json.Unmarshaler) (JWS, error) {
443	if b, ok := fromForm(req); ok {
444		return parseJumpTable[format](b, u...)
445	}
446	return nil, ErrNoTokenInRequest
447}
448
449// ParseFromRequest tries to find the JWS in an http.Request.
450// This method will call ParseMultipartForm if there's no token in the header.
451func ParseFromRequest(req *http.Request, format Format, u ...json.Unmarshaler) (JWS, error) {
452	token, err := ParseFromHeader(req, format, u...)
453	if err == nil {
454		return token, nil
455	}
456
457	token, err = ParseFromForm(req, format, u...)
458	if err == nil {
459		return token, nil
460	}
461
462	return nil, err
463}
464
465// IgnoreDupes should be set to true if the internal duplicate header key check
466// should ignore duplicate Header keys instead of reporting an error when
467// duplicate Header keys are found.
468//
469// Note:
470//     Duplicate Header keys are defined in
471//     https://tools.ietf.org/html/rfc7515#section-5.2
472//     meaning keys that both the protected and unprotected
473//     Headers possess.
474var IgnoreDupes bool
475
476// checkHeaders returns an error per the constraints described in
477// IgnoreDupes' comment.
478func checkHeaders(a, b jose.Header) error {
479	if len(a)+len(b) == 0 {
480		return ErrTwoEmptyHeaders
481	}
482	for key := range a {
483		if b.Has(key) && !IgnoreDupes {
484			return ErrDuplicateHeaderParameter
485		}
486	}
487	return nil
488}
489
490var _ JWS = (*jws)(nil)
491